"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "e566adc09c443af843e83b239c3a18b8e7bd422d"
Commit 41a13a63 authored by Stefan Schweter's avatar Stefan Schweter
Browse files

auto: add XLMRoBERTa to auto configuration

parent 01b68be3
...@@ -30,6 +30,7 @@ from .configuration_distilbert import DistilBertConfig, DISTILBERT_PRETRAINED_CO ...@@ -30,6 +30,7 @@ from .configuration_distilbert import DistilBertConfig, DISTILBERT_PRETRAINED_CO
from .configuration_albert import AlbertConfig, ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP from .configuration_albert import AlbertConfig, ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_camembert import CamembertConfig, CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP from .configuration_camembert import CamembertConfig, CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_t5 import T5Config, T5_PRETRAINED_CONFIG_ARCHIVE_MAP from .configuration_t5 import T5Config, T5_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_xlm_roberta import XLMRobertaConfig, XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -48,6 +49,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict((key, value) ...@@ -48,6 +49,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict((key, value)
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
] ]
for key, value, in pretrained_map.items()) for key, value, in pretrained_map.items())
...@@ -66,6 +68,7 @@ class AutoConfig(object): ...@@ -66,6 +68,7 @@ class AutoConfig(object):
- contains `distilbert`: DistilBertConfig (DistilBERT model) - contains `distilbert`: DistilBertConfig (DistilBERT model)
- contains `albert`: AlbertConfig (ALBERT model) - contains `albert`: AlbertConfig (ALBERT model)
- contains `camembert`: CamembertConfig (CamemBERT model) - contains `camembert`: CamembertConfig (CamemBERT model)
- contains `xlm-roberta`: XLMRobertaConfig (XLM-RoBERTa model)
- contains `roberta`: RobertaConfig (RoBERTa model) - contains `roberta`: RobertaConfig (RoBERTa model)
- contains `bert`: BertConfig (Bert model) - contains `bert`: BertConfig (Bert model)
- contains `openai-gpt`: OpenAIGPTConfig (OpenAI GPT model) - contains `openai-gpt`: OpenAIGPTConfig (OpenAI GPT model)
...@@ -91,6 +94,7 @@ class AutoConfig(object): ...@@ -91,6 +94,7 @@ class AutoConfig(object):
- contains `distilbert`: DistilBertConfig (DistilBERT model) - contains `distilbert`: DistilBertConfig (DistilBERT model)
- contains `albert`: AlbertConfig (ALBERT model) - contains `albert`: AlbertConfig (ALBERT model)
- contains `camembert`: CamembertConfig (CamemBERT model) - contains `camembert`: CamembertConfig (CamemBERT model)
- contains `xlm-roberta`: XLMRobertaConfig (XLM-RoBERTa model)
- contains `roberta`: RobertaConfig (RoBERTa model) - contains `roberta`: RobertaConfig (RoBERTa model)
- contains `bert`: BertConfig (Bert model) - contains `bert`: BertConfig (Bert model)
- contains `openai-gpt`: OpenAIGPTConfig (OpenAI GPT model) - contains `openai-gpt`: OpenAIGPTConfig (OpenAI GPT model)
...@@ -152,6 +156,8 @@ class AutoConfig(object): ...@@ -152,6 +156,8 @@ class AutoConfig(object):
return AlbertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) return AlbertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'camembert' in pretrained_model_name_or_path: elif 'camembert' in pretrained_model_name_or_path:
return CamembertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) return CamembertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'xlm-roberta' in pretrained_model_name_or_path:
return XLMRobertaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'roberta' in pretrained_model_name_or_path: elif 'roberta' in pretrained_model_name_or_path:
return RobertaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) return RobertaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'bert' in pretrained_model_name_or_path: elif 'bert' in pretrained_model_name_or_path:
...@@ -170,4 +176,4 @@ class AutoConfig(object): ...@@ -170,4 +176,4 @@ class AutoConfig(object):
return CTRLConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) return CTRLConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
raise ValueError("Unrecognized model identifier in {}. Should contains one of " raise ValueError("Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm', 'roberta', 'distilbert', 'camembert', 'ctrl', 'albert'".format(pretrained_model_name_or_path)) "'xlm-roberta', 'xlm', 'roberta', 'distilbert', 'camembert', 'ctrl', 'albert'".format(pretrained_model_name_or_path))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment