Unverified Commit 9870093f authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

[WIP] Disentangle auto modules from other modeling files (#13023)

* Initial work

* All auto models

* All tf auto models

* All flax auto models

* Tokenizers

* Add feature extractors

* Fix typos

* Fix other typo

* Use the right config

* Remove old mapping names and update logic in AutoTokenizer

* Update check_table

* Fix copies and check_repo script

* Fix last test

* Add back name

* clean up

* Update template

* Update template

* Forgot a )

* Use alternative to fixup

* Fix TF model template

* Address review comments

* Address review comments

* Style
parent 2e408236
...@@ -59,7 +59,7 @@ jobs: ...@@ -59,7 +59,7 @@ jobs:
- name: Run style changes - name: Run style changes
run: | run: |
git fetch origin master:master git fetch origin master:master
make fixup make style && make quality
- name: Failure short reports - name: Failure short reports
if: ${{ always() }} if: ${{ always() }}
......
...@@ -30,7 +30,6 @@ deps_table_check_updated: ...@@ -30,7 +30,6 @@ deps_table_check_updated:
# autogenerating code # autogenerating code
autogenerate_code: deps_table_update autogenerate_code: deps_table_update
python utils/class_mapping_update.py
# Check that source code meets quality standards # Check that source code meets quality standards
......
...@@ -213,6 +213,7 @@ _import_structure = { ...@@ -213,6 +213,7 @@ _import_structure = {
"models.m2m_100": ["M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP", "M2M100Config"], "models.m2m_100": ["M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP", "M2M100Config"],
"models.marian": ["MarianConfig"], "models.marian": ["MarianConfig"],
"models.mbart": ["MBartConfig"], "models.mbart": ["MBartConfig"],
"models.mbart50": [],
"models.megatron_bert": ["MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MegatronBertConfig"], "models.megatron_bert": ["MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MegatronBertConfig"],
"models.mmbt": ["MMBTConfig"], "models.mmbt": ["MMBTConfig"],
"models.mobilebert": ["MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MobileBertConfig", "MobileBertTokenizer"], "models.mobilebert": ["MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MobileBertConfig", "MobileBertTokenizer"],
...@@ -315,7 +316,7 @@ if is_sentencepiece_available(): ...@@ -315,7 +316,7 @@ if is_sentencepiece_available():
_import_structure["models.m2m_100"].append("M2M100Tokenizer") _import_structure["models.m2m_100"].append("M2M100Tokenizer")
_import_structure["models.marian"].append("MarianTokenizer") _import_structure["models.marian"].append("MarianTokenizer")
_import_structure["models.mbart"].append("MBartTokenizer") _import_structure["models.mbart"].append("MBartTokenizer")
_import_structure["models.mbart"].append("MBart50Tokenizer") _import_structure["models.mbart50"].append("MBart50Tokenizer")
_import_structure["models.mt5"].append("MT5Tokenizer") _import_structure["models.mt5"].append("MT5Tokenizer")
_import_structure["models.pegasus"].append("PegasusTokenizer") _import_structure["models.pegasus"].append("PegasusTokenizer")
_import_structure["models.reformer"].append("ReformerTokenizer") _import_structure["models.reformer"].append("ReformerTokenizer")
...@@ -358,7 +359,7 @@ if is_tokenizers_available(): ...@@ -358,7 +359,7 @@ if is_tokenizers_available():
_import_structure["models.longformer"].append("LongformerTokenizerFast") _import_structure["models.longformer"].append("LongformerTokenizerFast")
_import_structure["models.lxmert"].append("LxmertTokenizerFast") _import_structure["models.lxmert"].append("LxmertTokenizerFast")
_import_structure["models.mbart"].append("MBartTokenizerFast") _import_structure["models.mbart"].append("MBartTokenizerFast")
_import_structure["models.mbart"].append("MBart50TokenizerFast") _import_structure["models.mbart50"].append("MBart50TokenizerFast")
_import_structure["models.mobilebert"].append("MobileBertTokenizerFast") _import_structure["models.mobilebert"].append("MobileBertTokenizerFast")
_import_structure["models.mpnet"].append("MPNetTokenizerFast") _import_structure["models.mpnet"].append("MPNetTokenizerFast")
_import_structure["models.mt5"].append("MT5TokenizerFast") _import_structure["models.mt5"].append("MT5TokenizerFast")
...@@ -2021,7 +2022,8 @@ if TYPE_CHECKING: ...@@ -2021,7 +2022,8 @@ if TYPE_CHECKING:
from .models.led import LEDTokenizerFast from .models.led import LEDTokenizerFast
from .models.longformer import LongformerTokenizerFast from .models.longformer import LongformerTokenizerFast
from .models.lxmert import LxmertTokenizerFast from .models.lxmert import LxmertTokenizerFast
from .models.mbart import MBart50TokenizerFast, MBartTokenizerFast from .models.mbart import MBartTokenizerFast
from .models.mbart50 import MBart50TokenizerFast
from .models.mobilebert import MobileBertTokenizerFast from .models.mobilebert import MobileBertTokenizerFast
from .models.mpnet import MPNetTokenizerFast from .models.mpnet import MPNetTokenizerFast
from .models.mt5 import MT5TokenizerFast from .models.mt5 import MT5TokenizerFast
......
...@@ -41,9 +41,7 @@ from .file_utils import ( ...@@ -41,9 +41,7 @@ from .file_utils import (
is_tokenizers_available, is_tokenizers_available,
is_torch_available, is_torch_available,
) )
from .training_args import ParallelMode from .models.auto.modeling_auto import (
from .utils import logging
from .utils.modeling_auto_mapping import (
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
MODEL_FOR_MASKED_LM_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES,
...@@ -54,6 +52,8 @@ from .utils.modeling_auto_mapping import ( ...@@ -54,6 +52,8 @@ from .utils.modeling_auto_mapping import (
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES, MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
) )
from .training_args import ParallelMode
from .utils import logging
TASK_MAPPING = { TASK_MAPPING = {
......
...@@ -37,6 +37,7 @@ from . import ( ...@@ -37,6 +37,7 @@ from . import (
cpm, cpm,
ctrl, ctrl,
deberta, deberta,
deberta_v2,
deit, deit,
detr, detr,
dialogpt, dialogpt,
...@@ -50,6 +51,8 @@ from . import ( ...@@ -50,6 +51,8 @@ from . import (
gpt2, gpt2,
gpt_neo, gpt_neo,
herbert, herbert,
hubert,
ibert,
layoutlm, layoutlm,
led, led,
longformer, longformer,
...@@ -58,6 +61,7 @@ from . import ( ...@@ -58,6 +61,7 @@ from . import (
m2m_100, m2m_100,
marian, marian,
mbart, mbart,
mbart50,
megatron_bert, megatron_bert,
mmbt, mmbt,
mobilebert, mobilebert,
...@@ -82,6 +86,7 @@ from . import ( ...@@ -82,6 +86,7 @@ from . import (
vit, vit,
wav2vec2, wav2vec2,
xlm, xlm,
xlm_prophetnet,
xlm_roberta, xlm_roberta,
xlnet, xlnet,
) )
...@@ -13,11 +13,13 @@ ...@@ -13,11 +13,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Factory function to build auto-model classes.""" """Factory function to build auto-model classes."""
import importlib
from collections import OrderedDict
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...file_utils import copy_func from ...file_utils import copy_func
from ...utils import logging from ...utils import logging
from .configuration_auto import AutoConfig, replace_list_option_in_docstrings from .configuration_auto import AutoConfig, model_type_to_module_name, replace_list_option_in_docstrings
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -415,7 +417,7 @@ def auto_class_update(cls, checkpoint_for_example="bert-base-cased", head_doc="" ...@@ -415,7 +417,7 @@ def auto_class_update(cls, checkpoint_for_example="bert-base-cased", head_doc=""
from_config_docstring = from_config_docstring.replace("BaseAutoModelClass", name) from_config_docstring = from_config_docstring.replace("BaseAutoModelClass", name)
from_config_docstring = from_config_docstring.replace("checkpoint_placeholder", checkpoint_for_example) from_config_docstring = from_config_docstring.replace("checkpoint_placeholder", checkpoint_for_example)
from_config.__doc__ = from_config_docstring from_config.__doc__ = from_config_docstring
from_config = replace_list_option_in_docstrings(model_mapping, use_model_types=False)(from_config) from_config = replace_list_option_in_docstrings(model_mapping._model_mapping, use_model_types=False)(from_config)
cls.from_config = classmethod(from_config) cls.from_config = classmethod(from_config)
if name.startswith("TF"): if name.startswith("TF"):
...@@ -431,7 +433,7 @@ def auto_class_update(cls, checkpoint_for_example="bert-base-cased", head_doc="" ...@@ -431,7 +433,7 @@ def auto_class_update(cls, checkpoint_for_example="bert-base-cased", head_doc=""
shortcut = checkpoint_for_example.split("/")[-1].split("-")[0] shortcut = checkpoint_for_example.split("/")[-1].split("-")[0]
from_pretrained_docstring = from_pretrained_docstring.replace("shortcut_placeholder", shortcut) from_pretrained_docstring = from_pretrained_docstring.replace("shortcut_placeholder", shortcut)
from_pretrained.__doc__ = from_pretrained_docstring from_pretrained.__doc__ = from_pretrained_docstring
from_pretrained = replace_list_option_in_docstrings(model_mapping)(from_pretrained) from_pretrained = replace_list_option_in_docstrings(model_mapping._model_mapping)(from_pretrained)
cls.from_pretrained = classmethod(from_pretrained) cls.from_pretrained = classmethod(from_pretrained)
return cls return cls
...@@ -445,3 +447,79 @@ def get_values(model_mapping): ...@@ -445,3 +447,79 @@ def get_values(model_mapping):
result.append(model) result.append(model)
return result return result
def getattribute_from_module(module, attr):
if attr is None:
return None
if isinstance(attr, tuple):
return tuple(getattribute_from_module(module, a) for a in attr)
if hasattr(module, attr):
return getattr(module, attr)
# Some of the mappings have entries model_type -> object of another model type. In that case we try to grab the
# object at the top level.
transformers_module = importlib.import_module("transformers")
return getattribute_from_module(transformers_module, attr)
class _LazyAutoMapping(OrderedDict):
"""
" A mapping config to object (model or tokenizer for instance) that will load keys and values when it is accessed.
Args:
- config_mapping: The map model type to config class
- model_mapping: The map model type to model (or tokenizer) class
"""
def __init__(self, config_mapping, model_mapping):
self._config_mapping = config_mapping
self._reverse_config_mapping = {v: k for k, v in config_mapping.items()}
self._model_mapping = model_mapping
self._modules = {}
def __getitem__(self, key):
model_type = self._reverse_config_mapping[key.__name__]
if model_type not in self._model_mapping:
raise KeyError(key)
model_name = self._model_mapping[model_type]
return self._load_attr_from_module(model_type, model_name)
def _load_attr_from_module(self, model_type, attr):
module_name = model_type_to_module_name(model_type)
if module_name not in self._modules:
self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models")
return getattribute_from_module(self._modules[module_name], attr)
def keys(self):
return [
self._load_attr_from_module(key, name)
for key, name in self._config_mapping.items()
if key in self._model_mapping.keys()
]
def values(self):
return [
self._load_attr_from_module(key, name)
for key, name in self._model_mapping.items()
if key in self._config_mapping.keys()
]
def items(self):
return [
(
self._load_attr_from_module(key, self._config_mapping[key]),
self._load_attr_from_module(key, self._model_mapping[key]),
)
for key in self._model_mapping.keys()
if key in self._config_mapping.keys()
]
def __iter__(self):
return iter(self._mapping.keys())
def __contains__(self, item):
if not hasattr(item, "__name__") or item.__name__ not in self._reverse_config_mapping:
return False
model_type = self._reverse_config_mapping[item.__name__]
return model_type in self._model_mapping
...@@ -13,36 +13,43 @@ ...@@ -13,36 +13,43 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" AutoFeatureExtractor class. """ """ AutoFeatureExtractor class. """
import importlib
import os import os
from collections import OrderedDict from collections import OrderedDict
from transformers import BeitFeatureExtractor, DeiTFeatureExtractor, Speech2TextFeatureExtractor, ViTFeatureExtractor
from ... import BeitConfig, DeiTConfig, PretrainedConfig, Speech2TextConfig, ViTConfig, Wav2Vec2Config
from ...feature_extraction_utils import FeatureExtractionMixin
# Build the list of all feature extractors # Build the list of all feature extractors
from ...configuration_utils import PretrainedConfig
from ...feature_extraction_utils import FeatureExtractionMixin
from ...file_utils import FEATURE_EXTRACTOR_NAME from ...file_utils import FEATURE_EXTRACTOR_NAME
from ..wav2vec2.feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor from .auto_factory import _LazyAutoMapping
from .configuration_auto import AutoConfig, replace_list_option_in_docstrings from .configuration_auto import (
CONFIG_MAPPING_NAMES,
AutoConfig,
config_class_to_model_type,
replace_list_option_in_docstrings,
)
FEATURE_EXTRACTOR_MAPPING = OrderedDict( FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
[ [
(BeitConfig, BeitFeatureExtractor), ("beit", "BeitFeatureExtractor"),
(DeiTConfig, DeiTFeatureExtractor), ("deit", "DeiTFeatureExtractor"),
(Speech2TextConfig, Speech2TextFeatureExtractor), ("speech_to_text", "Speech2TextFeatureExtractor"),
(ViTConfig, ViTFeatureExtractor), ("vit", "ViTFeatureExtractor"),
(Wav2Vec2Config, Wav2Vec2FeatureExtractor), ("wav2vec2", "Wav2Vec2FeatureExtractor"),
] ]
) )
FEATURE_EXTRACTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FEATURE_EXTRACTOR_MAPPING_NAMES)
def feature_extractor_class_from_name(class_name: str): def feature_extractor_class_from_name(class_name: str):
for c in FEATURE_EXTRACTOR_MAPPING.values(): for module_name, extractors in FEATURE_EXTRACTOR_MAPPING_NAMES.items():
if c is not None and c.__name__ == class_name: if class_name in extractors:
return c break
module = importlib.import_module(f".{module_name}", "transformers.models")
return getattr(module, class_name)
class AutoFeatureExtractor: class AutoFeatureExtractor:
...@@ -60,7 +67,7 @@ class AutoFeatureExtractor: ...@@ -60,7 +67,7 @@ class AutoFeatureExtractor:
) )
@classmethod @classmethod
@replace_list_option_in_docstrings(FEATURE_EXTRACTOR_MAPPING) @replace_list_option_in_docstrings(FEATURE_EXTRACTOR_MAPPING_NAMES)
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
r""" r"""
Instantiate one of the feature extractor classes of the library from a pretrained model vocabulary. Instantiate one of the feature extractor classes of the library from a pretrained model vocabulary.
...@@ -142,7 +149,8 @@ class AutoFeatureExtractor: ...@@ -142,7 +149,8 @@ class AutoFeatureExtractor:
kwargs["_from_auto"] = True kwargs["_from_auto"] = True
config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs) config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
if type(config) in FEATURE_EXTRACTOR_MAPPING.keys(): model_type = config_class_to_model_type(type(config).__name__)
if model_type is not None:
return FEATURE_EXTRACTOR_MAPPING[type(config)].from_dict(config_dict, **kwargs) return FEATURE_EXTRACTOR_MAPPING[type(config)].from_dict(config_dict, **kwargs)
elif "feature_extractor_type" in config_dict: elif "feature_extractor_type" in config_dict:
feature_extractor_class = feature_extractor_class_from_name(config_dict["feature_extractor_type"]) feature_extractor_class = feature_extractor_class_from_name(config_dict["feature_extractor_type"])
......
...@@ -18,207 +18,163 @@ ...@@ -18,207 +18,163 @@
from collections import OrderedDict from collections import OrderedDict
from ...utils import logging from ...utils import logging
from ..bart.modeling_flax_bart import ( from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
FlaxBartForConditionalGeneration, from .configuration_auto import CONFIG_MAPPING_NAMES
FlaxBartForQuestionAnswering,
FlaxBartForSequenceClassification,
FlaxBartModel,
)
from ..bert.modeling_flax_bert import (
FlaxBertForMaskedLM,
FlaxBertForMultipleChoice,
FlaxBertForNextSentencePrediction,
FlaxBertForPreTraining,
FlaxBertForQuestionAnswering,
FlaxBertForSequenceClassification,
FlaxBertForTokenClassification,
FlaxBertModel,
)
from ..big_bird.modeling_flax_big_bird import (
FlaxBigBirdForMaskedLM,
FlaxBigBirdForMultipleChoice,
FlaxBigBirdForPreTraining,
FlaxBigBirdForQuestionAnswering,
FlaxBigBirdForSequenceClassification,
FlaxBigBirdForTokenClassification,
FlaxBigBirdModel,
)
from ..clip.modeling_flax_clip import FlaxCLIPModel
from ..electra.modeling_flax_electra import (
FlaxElectraForMaskedLM,
FlaxElectraForMultipleChoice,
FlaxElectraForPreTraining,
FlaxElectraForQuestionAnswering,
FlaxElectraForSequenceClassification,
FlaxElectraForTokenClassification,
FlaxElectraModel,
)
from ..gpt2.modeling_flax_gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model
from ..gpt_neo.modeling_flax_gpt_neo import FlaxGPTNeoForCausalLM, FlaxGPTNeoModel
from ..marian.modeling_flax_marian import FlaxMarianModel, FlaxMarianMTModel
from ..mbart.modeling_flax_mbart import (
FlaxMBartForConditionalGeneration,
FlaxMBartForQuestionAnswering,
FlaxMBartForSequenceClassification,
FlaxMBartModel,
)
from ..mt5.modeling_flax_mt5 import FlaxMT5ForConditionalGeneration, FlaxMT5Model
from ..roberta.modeling_flax_roberta import (
FlaxRobertaForMaskedLM,
FlaxRobertaForMultipleChoice,
FlaxRobertaForQuestionAnswering,
FlaxRobertaForSequenceClassification,
FlaxRobertaForTokenClassification,
FlaxRobertaModel,
)
from ..t5.modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model
from ..vit.modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel
from ..wav2vec2.modeling_flax_wav2vec2 import FlaxWav2Vec2ForPreTraining, FlaxWav2Vec2Model
from .auto_factory import _BaseAutoModelClass, auto_class_update
from .configuration_auto import (
BartConfig,
BertConfig,
BigBirdConfig,
CLIPConfig,
ElectraConfig,
GPT2Config,
GPTNeoConfig,
MarianConfig,
MBartConfig,
MT5Config,
RobertaConfig,
T5Config,
ViTConfig,
Wav2Vec2Config,
)
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
FLAX_MODEL_MAPPING = OrderedDict( FLAX_MODEL_MAPPING_NAMES = OrderedDict(
[ [
# Base model mapping # Base model mapping
(RobertaConfig, FlaxRobertaModel), ("roberta", "FlaxRobertaModel"),
(BertConfig, FlaxBertModel), ("bert", "FlaxBertModel"),
(BigBirdConfig, FlaxBigBirdModel), ("big_bird", "FlaxBigBirdModel"),
(BartConfig, FlaxBartModel), ("bart", "FlaxBartModel"),
(GPT2Config, FlaxGPT2Model), ("gpt2", "FlaxGPT2Model"),
(GPTNeoConfig, FlaxGPTNeoModel), ("gpt_neo", "FlaxGPTNeoModel"),
(ElectraConfig, FlaxElectraModel), ("electra", "FlaxElectraModel"),
(CLIPConfig, FlaxCLIPModel), ("clip", "FlaxCLIPModel"),
(ViTConfig, FlaxViTModel), ("vit", "FlaxViTModel"),
(MBartConfig, FlaxMBartModel), ("mbart", "FlaxMBartModel"),
(T5Config, FlaxT5Model), ("t5", "FlaxT5Model"),
(MT5Config, FlaxMT5Model), ("mt5", "FlaxMT5Model"),
(Wav2Vec2Config, FlaxWav2Vec2Model), ("wav2vec2", "FlaxWav2Vec2Model"),
(MarianConfig, FlaxMarianModel), ("marian", "FlaxMarianModel"),
] ]
) )
FLAX_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict( FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
[ [
# Model for pre-training mapping # Model for pre-training mapping
(RobertaConfig, FlaxRobertaForMaskedLM), ("roberta", "FlaxRobertaForMaskedLM"),
(BertConfig, FlaxBertForPreTraining), ("bert", "FlaxBertForPreTraining"),
(BigBirdConfig, FlaxBigBirdForPreTraining), ("big_bird", "FlaxBigBirdForPreTraining"),
(BartConfig, FlaxBartForConditionalGeneration), ("bart", "FlaxBartForConditionalGeneration"),
(ElectraConfig, FlaxElectraForPreTraining), ("electra", "FlaxElectraForPreTraining"),
(MBartConfig, FlaxMBartForConditionalGeneration), ("mbart", "FlaxMBartForConditionalGeneration"),
(T5Config, FlaxT5ForConditionalGeneration), ("t5", "FlaxT5ForConditionalGeneration"),
(MT5Config, FlaxMT5ForConditionalGeneration), ("mt5", "FlaxMT5ForConditionalGeneration"),
(Wav2Vec2Config, FlaxWav2Vec2ForPreTraining), ("wav2vec2", "FlaxWav2Vec2ForPreTraining"),
] ]
) )
FLAX_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict( FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
[ [
# Model for Masked LM mapping # Model for Masked LM mapping
(RobertaConfig, FlaxRobertaForMaskedLM), ("roberta", "FlaxRobertaForMaskedLM"),
(BertConfig, FlaxBertForMaskedLM), ("bert", "FlaxBertForMaskedLM"),
(BigBirdConfig, FlaxBigBirdForMaskedLM), ("big_bird", "FlaxBigBirdForMaskedLM"),
(BartConfig, FlaxBartForConditionalGeneration), ("bart", "FlaxBartForConditionalGeneration"),
(ElectraConfig, FlaxElectraForMaskedLM), ("electra", "FlaxElectraForMaskedLM"),
(MBartConfig, FlaxMBartForConditionalGeneration), ("mbart", "FlaxMBartForConditionalGeneration"),
] ]
) )
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict( FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[ [
# Model for Seq2Seq Causal LM mapping # Model for Seq2Seq Causal LM mapping
(BartConfig, FlaxBartForConditionalGeneration), ("bart", "FlaxBartForConditionalGeneration"),
(T5Config, FlaxT5ForConditionalGeneration), ("t5", "FlaxT5ForConditionalGeneration"),
(MT5Config, FlaxMT5ForConditionalGeneration), ("mt5", "FlaxMT5ForConditionalGeneration"),
(MarianConfig, FlaxMarianMTModel), ("marian", "FlaxMarianMTModel"),
] ]
) )
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = OrderedDict( FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[ [
# Model for Image-classsification # Model for Image-classsification
(ViTConfig, FlaxViTForImageClassification), ("vit", "FlaxViTForImageClassification"),
] ]
) )
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict( FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[ [
# Model for Causal LM mapping # Model for Causal LM mapping
(GPT2Config, FlaxGPT2LMHeadModel), ("gpt2", "FlaxGPT2LMHeadModel"),
(GPTNeoConfig, FlaxGPTNeoForCausalLM), ("gpt_neo", "FlaxGPTNeoForCausalLM"),
] ]
) )
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict( FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[ [
# Model for Sequence Classification mapping # Model for Sequence Classification mapping
(RobertaConfig, FlaxRobertaForSequenceClassification), ("roberta", "FlaxRobertaForSequenceClassification"),
(BertConfig, FlaxBertForSequenceClassification), ("bert", "FlaxBertForSequenceClassification"),
(BigBirdConfig, FlaxBigBirdForSequenceClassification), ("big_bird", "FlaxBigBirdForSequenceClassification"),
(BartConfig, FlaxBartForSequenceClassification), ("bart", "FlaxBartForSequenceClassification"),
(ElectraConfig, FlaxElectraForSequenceClassification), ("electra", "FlaxElectraForSequenceClassification"),
(MBartConfig, FlaxMBartForSequenceClassification), ("mbart", "FlaxMBartForSequenceClassification"),
] ]
) )
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict( FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[ [
# Model for Question Answering mapping # Model for Question Answering mapping
(RobertaConfig, FlaxRobertaForQuestionAnswering), ("roberta", "FlaxRobertaForQuestionAnswering"),
(BertConfig, FlaxBertForQuestionAnswering), ("bert", "FlaxBertForQuestionAnswering"),
(BigBirdConfig, FlaxBigBirdForQuestionAnswering), ("big_bird", "FlaxBigBirdForQuestionAnswering"),
(BartConfig, FlaxBartForQuestionAnswering), ("bart", "FlaxBartForQuestionAnswering"),
(ElectraConfig, FlaxElectraForQuestionAnswering), ("electra", "FlaxElectraForQuestionAnswering"),
(MBartConfig, FlaxMBartForQuestionAnswering), ("mbart", "FlaxMBartForQuestionAnswering"),
] ]
) )
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict( FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[ [
# Model for Token Classification mapping # Model for Token Classification mapping
(RobertaConfig, FlaxRobertaForTokenClassification), ("roberta", "FlaxRobertaForTokenClassification"),
(BertConfig, FlaxBertForTokenClassification), ("bert", "FlaxBertForTokenClassification"),
(BigBirdConfig, FlaxBigBirdForTokenClassification), ("big_bird", "FlaxBigBirdForTokenClassification"),
(ElectraConfig, FlaxElectraForTokenClassification), ("electra", "FlaxElectraForTokenClassification"),
] ]
) )
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict( FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
[ [
# Model for Multiple Choice mapping # Model for Multiple Choice mapping
(RobertaConfig, FlaxRobertaForMultipleChoice), ("roberta", "FlaxRobertaForMultipleChoice"),
(BertConfig, FlaxBertForMultipleChoice), ("bert", "FlaxBertForMultipleChoice"),
(BigBirdConfig, FlaxBigBirdForMultipleChoice), ("big_bird", "FlaxBigBirdForMultipleChoice"),
(ElectraConfig, FlaxElectraForMultipleChoice), ("electra", "FlaxElectraForMultipleChoice"),
] ]
) )
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict( FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
[ [
(BertConfig, FlaxBertForNextSentencePrediction), ("bert", "FlaxBertForNextSentencePrediction"),
] ]
) )
FLAX_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_MAPPING_NAMES)
FLAX_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
FLAX_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES)
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
)
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
)
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
)
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
)
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
)
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES
)
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
)
class FlaxAutoModel(_BaseAutoModelClass): class FlaxAutoModel(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_MAPPING _model_mapping = FLAX_MODEL_MAPPING
......
...@@ -33,10 +33,8 @@ _import_structure = { ...@@ -33,10 +33,8 @@ _import_structure = {
if is_sentencepiece_available(): if is_sentencepiece_available():
_import_structure["tokenization_mbart"] = ["MBartTokenizer"] _import_structure["tokenization_mbart"] = ["MBartTokenizer"]
_import_structure["tokenization_mbart50"] = ["MBart50Tokenizer"]
if is_tokenizers_available(): if is_tokenizers_available():
_import_structure["tokenization_mbart50_fast"] = ["MBart50TokenizerFast"]
_import_structure["tokenization_mbart_fast"] = ["MBartTokenizerFast"] _import_structure["tokenization_mbart_fast"] = ["MBartTokenizerFast"]
if is_torch_available(): if is_torch_available():
...@@ -72,10 +70,8 @@ if TYPE_CHECKING: ...@@ -72,10 +70,8 @@ if TYPE_CHECKING:
if is_sentencepiece_available(): if is_sentencepiece_available():
from .tokenization_mbart import MBartTokenizer from .tokenization_mbart import MBartTokenizer
from .tokenization_mbart50 import MBart50Tokenizer
if is_tokenizers_available(): if is_tokenizers_available():
from .tokenization_mbart50_fast import MBart50TokenizerFast
from .tokenization_mbart_fast import MBartTokenizerFast from .tokenization_mbart_fast import MBartTokenizerFast
if is_torch_available(): if is_torch_available():
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...file_utils import _LazyModule, is_sentencepiece_available, is_tokenizers_available
_import_structure = {}
if is_sentencepiece_available():
_import_structure["tokenization_mbart50"] = ["MBart50Tokenizer"]
if is_tokenizers_available():
_import_structure["tokenization_mbart50_fast"] = ["MBart50TokenizerFast"]
if TYPE_CHECKING:
if is_sentencepiece_available():
from .tokenization_mbart50 import MBart50Tokenizer
if is_tokenizers_available():
from .tokenization_mbart50_fast import MBart50TokenizerFast
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
...@@ -1781,25 +1781,22 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -1781,25 +1781,22 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
if config_tokenizer_class is None: if config_tokenizer_class is None:
# Third attempt. If we have not yet found the original type of the tokenizer, # Third attempt. If we have not yet found the original type of the tokenizer,
# we are loading we see if we can infer it from the type of the configuration file # we are loading we see if we can infer it from the type of the configuration file
from .models.auto.configuration_auto import CONFIG_MAPPING # tests_ignore from .models.auto.tokenization_auto import TOKENIZER_MAPPING_NAMES # tests_ignore
from .models.auto.tokenization_auto import TOKENIZER_MAPPING # tests_ignore
if hasattr(config, "model_type"): if hasattr(config, "model_type"):
config_class = CONFIG_MAPPING.get(config.model_type) model_type = config.model_type
else: else:
# Fallback: use pattern matching on the string. # Fallback: use pattern matching on the string.
config_class = None model_type = None
for pattern, config_class_tmp in CONFIG_MAPPING.items(): for pattern in TOKENIZER_MAPPING_NAMES.keys():
if pattern in str(pretrained_model_name_or_path): if pattern in str(pretrained_model_name_or_path):
config_class = config_class_tmp model_type = pattern
break break
if config_class in TOKENIZER_MAPPING.keys(): if model_type is not None:
config_tokenizer_class, config_tokenizer_class_fast = TOKENIZER_MAPPING[config_class] config_tokenizer_class, config_tokenizer_class_fast = TOKENIZER_MAPPING_NAMES[model_type]
if config_tokenizer_class is not None: if config_tokenizer_class is None:
config_tokenizer_class = config_tokenizer_class.__name__ config_tokenizer_class = config_tokenizer_class_fast
else:
config_tokenizer_class = config_tokenizer_class_fast.__name__
if config_tokenizer_class is not None: if config_tokenizer_class is not None:
if cls.__name__.replace("Fast", "") != config_tokenizer_class.replace("Fast", ""): if cls.__name__.replace("Fast", "") != config_tokenizer_class.replace("Fast", ""):
......
...@@ -74,6 +74,7 @@ from .file_utils import ( ...@@ -74,6 +74,7 @@ from .file_utils import (
) )
from .modelcard import TrainingSummary from .modelcard import TrainingSummary
from .modeling_utils import PreTrainedModel, unwrap_model from .modeling_utils import PreTrainedModel, unwrap_model
from .models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
from .optimization import Adafactor, AdamW, get_scheduler from .optimization import Adafactor, AdamW, get_scheduler
from .tokenization_utils_base import PreTrainedTokenizerBase from .tokenization_utils_base import PreTrainedTokenizerBase
from .trainer_callback import ( from .trainer_callback import (
...@@ -125,7 +126,6 @@ from .trainer_utils import ( ...@@ -125,7 +126,6 @@ from .trainer_utils import (
) )
from .training_args import ParallelMode, TrainingArguments from .training_args import ParallelMode, TrainingArguments
from .utils import logging from .utils import logging
from .utils.modeling_auto_mapping import MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
_is_torch_generator_available = False _is_torch_generator_available = False
......
...@@ -191,7 +191,7 @@ class LxmertTokenizerFast: ...@@ -191,7 +191,7 @@ class LxmertTokenizerFast:
requires_backends(cls, ["tokenizers"]) requires_backends(cls, ["tokenizers"])
class MBart50TokenizerFast: class MBartTokenizerFast:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["tokenizers"]) requires_backends(self, ["tokenizers"])
...@@ -200,7 +200,7 @@ class MBart50TokenizerFast: ...@@ -200,7 +200,7 @@ class MBart50TokenizerFast:
requires_backends(cls, ["tokenizers"]) requires_backends(cls, ["tokenizers"])
class MBartTokenizerFast: class MBart50TokenizerFast:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["tokenizers"]) requires_backends(self, ["tokenizers"])
......
# THIS FILE HAS BEEN AUTOGENERATED. To update:
# 1. modify: models/auto/modeling_auto.py
# 2. run: python utils/class_mapping_update.py
from collections import OrderedDict
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
("RemBertConfig", "RemBertForQuestionAnswering"),
("CanineConfig", "CanineForQuestionAnswering"),
("RoFormerConfig", "RoFormerForQuestionAnswering"),
("BigBirdPegasusConfig", "BigBirdPegasusForQuestionAnswering"),
("BigBirdConfig", "BigBirdForQuestionAnswering"),
("ConvBertConfig", "ConvBertForQuestionAnswering"),
("LEDConfig", "LEDForQuestionAnswering"),
("DistilBertConfig", "DistilBertForQuestionAnswering"),
("AlbertConfig", "AlbertForQuestionAnswering"),
("CamembertConfig", "CamembertForQuestionAnswering"),
("BartConfig", "BartForQuestionAnswering"),
("MBartConfig", "MBartForQuestionAnswering"),
("LongformerConfig", "LongformerForQuestionAnswering"),
("XLMRobertaConfig", "XLMRobertaForQuestionAnswering"),
("RobertaConfig", "RobertaForQuestionAnswering"),
("SqueezeBertConfig", "SqueezeBertForQuestionAnswering"),
("BertConfig", "BertForQuestionAnswering"),
("XLNetConfig", "XLNetForQuestionAnsweringSimple"),
("FlaubertConfig", "FlaubertForQuestionAnsweringSimple"),
("MegatronBertConfig", "MegatronBertForQuestionAnswering"),
("MobileBertConfig", "MobileBertForQuestionAnswering"),
("XLMConfig", "XLMForQuestionAnsweringSimple"),
("ElectraConfig", "ElectraForQuestionAnswering"),
("ReformerConfig", "ReformerForQuestionAnswering"),
("FunnelConfig", "FunnelForQuestionAnswering"),
("LxmertConfig", "LxmertForQuestionAnswering"),
("MPNetConfig", "MPNetForQuestionAnswering"),
("DebertaConfig", "DebertaForQuestionAnswering"),
("DebertaV2Config", "DebertaV2ForQuestionAnswering"),
("IBertConfig", "IBertForQuestionAnswering"),
]
)
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
("RemBertConfig", "RemBertForCausalLM"),
("RoFormerConfig", "RoFormerForCausalLM"),
("BigBirdPegasusConfig", "BigBirdPegasusForCausalLM"),
("GPTNeoConfig", "GPTNeoForCausalLM"),
("BigBirdConfig", "BigBirdForCausalLM"),
("CamembertConfig", "CamembertForCausalLM"),
("XLMRobertaConfig", "XLMRobertaForCausalLM"),
("RobertaConfig", "RobertaForCausalLM"),
("BertConfig", "BertLMHeadModel"),
("OpenAIGPTConfig", "OpenAIGPTLMHeadModel"),
("GPT2Config", "GPT2LMHeadModel"),
("TransfoXLConfig", "TransfoXLLMHeadModel"),
("XLNetConfig", "XLNetLMHeadModel"),
("XLMConfig", "XLMWithLMHeadModel"),
("CTRLConfig", "CTRLLMHeadModel"),
("ReformerConfig", "ReformerModelWithLMHead"),
("BertGenerationConfig", "BertGenerationDecoder"),
("XLMProphetNetConfig", "XLMProphetNetForCausalLM"),
("ProphetNetConfig", "ProphetNetForCausalLM"),
("BartConfig", "BartForCausalLM"),
("MBartConfig", "MBartForCausalLM"),
("PegasusConfig", "PegasusForCausalLM"),
("MarianConfig", "MarianForCausalLM"),
("BlenderbotConfig", "BlenderbotForCausalLM"),
("BlenderbotSmallConfig", "BlenderbotSmallForCausalLM"),
("MegatronBertConfig", "MegatronBertForCausalLM"),
]
)
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
("ViTConfig", "ViTForImageClassification"),
("DeiTConfig", "('DeiTForImageClassification', 'DeiTForImageClassificationWithTeacher')"),
("BeitConfig", "BeitForImageClassification"),
]
)
MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
[
("RemBertConfig", "RemBertForMaskedLM"),
("RoFormerConfig", "RoFormerForMaskedLM"),
("BigBirdConfig", "BigBirdForMaskedLM"),
("Wav2Vec2Config", "Wav2Vec2ForMaskedLM"),
("ConvBertConfig", "ConvBertForMaskedLM"),
("LayoutLMConfig", "LayoutLMForMaskedLM"),
("DistilBertConfig", "DistilBertForMaskedLM"),
("AlbertConfig", "AlbertForMaskedLM"),
("BartConfig", "BartForConditionalGeneration"),
("MBartConfig", "MBartForConditionalGeneration"),
("CamembertConfig", "CamembertForMaskedLM"),
("XLMRobertaConfig", "XLMRobertaForMaskedLM"),
("LongformerConfig", "LongformerForMaskedLM"),
("RobertaConfig", "RobertaForMaskedLM"),
("SqueezeBertConfig", "SqueezeBertForMaskedLM"),
("BertConfig", "BertForMaskedLM"),
("MegatronBertConfig", "MegatronBertForMaskedLM"),
("MobileBertConfig", "MobileBertForMaskedLM"),
("FlaubertConfig", "FlaubertWithLMHeadModel"),
("XLMConfig", "XLMWithLMHeadModel"),
("ElectraConfig", "ElectraForMaskedLM"),
("ReformerConfig", "ReformerForMaskedLM"),
("FunnelConfig", "FunnelForMaskedLM"),
("MPNetConfig", "MPNetForMaskedLM"),
("TapasConfig", "TapasForMaskedLM"),
("DebertaConfig", "DebertaForMaskedLM"),
("DebertaV2Config", "DebertaV2ForMaskedLM"),
("IBertConfig", "IBertForMaskedLM"),
]
)
MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
[
("RemBertConfig", "RemBertForMultipleChoice"),
("CanineConfig", "CanineForMultipleChoice"),
("RoFormerConfig", "RoFormerForMultipleChoice"),
("BigBirdConfig", "BigBirdForMultipleChoice"),
("ConvBertConfig", "ConvBertForMultipleChoice"),
("CamembertConfig", "CamembertForMultipleChoice"),
("ElectraConfig", "ElectraForMultipleChoice"),
("XLMRobertaConfig", "XLMRobertaForMultipleChoice"),
("LongformerConfig", "LongformerForMultipleChoice"),
("RobertaConfig", "RobertaForMultipleChoice"),
("SqueezeBertConfig", "SqueezeBertForMultipleChoice"),
("BertConfig", "BertForMultipleChoice"),
("DistilBertConfig", "DistilBertForMultipleChoice"),
("MegatronBertConfig", "MegatronBertForMultipleChoice"),
("MobileBertConfig", "MobileBertForMultipleChoice"),
("XLNetConfig", "XLNetForMultipleChoice"),
("AlbertConfig", "AlbertForMultipleChoice"),
("XLMConfig", "XLMForMultipleChoice"),
("FlaubertConfig", "FlaubertForMultipleChoice"),
("FunnelConfig", "FunnelForMultipleChoice"),
("MPNetConfig", "MPNetForMultipleChoice"),
("IBertConfig", "IBertForMultipleChoice"),
]
)
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
[
("BertConfig", "BertForNextSentencePrediction"),
("MegatronBertConfig", "MegatronBertForNextSentencePrediction"),
("MobileBertConfig", "MobileBertForNextSentencePrediction"),
]
)
MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
[
("DetrConfig", "DetrForObjectDetection"),
]
)
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
("BigBirdPegasusConfig", "BigBirdPegasusForConditionalGeneration"),
("M2M100Config", "M2M100ForConditionalGeneration"),
("LEDConfig", "LEDForConditionalGeneration"),
("BlenderbotSmallConfig", "BlenderbotSmallForConditionalGeneration"),
("MT5Config", "MT5ForConditionalGeneration"),
("T5Config", "T5ForConditionalGeneration"),
("PegasusConfig", "PegasusForConditionalGeneration"),
("MarianConfig", "MarianMTModel"),
("MBartConfig", "MBartForConditionalGeneration"),
("BlenderbotConfig", "BlenderbotForConditionalGeneration"),
("BartConfig", "BartForConditionalGeneration"),
("FSMTConfig", "FSMTForConditionalGeneration"),
("EncoderDecoderConfig", "EncoderDecoderModel"),
("XLMProphetNetConfig", "XLMProphetNetForConditionalGeneration"),
("ProphetNetConfig", "ProphetNetForConditionalGeneration"),
]
)
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
("RemBertConfig", "RemBertForSequenceClassification"),
("CanineConfig", "CanineForSequenceClassification"),
("RoFormerConfig", "RoFormerForSequenceClassification"),
("BigBirdPegasusConfig", "BigBirdPegasusForSequenceClassification"),
("BigBirdConfig", "BigBirdForSequenceClassification"),
("ConvBertConfig", "ConvBertForSequenceClassification"),
("LEDConfig", "LEDForSequenceClassification"),
("DistilBertConfig", "DistilBertForSequenceClassification"),
("AlbertConfig", "AlbertForSequenceClassification"),
("CamembertConfig", "CamembertForSequenceClassification"),
("XLMRobertaConfig", "XLMRobertaForSequenceClassification"),
("MBartConfig", "MBartForSequenceClassification"),
("BartConfig", "BartForSequenceClassification"),
("LongformerConfig", "LongformerForSequenceClassification"),
("RobertaConfig", "RobertaForSequenceClassification"),
("SqueezeBertConfig", "SqueezeBertForSequenceClassification"),
("LayoutLMConfig", "LayoutLMForSequenceClassification"),
("BertConfig", "BertForSequenceClassification"),
("XLNetConfig", "XLNetForSequenceClassification"),
("MegatronBertConfig", "MegatronBertForSequenceClassification"),
("MobileBertConfig", "MobileBertForSequenceClassification"),
("FlaubertConfig", "FlaubertForSequenceClassification"),
("XLMConfig", "XLMForSequenceClassification"),
("ElectraConfig", "ElectraForSequenceClassification"),
("FunnelConfig", "FunnelForSequenceClassification"),
("DebertaConfig", "DebertaForSequenceClassification"),
("DebertaV2Config", "DebertaV2ForSequenceClassification"),
("GPT2Config", "GPT2ForSequenceClassification"),
("GPTNeoConfig", "GPTNeoForSequenceClassification"),
("OpenAIGPTConfig", "OpenAIGPTForSequenceClassification"),
("ReformerConfig", "ReformerForSequenceClassification"),
("CTRLConfig", "CTRLForSequenceClassification"),
("TransfoXLConfig", "TransfoXLForSequenceClassification"),
("MPNetConfig", "MPNetForSequenceClassification"),
("TapasConfig", "TapasForSequenceClassification"),
("IBertConfig", "IBertForSequenceClassification"),
]
)
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
("TapasConfig", "TapasForQuestionAnswering"),
]
)
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
("RemBertConfig", "RemBertForTokenClassification"),
("CanineConfig", "CanineForTokenClassification"),
("RoFormerConfig", "RoFormerForTokenClassification"),
("BigBirdConfig", "BigBirdForTokenClassification"),
("ConvBertConfig", "ConvBertForTokenClassification"),
("LayoutLMConfig", "LayoutLMForTokenClassification"),
("DistilBertConfig", "DistilBertForTokenClassification"),
("CamembertConfig", "CamembertForTokenClassification"),
("FlaubertConfig", "FlaubertForTokenClassification"),
("XLMConfig", "XLMForTokenClassification"),
("XLMRobertaConfig", "XLMRobertaForTokenClassification"),
("LongformerConfig", "LongformerForTokenClassification"),
("RobertaConfig", "RobertaForTokenClassification"),
("SqueezeBertConfig", "SqueezeBertForTokenClassification"),
("BertConfig", "BertForTokenClassification"),
("MegatronBertConfig", "MegatronBertForTokenClassification"),
("MobileBertConfig", "MobileBertForTokenClassification"),
("XLNetConfig", "XLNetForTokenClassification"),
("AlbertConfig", "AlbertForTokenClassification"),
("ElectraConfig", "ElectraForTokenClassification"),
("FunnelConfig", "FunnelForTokenClassification"),
("MPNetConfig", "MPNetForTokenClassification"),
("DebertaConfig", "DebertaForTokenClassification"),
("DebertaV2Config", "DebertaV2ForTokenClassification"),
("IBertConfig", "IBertForTokenClassification"),
]
)
MODEL_MAPPING_NAMES = OrderedDict(
[
("BeitConfig", "BeitModel"),
("RemBertConfig", "RemBertModel"),
("VisualBertConfig", "VisualBertModel"),
("CanineConfig", "CanineModel"),
("RoFormerConfig", "RoFormerModel"),
("CLIPConfig", "CLIPModel"),
("BigBirdPegasusConfig", "BigBirdPegasusModel"),
("DeiTConfig", "DeiTModel"),
("LukeConfig", "LukeModel"),
("DetrConfig", "DetrModel"),
("GPTNeoConfig", "GPTNeoModel"),
("BigBirdConfig", "BigBirdModel"),
("Speech2TextConfig", "Speech2TextModel"),
("ViTConfig", "ViTModel"),
("Wav2Vec2Config", "Wav2Vec2Model"),
("HubertConfig", "HubertModel"),
("M2M100Config", "M2M100Model"),
("ConvBertConfig", "ConvBertModel"),
("LEDConfig", "LEDModel"),
("BlenderbotSmallConfig", "BlenderbotSmallModel"),
("RetriBertConfig", "RetriBertModel"),
("MT5Config", "MT5Model"),
("T5Config", "T5Model"),
("PegasusConfig", "PegasusModel"),
("MarianConfig", "MarianModel"),
("MBartConfig", "MBartModel"),
("BlenderbotConfig", "BlenderbotModel"),
("DistilBertConfig", "DistilBertModel"),
("AlbertConfig", "AlbertModel"),
("CamembertConfig", "CamembertModel"),
("XLMRobertaConfig", "XLMRobertaModel"),
("BartConfig", "BartModel"),
("LongformerConfig", "LongformerModel"),
("RobertaConfig", "RobertaModel"),
("LayoutLMConfig", "LayoutLMModel"),
("SqueezeBertConfig", "SqueezeBertModel"),
("BertConfig", "BertModel"),
("OpenAIGPTConfig", "OpenAIGPTModel"),
("GPT2Config", "GPT2Model"),
("MegatronBertConfig", "MegatronBertModel"),
("MobileBertConfig", "MobileBertModel"),
("TransfoXLConfig", "TransfoXLModel"),
("XLNetConfig", "XLNetModel"),
("FlaubertConfig", "FlaubertModel"),
("FSMTConfig", "FSMTModel"),
("XLMConfig", "XLMModel"),
("CTRLConfig", "CTRLModel"),
("ElectraConfig", "ElectraModel"),
("ReformerConfig", "ReformerModel"),
("FunnelConfig", "('FunnelModel', 'FunnelBaseModel')"),
("LxmertConfig", "LxmertModel"),
("BertGenerationConfig", "BertGenerationEncoder"),
("DebertaConfig", "DebertaModel"),
("DebertaV2Config", "DebertaV2Model"),
("DPRConfig", "DPRQuestionEncoder"),
("XLMProphetNetConfig", "XLMProphetNetModel"),
("ProphetNetConfig", "ProphetNetModel"),
("MPNetConfig", "MPNetModel"),
("TapasConfig", "TapasModel"),
("IBertConfig", "IBertModel"),
]
)
MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
[
("RemBertConfig", "RemBertForMaskedLM"),
("RoFormerConfig", "RoFormerForMaskedLM"),
("BigBirdPegasusConfig", "BigBirdPegasusForConditionalGeneration"),
("GPTNeoConfig", "GPTNeoForCausalLM"),
("BigBirdConfig", "BigBirdForMaskedLM"),
("Speech2TextConfig", "Speech2TextForConditionalGeneration"),
("Wav2Vec2Config", "Wav2Vec2ForMaskedLM"),
("M2M100Config", "M2M100ForConditionalGeneration"),
("ConvBertConfig", "ConvBertForMaskedLM"),
("LEDConfig", "LEDForConditionalGeneration"),
("BlenderbotSmallConfig", "BlenderbotSmallForConditionalGeneration"),
("LayoutLMConfig", "LayoutLMForMaskedLM"),
("T5Config", "T5ForConditionalGeneration"),
("DistilBertConfig", "DistilBertForMaskedLM"),
("AlbertConfig", "AlbertForMaskedLM"),
("CamembertConfig", "CamembertForMaskedLM"),
("XLMRobertaConfig", "XLMRobertaForMaskedLM"),
("MarianConfig", "MarianMTModel"),
("FSMTConfig", "FSMTForConditionalGeneration"),
("BartConfig", "BartForConditionalGeneration"),
("LongformerConfig", "LongformerForMaskedLM"),
("RobertaConfig", "RobertaForMaskedLM"),
("SqueezeBertConfig", "SqueezeBertForMaskedLM"),
("BertConfig", "BertForMaskedLM"),
("OpenAIGPTConfig", "OpenAIGPTLMHeadModel"),
("GPT2Config", "GPT2LMHeadModel"),
("MegatronBertConfig", "MegatronBertForCausalLM"),
("MobileBertConfig", "MobileBertForMaskedLM"),
("TransfoXLConfig", "TransfoXLLMHeadModel"),
("XLNetConfig", "XLNetLMHeadModel"),
("FlaubertConfig", "FlaubertWithLMHeadModel"),
("XLMConfig", "XLMWithLMHeadModel"),
("CTRLConfig", "CTRLLMHeadModel"),
("ElectraConfig", "ElectraForMaskedLM"),
("EncoderDecoderConfig", "EncoderDecoderModel"),
("ReformerConfig", "ReformerModelWithLMHead"),
("FunnelConfig", "FunnelForMaskedLM"),
("MPNetConfig", "MPNetForMaskedLM"),
("TapasConfig", "TapasForMaskedLM"),
("DebertaConfig", "DebertaForMaskedLM"),
("DebertaV2Config", "DebertaV2ForMaskedLM"),
("IBertConfig", "IBertForMaskedLM"),
]
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment