Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
74429563
Commit
74429563
authored
Jul 12, 2019
by
thomwolf
Browse files
save config file
parent
292140b9
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
2 deletions
+14
-2
pytorch_transformers/modeling_utils.py
pytorch_transformers/modeling_utils.py
+14
-2
No files found.
pytorch_transformers/modeling_utils.py
View file @
74429563
...
@@ -48,6 +48,17 @@ class PretrainedConfig(object):
...
@@ -48,6 +48,17 @@ class PretrainedConfig(object):
self
.
output_hidden_states
=
kwargs
.
pop
(
'output_hidden_states'
,
False
)
self
.
output_hidden_states
=
kwargs
.
pop
(
'output_hidden_states'
,
False
)
self
.
torchscript
=
kwargs
.
pop
(
'torchscript'
,
False
)
self
.
torchscript
=
kwargs
.
pop
(
'torchscript'
,
False
)
def
save_pretrained
(
self
,
save_directory
):
""" Save a configuration file to a directory, so that it
can be re-loaded using the `from_pretrained(save_directory)` class method.
"""
assert
os
.
path
.
isdir
(
save_directory
),
"Saving path should be a directory where the model and configuration can be saved"
# If we save using the predefined names, we can load using `from_pretrained`
output_config_file
=
os
.
path
.
join
(
save_directory
,
CONFIG_NAME
)
self
.
to_json_file
(
output_config_file
)
@
classmethod
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
input
,
**
kwargs
):
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
input
,
**
kwargs
):
"""
"""
...
@@ -248,12 +259,13 @@ class PreTrainedModel(nn.Module):
...
@@ -248,12 +259,13 @@ class PreTrainedModel(nn.Module):
# Only save the model it-self if we are using distributed training
# Only save the model it-self if we are using distributed training
model_to_save
=
self
.
module
if
hasattr
(
self
,
'module'
)
else
self
model_to_save
=
self
.
module
if
hasattr
(
self
,
'module'
)
else
self
# Save configuration file
model_to_save
.
config
.
save_pretrained
(
save_directory
)
# If we save using the predefined names, we can load using `from_pretrained`
# If we save using the predefined names, we can load using `from_pretrained`
output_model_file
=
os
.
path
.
join
(
save_directory
,
WEIGHTS_NAME
)
output_model_file
=
os
.
path
.
join
(
save_directory
,
WEIGHTS_NAME
)
output_config_file
=
os
.
path
.
join
(
save_directory
,
CONFIG_NAME
)
torch
.
save
(
model_to_save
.
state_dict
(),
output_model_file
)
torch
.
save
(
model_to_save
.
state_dict
(),
output_model_file
)
model_to_save
.
config
.
to_json_file
(
output_config_file
)
@
classmethod
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
):
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment