Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9c3c2480
Commit
9c3c2480
authored
Feb 07, 2019
by
thomwolf
Browse files
split saved model in config & weights
parent
2df41663
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
27 additions
and
27 deletions
+27
-27
pytorch_pretrained_bert/modeling_openai.py
pytorch_pretrained_bert/modeling_openai.py
+27
-27
No files found.
pytorch_pretrained_bert/modeling_openai.py
View file @
9c3c2480
...
@@ -37,7 +37,9 @@ from .modeling import BertLayerNorm as LayerNorm
...
@@ -37,7 +37,9 @@ from .modeling import BertLayerNorm as LayerNorm
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
PRETRAINED_MODEL_ARCHIVE_MAP
=
{
"openai-gpt"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt.tar.gz"
}
PRETRAINED_MODEL_ARCHIVE_MAP
=
{
"openai-gpt"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-pytorch_model.bin"
}
PRETRAINED_CONFIG_ARCHIVE_MAP
=
{
"openai-gpt"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-openai_gpt_config.json"
}
CONFIG_NAME
=
"openai_gpt_config.json"
CONFIG_NAME
=
"openai_gpt_config.json"
WEIGHTS_NAME
=
"pytorch_model.bin"
WEIGHTS_NAME
=
"pytorch_model.bin"
...
@@ -440,49 +442,42 @@ class OpenAIGPTPreTrainedModel(nn.Module):
...
@@ -440,49 +442,42 @@ class OpenAIGPTPreTrainedModel(nn.Module):
"""
"""
if
pretrained_model_name
in
PRETRAINED_MODEL_ARCHIVE_MAP
:
if
pretrained_model_name
in
PRETRAINED_MODEL_ARCHIVE_MAP
:
archive_file
=
PRETRAINED_MODEL_ARCHIVE_MAP
[
pretrained_model_name
]
archive_file
=
PRETRAINED_MODEL_ARCHIVE_MAP
[
pretrained_model_name
]
config_file
=
PRETRAINED_CONFIG_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
else
:
else
:
archive_file
=
pretrained_model_name
archive_file
=
pretrained_model_name
config_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
CONFIG_NAME
)
# redirect to the cache, if necessary
# redirect to the cache, if necessary
try
:
try
:
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
)
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
)
resolved_config_file
=
cached_path
(
config_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
except
EnvironmentError
:
logger
.
error
(
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"associated to this path or url."
.
format
(
"at this path or url."
.
format
(
pretrained_model_name
,
", "
.
join
(
PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
()),
archive_file
pretrained_model_name
,
", "
.
join
(
PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
()),
pretrained_model_name_or_path
,
archive_file
,
config_file
)
)
)
)
return
None
return
None
if
resolved_archive_file
==
archive_file
:
if
resolved_archive_file
==
archive_file
and
resolved_config_file
==
config_file
:
logger
.
info
(
"loading archive file {}"
.
format
(
archive_file
))
logger
.
info
(
"loading weights file {}"
.
format
(
archive_file
))
else
:
logger
.
info
(
"loading configuration file {}"
.
format
(
config_file
))
logger
.
info
(
"loading archive file {} from cache at {}"
.
format
(
archive_file
,
resolved_archive_file
))
tempdir
=
None
if
os
.
path
.
isdir
(
resolved_archive_file
):
serialization_dir
=
resolved_archive_file
else
:
else
:
# Extract archive to temp dir
logger
.
info
(
"loading weights file {} from cache at {}"
.
format
(
tempdir
=
tempfile
.
mkdtemp
()
archive_file
,
resolved_archive_file
))
logger
.
info
(
"extracting archive file {} to temp dir {}"
.
format
(
resolved_archive_file
,
tempdir
))
logger
.
info
(
"loading configuration file {} from cache at {}"
.
format
(
with
tarfile
.
open
(
resolved_archive_file
,
"r:gz"
)
as
archive
:
config_file
,
resolved_config_file
))
archive
.
extractall
(
tempdir
)
serialization_dir
=
tempdir
# Load config
# Load config
config_file
=
os
.
path
.
join
(
serialization_dir
,
CONFIG_NAME
)
config
=
OpenAIGPTConfig
.
from_json_file
(
resolved_config_file
)
config
=
OpenAIGPTConfig
.
from_json_file
(
config_file
)
logger
.
info
(
"Model config {}"
.
format
(
config
))
logger
.
info
(
"Model config {}"
.
format
(
config
))
# Instantiate model.
# Instantiate model.
model
=
cls
(
config
,
*
inputs
,
**
kwargs
)
model
=
cls
(
config
,
*
inputs
,
**
kwargs
)
if
state_dict
is
None
and
not
from_tf
:
if
state_dict
is
None
and
not
from_tf
:
weights_path
=
os
.
path
.
join
(
serialization_dir
,
WEIGHTS_NAME
)
state_dict
=
torch
.
load
(
resolved_archive_file
,
map_location
=
'cpu'
if
not
torch
.
cuda
.
is_available
()
else
None
)
state_dict
=
torch
.
load
(
weights_path
,
map_location
=
'cpu'
if
not
torch
.
cuda
.
is_available
()
else
None
)
if
tempdir
:
# Clean up temp dir
shutil
.
rmtree
(
tempdir
)
if
from_tf
:
if
from_tf
:
# Directly load from a TensorFlow checkpoint (stored as NumPy array)
# Directly load from a TensorFlow checkpoint (stored as NumPy array)
return
load_tf_weights_in_openai_gpt
(
model
,
serialization_dir
)
return
load_tf_weights_in_openai_gpt
(
model
,
resolved_archive_file
)
old_keys
=
[]
old_keys
=
[]
new_keys
=
[]
new_keys
=
[]
...
@@ -535,6 +530,7 @@ class OpenAIGPTPreTrainedModel(nn.Module):
...
@@ -535,6 +530,7 @@ class OpenAIGPTPreTrainedModel(nn.Module):
raise
RuntimeError
(
raise
RuntimeError
(
"Error(s) in loading state_dict for {}:
\n\t
{}"
.
format
(
model
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
))
"Error(s) in loading state_dict for {}:
\n\t
{}"
.
format
(
model
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
))
)
)
# Add additional embeddings for special tokens if needed
# Add additional embeddings for special tokens if needed
# This step also make sure we are still sharing the output and input embeddings after loading weights
# This step also make sure we are still sharing the output and input embeddings after loading weights
model
.
set_num_special_tokens
(
num_special_tokens
if
num_special_tokens
is
not
None
else
config
.
n_special
)
model
.
set_num_special_tokens
(
num_special_tokens
if
num_special_tokens
is
not
None
else
config
.
n_special
)
...
@@ -711,7 +707,9 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
...
@@ -711,7 +707,9 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
def
set_num_special_tokens
(
self
,
num_special_tokens
):
def
set_num_special_tokens
(
self
,
num_special_tokens
):
" Update input and output embeddings with new embedding matrice "
""" Update input and output embeddings with new embedding matrice
Make sure we are sharing the embeddings
"""
self
.
transformer
.
set_num_special_tokens
(
num_special_tokens
)
self
.
transformer
.
set_num_special_tokens
(
num_special_tokens
)
self
.
lm_head
.
set_embeddings_weights
(
self
.
transformer
.
tokens_embed
.
weight
)
self
.
lm_head
.
set_embeddings_weights
(
self
.
transformer
.
tokens_embed
.
weight
)
...
@@ -792,7 +790,9 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
...
@@ -792,7 +790,9 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
def
set_num_special_tokens
(
self
,
num_special_tokens
):
def
set_num_special_tokens
(
self
,
num_special_tokens
):
" Update input and output embeddings with new embedding matrice "
""" Update input and output embeddings with new embedding matrice
Make sure we are sharing the embeddings
"""
self
.
transformer
.
set_num_special_tokens
(
num_special_tokens
)
self
.
transformer
.
set_num_special_tokens
(
num_special_tokens
)
self
.
lm_head
.
set_embeddings_weights
(
self
.
transformer
.
tokens_embed
.
weight
)
self
.
lm_head
.
set_embeddings_weights
(
self
.
transformer
.
tokens_embed
.
weight
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment