Commit d7395789 authored by danai-antoniou's avatar danai-antoniou
Browse files

Merge branch 'master' of...

Merge branch 'master' of https://github.com/danai-antoniou/pytorch-transformers into add-duplicate-tokens-error
parents 2e6797cc 391db836
...@@ -3,36 +3,37 @@ def main(): ...@@ -3,36 +3,37 @@ def main():
import sys import sys
if (len(sys.argv) < 4 or len(sys.argv) > 6) or sys.argv[1] not in ["bert", "gpt", "transfo_xl", "gpt2", "xlnet", "xlm"]: if (len(sys.argv) < 4 or len(sys.argv) > 6) or sys.argv[1] not in ["bert", "gpt", "transfo_xl", "gpt2", "xlnet", "xlm"]:
print( print(
"Should be used as one of: \n" "This command line utility let you convert original (author released) model checkpoint to pytorch.\n"
">> pytorch_transformers bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT, \n" "It should be used as one of: \n"
">> pytorch_transformers gpt OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG], \n" ">> transformers bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT, \n"
">> pytorch_transformers transfo_xl TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG] or \n" ">> transformers gpt OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG], \n"
">> pytorch_transformers gpt2 TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG] or \n" ">> transformers transfo_xl TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG] or \n"
">> pytorch_transformers xlnet TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT [FINETUNING_TASK_NAME] or \n" ">> transformers gpt2 TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG] or \n"
">> pytorch_transformers xlm XLM_CHECKPOINT_PATH PYTORCH_DUMP_OUTPUT") ">> transformers xlnet TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT [FINETUNING_TASK_NAME] or \n"
">> transformers xlm XLM_CHECKPOINT_PATH PYTORCH_DUMP_OUTPUT")
else: else:
if sys.argv[1] == "bert": if sys.argv[1] == "bert":
try: try:
from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch from .convert_bert_original_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch
except ImportError: except ImportError:
print("pytorch_transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " print("transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
"In that case, it requires TensorFlow to be installed. Please see " "In that case, it requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions.") "https://www.tensorflow.org/install/ for installation instructions.")
raise raise
if len(sys.argv) != 5: if len(sys.argv) != 5:
# pylint: disable=line-too-long # pylint: disable=line-too-long
print("Should be used as `pytorch_transformers bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`") print("Should be used as `transformers bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`")
else: else:
PYTORCH_DUMP_OUTPUT = sys.argv.pop() PYTORCH_DUMP_OUTPUT = sys.argv.pop()
TF_CONFIG = sys.argv.pop() TF_CONFIG = sys.argv.pop()
TF_CHECKPOINT = sys.argv.pop() TF_CHECKPOINT = sys.argv.pop()
convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
elif sys.argv[1] == "gpt": elif sys.argv[1] == "gpt":
from .convert_openai_checkpoint_to_pytorch import convert_openai_checkpoint_to_pytorch from .convert_openai_original_tf_checkpoint_to_pytorch import convert_openai_checkpoint_to_pytorch
if len(sys.argv) < 4 or len(sys.argv) > 5: if len(sys.argv) < 4 or len(sys.argv) > 5:
# pylint: disable=line-too-long # pylint: disable=line-too-long
print("Should be used as `pytorch_transformers gpt OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`") print("Should be used as `transformers gpt OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`")
else: else:
OPENAI_GPT_CHECKPOINT_FOLDER_PATH = sys.argv[2] OPENAI_GPT_CHECKPOINT_FOLDER_PATH = sys.argv[2]
PYTORCH_DUMP_OUTPUT = sys.argv[3] PYTORCH_DUMP_OUTPUT = sys.argv[3]
...@@ -45,15 +46,15 @@ def main(): ...@@ -45,15 +46,15 @@ def main():
PYTORCH_DUMP_OUTPUT) PYTORCH_DUMP_OUTPUT)
elif sys.argv[1] == "transfo_xl": elif sys.argv[1] == "transfo_xl":
try: try:
from .convert_transfo_xl_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch from .convert_transfo_xl_original_tf_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch
except ImportError: except ImportError:
print("pytorch_transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " print("transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
"In that case, it requires TensorFlow to be installed. Please see " "In that case, it requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions.") "https://www.tensorflow.org/install/ for installation instructions.")
raise raise
if len(sys.argv) < 4 or len(sys.argv) > 5: if len(sys.argv) < 4 or len(sys.argv) > 5:
# pylint: disable=line-too-long # pylint: disable=line-too-long
print("Should be used as `pytorch_transformers transfo_xl TF_CHECKPOINT/TF_DATASET_FILE PYTORCH_DUMP_OUTPUT [TF_CONFIG]`") print("Should be used as `transformers transfo_xl TF_CHECKPOINT/TF_DATASET_FILE PYTORCH_DUMP_OUTPUT [TF_CONFIG]`")
else: else:
if 'ckpt' in sys.argv[2].lower(): if 'ckpt' in sys.argv[2].lower():
TF_CHECKPOINT = sys.argv[2] TF_CHECKPOINT = sys.argv[2]
...@@ -69,16 +70,16 @@ def main(): ...@@ -69,16 +70,16 @@ def main():
convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT, TF_DATASET_FILE) convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT, TF_DATASET_FILE)
elif sys.argv[1] == "gpt2": elif sys.argv[1] == "gpt2":
try: try:
from .convert_gpt2_checkpoint_to_pytorch import convert_gpt2_checkpoint_to_pytorch from .convert_gpt2_original_tf_checkpoint_to_pytorch import convert_gpt2_checkpoint_to_pytorch
except ImportError: except ImportError:
print("pytorch_transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " print("transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
"In that case, it requires TensorFlow to be installed. Please see " "In that case, it requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions.") "https://www.tensorflow.org/install/ for installation instructions.")
raise raise
if len(sys.argv) < 4 or len(sys.argv) > 5: if len(sys.argv) < 4 or len(sys.argv) > 5:
# pylint: disable=line-too-long # pylint: disable=line-too-long
print("Should be used as `pytorch_transformers gpt2 TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [TF_CONFIG]`") print("Should be used as `transformers gpt2 TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [TF_CONFIG]`")
else: else:
TF_CHECKPOINT = sys.argv[2] TF_CHECKPOINT = sys.argv[2]
PYTORCH_DUMP_OUTPUT = sys.argv[3] PYTORCH_DUMP_OUTPUT = sys.argv[3]
...@@ -89,16 +90,16 @@ def main(): ...@@ -89,16 +90,16 @@ def main():
convert_gpt2_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) convert_gpt2_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
elif sys.argv[1] == "xlnet": elif sys.argv[1] == "xlnet":
try: try:
from .convert_xlnet_checkpoint_to_pytorch import convert_xlnet_checkpoint_to_pytorch from .convert_xlnet_original_tf_checkpoint_to_pytorch import convert_xlnet_checkpoint_to_pytorch
except ImportError: except ImportError:
print("pytorch_transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " print("transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
"In that case, it requires TensorFlow to be installed. Please see " "In that case, it requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions.") "https://www.tensorflow.org/install/ for installation instructions.")
raise raise
if len(sys.argv) < 5 or len(sys.argv) > 6: if len(sys.argv) < 5 or len(sys.argv) > 6:
# pylint: disable=line-too-long # pylint: disable=line-too-long
print("Should be used as `pytorch_transformers xlnet TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT [FINETUNING_TASK_NAME]`") print("Should be used as `transformers xlnet TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT [FINETUNING_TASK_NAME]`")
else: else:
TF_CHECKPOINT = sys.argv[2] TF_CHECKPOINT = sys.argv[2]
TF_CONFIG = sys.argv[3] TF_CONFIG = sys.argv[3]
...@@ -113,11 +114,11 @@ def main(): ...@@ -113,11 +114,11 @@ def main():
PYTORCH_DUMP_OUTPUT, PYTORCH_DUMP_OUTPUT,
FINETUNING_TASK) FINETUNING_TASK)
elif sys.argv[1] == "xlm": elif sys.argv[1] == "xlm":
from .convert_xlm_checkpoint_to_pytorch import convert_xlm_checkpoint_to_pytorch from .convert_xlm_original_pytorch_checkpoint_to_pytorch import convert_xlm_checkpoint_to_pytorch
if len(sys.argv) != 4: if len(sys.argv) != 4:
# pylint: disable=line-too-long # pylint: disable=line-too-long
print("Should be used as `pytorch_transformers xlm XLM_CHECKPOINT_PATH PYTORCH_DUMP_OUTPUT`") print("Should be used as `transformers xlm XLM_CHECKPOINT_PATH PYTORCH_DUMP_OUTPUT`")
else: else:
XLM_CHECKPOINT_PATH = sys.argv[2] XLM_CHECKPOINT_PATH = sys.argv[2]
PYTORCH_DUMP_OUTPUT = sys.argv[3] PYTORCH_DUMP_OUTPUT = sys.argv[3]
......
...@@ -31,7 +31,7 @@ logger = logging.getLogger(__name__) ...@@ -31,7 +31,7 @@ logger = logging.getLogger(__name__)
class AutoConfig(object): class AutoConfig(object):
r""":class:`~pytorch_transformers.AutoConfig` is a generic configuration class r""":class:`~transformers.AutoConfig` is a generic configuration class
that will be instantiated as one of the configuration classes of the library that will be instantiated as one of the configuration classes of the library
when created with the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` when created with the `AutoConfig.from_pretrained(pretrained_model_name_or_path)`
class method. class method.
...@@ -76,7 +76,7 @@ class AutoConfig(object): ...@@ -76,7 +76,7 @@ class AutoConfig(object):
pretrained_model_name_or_path: either: pretrained_model_name_or_path: either:
- a string with the `shortcut name` of a pre-trained model configuration to load from cache or download, e.g.: ``bert-base-uncased``. - a string with the `shortcut name` of a pre-trained model configuration to load from cache or download, e.g.: ``bert-base-uncased``.
- a path to a `directory` containing a configuration file saved using the :func:`~pytorch_transformers.PretrainedConfig.save_pretrained` method, e.g.: ``./my_model_directory/``. - a path to a `directory` containing a configuration file saved using the :func:`~transformers.PretrainedConfig.save_pretrained` method, e.g.: ``./my_model_directory/``.
- a path or url to a saved configuration JSON `file`, e.g.: ``./my_model_directory/configuration.json``. - a path or url to a saved configuration JSON `file`, e.g.: ``./my_model_directory/configuration.json``.
cache_dir: (`optional`) string: cache_dir: (`optional`) string:
......
...@@ -45,7 +45,7 @@ BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { ...@@ -45,7 +45,7 @@ BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
class BertConfig(PretrainedConfig): class BertConfig(PretrainedConfig):
r""" r"""
:class:`~pytorch_transformers.BertConfig` is the configuration class to store the configuration of a :class:`~transformers.BertConfig` is the configuration class to store the configuration of a
`BertModel`. `BertModel`.
...@@ -58,7 +58,7 @@ class BertConfig(PretrainedConfig): ...@@ -58,7 +58,7 @@ class BertConfig(PretrainedConfig):
intermediate_size: The size of the "intermediate" (i.e., feed-forward) intermediate_size: The size of the "intermediate" (i.e., feed-forward)
layer in the Transformer encoder. layer in the Transformer encoder.
hidden_act: The non-linear activation function (function or string) in the hidden_act: The non-linear activation function (function or string) in the
encoder and pooler. If string, "gelu", "relu" and "swish" are supported. encoder and pooler. If string, "gelu", "relu", "swish" and "gelu_new" are supported.
hidden_dropout_prob: The dropout probabilitiy for all fully connected hidden_dropout_prob: The dropout probabilitiy for all fully connected
layers in the embeddings, encoder, and pooler. layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob: The dropout ratio for the attention attention_probs_dropout_prob: The dropout ratio for the attention
......
...@@ -37,7 +37,7 @@ class DistilBertConfig(PretrainedConfig): ...@@ -37,7 +37,7 @@ class DistilBertConfig(PretrainedConfig):
def __init__(self, def __init__(self,
vocab_size_or_config_json_file=30522, vocab_size_or_config_json_file=30522,
max_position_embeddings=512, max_position_embeddings=512,
sinusoidal_pos_embds=True, sinusoidal_pos_embds=False,
n_layers=6, n_layers=6,
n_heads=12, n_heads=12,
dim=768, dim=768,
......
...@@ -36,7 +36,6 @@ class OpenAIGPTConfig(PretrainedConfig): ...@@ -36,7 +36,6 @@ class OpenAIGPTConfig(PretrainedConfig):
Args: Args:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `OpenAIGPTModel` or a configuration json file. vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `OpenAIGPTModel` or a configuration json file.
n_special: The number of special tokens to learn during fine-tuning ('[SEP]', '[CLF]', ...)
n_positions: Number of positional embeddings. n_positions: Number of positional embeddings.
n_ctx: Size of the causal mask (usually same as n_positions). n_ctx: Size of the causal mask (usually same as n_positions).
n_embd: Dimensionality of the embeddings and hidden states. n_embd: Dimensionality of the embeddings and hidden states.
......
...@@ -95,10 +95,43 @@ class TransfoXLConfig(PretrainedConfig): ...@@ -95,10 +95,43 @@ class TransfoXLConfig(PretrainedConfig):
init_range=0.01, init_range=0.01,
proj_init_std=0.01, proj_init_std=0.01,
init_std=0.02, init_std=0.02,
layer_norm_epsilon=1e-5,
**kwargs): **kwargs):
"""Constructs TransfoXLConfig. """Constructs TransfoXLConfig.
""" """
super(TransfoXLConfig, self).__init__(**kwargs) super(TransfoXLConfig, self).__init__(**kwargs)
self.n_token = vocab_size_or_config_json_file if isinstance(vocab_size_or_config_json_file, int) else -1
self.cutoffs = []
self.cutoffs.extend(cutoffs)
self.tie_weight = tie_weight
if proj_share_all_but_first:
self.tie_projs = [False] + [True] * len(self.cutoffs)
else:
self.tie_projs = [False] + [False] * len(self.cutoffs)
self.d_model = d_model
self.d_embed = d_embed
self.d_head = d_head
self.d_inner = d_inner
self.div_val = div_val
self.pre_lnorm = pre_lnorm
self.n_layer = n_layer
self.n_head = n_head
self.tgt_len = tgt_len
self.ext_len = ext_len
self.mem_len = mem_len
self.same_length = same_length
self.attn_type = attn_type
self.clamp_len = clamp_len
self.sample_softmax = sample_softmax
self.adaptive = adaptive
self.dropout = dropout
self.dropatt = dropatt
self.untie_r = untie_r
self.init = init
self.init_range = init_range
self.proj_init_std = proj_init_std
self.init_std = init_std
self.layer_norm_epsilon = layer_norm_epsilon
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)): and isinstance(vocab_size_or_config_json_file, unicode)):
...@@ -106,39 +139,7 @@ class TransfoXLConfig(PretrainedConfig): ...@@ -106,39 +139,7 @@ class TransfoXLConfig(PretrainedConfig):
json_config = json.loads(reader.read()) json_config = json.loads(reader.read())
for key, value in json_config.items(): for key, value in json_config.items():
self.__dict__[key] = value self.__dict__[key] = value
elif isinstance(vocab_size_or_config_json_file, int): elif not isinstance(vocab_size_or_config_json_file, int):
self.n_token = vocab_size_or_config_json_file
self.cutoffs = []
self.cutoffs.extend(cutoffs)
self.tie_weight = tie_weight
if proj_share_all_but_first:
self.tie_projs = [False] + [True] * len(self.cutoffs)
else:
self.tie_projs = [False] + [False] * len(self.cutoffs)
self.d_model = d_model
self.d_embed = d_embed
self.d_head = d_head
self.d_inner = d_inner
self.div_val = div_val
self.pre_lnorm = pre_lnorm
self.n_layer = n_layer
self.n_head = n_head
self.tgt_len = tgt_len
self.ext_len = ext_len
self.mem_len = mem_len
self.same_length = same_length
self.attn_type = attn_type
self.clamp_len = clamp_len
self.sample_softmax = sample_softmax
self.adaptive = adaptive
self.dropout = dropout
self.dropatt = dropatt
self.untie_r = untie_r
self.init = init
self.init_range = init_range
self.proj_init_std = proj_init_std
self.init_std = init_std
else:
raise ValueError("First argument must be either a vocabulary size (int)" raise ValueError("First argument must be either a vocabulary size (int)"
" or the path to a pretrained model config file (str)") " or the path to a pretrained model config file (str)")
......
...@@ -54,11 +54,12 @@ class PretrainedConfig(object): ...@@ -54,11 +54,12 @@ class PretrainedConfig(object):
self.output_attentions = kwargs.pop('output_attentions', False) self.output_attentions = kwargs.pop('output_attentions', False)
self.output_hidden_states = kwargs.pop('output_hidden_states', False) self.output_hidden_states = kwargs.pop('output_hidden_states', False)
self.torchscript = kwargs.pop('torchscript', False) self.torchscript = kwargs.pop('torchscript', False)
self.use_bfloat16 = kwargs.pop('use_bfloat16', False)
self.pruned_heads = kwargs.pop('pruned_heads', {}) self.pruned_heads = kwargs.pop('pruned_heads', {})
def save_pretrained(self, save_directory): def save_pretrained(self, save_directory):
""" Save a configuration object to the directory `save_directory`, so that it """ Save a configuration object to the directory `save_directory`, so that it
can be re-loaded using the :func:`~pytorch_transformers.PretrainedConfig.from_pretrained` class method. can be re-loaded using the :func:`~transformers.PretrainedConfig.from_pretrained` class method.
""" """
assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved" assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved"
...@@ -66,16 +67,17 @@ class PretrainedConfig(object): ...@@ -66,16 +67,17 @@ class PretrainedConfig(object):
output_config_file = os.path.join(save_directory, CONFIG_NAME) output_config_file = os.path.join(save_directory, CONFIG_NAME)
self.to_json_file(output_config_file) self.to_json_file(output_config_file)
logger.info("Configuration saved in {}".format(output_config_file))
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
r""" Instantiate a :class:`~pytorch_transformers.PretrainedConfig` (or a derived class) from a pre-trained model configuration. r""" Instantiate a :class:`~transformers.PretrainedConfig` (or a derived class) from a pre-trained model configuration.
Parameters: Parameters:
pretrained_model_name_or_path: either: pretrained_model_name_or_path: either:
- a string with the `shortcut name` of a pre-trained model configuration to load from cache or download, e.g.: ``bert-base-uncased``. - a string with the `shortcut name` of a pre-trained model configuration to load from cache or download, e.g.: ``bert-base-uncased``.
- a path to a `directory` containing a configuration file saved using the :func:`~pytorch_transformers.PretrainedConfig.save_pretrained` method, e.g.: ``./my_model_directory/``. - a path to a `directory` containing a configuration file saved using the :func:`~transformers.PretrainedConfig.save_pretrained` method, e.g.: ``./my_model_directory/``.
- a path or url to a saved configuration JSON `file`, e.g.: ``./my_model_directory/configuration.json``. - a path or url to a saved configuration JSON `file`, e.g.: ``./my_model_directory/configuration.json``.
cache_dir: (`optional`) string: cache_dir: (`optional`) string:
...@@ -174,7 +176,7 @@ class PretrainedConfig(object): ...@@ -174,7 +176,7 @@ class PretrainedConfig(object):
"""Constructs a `Config` from a Python dictionary of parameters.""" """Constructs a `Config` from a Python dictionary of parameters."""
config = cls(vocab_size_or_config_json_file=-1) config = cls(vocab_size_or_config_json_file=-1)
for key, value in json_object.items(): for key, value in json_object.items():
config.__dict__[key] = value setattr(config, key, value)
return config return config
@classmethod @classmethod
......
...@@ -56,8 +56,6 @@ class XLMConfig(PretrainedConfig): ...@@ -56,8 +56,6 @@ class XLMConfig(PretrainedConfig):
dropout: The dropout probabilitiy for all fully connected dropout: The dropout probabilitiy for all fully connected
layers in the embeddings, encoder, and pooler. layers in the embeddings, encoder, and pooler.
dropatt: The dropout ratio for the attention
probabilities.
max_position_embeddings: The maximum sequence length that this model might max_position_embeddings: The maximum sequence length that this model might
ever be used with. Typically set this to something large just in case ever be used with. Typically set this to something large just in case
(e.g., 512 or 1024 or 2048). (e.g., 512 or 1024 or 2048).
...@@ -66,7 +64,6 @@ class XLMConfig(PretrainedConfig): ...@@ -66,7 +64,6 @@ class XLMConfig(PretrainedConfig):
layer_norm_eps: The epsilon used by LayerNorm. layer_norm_eps: The epsilon used by LayerNorm.
dropout: float, dropout rate. dropout: float, dropout rate.
dropatt: float, dropout rate on attention probabilities.
init: str, the initialization scheme, either "normal" or "uniform". init: str, the initialization scheme, either "normal" or "uniform".
init_range: float, initialize the parameters with a uniform distribution init_range: float, initialize the parameters with a uniform distribution
in [-init_range, init_range]. Only effective when init="uniform". in [-init_range, init_range]. Only effective when init="uniform".
......
...@@ -49,14 +49,11 @@ class XLNetConfig(PretrainedConfig): ...@@ -49,14 +49,11 @@ class XLNetConfig(PretrainedConfig):
dropout: The dropout probabilitiy for all fully connected dropout: The dropout probabilitiy for all fully connected
layers in the embeddings, encoder, and pooler. layers in the embeddings, encoder, and pooler.
dropatt: The dropout ratio for the attention
probabilities.
initializer_range: The sttdev of the truncated_normal_initializer for initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices. initializing all weight matrices.
layer_norm_eps: The epsilon used by LayerNorm. layer_norm_eps: The epsilon used by LayerNorm.
dropout: float, dropout rate. dropout: float, dropout rate.
dropatt: float, dropout rate on attention probabilities.
init: str, the initialization scheme, either "normal" or "uniform". init: str, the initialization scheme, either "normal" or "uniform".
init_range: float, initialize the parameters with a uniform distribution init_range: float, initialize the parameters with a uniform distribution
in [-init_range, init_range]. Only effective when init="uniform". in [-init_range, init_range]. Only effective when init="uniform".
...@@ -80,6 +77,7 @@ class XLNetConfig(PretrainedConfig): ...@@ -80,6 +77,7 @@ class XLNetConfig(PretrainedConfig):
n_layer=24, n_layer=24,
n_head=16, n_head=16,
d_inner=4096, d_inner=4096,
max_position_embeddings=512,
ff_activation="gelu", ff_activation="gelu",
untie_r=True, untie_r=True,
attn_type="bi", attn_type="bi",
...@@ -112,7 +110,7 @@ class XLNetConfig(PretrainedConfig): ...@@ -112,7 +110,7 @@ class XLNetConfig(PretrainedConfig):
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
json_config = json.loads(reader.read()) json_config = json.loads(reader.read())
for key, value in json_config.items(): for key, value in json_config.items():
self.__dict__[key] = value setattr(config, key, value)
elif isinstance(vocab_size_or_config_json_file, int): elif isinstance(vocab_size_or_config_json_file, int):
self.n_token = vocab_size_or_config_json_file self.n_token = vocab_size_or_config_json_file
self.d_model = d_model self.d_model = d_model
......
...@@ -21,7 +21,7 @@ from __future__ import print_function ...@@ -21,7 +21,7 @@ from __future__ import print_function
import argparse import argparse
import torch import torch
from pytorch_transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
import logging import logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
......
...@@ -20,7 +20,7 @@ import argparse ...@@ -20,7 +20,7 @@ import argparse
import torch import torch
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from pytorch_transformers import BertModel from transformers import BertModel
def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:str): def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:str):
......
...@@ -21,7 +21,7 @@ from io import open ...@@ -21,7 +21,7 @@ from io import open
import torch import torch
from pytorch_transformers import (CONFIG_NAME, WEIGHTS_NAME, from transformers import (CONFIG_NAME, WEIGHTS_NAME,
GPT2Config, GPT2Config,
GPT2Model, GPT2Model,
load_tf_weights_in_gpt2) load_tf_weights_in_gpt2)
......
...@@ -21,7 +21,7 @@ from io import open ...@@ -21,7 +21,7 @@ from io import open
import torch import torch
from pytorch_transformers import (CONFIG_NAME, WEIGHTS_NAME, from transformers import (CONFIG_NAME, WEIGHTS_NAME,
OpenAIGPTConfig, OpenAIGPTConfig,
OpenAIGPTModel, OpenAIGPTModel,
load_tf_weights_in_openai_gpt) load_tf_weights_in_openai_gpt)
......
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Convert pytorch checkpoints to TensorFlow """
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import argparse
import tensorflow as tf
from transformers import is_torch_available, cached_path
from transformers import (BertConfig, TFBertForPreTraining, TFBertForQuestionAnswering, TFBertForSequenceClassification, load_bert_pt_weights_in_tf2, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
GPT2Config, TFGPT2LMHeadModel, load_gpt2_pt_weights_in_tf2, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLNetConfig, TFXLNetLMHeadModel, load_xlnet_pt_weights_in_tf2, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLMConfig, TFXLMWithLMHeadModel, load_xlm_pt_weights_in_tf2, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
TransfoXLConfig, TFTransfoXLLMHeadModel, load_transfo_xl_pt_weights_in_tf2, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, load_openai_gpt_pt_weights_in_tf2, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
RobertaConfig, TFRobertaForMaskedLM, TFRobertaForSequenceClassification, load_roberta_pt_weights_in_tf2, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
DistilBertConfig, TFDistilBertForMaskedLM, TFDistilBertForQuestionAnswering, load_distilbert_pt_weights_in_tf2, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP)
if is_torch_available():
import torch
import numpy as np
from transformers import (BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
DistilBertForMaskedLM, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
else:
(BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
DistilBertForMaskedLM, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,) = (
None, None, None, None,
None, None,
None, None,
None, None,
None, None,
None, None,
None, None, None,
None, None, None,)
import logging
logging.basicConfig(level=logging.INFO)
MODEL_CLASSES = {
'bert': (BertConfig, TFBertForPreTraining, load_bert_pt_weights_in_tf2, BertForPreTraining, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
'bert-large-uncased-whole-word-masking-finetuned-squad': (BertConfig, TFBertForQuestionAnswering, load_bert_pt_weights_in_tf2, BertForQuestionAnswering, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
'bert-large-cased-whole-word-masking-finetuned-squad': (BertConfig, TFBertForQuestionAnswering, load_bert_pt_weights_in_tf2, BertForQuestionAnswering, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
'bert-base-cased-finetuned-mrpc': (BertConfig, TFBertForSequenceClassification, load_bert_pt_weights_in_tf2, BertForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
'gpt2': (GPT2Config, TFGPT2LMHeadModel, load_gpt2_pt_weights_in_tf2, GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP),
'xlnet': (XLNetConfig, TFXLNetLMHeadModel, load_xlnet_pt_weights_in_tf2, XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP),
'xlm': (XLMConfig, TFXLMWithLMHeadModel, load_xlm_pt_weights_in_tf2, XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP),
'transfo-xl': (TransfoXLConfig, TFTransfoXLLMHeadModel, load_transfo_xl_pt_weights_in_tf2, TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP),
'openai-gpt': (OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, load_openai_gpt_pt_weights_in_tf2, OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP),
'roberta': (RobertaConfig, TFRobertaForMaskedLM, load_roberta_pt_weights_in_tf2, RobertaForMaskedLM, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP),
'roberta-large-mnli': (RobertaConfig, TFRobertaForSequenceClassification, load_roberta_pt_weights_in_tf2, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP),
'distilbert': (DistilBertConfig, TFDistilBertForMaskedLM, load_distilbert_pt_weights_in_tf2, DistilBertForMaskedLM, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
'distilbert-base-uncased-distilled-squad': (DistilBertConfig, TFDistilBertForQuestionAnswering, load_distilbert_pt_weights_in_tf2, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
}
def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False, use_cached_models=True):
if model_type not in MODEL_CLASSES:
raise ValueError("Unrecognized model type, should be one of {}.".format(list(MODEL_CLASSES.keys())))
config_class, model_class, loading_fct, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type]
# Initialise TF model
if config_file in aws_config_map:
config_file = cached_path(aws_config_map[config_file], force_download=not use_cached_models)
config = config_class.from_json_file(config_file)
config.output_hidden_states = True
config.output_attentions = True
print("Building TensorFlow model from configuration: {}".format(str(config)))
tf_model = model_class(config)
# Load weights from tf checkpoint
if pytorch_checkpoint_path in aws_model_maps:
pytorch_checkpoint_path = cached_path(aws_model_maps[pytorch_checkpoint_path], force_download=not use_cached_models)
tf_model = loading_fct(tf_model, pytorch_checkpoint_path)
if compare_with_pt_model:
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
tf_inputs = tf.constant(inputs_list)
tfo = tf_model(tf_inputs, training=False) # build the network
pt_model = pt_model_class.from_pretrained(None,
config=config,
state_dict=torch.load(pytorch_checkpoint_path,
map_location='cpu'))
pt_inputs = torch.tensor(inputs_list)
with torch.no_grad():
pto = pt_model(pt_inputs)
np_pt = pto[0].detach().numpy()
np_tf = tfo[0].numpy()
diff = np.amax(np.abs(np_pt - np_tf))
print("Max absolute difference between models outputs {}".format(diff))
assert diff <= 2e-2, "Error, model absolute difference is >2e-2"
# Save pytorch-model
print("Save TensorFlow model to {}".format(tf_dump_path))
tf_model.save_weights(tf_dump_path, save_format='h5')
def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortcut_names_or_path=None, config_shortcut_names_or_path=None,
compare_with_pt_model=False, use_cached_models=False, only_convert_finetuned_models=False):
assert os.path.isdir(args.tf_dump_path), "--tf_dump_path should be a directory"
if args_model_type is None:
model_types = list(MODEL_CLASSES.keys())
else:
model_types = [args_model_type]
for j, model_type in enumerate(model_types, start=1):
print("=" * 100)
print(" Converting model type {}/{}: {}".format(j, len(model_types), model_type))
print("=" * 100)
if model_type not in MODEL_CLASSES:
raise ValueError("Unrecognized model type {}, should be one of {}.".format(model_type, list(MODEL_CLASSES.keys())))
config_class, model_class, loading_fct, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type]
if model_shortcut_names_or_path is None:
model_shortcut_names_or_path = list(aws_model_maps.keys())
if config_shortcut_names_or_path is None:
config_shortcut_names_or_path = model_shortcut_names_or_path
for i, (model_shortcut_name, config_shortcut_name) in enumerate(
zip(model_shortcut_names_or_path, config_shortcut_names_or_path), start=1):
print("-" * 100)
if '-squad' in model_shortcut_name or '-mrpc' in model_shortcut_name or '-mnli' in model_shortcut_name:
if not only_convert_finetuned_models:
print(" Skipping finetuned checkpoint {}".format(model_shortcut_name))
continue
model_type = model_shortcut_name
elif only_convert_finetuned_models:
print(" Skipping not finetuned checkpoint {}".format(model_shortcut_name))
continue
print(" Converting checkpoint {}/{}: {} - model_type {}".format(i, len(aws_config_map), model_shortcut_name, model_type))
print("-" * 100)
if config_shortcut_name in aws_config_map:
config_file = cached_path(aws_config_map[config_shortcut_name], force_download=not use_cached_models)
else:
config_file = cached_path(config_shortcut_name, force_download=not use_cached_models)
if model_shortcut_name in aws_model_maps:
model_file = cached_path(aws_model_maps[model_shortcut_name], force_download=not use_cached_models)
else:
model_file = cached_path(model_shortcut_name, force_download=not use_cached_models)
convert_pt_checkpoint_to_tf(model_type,
model_file,
config_file,
os.path.join(tf_dump_path, model_shortcut_name + '-tf_model.h5'),
compare_with_pt_model=compare_with_pt_model)
os.remove(config_file)
os.remove(model_file)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--tf_dump_path",
default = None,
type = str,
required = True,
help = "Path to the output Tensorflow dump file.")
parser.add_argument("--model_type",
default = None,
type = str,
help = "Model type selected in the list of {}. If not given, will download and convert all the models from AWS.".format(list(MODEL_CLASSES.keys())))
parser.add_argument("--pytorch_checkpoint_path",
default = None,
type = str,
help = "Path to the PyTorch checkpoint path or shortcut name to download from AWS. "
"If not given, will download and convert all the checkpoints from AWS.")
parser.add_argument("--config_file",
default = None,
type = str,
help = "The config json file corresponding to the pre-trained model. \n"
"This specifies the model architecture. If not given and "
"--pytorch_checkpoint_path is not given or is a shortcut name"
"use the configuration associated to the shortcut name on the AWS")
parser.add_argument("--compare_with_pt_model",
action='store_true',
help = "Compare Tensorflow and PyTorch model predictions.")
parser.add_argument("--use_cached_models",
action='store_true',
help = "Use cached models if possible instead of updating to latest checkpoint versions.")
parser.add_argument("--only_convert_finetuned_models",
action='store_true',
help = "Only convert finetuned models.")
args = parser.parse_args()
# if args.pytorch_checkpoint_path is not None:
# convert_pt_checkpoint_to_tf(args.model_type.lower(),
# args.pytorch_checkpoint_path,
# args.config_file if args.config_file is not None else args.pytorch_checkpoint_path,
# args.tf_dump_path,
# compare_with_pt_model=args.compare_with_pt_model,
# use_cached_models=args.use_cached_models)
# else:
convert_all_pt_checkpoints_to_tf(args.model_type.lower() if args.model_type is not None else None,
args.tf_dump_path,
model_shortcut_names_or_path=[args.pytorch_checkpoint_path] if args.pytorch_checkpoint_path is not None else None,
compare_with_pt_model=args.compare_with_pt_model,
use_cached_models=args.use_cached_models,
only_convert_finetuned_models=args.only_convert_finetuned_models)
...@@ -23,12 +23,12 @@ import torch ...@@ -23,12 +23,12 @@ import torch
from fairseq.models.roberta import RobertaModel as FairseqRobertaModel from fairseq.models.roberta import RobertaModel as FairseqRobertaModel
from fairseq.modules import TransformerSentenceEncoderLayer from fairseq.modules import TransformerSentenceEncoderLayer
from pytorch_transformers import (BertConfig, BertEncoder, from transformers import (BertConfig, BertEncoder,
BertIntermediate, BertLayer, BertIntermediate, BertLayer,
BertModel, BertOutput, BertModel, BertOutput,
BertSelfAttention, BertSelfAttention,
BertSelfOutput) BertSelfOutput)
from pytorch_transformers import (RobertaEmbeddings, from transformers import (RobertaEmbeddings,
RobertaForMaskedLM, RobertaForMaskedLM,
RobertaForSequenceClassification, RobertaForSequenceClassification,
RobertaModel) RobertaModel)
......
...@@ -23,12 +23,12 @@ from io import open ...@@ -23,12 +23,12 @@ from io import open
import torch import torch
import pytorch_transformers.tokenization_transfo_xl as data_utils import transformers.tokenization_transfo_xl as data_utils
from pytorch_transformers import CONFIG_NAME, WEIGHTS_NAME from transformers import CONFIG_NAME, WEIGHTS_NAME
from pytorch_transformers import (TransfoXLConfig, TransfoXLLMHeadModel, from transformers import (TransfoXLConfig, TransfoXLLMHeadModel,
load_tf_weights_in_transfo_xl) load_tf_weights_in_transfo_xl)
from pytorch_transformers.tokenization_transfo_xl import (CORPUS_NAME, VOCAB_FILES_NAMES) from transformers.tokenization_transfo_xl import (CORPUS_NAME, VOCAB_FILES_NAMES)
if sys.version_info[0] == 2: if sys.version_info[0] == 2:
import cPickle as pickle import cPickle as pickle
......
...@@ -23,8 +23,8 @@ from io import open ...@@ -23,8 +23,8 @@ from io import open
import torch import torch
import numpy import numpy
from pytorch_transformers import CONFIG_NAME, WEIGHTS_NAME from transformers import CONFIG_NAME, WEIGHTS_NAME
from pytorch_transformers.tokenization_xlm import VOCAB_FILES_NAMES from transformers.tokenization_xlm import VOCAB_FILES_NAMES
import logging import logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
...@@ -33,7 +33,15 @@ def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_p ...@@ -33,7 +33,15 @@ def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_p
# Load checkpoint # Load checkpoint
chkpt = torch.load(xlm_checkpoint_path, map_location='cpu') chkpt = torch.load(xlm_checkpoint_path, map_location='cpu')
model = chkpt['model'] state_dict = chkpt['model']
# We have the base model one level deeper than the original XLM repository
two_levels_state_dict = {}
for k, v in state_dict.items():
if 'pred_layer' in k:
two_levels_state_dict[k] = v
else:
two_levels_state_dict['transformer.' + k] = v
config = chkpt['params'] config = chkpt['params']
config = dict((n, v) for n, v in config.items() if not isinstance(v, (torch.FloatTensor, numpy.ndarray))) config = dict((n, v) for n, v in config.items() if not isinstance(v, (torch.FloatTensor, numpy.ndarray)))
...@@ -47,7 +55,7 @@ def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_p ...@@ -47,7 +55,7 @@ def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_p
pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_FILES_NAMES['vocab_file'] pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_FILES_NAMES['vocab_file']
print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
torch.save(model, pytorch_weights_dump_path) torch.save(two_levels_state_dict, pytorch_weights_dump_path)
print("Save configuration file to {}".format(pytorch_config_dump_path)) print("Save configuration file to {}".format(pytorch_config_dump_path))
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
......
...@@ -22,7 +22,7 @@ import os ...@@ -22,7 +22,7 @@ import os
import argparse import argparse
import torch import torch
from pytorch_transformers import (CONFIG_NAME, WEIGHTS_NAME, from transformers import (CONFIG_NAME, WEIGHTS_NAME,
XLNetConfig, XLNetConfig,
XLNetLMHeadModel, XLNetForQuestionAnswering, XLNetLMHeadModel, XLNetForQuestionAnswering,
XLNetForSequenceClassification, XLNetForSequenceClassification,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment