Unverified Commit 8438bab3 authored by Pierric Cistac's avatar Pierric Cistac Committed by GitHub
Browse files

Fix roberta model ordering for TFAutoModel (#5414)

parent 6b735a72
......@@ -141,7 +141,6 @@ logger = logging.getLogger(__name__)
TF_MODEL_MAPPING = OrderedDict(
[
(AlbertConfig, TFAlbertModel),
(BertConfig, TFBertModel),
(CamembertConfig, TFCamembertModel),
(CTRLConfig, TFCTRLModel),
(DistilBertConfig, TFDistilBertModel),
......@@ -151,6 +150,7 @@ TF_MODEL_MAPPING = OrderedDict(
(MobileBertConfig, TFMobileBertModel),
(OpenAIGPTConfig, TFOpenAIGPTModel),
(RobertaConfig, TFRobertaModel),
(BertConfig, TFBertModel),
(T5Config, TFT5Model),
(TransfoXLConfig, TFTransfoXLModel),
(XLMConfig, TFXLMModel),
......@@ -162,7 +162,6 @@ TF_MODEL_MAPPING = OrderedDict(
TF_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
[
(AlbertConfig, TFAlbertForPreTraining),
(BertConfig, TFBertForPreTraining),
(CamembertConfig, TFCamembertForMaskedLM),
(CTRLConfig, TFCTRLLMHeadModel),
(DistilBertConfig, TFDistilBertForMaskedLM),
......@@ -172,6 +171,7 @@ TF_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
(MobileBertConfig, TFMobileBertForPreTraining),
(OpenAIGPTConfig, TFOpenAIGPTLMHeadModel),
(RobertaConfig, TFRobertaForMaskedLM),
(BertConfig, TFBertForPreTraining),
(T5Config, TFT5ForConditionalGeneration),
(TransfoXLConfig, TFTransfoXLLMHeadModel),
(XLMConfig, TFXLMWithLMHeadModel),
......@@ -183,7 +183,6 @@ TF_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
TF_MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
[
(AlbertConfig, TFAlbertForMaskedLM),
(BertConfig, TFBertForMaskedLM),
(CamembertConfig, TFCamembertForMaskedLM),
(CTRLConfig, TFCTRLLMHeadModel),
(DistilBertConfig, TFDistilBertForMaskedLM),
......@@ -193,6 +192,7 @@ TF_MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
(MobileBertConfig, TFMobileBertForMaskedLM),
(OpenAIGPTConfig, TFOpenAIGPTLMHeadModel),
(RobertaConfig, TFRobertaForMaskedLM),
(BertConfig, TFBertForMaskedLM),
(T5Config, TFT5ForConditionalGeneration),
(TransfoXLConfig, TFTransfoXLLMHeadModel),
(XLMConfig, TFXLMWithLMHeadModel),
......@@ -204,12 +204,12 @@ TF_MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
[
(AlbertConfig, TFAlbertForMultipleChoice),
(BertConfig, TFBertForMultipleChoice),
(CamembertConfig, TFCamembertForMultipleChoice),
(DistilBertConfig, TFDistilBertForMultipleChoice),
(FlaubertConfig, TFFlaubertForMultipleChoice),
(MobileBertConfig, TFMobileBertForMultipleChoice),
(RobertaConfig, TFRobertaForMultipleChoice),
(BertConfig, TFBertForMultipleChoice),
(XLMConfig, TFXLMForMultipleChoice),
(XLMRobertaConfig, TFXLMRobertaForMultipleChoice),
(XLNetConfig, TFXLNetForMultipleChoice),
......@@ -219,13 +219,13 @@ TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
[
(AlbertConfig, TFAlbertForQuestionAnswering),
(BertConfig, TFBertForQuestionAnswering),
(CamembertConfig, TFCamembertForQuestionAnswering),
(DistilBertConfig, TFDistilBertForQuestionAnswering),
(ElectraConfig, TFElectraForQuestionAnswering),
(FlaubertConfig, TFFlaubertForQuestionAnsweringSimple),
(MobileBertConfig, TFMobileBertForQuestionAnswering),
(RobertaConfig, TFRobertaForQuestionAnswering),
(BertConfig, TFBertForQuestionAnswering),
(XLMConfig, TFXLMForQuestionAnsweringSimple),
(XLMRobertaConfig, TFXLMRobertaForQuestionAnswering),
(XLNetConfig, TFXLNetForQuestionAnsweringSimple),
......@@ -235,12 +235,12 @@ TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
[
(AlbertConfig, TFAlbertForSequenceClassification),
(BertConfig, TFBertForSequenceClassification),
(CamembertConfig, TFCamembertForSequenceClassification),
(DistilBertConfig, TFDistilBertForSequenceClassification),
(FlaubertConfig, TFFlaubertForSequenceClassification),
(MobileBertConfig, TFMobileBertForSequenceClassification),
(RobertaConfig, TFRobertaForSequenceClassification),
(BertConfig, TFBertForSequenceClassification),
(XLMConfig, TFXLMForSequenceClassification),
(XLMRobertaConfig, TFXLMRobertaForSequenceClassification),
(XLNetConfig, TFXLNetForSequenceClassification),
......@@ -250,13 +250,13 @@ TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
[
(AlbertConfig, TFAlbertForTokenClassification),
(BertConfig, TFBertForTokenClassification),
(CamembertConfig, TFCamembertForTokenClassification),
(DistilBertConfig, TFDistilBertForTokenClassification),
(ElectraConfig, TFElectraForTokenClassification),
(FlaubertConfig, TFFlaubertForTokenClassification),
(MobileBertConfig, TFMobileBertForTokenClassification),
(RobertaConfig, TFRobertaForTokenClassification),
(BertConfig, TFBertForTokenClassification),
(XLMConfig, TFXLMForTokenClassification),
(XLMRobertaConfig, TFXLMRobertaForTokenClassification),
(XLNetConfig, TFXLNetForTokenClassification),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment