Commit ee261439 authored by thomwolf's avatar thomwolf
Browse files

add save_pretrained

parent 29bb3e4e
...@@ -125,7 +125,15 @@ class TFPreTrainedModel(tf.keras.Model): ...@@ -125,7 +125,15 @@ class TFPreTrainedModel(tf.keras.Model):
""" Save a model and its configuration file to a directory, so that it """ Save a model and its configuration file to a directory, so that it
can be re-loaded using the `:func:`~pytorch_transformers.PreTrainedModel.from_pretrained`` class method. can be re-loaded using the `:func:`~pytorch_transformers.PreTrainedModel.from_pretrained`` class method.
""" """
raise NotImplementedError assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved"
# Save configuration file
self.config.save_pretrained(save_directory)
# If we save using the predefined names, we can load using `from_pretrained`
output_model_file = os.path.join(save_directory, TF2_WEIGHTS_NAME)
self.save_weights(output_model_file)
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment