Unverified Commit 07681b6b authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #1064 from huggingface/gpt-2-large

Adding gpt-2 large (774M parameters) model
parents e4515faf fdc487d8
...@@ -35,7 +35,7 @@ def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, p ...@@ -35,7 +35,7 @@ def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, p
if gpt2_config_file == "": if gpt2_config_file == "":
config = GPT2Config() config = GPT2Config()
else: else:
config = GPT2Config(gpt2_config_file) config = GPT2Config.from_json_file(gpt2_config_file)
model = GPT2Model(config) model = GPT2Model(config)
# Load weights from numpy # Load weights from numpy
......
...@@ -35,7 +35,7 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c ...@@ -35,7 +35,7 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c
if openai_config_file == "": if openai_config_file == "":
config = OpenAIGPTConfig() config = OpenAIGPTConfig()
else: else:
config = OpenAIGPTConfig(openai_config_file) config = OpenAIGPTConfig.from_json_file(openai_config_file)
model = OpenAIGPTModel(config) model = OpenAIGPTModel(config)
# Load weights from numpy # Load weights from numpy
......
...@@ -75,7 +75,7 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path, ...@@ -75,7 +75,7 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
if transfo_xl_config_file == "": if transfo_xl_config_file == "":
config = TransfoXLConfig() config = TransfoXLConfig()
else: else:
config = TransfoXLConfig(transfo_xl_config_file) config = TransfoXLConfig.from_json_file(transfo_xl_config_file)
print("Building PyTorch model from configuration: {}".format(str(config))) print("Building PyTorch model from configuration: {}".format(str(config)))
model = TransfoXLLMHeadModel(config) model = TransfoXLLMHeadModel(config)
......
...@@ -38,9 +38,11 @@ from .modeling_bert import BertLayerNorm as LayerNorm ...@@ -38,9 +38,11 @@ from .modeling_bert import BertLayerNorm as LayerNorm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin", GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin",
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin"} "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin",
"gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-pytorch_model.bin"}
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json", GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json",
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json"} "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json",
"gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-config.json"}
def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
""" Load tf checkpoints in a pytorch model """ Load tf checkpoints in a pytorch model
......
...@@ -45,17 +45,20 @@ PRETRAINED_VOCAB_FILES_MAP = { ...@@ -45,17 +45,20 @@ PRETRAINED_VOCAB_FILES_MAP = {
{ {
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json", 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json", 'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json",
'gpt2-large': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-vocab.json",
}, },
'merges_file': 'merges_file':
{ {
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt", 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt", 'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt",
'gpt2-large': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-merges.txt",
}, },
} }
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'gpt2': 1024, 'gpt2': 1024,
'gpt2-medium': 1024, 'gpt2-medium': 1024,
'gpt2-large': 1024,
} }
@lru_cache() @lru_cache()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment