Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9c3c2480
Commit
9c3c2480
authored
Feb 07, 2019
by
thomwolf
Browse files
split saved model in config & weights
parent
2df41663
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
27 additions
and
27 deletions
+27
-27
pytorch_pretrained_bert/modeling_openai.py
pytorch_pretrained_bert/modeling_openai.py
+27
-27
No files found.
pytorch_pretrained_bert/modeling_openai.py
View file @
9c3c2480
...
...
@@ -37,7 +37,9 @@ from .modeling import BertLayerNorm as LayerNorm
logger
=
logging
.
getLogger
(
__name__
)
PRETRAINED_MODEL_ARCHIVE_MAP
=
{
"openai-gpt"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt.tar.gz"
}
PRETRAINED_MODEL_ARCHIVE_MAP
=
{
"openai-gpt"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-pytorch_model.bin"
}
PRETRAINED_CONFIG_ARCHIVE_MAP
=
{
"openai-gpt"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-openai_gpt_config.json"
}
CONFIG_NAME
=
"openai_gpt_config.json"
WEIGHTS_NAME
=
"pytorch_model.bin"
...
...
@@ -440,49 +442,42 @@ class OpenAIGPTPreTrainedModel(nn.Module):
"""
if
pretrained_model_name
in
PRETRAINED_MODEL_ARCHIVE_MAP
:
archive_file
=
PRETRAINED_MODEL_ARCHIVE_MAP
[
pretrained_model_name
]
config_file
=
PRETRAINED_CONFIG_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
else
:
archive_file
=
pretrained_model_name
config_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
CONFIG_NAME
)
# redirect to the cache, if necessary
try
:
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
)
resolved_config_file
=
cached_path
(
config_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url."
.
format
(
pretrained_model_name
,
", "
.
join
(
PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
()),
archive_file
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"at this path or url."
.
format
(
pretrained_model_name
,
", "
.
join
(
PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
()),
pretrained_model_name_or_path
,
archive_file
,
config_file
)
)
return
None
if
resolved_archive_file
==
archive_file
:
logger
.
info
(
"loading archive file {}"
.
format
(
archive_file
))
else
:
logger
.
info
(
"loading archive file {} from cache at {}"
.
format
(
archive_file
,
resolved_archive_file
))
tempdir
=
None
if
os
.
path
.
isdir
(
resolved_archive_file
):
serialization_dir
=
resolved_archive_file
if
resolved_archive_file
==
archive_file
and
resolved_config_file
==
config_file
:
logger
.
info
(
"loading weights file {}"
.
format
(
archive_file
))
logger
.
info
(
"loading configuration file {}"
.
format
(
config_file
))
else
:
# Extract archive to temp dir
tempdir
=
tempfile
.
mkdtemp
()
logger
.
info
(
"extracting archive file {} to temp dir {}"
.
format
(
resolved_archive_file
,
tempdir
))
with
tarfile
.
open
(
resolved_archive_file
,
"r:gz"
)
as
archive
:
archive
.
extractall
(
tempdir
)
serialization_dir
=
tempdir
logger
.
info
(
"loading weights file {} from cache at {}"
.
format
(
archive_file
,
resolved_archive_file
))
logger
.
info
(
"loading configuration file {} from cache at {}"
.
format
(
config_file
,
resolved_config_file
))
# Load config
config_file
=
os
.
path
.
join
(
serialization_dir
,
CONFIG_NAME
)
config
=
OpenAIGPTConfig
.
from_json_file
(
config_file
)
config
=
OpenAIGPTConfig
.
from_json_file
(
resolved_config_file
)
logger
.
info
(
"Model config {}"
.
format
(
config
))
# Instantiate model.
model
=
cls
(
config
,
*
inputs
,
**
kwargs
)
if
state_dict
is
None
and
not
from_tf
:
weights_path
=
os
.
path
.
join
(
serialization_dir
,
WEIGHTS_NAME
)
state_dict
=
torch
.
load
(
weights_path
,
map_location
=
'cpu'
if
not
torch
.
cuda
.
is_available
()
else
None
)
if
tempdir
:
# Clean up temp dir
shutil
.
rmtree
(
tempdir
)
state_dict
=
torch
.
load
(
resolved_archive_file
,
map_location
=
'cpu'
if
not
torch
.
cuda
.
is_available
()
else
None
)
if
from_tf
:
# Directly load from a TensorFlow checkpoint (stored as NumPy array)
return
load_tf_weights_in_openai_gpt
(
model
,
serialization_dir
)
return
load_tf_weights_in_openai_gpt
(
model
,
resolved_archive_file
)
old_keys
=
[]
new_keys
=
[]
...
...
@@ -535,6 +530,7 @@ class OpenAIGPTPreTrainedModel(nn.Module):
raise
RuntimeError
(
"Error(s) in loading state_dict for {}:
\n\t
{}"
.
format
(
model
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
))
)
# Add additional embeddings for special tokens if needed
# This step also make sure we are still sharing the output and input embeddings after loading weights
model
.
set_num_special_tokens
(
num_special_tokens
if
num_special_tokens
is
not
None
else
config
.
n_special
)
...
...
@@ -711,7 +707,9 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
self
.
apply
(
self
.
init_weights
)
def
set_num_special_tokens
(
self
,
num_special_tokens
):
" Update input and output embeddings with new embedding matrice "
""" Update input and output embeddings with new embedding matrice
Make sure we are sharing the embeddings
"""
self
.
transformer
.
set_num_special_tokens
(
num_special_tokens
)
self
.
lm_head
.
set_embeddings_weights
(
self
.
transformer
.
tokens_embed
.
weight
)
...
...
@@ -792,7 +790,9 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
self
.
apply
(
self
.
init_weights
)
def
set_num_special_tokens
(
self
,
num_special_tokens
):
" Update input and output embeddings with new embedding matrice "
""" Update input and output embeddings with new embedding matrice
Make sure we are sharing the embeddings
"""
self
.
transformer
.
set_num_special_tokens
(
num_special_tokens
)
self
.
lm_head
.
set_embeddings_weights
(
self
.
transformer
.
tokens_embed
.
weight
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment