Commit b85c59f9 authored by Julien Chaumond's avatar Julien Chaumond Committed by Lysandre Debut
Browse files

config.architectures

parent f9bc3f57
......@@ -82,6 +82,7 @@ class PretrainedConfig(object):
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
# Fine-tuning task arguments
self.architectures = kwargs.pop("architectures", None)
self.finetuning_task = kwargs.pop("finetuning_task", None)
self.num_labels = kwargs.pop("num_labels", 2)
self.id2label = kwargs.pop("id2label", {i: "LABEL_{}".format(i) for i in range(self.num_labels)})
......
......@@ -284,6 +284,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# Only save the model itself if we are using distributed training
model_to_save = self.module if hasattr(self, "module") else self
# Attach architecture to the config
model_to_save.config.architectures = [model_to_save.__class__.__name__]
# Save configuration file
model_to_save.config.save_pretrained(save_directory)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment