Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9c3c2480
Commit
9c3c2480
authored
Feb 07, 2019
by
thomwolf
Browse files
split saved model in config & weights
parent
2df41663
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
27 additions
and
27 deletions
+27
-27
pytorch_pretrained_bert/modeling_openai.py
pytorch_pretrained_bert/modeling_openai.py
+27
-27
No files found.
pytorch_pretrained_bert/modeling_openai.py
View file @
9c3c2480
...
@@ -37,7 +37,9 @@ from .modeling import BertLayerNorm as LayerNorm
...
@@ -37,7 +37,9 @@ from .modeling import BertLayerNorm as LayerNorm
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
PRETRAINED_MODEL_ARCHIVE_MAP
=
{
"openai-gpt"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt.tar.gz"
}
PRETRAINED_MODEL_ARCHIVE_MAP
=
{
"openai-gpt"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-pytorch_model.bin"
}
PRETRAINED_CONFIG_ARCHIVE_MAP
=
{
"openai-gpt"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-openai_gpt_config.json"
}
CONFIG_NAME
=
"openai_gpt_config.json"
CONFIG_NAME
=
"openai_gpt_config.json"
WEIGHTS_NAME
=
"pytorch_model.bin"
WEIGHTS_NAME
=
"pytorch_model.bin"
...
@@ -440,49 +442,42 @@ class OpenAIGPTPreTrainedModel(nn.Module):
...
@@ -440,49 +442,42 @@ class OpenAIGPTPreTrainedModel(nn.Module):
"""
"""
if
pretrained_model_name
in
PRETRAINED_MODEL_ARCHIVE_MAP
:
if
pretrained_model_name
in
PRETRAINED_MODEL_ARCHIVE_MAP
:
archive_file
=
PRETRAINED_MODEL_ARCHIVE_MAP
[
pretrained_model_name
]
archive_file
=
PRETRAINED_MODEL_ARCHIVE_MAP
[
pretrained_model_name
]
config_file
=
PRETRAINED_CONFIG_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
else
:
else
:
archive_file
=
pretrained_model_name
archive_file
=
pretrained_model_name
config_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
CONFIG_NAME
)
# redirect to the cache, if necessary
# redirect to the cache, if necessary
try
:
try
:
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
)
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
)
resolved_config_file
=
cached_path
(
config_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
except
EnvironmentError
:
logger
.
error
(
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"associated to this path or url."
.
format
(
"at this path or url."
.
format
(
pretrained_model_name
,
", "
.
join
(
PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
()),
archive_file
pretrained_model_name
,
", "
.
join
(
PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
()),
pretrained_model_name_or_path
,
archive_file
,
config_file
)
)
)
)
return
None
return
None
if
resolved_archive_file
==
archive_file
:
if
resolved_archive_file
==
archive_file
and
resolved_config_file
==
config_file
:
logger
.
info
(
"loading archive file {}"
.
format
(
archive_file
))
logger
.
info
(
"loading weights file {}"
.
format
(
archive_file
))
else
:
logger
.
info
(
"loading configuration file {}"
.
format
(
config_file
))
logger
.
info
(
"loading archive file {} from cache at {}"
.
format
(
archive_file
,
resolved_archive_file
))
tempdir
=
None
if
os
.
path
.
isdir
(
resolved_archive_file
):
serialization_dir
=
resolved_archive_file
else
:
else
:
# Extract archive to temp dir
logger
.
info
(
"loading weights file {} from cache at {}"
.
format
(
tempdir
=
tempfile
.
mkdtemp
()
archive_file
,
resolved_archive_file
))
logger
.
info
(
"extracting archive file {} to temp dir {}"
.
format
(
resolved_archive_file
,
tempdir
))
logger
.
info
(
"loading configuration file {} from cache at {}"
.
format
(
with
tarfile
.
open
(
resolved_archive_file
,
"r:gz"
)
as
archive
:
config_file
,
resolved_config_file
))
archive
.
extractall
(
tempdir
)
serialization_dir
=
tempdir
# Load config
# Load config
config_file
=
os
.
path
.
join
(
serialization_dir
,
CONFIG_NAME
)
config
=
OpenAIGPTConfig
.
from_json_file
(
resolved_config_file
)
config
=
OpenAIGPTConfig
.
from_json_file
(
config_file
)
logger
.
info
(
"Model config {}"
.
format
(
config
))
logger
.
info
(
"Model config {}"
.
format
(
config
))
# Instantiate model.
# Instantiate model.
model
=
cls
(
config
,
*
inputs
,
**
kwargs
)
model
=
cls
(
config
,
*
inputs
,
**
kwargs
)
if
state_dict
is
None
and
not
from_tf
:
if
state_dict
is
None
and
not
from_tf
:
weights_path
=
os
.
path
.
join
(
serialization_dir
,
WEIGHTS_NAME
)
state_dict
=
torch
.
load
(
resolved_archive_file
,
map_location
=
'cpu'
if
not
torch
.
cuda
.
is_available
()
else
None
)
state_dict
=
torch
.
load
(
weights_path
,
map_location
=
'cpu'
if
not
torch
.
cuda
.
is_available
()
else
None
)
if
tempdir
:
# Clean up temp dir
shutil
.
rmtree
(
tempdir
)
if
from_tf
:
if
from_tf
:
# Directly load from a TensorFlow checkpoint (stored as NumPy array)
# Directly load from a TensorFlow checkpoint (stored as NumPy array)
return
load_tf_weights_in_openai_gpt
(
model
,
serialization_dir
)
return
load_tf_weights_in_openai_gpt
(
model
,
resolved_archive_file
)
old_keys
=
[]
old_keys
=
[]
new_keys
=
[]
new_keys
=
[]
...
@@ -535,6 +530,7 @@ class OpenAIGPTPreTrainedModel(nn.Module):
...
@@ -535,6 +530,7 @@ class OpenAIGPTPreTrainedModel(nn.Module):
raise
RuntimeError
(
raise
RuntimeError
(
"Error(s) in loading state_dict for {}:
\n\t
{}"
.
format
(
model
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
))
"Error(s) in loading state_dict for {}:
\n\t
{}"
.
format
(
model
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
))
)
)
# Add additional embeddings for special tokens if needed
# Add additional embeddings for special tokens if needed
# This step also make sure we are still sharing the output and input embeddings after loading weights
# This step also make sure we are still sharing the output and input embeddings after loading weights
model
.
set_num_special_tokens
(
num_special_tokens
if
num_special_tokens
is
not
None
else
config
.
n_special
)
model
.
set_num_special_tokens
(
num_special_tokens
if
num_special_tokens
is
not
None
else
config
.
n_special
)
...
@@ -711,7 +707,9 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
...
@@ -711,7 +707,9 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
def
set_num_special_tokens
(
self
,
num_special_tokens
):
def
set_num_special_tokens
(
self
,
num_special_tokens
):
" Update input and output embeddings with new embedding matrice "
""" Update input and output embeddings with new embedding matrice
Make sure we are sharing the embeddings
"""
self
.
transformer
.
set_num_special_tokens
(
num_special_tokens
)
self
.
transformer
.
set_num_special_tokens
(
num_special_tokens
)
self
.
lm_head
.
set_embeddings_weights
(
self
.
transformer
.
tokens_embed
.
weight
)
self
.
lm_head
.
set_embeddings_weights
(
self
.
transformer
.
tokens_embed
.
weight
)
...
@@ -792,7 +790,9 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
...
@@ -792,7 +790,9 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
def
set_num_special_tokens
(
self
,
num_special_tokens
):
def
set_num_special_tokens
(
self
,
num_special_tokens
):
" Update input and output embeddings with new embedding matrice "
""" Update input and output embeddings with new embedding matrice
Make sure we are sharing the embeddings
"""
self
.
transformer
.
set_num_special_tokens
(
num_special_tokens
)
self
.
transformer
.
set_num_special_tokens
(
num_special_tokens
)
self
.
lm_head
.
set_embeddings_weights
(
self
.
transformer
.
tokens_embed
.
weight
)
self
.
lm_head
.
set_embeddings_weights
(
self
.
transformer
.
tokens_embed
.
weight
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment