Commit fa84ae26 authored by Aymeric Augustin's avatar Aymeric Augustin
Browse files

Reformat source code with black.

This is the result of:

    $ black --line-length 119 examples templates transformers utils hubconf.py setup.py

There's a lot of fairly long lines in the project. As a consequence, I'm
picking the longest widely accepted line length, 119 characters.

This is also Thomas' preference, because it allows for explicit variable
names, to make the code easier to understand.
parent 63e3827c
...@@ -35,7 +35,8 @@ from .configuration_xlm_roberta import XLMRobertaConfig, XLM_ROBERTA_PRETRAINED_ ...@@ -35,7 +35,8 @@ from .configuration_xlm_roberta import XLMRobertaConfig, XLM_ROBERTA_PRETRAINED_
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict((key, value) ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
(key, value)
for pretrained_map in [ for pretrained_map in [
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
...@@ -50,8 +51,9 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict((key, value) ...@@ -50,8 +51,9 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict((key, value)
CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
] ]
for key, value, in pretrained_map.items()) for key, value, in pretrained_map.items()
)
class AutoConfig(object): class AutoConfig(object):
...@@ -79,37 +81,42 @@ class AutoConfig(object): ...@@ -79,37 +81,42 @@ class AutoConfig(object):
- contains `ctrl` : CTRLConfig (CTRL model) - contains `ctrl` : CTRLConfig (CTRL model)
This class cannot be instantiated using `__init__()` (throw an error). This class cannot be instantiated using `__init__()` (throw an error).
""" """
def __init__(self): def __init__(self):
raise EnvironmentError("AutoConfig is designed to be instantiated " raise EnvironmentError(
"using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method.") "AutoConfig is designed to be instantiated "
"using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method."
)
@classmethod @classmethod
def for_model(cls, model_type, *args, **kwargs): def for_model(cls, model_type, *args, **kwargs):
if 'distilbert' in model_type: if "distilbert" in model_type:
return DistilBertConfig(*args, **kwargs) return DistilBertConfig(*args, **kwargs)
elif 'roberta' in model_type: elif "roberta" in model_type:
return RobertaConfig(*args, **kwargs) return RobertaConfig(*args, **kwargs)
elif 'bert' in model_type: elif "bert" in model_type:
return BertConfig(*args, **kwargs) return BertConfig(*args, **kwargs)
elif 'openai-gpt' in model_type: elif "openai-gpt" in model_type:
return OpenAIGPTConfig(*args, **kwargs) return OpenAIGPTConfig(*args, **kwargs)
elif 'gpt2' in model_type: elif "gpt2" in model_type:
return GPT2Config(*args, **kwargs) return GPT2Config(*args, **kwargs)
elif 'transfo-xl' in model_type: elif "transfo-xl" in model_type:
return TransfoXLConfig(*args, **kwargs) return TransfoXLConfig(*args, **kwargs)
elif 'xlnet' in model_type: elif "xlnet" in model_type:
return XLNetConfig(*args, **kwargs) return XLNetConfig(*args, **kwargs)
elif 'xlm' in model_type: elif "xlm" in model_type:
return XLMConfig(*args, **kwargs) return XLMConfig(*args, **kwargs)
elif 'ctrl' in model_type: elif "ctrl" in model_type:
return CTRLConfig(*args, **kwargs) return CTRLConfig(*args, **kwargs)
elif 'albert' in model_type: elif "albert" in model_type:
return AlbertConfig(*args, **kwargs) return AlbertConfig(*args, **kwargs)
elif 'camembert' in model_type: elif "camembert" in model_type:
return CamembertConfig(*args, **kwargs) return CamembertConfig(*args, **kwargs)
raise ValueError("Unrecognized model identifier in {}. Should contains one of " raise ValueError(
"'distilbert', 'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " "Unrecognized model identifier in {}. Should contains one of "
"'xlm', 'roberta', 'ctrl', 'camembert', 'albert'".format(model_type)) "'distilbert', 'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm', 'roberta', 'ctrl', 'camembert', 'albert'".format(model_type)
)
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
...@@ -176,32 +183,36 @@ class AutoConfig(object): ...@@ -176,32 +183,36 @@ class AutoConfig(object):
assert unused_kwargs == {'foo': False} assert unused_kwargs == {'foo': False}
""" """
if 't5' in pretrained_model_name_or_path: if "t5" in pretrained_model_name_or_path:
return T5Config.from_pretrained(pretrained_model_name_or_path, **kwargs) return T5Config.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'distilbert' in pretrained_model_name_or_path: elif "distilbert" in pretrained_model_name_or_path:
return DistilBertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) return DistilBertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'albert' in pretrained_model_name_or_path: elif "albert" in pretrained_model_name_or_path:
return AlbertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) return AlbertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'camembert' in pretrained_model_name_or_path: elif "camembert" in pretrained_model_name_or_path:
return CamembertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) return CamembertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'xlm-roberta' in pretrained_model_name_or_path: elif "xlm-roberta" in pretrained_model_name_or_path:
return XLMRobertaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) return XLMRobertaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'roberta' in pretrained_model_name_or_path: elif "roberta" in pretrained_model_name_or_path:
return RobertaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) return RobertaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'bert' in pretrained_model_name_or_path: elif "bert" in pretrained_model_name_or_path:
return BertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) return BertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'openai-gpt' in pretrained_model_name_or_path: elif "openai-gpt" in pretrained_model_name_or_path:
return OpenAIGPTConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) return OpenAIGPTConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'gpt2' in pretrained_model_name_or_path: elif "gpt2" in pretrained_model_name_or_path:
return GPT2Config.from_pretrained(pretrained_model_name_or_path, **kwargs) return GPT2Config.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'transfo-xl' in pretrained_model_name_or_path: elif "transfo-xl" in pretrained_model_name_or_path:
return TransfoXLConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) return TransfoXLConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'xlnet' in pretrained_model_name_or_path: elif "xlnet" in pretrained_model_name_or_path:
return XLNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) return XLNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'xlm' in pretrained_model_name_or_path: elif "xlm" in pretrained_model_name_or_path:
return XLMConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) return XLMConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'ctrl' in pretrained_model_name_or_path: elif "ctrl" in pretrained_model_name_or_path:
return CTRLConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) return CTRLConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
raise ValueError("Unrecognized model identifier in {}. Should contains one of " raise ValueError(
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " "Unrecognized model identifier in {}. Should contains one of "
"'xlm-roberta', 'xlm', 'roberta', 'distilbert', 'camembert', 'ctrl', 'albert'".format(pretrained_model_name_or_path)) "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm-roberta', 'xlm', 'roberta', 'distilbert', 'camembert', 'ctrl', 'albert'".format(
pretrained_model_name_or_path
)
)
...@@ -27,27 +27,27 @@ from .configuration_utils import PretrainedConfig ...@@ -27,27 +27,27 @@ from .configuration_utils import PretrainedConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json", "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json", "bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json", "bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json",
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json", "bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json",
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json", "bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json",
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json", "bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json", "bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json",
'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json", "bert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json",
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json", "bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json",
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json", "bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json",
'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json", "bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json",
'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json", "bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json",
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json", "bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json",
'bert-base-german-dbmdz-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-config.json", "bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-config.json",
'bert-base-german-dbmdz-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-config.json", "bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-config.json",
'bert-base-japanese': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-config.json", "bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-config.json",
'bert-base-japanese-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-config.json", "bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-config.json",
'bert-base-japanese-char': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-config.json", "bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-config.json",
'bert-base-japanese-char-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-config.json", "bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-config.json",
'bert-base-finnish-cased-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/config.json", "bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/config.json",
'bert-base-finnish-uncased-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/config.json", "bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/config.json",
} }
...@@ -82,20 +82,22 @@ class BertConfig(PretrainedConfig): ...@@ -82,20 +82,22 @@ class BertConfig(PretrainedConfig):
""" """
pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self, def __init__(
vocab_size=30522, self,
hidden_size=768, vocab_size=30522,
num_hidden_layers=12, hidden_size=768,
num_attention_heads=12, num_hidden_layers=12,
intermediate_size=3072, num_attention_heads=12,
hidden_act="gelu", intermediate_size=3072,
hidden_dropout_prob=0.1, hidden_act="gelu",
attention_probs_dropout_prob=0.1, hidden_dropout_prob=0.1,
max_position_embeddings=512, attention_probs_dropout_prob=0.1,
type_vocab_size=2, max_position_embeddings=512,
initializer_range=0.02, type_vocab_size=2,
layer_norm_eps=1e-12, initializer_range=0.02,
**kwargs): layer_norm_eps=1e-12,
**kwargs
):
super(BertConfig, self).__init__(**kwargs) super(BertConfig, self).__init__(**kwargs)
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.hidden_size = hidden_size self.hidden_size = hidden_size
......
...@@ -15,8 +15,7 @@ ...@@ -15,8 +15,7 @@
# limitations under the License. # limitations under the License.
""" CamemBERT configuration """ """ CamemBERT configuration """
from __future__ import (absolute_import, division, print_function, from __future__ import absolute_import, division, print_function, unicode_literals
unicode_literals)
import logging import logging
...@@ -25,7 +24,7 @@ from .configuration_roberta import RobertaConfig ...@@ -25,7 +24,7 @@ from .configuration_roberta import RobertaConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'camembert-base': "https://s3.amazonaws.com/models.huggingface.co/bert/camembert-base-config.json", "camembert-base": "https://s3.amazonaws.com/models.huggingface.co/bert/camembert-base-config.json",
} }
......
...@@ -27,6 +27,7 @@ logger = logging.getLogger(__name__) ...@@ -27,6 +27,7 @@ logger = logging.getLogger(__name__)
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP = {"ctrl": "https://storage.googleapis.com/sf-ctrl/pytorch/ctrl-config.json"} CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP = {"ctrl": "https://storage.googleapis.com/sf-ctrl/pytorch/ctrl-config.json"}
class CTRLConfig(PretrainedConfig): class CTRLConfig(PretrainedConfig):
"""Configuration class to store the configuration of a `CTRLModel`. """Configuration class to store the configuration of a `CTRLModel`.
...@@ -48,6 +49,7 @@ class CTRLConfig(PretrainedConfig): ...@@ -48,6 +49,7 @@ class CTRLConfig(PretrainedConfig):
initializer_range: The sttdev of the truncated_normal_initializer for initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices. initializing all weight matrices.
""" """
pretrained_config_archive_map = CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP pretrained_config_archive_map = CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__( def __init__(
...@@ -64,7 +66,7 @@ class CTRLConfig(PretrainedConfig): ...@@ -64,7 +66,7 @@ class CTRLConfig(PretrainedConfig):
attn_pdrop=0.1, attn_pdrop=0.1,
layer_norm_epsilon=1e-6, layer_norm_epsilon=1e-6,
initializer_range=0.02, initializer_range=0.02,
summary_type='cls_index', summary_type="cls_index",
summary_use_proj=True, summary_use_proj=True,
summary_activation=None, summary_activation=None,
summary_proj_to_labels=True, summary_proj_to_labels=True,
......
...@@ -13,8 +13,7 @@ ...@@ -13,8 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" DistilBERT model configuration """ """ DistilBERT model configuration """
from __future__ import (absolute_import, division, print_function, from __future__ import absolute_import, division, print_function, unicode_literals
unicode_literals)
import sys import sys
import json import json
...@@ -26,32 +25,34 @@ from .configuration_utils import PretrainedConfig ...@@ -26,32 +25,34 @@ from .configuration_utils import PretrainedConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'distilbert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-config.json", "distilbert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-config.json",
'distilbert-base-uncased-distilled-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-distilled-squad-config.json", "distilbert-base-uncased-distilled-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-distilled-squad-config.json",
'distilbert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-german-cased-config.json", "distilbert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-german-cased-config.json",
'distilbert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-multilingual-cased-config.json", "distilbert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-multilingual-cased-config.json",
} }
class DistilBertConfig(PretrainedConfig): class DistilBertConfig(PretrainedConfig):
pretrained_config_archive_map = DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP pretrained_config_archive_map = DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self, def __init__(
vocab_size=30522, self,
max_position_embeddings=512, vocab_size=30522,
sinusoidal_pos_embds=False, max_position_embeddings=512,
n_layers=6, sinusoidal_pos_embds=False,
n_heads=12, n_layers=6,
dim=768, n_heads=12,
hidden_dim=4*768, dim=768,
dropout=0.1, hidden_dim=4 * 768,
attention_dropout=0.1, dropout=0.1,
activation='gelu', attention_dropout=0.1,
initializer_range=0.02, activation="gelu",
tie_weights_=True, initializer_range=0.02,
qa_dropout=0.1, tie_weights_=True,
seq_classif_dropout=0.2, qa_dropout=0.1,
**kwargs): seq_classif_dropout=0.2,
**kwargs
):
super(DistilBertConfig, self).__init__(**kwargs) super(DistilBertConfig, self).__init__(**kwargs)
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
......
...@@ -26,11 +26,14 @@ from .configuration_utils import PretrainedConfig ...@@ -26,11 +26,14 @@ from .configuration_utils import PretrainedConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json", GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json", "gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json",
"gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-config.json", "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json",
"gpt2-xl": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-config.json", "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-config.json",
"distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-config.json",} "gpt2-xl": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-config.json",
"distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-config.json",
}
class GPT2Config(PretrainedConfig): class GPT2Config(PretrainedConfig):
"""Configuration class to store the configuration of a `GPT2Model`. """Configuration class to store the configuration of a `GPT2Model`.
...@@ -52,6 +55,7 @@ class GPT2Config(PretrainedConfig): ...@@ -52,6 +55,7 @@ class GPT2Config(PretrainedConfig):
initializer_range: The sttdev of the truncated_normal_initializer for initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices. initializing all weight matrices.
""" """
pretrained_config_archive_map = GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP pretrained_config_archive_map = GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__( def __init__(
...@@ -67,7 +71,7 @@ class GPT2Config(PretrainedConfig): ...@@ -67,7 +71,7 @@ class GPT2Config(PretrainedConfig):
attn_pdrop=0.1, attn_pdrop=0.1,
layer_norm_epsilon=1e-5, layer_norm_epsilon=1e-5,
initializer_range=0.02, initializer_range=0.02,
summary_type='cls_index', summary_type="cls_index",
summary_use_proj=True, summary_use_proj=True,
summary_activation=None, summary_activation=None,
summary_proj_to_labels=True, summary_proj_to_labels=True,
......
...@@ -15,8 +15,7 @@ ...@@ -15,8 +15,7 @@
# limitations under the License. # limitations under the License.
""" MMBT configuration """ """ MMBT configuration """
from __future__ import (absolute_import, division, print_function, from __future__ import absolute_import, division, print_function, unicode_literals
unicode_literals)
import logging import logging
...@@ -31,6 +30,7 @@ class MMBTConfig(object): ...@@ -31,6 +30,7 @@ class MMBTConfig(object):
num_labels: Size of final Linear layer for classification. num_labels: Size of final Linear layer for classification.
modal_hidden_size: Embedding dimension of the non-text modality encoder. modal_hidden_size: Embedding dimension of the non-text modality encoder.
""" """
def __init__(self, config, num_labels=None, modal_hidden_size=2048): def __init__(self, config, num_labels=None, modal_hidden_size=2048):
self.__dict__ = config.__dict__ self.__dict__ = config.__dict__
self.modal_hidden_size = modal_hidden_size self.modal_hidden_size = modal_hidden_size
......
...@@ -30,6 +30,7 @@ OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP = { ...@@ -30,6 +30,7 @@ OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json" "openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json"
} }
class OpenAIGPTConfig(PretrainedConfig): class OpenAIGPTConfig(PretrainedConfig):
""" """
Configuration class to store the configuration of a `OpenAIGPTModel`. Configuration class to store the configuration of a `OpenAIGPTModel`.
...@@ -54,6 +55,7 @@ class OpenAIGPTConfig(PretrainedConfig): ...@@ -54,6 +55,7 @@ class OpenAIGPTConfig(PretrainedConfig):
initializing all weight matrices. initializing all weight matrices.
predict_special_tokens: should we predict special tokens (when the model has a LM head) predict_special_tokens: should we predict special tokens (when the model has a LM head)
""" """
pretrained_config_archive_map = OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP pretrained_config_archive_map = OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__( def __init__(
...@@ -71,7 +73,7 @@ class OpenAIGPTConfig(PretrainedConfig): ...@@ -71,7 +73,7 @@ class OpenAIGPTConfig(PretrainedConfig):
layer_norm_epsilon=1e-5, layer_norm_epsilon=1e-5,
initializer_range=0.02, initializer_range=0.02,
predict_special_tokens=True, predict_special_tokens=True,
summary_type='cls_index', summary_type="cls_index",
summary_use_proj=True, summary_use_proj=True,
summary_activation=None, summary_activation=None,
summary_proj_to_labels=True, summary_proj_to_labels=True,
......
...@@ -15,8 +15,7 @@ ...@@ -15,8 +15,7 @@
# limitations under the License. # limitations under the License.
""" RoBERTa configuration """ """ RoBERTa configuration """
from __future__ import (absolute_import, division, print_function, from __future__ import absolute_import, division, print_function, unicode_literals
unicode_literals)
import logging import logging
...@@ -25,12 +24,12 @@ from .configuration_bert import BertConfig ...@@ -25,12 +24,12 @@ from .configuration_bert import BertConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = { ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-config.json", "roberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-config.json",
'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-config.json", "roberta-large": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-config.json",
'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-config.json", "roberta-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-config.json",
'distilroberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/distilroberta-base-config.json", "distilroberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/distilroberta-base-config.json",
'roberta-base-openai-detector': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-openai-detector-config.json", "roberta-base-openai-detector": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-openai-detector-config.json",
'roberta-large-openai-detector': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-openai-detector-config.json", "roberta-large-openai-detector": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-openai-detector-config.json",
} }
......
...@@ -27,11 +27,11 @@ from .configuration_utils import PretrainedConfig ...@@ -27,11 +27,11 @@ from .configuration_utils import PretrainedConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
T5_PRETRAINED_CONFIG_ARCHIVE_MAP = { T5_PRETRAINED_CONFIG_ARCHIVE_MAP = {
't5-small': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-small-config.json", "t5-small": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-small-config.json",
't5-base': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-base-config.json", "t5-base": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-base-config.json",
't5-large': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-large-config.json", "t5-large": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-large-config.json",
't5-3b': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-3b-config.json", "t5-3b": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-3b-config.json",
't5-11b': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-11b-config.json", "t5-11b": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-11b-config.json",
} }
...@@ -65,19 +65,21 @@ class T5Config(PretrainedConfig): ...@@ -65,19 +65,21 @@ class T5Config(PretrainedConfig):
""" """
pretrained_config_archive_map = T5_PRETRAINED_CONFIG_ARCHIVE_MAP pretrained_config_archive_map = T5_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self, def __init__(
vocab_size=32128, self,
n_positions=512, vocab_size=32128,
d_model=512, n_positions=512,
d_kv=64, d_model=512,
d_ff=2048, d_kv=64,
num_layers=6, d_ff=2048,
num_heads=8, num_layers=6,
relative_attention_num_buckets=32, num_heads=8,
dropout_rate=0.1, relative_attention_num_buckets=32,
layer_norm_epsilon=1e-6, dropout_rate=0.1,
initializer_factor=1.0, layer_norm_epsilon=1e-6,
**kwargs): initializer_factor=1.0,
**kwargs
):
super(T5Config, self).__init__(**kwargs) super(T5Config, self).__init__(**kwargs)
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.n_positions = n_positions self.n_positions = n_positions
......
...@@ -27,9 +27,10 @@ from .configuration_utils import PretrainedConfig ...@@ -27,9 +27,10 @@ from .configuration_utils import PretrainedConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP = { TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-config.json", "transfo-xl-wt103": "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-config.json",
} }
class TransfoXLConfig(PretrainedConfig): class TransfoXLConfig(PretrainedConfig):
"""Configuration class to store the configuration of a `TransfoXLModel`. """Configuration class to store the configuration of a `TransfoXLModel`.
...@@ -65,38 +66,41 @@ class TransfoXLConfig(PretrainedConfig): ...@@ -65,38 +66,41 @@ class TransfoXLConfig(PretrainedConfig):
proj_init_std: parameters initialized by N(0, init_std) proj_init_std: parameters initialized by N(0, init_std)
init_std: parameters initialized by N(0, init_std) init_std: parameters initialized by N(0, init_std)
""" """
pretrained_config_archive_map = TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP pretrained_config_archive_map = TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self, def __init__(
vocab_size=267735, self,
cutoffs=[20000, 40000, 200000], vocab_size=267735,
d_model=1024, cutoffs=[20000, 40000, 200000],
d_embed=1024, d_model=1024,
n_head=16, d_embed=1024,
d_head=64, n_head=16,
d_inner=4096, d_head=64,
div_val=4, d_inner=4096,
pre_lnorm=False, div_val=4,
n_layer=18, pre_lnorm=False,
tgt_len=128, n_layer=18,
ext_len=0, tgt_len=128,
mem_len=1600, ext_len=0,
clamp_len=1000, mem_len=1600,
same_length=True, clamp_len=1000,
proj_share_all_but_first=True, same_length=True,
attn_type=0, proj_share_all_but_first=True,
sample_softmax=-1, attn_type=0,
adaptive=True, sample_softmax=-1,
tie_weight=True, adaptive=True,
dropout=0.1, tie_weight=True,
dropatt=0.0, dropout=0.1,
untie_r=True, dropatt=0.0,
init="normal", untie_r=True,
init_range=0.01, init="normal",
proj_init_std=0.01, init_range=0.01,
init_std=0.02, proj_init_std=0.01,
layer_norm_epsilon=1e-5, init_std=0.02,
**kwargs): layer_norm_epsilon=1e-5,
**kwargs
):
"""Constructs TransfoXLConfig. """Constructs TransfoXLConfig.
""" """
super(TransfoXLConfig, self).__init__(**kwargs) super(TransfoXLConfig, self).__init__(**kwargs)
......
...@@ -15,8 +15,7 @@ ...@@ -15,8 +15,7 @@
# limitations under the License. # limitations under the License.
""" Configuration base class and utilities.""" """ Configuration base class and utilities."""
from __future__ import (absolute_import, division, print_function, from __future__ import absolute_import, division, print_function, unicode_literals
unicode_literals)
import copy import copy
import json import json
...@@ -28,6 +27,7 @@ from .file_utils import CONFIG_NAME, cached_path, is_remote_url, hf_bucket_url ...@@ -28,6 +27,7 @@ from .file_utils import CONFIG_NAME, cached_path, is_remote_url, hf_bucket_url
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PretrainedConfig(object): class PretrainedConfig(object):
r""" Base class for all configuration classes. r""" Base class for all configuration classes.
Handles a few parameters common to all models' configurations as well as methods for loading/downloading/saving configurations. Handles a few parameters common to all models' configurations as well as methods for loading/downloading/saving configurations.
...@@ -50,36 +50,36 @@ class PretrainedConfig(object): ...@@ -50,36 +50,36 @@ class PretrainedConfig(object):
def __init__(self, **kwargs): def __init__(self, **kwargs):
# Attributes with defaults # Attributes with defaults
self.output_attentions = kwargs.pop('output_attentions', False) self.output_attentions = kwargs.pop("output_attentions", False)
self.output_hidden_states = kwargs.pop('output_hidden_states', False) self.output_hidden_states = kwargs.pop("output_hidden_states", False)
self.output_past = kwargs.pop('output_past', True) # Not used by all models self.output_past = kwargs.pop("output_past", True) # Not used by all models
self.torchscript = kwargs.pop('torchscript', False) # Only used by PyTorch models self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
self.use_bfloat16 = kwargs.pop('use_bfloat16', False) self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
self.pruned_heads = kwargs.pop('pruned_heads', {}) self.pruned_heads = kwargs.pop("pruned_heads", {})
# Is decoder is used in encoder-decoder models to differentiate encoder from decoder # Is decoder is used in encoder-decoder models to differentiate encoder from decoder
self.is_decoder = kwargs.pop('is_decoder', False) self.is_decoder = kwargs.pop("is_decoder", False)
# Parameters for sequence generation # Parameters for sequence generation
self.max_length = kwargs.pop('max_length', 20) self.max_length = kwargs.pop("max_length", 20)
self.do_sample = kwargs.pop('do_sample', False) self.do_sample = kwargs.pop("do_sample", False)
self.num_beams = kwargs.pop('num_beams', 1) self.num_beams = kwargs.pop("num_beams", 1)
self.temperature = kwargs.pop('temperature', 1.0) self.temperature = kwargs.pop("temperature", 1.0)
self.top_k = kwargs.pop('top_k', 50) self.top_k = kwargs.pop("top_k", 50)
self.top_p = kwargs.pop('top_p', 1.0) self.top_p = kwargs.pop("top_p", 1.0)
self.repetition_penalty = kwargs.pop('repetition_penalty', 1.0) self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
self.bos_token_id = kwargs.pop('bos_token_id', 0) self.bos_token_id = kwargs.pop("bos_token_id", 0)
self.pad_token_id = kwargs.pop('pad_token_id', 0) self.pad_token_id = kwargs.pop("pad_token_id", 0)
self.eos_token_ids = kwargs.pop('eos_token_ids', 0) self.eos_token_ids = kwargs.pop("eos_token_ids", 0)
self.length_penalty = kwargs.pop('length_penalty', 1.) self.length_penalty = kwargs.pop("length_penalty", 1.0)
self.num_return_sequences = kwargs.pop('num_return_sequences', 1) self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
# Fine-tuning task arguments # Fine-tuning task arguments
self.finetuning_task = kwargs.pop('finetuning_task', None) self.finetuning_task = kwargs.pop("finetuning_task", None)
self.num_labels = kwargs.pop('num_labels', 2) self.num_labels = kwargs.pop("num_labels", 2)
self.id2label = kwargs.pop('id2label', {i: 'LABEL_{}'.format(i) for i in range(self.num_labels)}) self.id2label = kwargs.pop("id2label", {i: "LABEL_{}".format(i) for i in range(self.num_labels)})
self.id2label = dict((int(key), value) for key, value in self.id2label.items()) self.id2label = dict((int(key), value) for key, value in self.id2label.items())
self.label2id = kwargs.pop('label2id', dict(zip(self.id2label.values(), self.id2label.keys()))) self.label2id = kwargs.pop("label2id", dict(zip(self.id2label.values(), self.id2label.keys())))
self.label2id = dict((key, int(value)) for key, value in self.label2id.items()) self.label2id = dict((key, int(value)) for key, value in self.label2id.items())
# Additional attributes without default values # Additional attributes without default values
...@@ -94,7 +94,9 @@ class PretrainedConfig(object): ...@@ -94,7 +94,9 @@ class PretrainedConfig(object):
""" Save a configuration object to the directory `save_directory`, so that it """ Save a configuration object to the directory `save_directory`, so that it
can be re-loaded using the :func:`~transformers.PretrainedConfig.from_pretrained` class method. can be re-loaded using the :func:`~transformers.PretrainedConfig.from_pretrained` class method.
""" """
assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved" assert os.path.isdir(
save_directory
), "Saving path should be a directory where the model and configuration can be saved"
# If we save using the predefined names, we can load using `from_pretrained` # If we save using the predefined names, we can load using `from_pretrained`
output_config_file = os.path.join(save_directory, CONFIG_NAME) output_config_file = os.path.join(save_directory, CONFIG_NAME)
...@@ -153,11 +155,11 @@ class PretrainedConfig(object): ...@@ -153,11 +155,11 @@ class PretrainedConfig(object):
assert unused_kwargs == {'foo': False} assert unused_kwargs == {'foo': False}
""" """
cache_dir = kwargs.pop('cache_dir', None) cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop('force_download', False) force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop('resume_download', False) resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop('proxies', None) proxies = kwargs.pop("proxies", None)
return_unused_kwargs = kwargs.pop('return_unused_kwargs', False) return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
if pretrained_model_name_or_path in cls.pretrained_config_archive_map: if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path] config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path]
...@@ -170,37 +172,48 @@ class PretrainedConfig(object): ...@@ -170,37 +172,48 @@ class PretrainedConfig(object):
try: try:
# Load from URL or cache if already cached # Load from URL or cache if already cached
resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download, resolved_config_file = cached_path(
proxies=proxies, resume_download=resume_download) config_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
)
# Load config # Load config
config = cls.from_json_file(resolved_config_file) config = cls.from_json_file(resolved_config_file)
except EnvironmentError: except EnvironmentError:
if pretrained_model_name_or_path in cls.pretrained_config_archive_map: if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
msg = "Couldn't reach server at '{}' to download pretrained model configuration file.".format( msg = "Couldn't reach server at '{}' to download pretrained model configuration file.".format(
config_file) config_file
)
else: else:
msg = "Model name '{}' was not found in model name list ({}). " \ msg = (
"We assumed '{}' was a path or url to a configuration file named {} or " \ "Model name '{}' was not found in model name list ({}). "
"a directory containing such a file but couldn't find any such file at this path or url.".format( "We assumed '{}' was a path or url to a configuration file named {} or "
"a directory containing such a file but couldn't find any such file at this path or url.".format(
pretrained_model_name_or_path, pretrained_model_name_or_path,
', '.join(cls.pretrained_config_archive_map.keys()), ", ".join(cls.pretrained_config_archive_map.keys()),
config_file, CONFIG_NAME) config_file,
CONFIG_NAME,
)
)
raise EnvironmentError(msg) raise EnvironmentError(msg)
except json.JSONDecodeError: except json.JSONDecodeError:
msg = "Couldn't reach server at '{}' to download configuration file or " \ msg = (
"configuration file is not a valid JSON file. " \ "Couldn't reach server at '{}' to download configuration file or "
"Please check network or file content here: {}.".format(config_file, resolved_config_file) "configuration file is not a valid JSON file. "
"Please check network or file content here: {}.".format(config_file, resolved_config_file)
)
raise EnvironmentError(msg) raise EnvironmentError(msg)
if resolved_config_file == config_file: if resolved_config_file == config_file:
logger.info("loading configuration file {}".format(config_file)) logger.info("loading configuration file {}".format(config_file))
else: else:
logger.info("loading configuration file {} from cache at {}".format( logger.info("loading configuration file {} from cache at {}".format(config_file, resolved_config_file))
config_file, resolved_config_file))
if hasattr(config, 'pruned_heads'): if hasattr(config, "pruned_heads"):
config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items()) config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items())
# Update config with kwargs if needed # Update config with kwargs if needed
...@@ -226,7 +239,7 @@ class PretrainedConfig(object): ...@@ -226,7 +239,7 @@ class PretrainedConfig(object):
@classmethod @classmethod
def from_json_file(cls, json_file): def from_json_file(cls, json_file):
"""Constructs a `Config` from a json file of parameters.""" """Constructs a `Config` from a json file of parameters."""
with open(json_file, "r", encoding='utf-8') as reader: with open(json_file, "r", encoding="utf-8") as reader:
text = reader.read() text = reader.read()
dict_obj = json.loads(text) dict_obj = json.loads(text)
return cls(**dict_obj) return cls(**dict_obj)
...@@ -248,5 +261,5 @@ class PretrainedConfig(object): ...@@ -248,5 +261,5 @@ class PretrainedConfig(object):
def to_json_file(self, json_file_path): def to_json_file(self, json_file_path):
""" Save this instance to a json file.""" """ Save this instance to a json file."""
with open(json_file_path, "w", encoding='utf-8') as writer: with open(json_file_path, "w", encoding="utf-8") as writer:
writer.write(self.to_json_string()) writer.write(self.to_json_string())
...@@ -25,16 +25,16 @@ from .configuration_utils import PretrainedConfig ...@@ -25,16 +25,16 @@ from .configuration_utils import PretrainedConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP = { XLM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json", "xlm-mlm-en-2048": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json",
'xlm-mlm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-config.json", "xlm-mlm-ende-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-config.json",
'xlm-mlm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-config.json", "xlm-mlm-enfr-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-config.json",
'xlm-mlm-enro-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-config.json", "xlm-mlm-enro-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-config.json",
'xlm-mlm-tlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-config.json", "xlm-mlm-tlm-xnli15-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-config.json",
'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-config.json", "xlm-mlm-xnli15-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-config.json",
'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-config.json", "xlm-clm-enfr-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-config.json",
'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-config.json", "xlm-clm-ende-1024": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-config.json",
'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-config.json", "xlm-mlm-17-1280": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-config.json",
'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-100-1280-config.json", "xlm-mlm-100-1280": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-100-1280-config.json",
} }
...@@ -78,41 +78,44 @@ class XLMConfig(PretrainedConfig): ...@@ -78,41 +78,44 @@ class XLMConfig(PretrainedConfig):
-1 means no clamping. -1 means no clamping.
same_length: bool, whether to use the same attention length for each token. same_length: bool, whether to use the same attention length for each token.
""" """
pretrained_config_archive_map = XLM_PRETRAINED_CONFIG_ARCHIVE_MAP pretrained_config_archive_map = XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self, def __init__(
vocab_size=30145, self,
emb_dim=2048, vocab_size=30145,
n_layers=12, emb_dim=2048,
n_heads=16, n_layers=12,
dropout=0.1, n_heads=16,
attention_dropout=0.1, dropout=0.1,
gelu_activation=True, attention_dropout=0.1,
sinusoidal_embeddings=False, gelu_activation=True,
causal=False, sinusoidal_embeddings=False,
asm=False, causal=False,
n_langs=1, asm=False,
use_lang_emb=True, n_langs=1,
max_position_embeddings=512, use_lang_emb=True,
embed_init_std=2048 ** -0.5, max_position_embeddings=512,
layer_norm_eps=1e-12, embed_init_std=2048 ** -0.5,
init_std=0.02, layer_norm_eps=1e-12,
bos_index=0, init_std=0.02,
eos_index=1, bos_index=0,
pad_index=2, eos_index=1,
unk_index=3, pad_index=2,
mask_index=5, unk_index=3,
is_encoder=True, mask_index=5,
summary_type='first', is_encoder=True,
summary_use_proj=True, summary_type="first",
summary_activation=None, summary_use_proj=True,
summary_proj_to_labels=True, summary_activation=None,
summary_first_dropout=0.1, summary_proj_to_labels=True,
start_n_top=5, summary_first_dropout=0.1,
end_n_top=5, start_n_top=5,
mask_token_id=0, end_n_top=5,
lang_id=0, mask_token_id=0,
**kwargs): lang_id=0,
**kwargs
):
"""Constructs XLMConfig. """Constructs XLMConfig.
""" """
super(XLMConfig, self).__init__(**kwargs) super(XLMConfig, self).__init__(**kwargs)
......
...@@ -15,8 +15,7 @@ ...@@ -15,8 +15,7 @@
# limitations under the License. # limitations under the License.
""" XLM-RoBERTa configuration """ """ XLM-RoBERTa configuration """
from __future__ import (absolute_import, division, print_function, from __future__ import absolute_import, division, print_function, unicode_literals
unicode_literals)
import logging import logging
...@@ -25,12 +24,12 @@ from .configuration_roberta import RobertaConfig ...@@ -25,12 +24,12 @@ from .configuration_roberta import RobertaConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = { XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'xlm-roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-base-config.json", "xlm-roberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-base-config.json",
'xlm-roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-config.json", "xlm-roberta-large": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-config.json",
'xlm-roberta-large-finetuned-conll02-dutch': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-dutch-config.json", "xlm-roberta-large-finetuned-conll02-dutch": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-dutch-config.json",
'xlm-roberta-large-finetuned-conll02-spanish': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-spanish-config.json", "xlm-roberta-large-finetuned-conll02-spanish": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-spanish-config.json",
'xlm-roberta-large-finetuned-conll03-english': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-english-config.json", "xlm-roberta-large-finetuned-conll03-english": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-english-config.json",
'xlm-roberta-large-finetuned-conll03-german': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-german-config.json", "xlm-roberta-large-finetuned-conll03-german": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-german-config.json",
} }
......
...@@ -26,8 +26,8 @@ from .configuration_utils import PretrainedConfig ...@@ -26,8 +26,8 @@ from .configuration_utils import PretrainedConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP = { XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'xlnet-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-config.json", "xlnet-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-config.json",
'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-config.json", "xlnet-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-config.json",
} }
...@@ -69,32 +69,35 @@ class XLNetConfig(PretrainedConfig): ...@@ -69,32 +69,35 @@ class XLNetConfig(PretrainedConfig):
same_length: bool, whether to use the same attention length for each token. same_length: bool, whether to use the same attention length for each token.
finetuning_task: name of the glue task on which the model was fine-tuned if any finetuning_task: name of the glue task on which the model was fine-tuned if any
""" """
pretrained_config_archive_map = XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP pretrained_config_archive_map = XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self, def __init__(
vocab_size=32000, self,
d_model=1024, vocab_size=32000,
n_layer=24, d_model=1024,
n_head=16, n_layer=24,
d_inner=4096, n_head=16,
ff_activation="gelu", d_inner=4096,
untie_r=True, ff_activation="gelu",
attn_type="bi", untie_r=True,
initializer_range=0.02, attn_type="bi",
layer_norm_eps=1e-12, initializer_range=0.02,
dropout=0.1, layer_norm_eps=1e-12,
mem_len=None, dropout=0.1,
reuse_len=None, mem_len=None,
bi_data=False, reuse_len=None,
clamp_len=-1, bi_data=False,
same_length=False, clamp_len=-1,
summary_type='last', same_length=False,
summary_use_proj=True, summary_type="last",
summary_activation='tanh', summary_use_proj=True,
summary_last_dropout=0.1, summary_activation="tanh",
start_n_top=5, summary_last_dropout=0.1,
end_n_top=5, start_n_top=5,
**kwargs): end_n_top=5,
**kwargs
):
"""Constructs XLNetConfig. """Constructs XLNetConfig.
""" """
super(XLNetConfig, self).__init__(**kwargs) super(XLNetConfig, self).__init__(**kwargs)
......
...@@ -24,6 +24,7 @@ import torch ...@@ -24,6 +24,7 @@ import torch
from transformers import AlbertConfig, AlbertForMaskedLM, load_tf_weights_in_albert from transformers import AlbertConfig, AlbertForMaskedLM, load_tf_weights_in_albert
import logging import logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
...@@ -44,24 +45,19 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, albert_config_file, pyt ...@@ -44,24 +45,19 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, albert_config_file, pyt
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters ## Required parameters
parser.add_argument("--tf_checkpoint_path", parser.add_argument(
default = None, "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
type = str, )
required = True, parser.add_argument(
help = "Path to the TensorFlow checkpoint path.") "--albert_config_file",
parser.add_argument("--albert_config_file", default=None,
default = None, type=str,
type = str, required=True,
required = True, help="The config json file corresponding to the pre-trained ALBERT model. \n"
help = "The config json file corresponding to the pre-trained ALBERT model. \n" "This specifies the model architecture.",
"This specifies the model architecture.") )
parser.add_argument("--pytorch_dump_path", parser.add_argument(
default = None, "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
type = str, )
required = True,
help = "Path to the output PyTorch model.")
args = parser.parse_args() args = parser.parse_args()
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.albert_config_file, args.pytorch_dump_path)
args.albert_config_file,
args.pytorch_dump_path)
\ No newline at end of file
...@@ -24,8 +24,10 @@ import torch ...@@ -24,8 +24,10 @@ import torch
from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
import logging import logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
# Initialise PyTorch model # Initialise PyTorch model
config = BertConfig.from_json_file(bert_config_file) config = BertConfig.from_json_file(bert_config_file)
...@@ -43,23 +45,19 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor ...@@ -43,23 +45,19 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters ## Required parameters
parser.add_argument("--tf_checkpoint_path", parser.add_argument(
default = None, "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
type = str, )
required = True, parser.add_argument(
help = "Path to the TensorFlow checkpoint path.") "--bert_config_file",
parser.add_argument("--bert_config_file", default=None,
default = None, type=str,
type = str, required=True,
required = True, help="The config json file corresponding to the pre-trained BERT model. \n"
help = "The config json file corresponding to the pre-trained BERT model. \n" "This specifies the model architecture.",
"This specifies the model architecture.") )
parser.add_argument("--pytorch_dump_path", parser.add_argument(
default = None, "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
type = str, )
required = True,
help = "Path to the output PyTorch model.")
args = parser.parse_args() args = parser.parse_args()
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)
args.bert_config_file,
args.pytorch_dump_path)
...@@ -23,7 +23,7 @@ import tensorflow as tf ...@@ -23,7 +23,7 @@ import tensorflow as tf
from transformers import BertModel from transformers import BertModel
def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:str): def convert_pytorch_checkpoint_to_tf(model: BertModel, ckpt_dir: str, model_name: str):
""" """
:param model:BertModel Pytorch model instance to be converted :param model:BertModel Pytorch model instance to be converted
...@@ -41,22 +41,17 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:s ...@@ -41,22 +41,17 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:s
N BertForQuestionAnswering N BertForQuestionAnswering
""" """
tensors_to_transpose = ( tensors_to_transpose = ("dense.weight", "attention.self.query", "attention.self.key", "attention.self.value")
"dense.weight",
"attention.self.query",
"attention.self.key",
"attention.self.value"
)
var_map = ( var_map = (
('layer.', 'layer_'), ("layer.", "layer_"),
('word_embeddings.weight', 'word_embeddings'), ("word_embeddings.weight", "word_embeddings"),
('position_embeddings.weight', 'position_embeddings'), ("position_embeddings.weight", "position_embeddings"),
('token_type_embeddings.weight', 'token_type_embeddings'), ("token_type_embeddings.weight", "token_type_embeddings"),
('.', '/'), (".", "/"),
('LayerNorm/weight', 'LayerNorm/gamma'), ("LayerNorm/weight", "LayerNorm/gamma"),
('LayerNorm/bias', 'LayerNorm/beta'), ("LayerNorm/bias", "LayerNorm/beta"),
('weight', 'kernel') ("weight", "kernel"),
) )
if not os.path.isdir(ckpt_dir): if not os.path.isdir(ckpt_dir):
...@@ -64,12 +59,12 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:s ...@@ -64,12 +59,12 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:s
state_dict = model.state_dict() state_dict = model.state_dict()
def to_tf_var_name(name:str): def to_tf_var_name(name: str):
for patt, repl in iter(var_map): for patt, repl in iter(var_map):
name = name.replace(patt, repl) name = name.replace(patt, repl)
return 'bert/{}'.format(name) return "bert/{}".format(name)
def create_tf_var(tensor:np.ndarray, name:str, session:tf.Session): def create_tf_var(tensor: np.ndarray, name: str, session: tf.Session):
tf_dtype = tf.dtypes.as_dtype(tensor.dtype) tf_dtype = tf.dtypes.as_dtype(tensor.dtype)
tf_var = tf.get_variable(dtype=tf_dtype, shape=tensor.shape, name=name, initializer=tf.zeros_initializer()) tf_var = tf.get_variable(dtype=tf_dtype, shape=tensor.shape, name=name, initializer=tf.zeros_initializer())
session.run(tf.variables_initializer([tf_var])) session.run(tf.variables_initializer([tf_var]))
...@@ -94,37 +89,22 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:s ...@@ -94,37 +89,22 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:s
def main(raw_args=None): def main(raw_args=None):
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model_name", parser.add_argument("--model_name", type=str, required=True, help="model name e.g. bert-base-uncased")
type=str, parser.add_argument(
required=True, "--cache_dir", type=str, default=None, required=False, help="Directory containing pytorch model"
help="model name e.g. bert-base-uncased") )
parser.add_argument("--cache_dir", parser.add_argument("--pytorch_model_path", type=str, required=True, help="/path/to/<pytorch-model-name>.bin")
type=str, parser.add_argument("--tf_cache_dir", type=str, required=True, help="Directory in which to save tensorflow model")
default=None,
required=False,
help="Directory containing pytorch model")
parser.add_argument("--pytorch_model_path",
type=str,
required=True,
help="/path/to/<pytorch-model-name>.bin")
parser.add_argument("--tf_cache_dir",
type=str,
required=True,
help="Directory in which to save tensorflow model")
args = parser.parse_args(raw_args) args = parser.parse_args(raw_args)
model = BertModel.from_pretrained( model = BertModel.from_pretrained(
pretrained_model_name_or_path=args.model_name, pretrained_model_name_or_path=args.model_name,
state_dict=torch.load(args.pytorch_model_path), state_dict=torch.load(args.pytorch_model_path),
cache_dir=args.cache_dir cache_dir=args.cache_dir,
)
convert_pytorch_checkpoint_to_tf(
model=model,
ckpt_dir=args.tf_cache_dir,
model_name=args.model_name
) )
convert_pytorch_checkpoint_to_tf(model=model, ckpt_dir=args.tf_cache_dir, model_name=args.model_name)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
...@@ -21,12 +21,10 @@ from io import open ...@@ -21,12 +21,10 @@ from io import open
import torch import torch
from transformers import (CONFIG_NAME, WEIGHTS_NAME, from transformers import CONFIG_NAME, WEIGHTS_NAME, GPT2Config, GPT2Model, load_tf_weights_in_gpt2
GPT2Config,
GPT2Model,
load_tf_weights_in_gpt2)
import logging import logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
...@@ -42,8 +40,8 @@ def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, p ...@@ -42,8 +40,8 @@ def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, p
load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path) load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path)
# Save pytorch-model # Save pytorch-model
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME
pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
torch.save(model.state_dict(), pytorch_weights_dump_path) torch.save(model.state_dict(), pytorch_weights_dump_path)
print("Save configuration file to {}".format(pytorch_config_dump_path)) print("Save configuration file to {}".format(pytorch_config_dump_path))
...@@ -54,22 +52,18 @@ def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, p ...@@ -54,22 +52,18 @@ def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, p
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters ## Required parameters
parser.add_argument("--gpt2_checkpoint_path", parser.add_argument(
default = None, "--gpt2_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
type = str, )
required = True, parser.add_argument(
help = "Path to the TensorFlow checkpoint path.") "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
parser.add_argument("--pytorch_dump_folder_path", )
default = None, parser.add_argument(
type = str, "--gpt2_config_file",
required = True, default="",
help = "Path to the output PyTorch model.") type=str,
parser.add_argument("--gpt2_config_file", help="An optional config json file corresponding to the pre-trained OpenAI model. \n"
default = "", "This specifies the model architecture.",
type = str, )
help = "An optional config json file corresponding to the pre-trained OpenAI model. \n"
"This specifies the model architecture.")
args = parser.parse_args() args = parser.parse_args()
convert_gpt2_checkpoint_to_pytorch(args.gpt2_checkpoint_path, convert_gpt2_checkpoint_to_pytorch(args.gpt2_checkpoint_path, args.gpt2_config_file, args.pytorch_dump_folder_path)
args.gpt2_config_file,
args.pytorch_dump_folder_path)
...@@ -21,12 +21,10 @@ from io import open ...@@ -21,12 +21,10 @@ from io import open
import torch import torch
from transformers import (CONFIG_NAME, WEIGHTS_NAME, from transformers import CONFIG_NAME, WEIGHTS_NAME, OpenAIGPTConfig, OpenAIGPTModel, load_tf_weights_in_openai_gpt
OpenAIGPTConfig,
OpenAIGPTModel,
load_tf_weights_in_openai_gpt)
import logging import logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
...@@ -42,8 +40,8 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c ...@@ -42,8 +40,8 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c
load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path) load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path)
# Save pytorch-model # Save pytorch-model
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME
pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
torch.save(model.state_dict(), pytorch_weights_dump_path) torch.save(model.state_dict(), pytorch_weights_dump_path)
print("Save configuration file to {}".format(pytorch_config_dump_path)) print("Save configuration file to {}".format(pytorch_config_dump_path))
...@@ -54,22 +52,24 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c ...@@ -54,22 +52,24 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters ## Required parameters
parser.add_argument("--openai_checkpoint_folder_path", parser.add_argument(
default = None, "--openai_checkpoint_folder_path",
type = str, default=None,
required = True, type=str,
help = "Path to the TensorFlow checkpoint path.") required=True,
parser.add_argument("--pytorch_dump_folder_path", help="Path to the TensorFlow checkpoint path.",
default = None, )
type = str, parser.add_argument(
required = True, "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
help = "Path to the output PyTorch model.") )
parser.add_argument("--openai_config_file", parser.add_argument(
default = "", "--openai_config_file",
type = str, default="",
help = "An optional config json file corresponding to the pre-trained OpenAI model. \n" type=str,
"This specifies the model architecture.") help="An optional config json file corresponding to the pre-trained OpenAI model. \n"
"This specifies the model architecture.",
)
args = parser.parse_args() args = parser.parse_args()
convert_openai_checkpoint_to_pytorch(args.openai_checkpoint_folder_path, convert_openai_checkpoint_to_pytorch(
args.openai_config_file, args.openai_checkpoint_folder_path, args.openai_config_file, args.pytorch_dump_folder_path
args.pytorch_dump_folder_path) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment