"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "d9ece8233d584cdc2eeae5165dd3329328fae328"
Commit 778a263f authored by LysandreJik's avatar LysandreJik
Browse files

GilBert added to AutoModels

parent 74d78bee
...@@ -30,6 +30,7 @@ from .modeling_transfo_xl import TransfoXLConfig, TransfoXLModel ...@@ -30,6 +30,7 @@ from .modeling_transfo_xl import TransfoXLConfig, TransfoXLModel
from .modeling_xlnet import XLNetConfig, XLNetModel from .modeling_xlnet import XLNetConfig, XLNetModel
from .modeling_xlm import XLMConfig, XLMModel from .modeling_xlm import XLMConfig, XLMModel
from .modeling_roberta import RobertaConfig, RobertaModel from .modeling_roberta import RobertaConfig, RobertaModel
from .modeling_dilbert import DilBertConfig, DilBertModel
from .modeling_utils import PreTrainedModel, SequenceSummary from .modeling_utils import PreTrainedModel, SequenceSummary
...@@ -110,7 +111,9 @@ class AutoConfig(object): ...@@ -110,7 +111,9 @@ class AutoConfig(object):
assert unused_kwargs == {'foo': False} assert unused_kwargs == {'foo': False}
""" """
if 'roberta' in pretrained_model_name_or_path: if 'dilbert' in pretrained_model_name_or_path:
return DilBertconfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'roberta' in pretrained_model_name_or_path:
return RobertaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) return RobertaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'bert' in pretrained_model_name_or_path: elif 'bert' in pretrained_model_name_or_path:
return BertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) return BertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
...@@ -225,7 +228,9 @@ class AutoModel(object): ...@@ -225,7 +228,9 @@ class AutoModel(object):
model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
""" """
if 'roberta' in pretrained_model_name_or_path: if 'dilbert' in pretrained_model_name_or_path:
return DilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'roberta' in pretrained_model_name_or_path:
return RobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return RobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'bert' in pretrained_model_name_or_path: elif 'bert' in pretrained_model_name_or_path:
return BertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return BertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment