Commit 6bb3edc3 authored by Julien Chaumond's avatar Julien Chaumond
Browse files

Serialize model_type if exists

parent c6f682c1
......@@ -35,3 +35,4 @@ ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
class RobertaConfig(BertConfig):
pretrained_config_archive_map = ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
model_type = "roberta"
......@@ -276,6 +276,8 @@ class PretrainedConfig(object):
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
if hasattr(self.__class__, "model_type"):
output["model_type"] = self.__class__.model_type
return output
def to_json_string(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment