Commit 74429563 authored by thomwolf's avatar thomwolf
Browse files

save config file

parent 292140b9
...@@ -48,6 +48,17 @@ class PretrainedConfig(object): ...@@ -48,6 +48,17 @@ class PretrainedConfig(object):
self.output_hidden_states = kwargs.pop('output_hidden_states', False) self.output_hidden_states = kwargs.pop('output_hidden_states', False)
self.torchscript = kwargs.pop('torchscript', False) self.torchscript = kwargs.pop('torchscript', False)
def save_pretrained(self, save_directory):
""" Save a configuration file to a directory, so that it
can be re-loaded using the `from_pretrained(save_directory)` class method.
"""
assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved"
# If we save using the predefined names, we can load using `from_pretrained`
output_config_file = os.path.join(save_directory, CONFIG_NAME)
self.to_json_file(output_config_file)
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *input, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *input, **kwargs):
""" """
...@@ -248,12 +259,13 @@ class PreTrainedModel(nn.Module): ...@@ -248,12 +259,13 @@ class PreTrainedModel(nn.Module):
# Only save the model it-self if we are using distributed training # Only save the model it-self if we are using distributed training
model_to_save = self.module if hasattr(self, 'module') else self model_to_save = self.module if hasattr(self, 'module') else self
# Save configuration file
model_to_save.config.save_pretrained(save_directory)
# If we save using the predefined names, we can load using `from_pretrained` # If we save using the predefined names, we can load using `from_pretrained`
output_model_file = os.path.join(save_directory, WEIGHTS_NAME) output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
output_config_file = os.path.join(save_directory, CONFIG_NAME)
torch.save(model_to_save.state_dict(), output_model_file) torch.save(model_to_save.state_dict(), output_model_file)
model_to_save.config.to_json_file(output_config_file)
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment