Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ee261439
Commit
ee261439
authored
Sep 24, 2019
by
thomwolf
Browse files
add save_pretrained
parent
29bb3e4e
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
1 deletion
+9
-1
pytorch_transformers/modeling_tf_utils.py
pytorch_transformers/modeling_tf_utils.py
+9
-1
No files found.
pytorch_transformers/modeling_tf_utils.py
View file @
ee261439
...
...
@@ -125,7 +125,15 @@ class TFPreTrainedModel(tf.keras.Model):
""" Save a model and its configuration file to a directory, so that it
can be re-loaded using the `:func:`~pytorch_transformers.PreTrainedModel.from_pretrained`` class method.
"""
raise
NotImplementedError
assert
os
.
path
.
isdir
(
save_directory
),
"Saving path should be a directory where the model and configuration can be saved"
# Save configuration file
self
.
config
.
save_pretrained
(
save_directory
)
# If we save using the predefined names, we can load using `from_pretrained`
output_model_file
=
os
.
path
.
join
(
save_directory
,
TF2_WEIGHTS_NAME
)
self
.
save_weights
(
output_model_file
)
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment