"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "11cb6e0f7eb48bf973595eb42e827b89831704ab"
Commit 778a263f authored by LysandreJik's avatar LysandreJik
Browse files

GilBert added to AutoModels

parent 74d78bee
......@@ -30,6 +30,7 @@ from .modeling_transfo_xl import TransfoXLConfig, TransfoXLModel
from .modeling_xlnet import XLNetConfig, XLNetModel
from .modeling_xlm import XLMConfig, XLMModel
from .modeling_roberta import RobertaConfig, RobertaModel
from .modeling_dilbert import DilBertConfig, DilBertModel
from .modeling_utils import PreTrainedModel, SequenceSummary
......@@ -110,7 +111,9 @@ class AutoConfig(object):
assert unused_kwargs == {'foo': False}
"""
if 'roberta' in pretrained_model_name_or_path:
if 'dilbert' in pretrained_model_name_or_path:
return DilBertconfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'roberta' in pretrained_model_name_or_path:
return RobertaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'bert' in pretrained_model_name_or_path:
return BertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
......@@ -225,7 +228,9 @@ class AutoModel(object):
model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
"""
if 'roberta' in pretrained_model_name_or_path:
if 'dilbert' in pretrained_model_name_or_path:
return DilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'roberta' in pretrained_model_name_or_path:
return RobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'bert' in pretrained_model_name_or_path:
return BertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment