Unverified Commit 9870093f authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

[WIP] Disentangle auto modules from other modeling files (#13023)

* Initial work

* All auto models

* All tf auto models

* All flax auto models

* Tokenizers

* Add feature extractors

* Fix typos

* Fix other typo

* Use the right config

* Remove old mapping names and update logic in AutoTokenizer

* Update check_table

* Fix copies and check_repo script

* Fix last test

* Add back name

* clean up

* Update template

* Update template

* Forgot a )

* Use alternative to fixup

* Fix TF model template

* Address review comments

* Address review comments

* Style
parent 2e408236
......@@ -59,7 +59,7 @@ jobs:
- name: Run style changes
run: |
git fetch origin master:master
make fixup
make style && make quality
- name: Failure short reports
if: ${{ always() }}
......
......@@ -30,7 +30,6 @@ deps_table_check_updated:
# autogenerating code
autogenerate_code: deps_table_update
python utils/class_mapping_update.py
# Check that source code meets quality standards
......
......@@ -213,6 +213,7 @@ _import_structure = {
"models.m2m_100": ["M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP", "M2M100Config"],
"models.marian": ["MarianConfig"],
"models.mbart": ["MBartConfig"],
"models.mbart50": [],
"models.megatron_bert": ["MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MegatronBertConfig"],
"models.mmbt": ["MMBTConfig"],
"models.mobilebert": ["MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MobileBertConfig", "MobileBertTokenizer"],
......@@ -315,7 +316,7 @@ if is_sentencepiece_available():
_import_structure["models.m2m_100"].append("M2M100Tokenizer")
_import_structure["models.marian"].append("MarianTokenizer")
_import_structure["models.mbart"].append("MBartTokenizer")
_import_structure["models.mbart"].append("MBart50Tokenizer")
_import_structure["models.mbart50"].append("MBart50Tokenizer")
_import_structure["models.mt5"].append("MT5Tokenizer")
_import_structure["models.pegasus"].append("PegasusTokenizer")
_import_structure["models.reformer"].append("ReformerTokenizer")
......@@ -358,7 +359,7 @@ if is_tokenizers_available():
_import_structure["models.longformer"].append("LongformerTokenizerFast")
_import_structure["models.lxmert"].append("LxmertTokenizerFast")
_import_structure["models.mbart"].append("MBartTokenizerFast")
_import_structure["models.mbart"].append("MBart50TokenizerFast")
_import_structure["models.mbart50"].append("MBart50TokenizerFast")
_import_structure["models.mobilebert"].append("MobileBertTokenizerFast")
_import_structure["models.mpnet"].append("MPNetTokenizerFast")
_import_structure["models.mt5"].append("MT5TokenizerFast")
......@@ -2021,7 +2022,8 @@ if TYPE_CHECKING:
from .models.led import LEDTokenizerFast
from .models.longformer import LongformerTokenizerFast
from .models.lxmert import LxmertTokenizerFast
from .models.mbart import MBart50TokenizerFast, MBartTokenizerFast
from .models.mbart import MBartTokenizerFast
from .models.mbart50 import MBart50TokenizerFast
from .models.mobilebert import MobileBertTokenizerFast
from .models.mpnet import MPNetTokenizerFast
from .models.mt5 import MT5TokenizerFast
......
......@@ -41,9 +41,7 @@ from .file_utils import (
is_tokenizers_available,
is_torch_available,
)
from .training_args import ParallelMode
from .utils import logging
from .utils.modeling_auto_mapping import (
from .models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
MODEL_FOR_MASKED_LM_MAPPING_NAMES,
......@@ -54,6 +52,8 @@ from .utils.modeling_auto_mapping import (
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
)
from .training_args import ParallelMode
from .utils import logging
TASK_MAPPING = {
......
......@@ -37,6 +37,7 @@ from . import (
cpm,
ctrl,
deberta,
deberta_v2,
deit,
detr,
dialogpt,
......@@ -50,6 +51,8 @@ from . import (
gpt2,
gpt_neo,
herbert,
hubert,
ibert,
layoutlm,
led,
longformer,
......@@ -58,6 +61,7 @@ from . import (
m2m_100,
marian,
mbart,
mbart50,
megatron_bert,
mmbt,
mobilebert,
......@@ -82,6 +86,7 @@ from . import (
vit,
wav2vec2,
xlm,
xlm_prophetnet,
xlm_roberta,
xlnet,
)
......@@ -13,11 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Factory function to build auto-model classes."""
import importlib
from collections import OrderedDict
from ...configuration_utils import PretrainedConfig
from ...file_utils import copy_func
from ...utils import logging
from .configuration_auto import AutoConfig, replace_list_option_in_docstrings
from .configuration_auto import AutoConfig, model_type_to_module_name, replace_list_option_in_docstrings
logger = logging.get_logger(__name__)
......@@ -415,7 +417,7 @@ def auto_class_update(cls, checkpoint_for_example="bert-base-cased", head_doc=""
from_config_docstring = from_config_docstring.replace("BaseAutoModelClass", name)
from_config_docstring = from_config_docstring.replace("checkpoint_placeholder", checkpoint_for_example)
from_config.__doc__ = from_config_docstring
from_config = replace_list_option_in_docstrings(model_mapping, use_model_types=False)(from_config)
from_config = replace_list_option_in_docstrings(model_mapping._model_mapping, use_model_types=False)(from_config)
cls.from_config = classmethod(from_config)
if name.startswith("TF"):
......@@ -431,7 +433,7 @@ def auto_class_update(cls, checkpoint_for_example="bert-base-cased", head_doc=""
shortcut = checkpoint_for_example.split("/")[-1].split("-")[0]
from_pretrained_docstring = from_pretrained_docstring.replace("shortcut_placeholder", shortcut)
from_pretrained.__doc__ = from_pretrained_docstring
from_pretrained = replace_list_option_in_docstrings(model_mapping)(from_pretrained)
from_pretrained = replace_list_option_in_docstrings(model_mapping._model_mapping)(from_pretrained)
cls.from_pretrained = classmethod(from_pretrained)
return cls
......@@ -445,3 +447,79 @@ def get_values(model_mapping):
result.append(model)
return result
def getattribute_from_module(module, attr):
if attr is None:
return None
if isinstance(attr, tuple):
return tuple(getattribute_from_module(module, a) for a in attr)
if hasattr(module, attr):
return getattr(module, attr)
# Some of the mappings have entries model_type -> object of another model type. In that case we try to grab the
# object at the top level.
transformers_module = importlib.import_module("transformers")
return getattribute_from_module(transformers_module, attr)
class _LazyAutoMapping(OrderedDict):
"""
" A mapping config to object (model or tokenizer for instance) that will load keys and values when it is accessed.
Args:
- config_mapping: The map model type to config class
- model_mapping: The map model type to model (or tokenizer) class
"""
def __init__(self, config_mapping, model_mapping):
self._config_mapping = config_mapping
self._reverse_config_mapping = {v: k for k, v in config_mapping.items()}
self._model_mapping = model_mapping
self._modules = {}
def __getitem__(self, key):
model_type = self._reverse_config_mapping[key.__name__]
if model_type not in self._model_mapping:
raise KeyError(key)
model_name = self._model_mapping[model_type]
return self._load_attr_from_module(model_type, model_name)
def _load_attr_from_module(self, model_type, attr):
module_name = model_type_to_module_name(model_type)
if module_name not in self._modules:
self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models")
return getattribute_from_module(self._modules[module_name], attr)
def keys(self):
return [
self._load_attr_from_module(key, name)
for key, name in self._config_mapping.items()
if key in self._model_mapping.keys()
]
def values(self):
return [
self._load_attr_from_module(key, name)
for key, name in self._model_mapping.items()
if key in self._config_mapping.keys()
]
def items(self):
return [
(
self._load_attr_from_module(key, self._config_mapping[key]),
self._load_attr_from_module(key, self._model_mapping[key]),
)
for key in self._model_mapping.keys()
if key in self._config_mapping.keys()
]
def __iter__(self):
return iter(self._mapping.keys())
def __contains__(self, item):
if not hasattr(item, "__name__") or item.__name__ not in self._reverse_config_mapping:
return False
model_type = self._reverse_config_mapping[item.__name__]
return model_type in self._model_mapping
......@@ -13,215 +13,140 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" Auto Config class. """
import importlib
import re
import warnings
from collections import OrderedDict
from typing import List, Union
from ...configuration_utils import PretrainedConfig
from ..albert.configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
from ..bart.configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig
from ..beit.configuration_beit import BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, BeitConfig
from ..bert.configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
from ..bert_generation.configuration_bert_generation import BertGenerationConfig
from ..big_bird.configuration_big_bird import BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP, BigBirdConfig
from ..bigbird_pegasus.configuration_bigbird_pegasus import (
BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP,
BigBirdPegasusConfig,
)
from ..blenderbot.configuration_blenderbot import BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP, BlenderbotConfig
from ..blenderbot_small.configuration_blenderbot_small import (
BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP,
BlenderbotSmallConfig,
)
from ..camembert.configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
from ..canine.configuration_canine import CANINE_PRETRAINED_CONFIG_ARCHIVE_MAP, CanineConfig
from ..clip.configuration_clip import CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP, CLIPConfig
from ..convbert.configuration_convbert import CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvBertConfig
from ..ctrl.configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
from ..deberta.configuration_deberta import DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaConfig
from ..deberta_v2.configuration_deberta_v2 import DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaV2Config
from ..deit.configuration_deit import DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, DeiTConfig
from ..detr.configuration_detr import DETR_PRETRAINED_CONFIG_ARCHIVE_MAP, DetrConfig
from ..distilbert.configuration_distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig
from ..dpr.configuration_dpr import DPR_PRETRAINED_CONFIG_ARCHIVE_MAP, DPRConfig
from ..electra.configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig
from ..encoder_decoder.configuration_encoder_decoder import EncoderDecoderConfig
from ..flaubert.configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig
from ..fsmt.configuration_fsmt import FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP, FSMTConfig
from ..funnel.configuration_funnel import FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP, FunnelConfig
from ..gpt2.configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
from ..gpt_neo.configuration_gpt_neo import GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoConfig
from ..hubert.configuration_hubert import HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, HubertConfig
from ..ibert.configuration_ibert import IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, IBertConfig
from ..layoutlm.configuration_layoutlm import LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMConfig
from ..led.configuration_led import LED_PRETRAINED_CONFIG_ARCHIVE_MAP, LEDConfig
from ..longformer.configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig
from ..luke.configuration_luke import LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP, LukeConfig
from ..lxmert.configuration_lxmert import LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP, LxmertConfig
from ..m2m_100.configuration_m2m_100 import M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP, M2M100Config
from ..marian.configuration_marian import MarianConfig
from ..mbart.configuration_mbart import MBART_PRETRAINED_CONFIG_ARCHIVE_MAP, MBartConfig
from ..megatron_bert.configuration_megatron_bert import MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MegatronBertConfig
from ..mobilebert.configuration_mobilebert import MobileBertConfig
from ..mpnet.configuration_mpnet import MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP, MPNetConfig
from ..mt5.configuration_mt5 import MT5Config
from ..openai.configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
from ..pegasus.configuration_pegasus import PegasusConfig
from ..prophetnet.configuration_prophetnet import PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ProphetNetConfig
from ..rag.configuration_rag import RagConfig
from ..reformer.configuration_reformer import ReformerConfig
from ..rembert.configuration_rembert import REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RemBertConfig
from ..retribert.configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig
from ..roberta.configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
from ..roformer.configuration_roformer import ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, RoFormerConfig
from ..speech_to_text.configuration_speech_to_text import (
SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP,
Speech2TextConfig,
)
from ..squeezebert.configuration_squeezebert import SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, SqueezeBertConfig
from ..t5.configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
from ..tapas.configuration_tapas import TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, TapasConfig
from ..transfo_xl.configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig
from ..visual_bert.configuration_visual_bert import VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, VisualBertConfig
from ..vit.configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig
from ..wav2vec2.configuration_wav2vec2 import WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP, Wav2Vec2Config
from ..xlm.configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig
from ..xlm_prophetnet.configuration_xlm_prophetnet import (
XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLMProphetNetConfig,
)
from ..xlm_roberta.configuration_xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig
from ..xlnet.configuration_xlnet import XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNetConfig
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
(key, value)
for pretrained_map in [
# Add archive maps here
BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP,
REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
CANINE_PRETRAINED_CONFIG_ARCHIVE_MAP,
ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP,
BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP,
DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP,
LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP,
DETR_PRETRAINED_CONFIG_ARCHIVE_MAP,
GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP,
BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP,
MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP,
VIT_PRETRAINED_CONFIG_ARCHIVE_MAP,
WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP,
M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP,
CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
LED_PRETRAINED_CONFIG_ARCHIVE_MAP,
BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
BART_PRETRAINED_CONFIG_ARCHIVE_MAP,
BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP,
MBART_PRETRAINED_CONFIG_ARCHIVE_MAP,
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP,
ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP,
LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP,
LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
DPR_PRETRAINED_CONFIG_ARCHIVE_MAP,
DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP,
SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP,
IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
CONFIG_MAPPING_NAMES = OrderedDict(
[
# Add configs here
("beit", "BeitConfig"),
("rembert", "RemBertConfig"),
("visual_bert", "VisualBertConfig"),
("canine", "CanineConfig"),
("roformer", "RoFormerConfig"),
("clip", "CLIPConfig"),
("bigbird_pegasus", "BigBirdPegasusConfig"),
("deit", "DeiTConfig"),
("luke", "LukeConfig"),
("detr", "DetrConfig"),
("gpt_neo", "GPTNeoConfig"),
("big_bird", "BigBirdConfig"),
("speech_to_text", "Speech2TextConfig"),
("vit", "ViTConfig"),
("wav2vec2", "Wav2Vec2Config"),
("m2m_100", "M2M100Config"),
("convbert", "ConvBertConfig"),
("led", "LEDConfig"),
("blenderbot-small", "BlenderbotSmallConfig"),
("retribert", "RetriBertConfig"),
("ibert", "IBertConfig"),
("mt5", "MT5Config"),
("t5", "T5Config"),
("mobilebert", "MobileBertConfig"),
("distilbert", "DistilBertConfig"),
("albert", "AlbertConfig"),
("bert-generation", "BertGenerationConfig"),
("camembert", "CamembertConfig"),
("xlm-roberta", "XLMRobertaConfig"),
("pegasus", "PegasusConfig"),
("marian", "MarianConfig"),
("mbart", "MBartConfig"),
("megatron-bert", "MegatronBertConfig"),
("mpnet", "MPNetConfig"),
("bart", "BartConfig"),
("blenderbot", "BlenderbotConfig"),
("reformer", "ReformerConfig"),
("longformer", "LongformerConfig"),
("roberta", "RobertaConfig"),
("deberta-v2", "DebertaV2Config"),
("deberta", "DebertaConfig"),
("flaubert", "FlaubertConfig"),
("fsmt", "FSMTConfig"),
("squeezebert", "SqueezeBertConfig"),
("hubert", "HubertConfig"),
("bert", "BertConfig"),
("openai-gpt", "OpenAIGPTConfig"),
("gpt2", "GPT2Config"),
("transfo-xl", "TransfoXLConfig"),
("xlnet", "XLNetConfig"),
("xlm-prophetnet", "XLMProphetNetConfig"),
("prophetnet", "ProphetNetConfig"),
("xlm", "XLMConfig"),
("ctrl", "CTRLConfig"),
("electra", "ElectraConfig"),
("encoder-decoder", "EncoderDecoderConfig"),
("funnel", "FunnelConfig"),
("lxmert", "LxmertConfig"),
("dpr", "DPRConfig"),
("layoutlm", "LayoutLMConfig"),
("rag", "RagConfig"),
("tapas", "TapasConfig"),
]
for key, value, in pretrained_map.items()
)
CONFIG_MAPPING = OrderedDict(
CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
[
# Add configs here
("beit", BeitConfig),
("rembert", RemBertConfig),
("visual_bert", VisualBertConfig),
("canine", CanineConfig),
("roformer", RoFormerConfig),
("clip", CLIPConfig),
("bigbird_pegasus", BigBirdPegasusConfig),
("deit", DeiTConfig),
("luke", LukeConfig),
("detr", DetrConfig),
("gpt_neo", GPTNeoConfig),
("big_bird", BigBirdConfig),
("speech_to_text", Speech2TextConfig),
("vit", ViTConfig),
("wav2vec2", Wav2Vec2Config),
("m2m_100", M2M100Config),
("convbert", ConvBertConfig),
("led", LEDConfig),
("blenderbot-small", BlenderbotSmallConfig),
("retribert", RetriBertConfig),
("ibert", IBertConfig),
("mt5", MT5Config),
("t5", T5Config),
("mobilebert", MobileBertConfig),
("distilbert", DistilBertConfig),
("albert", AlbertConfig),
("bert-generation", BertGenerationConfig),
("camembert", CamembertConfig),
("xlm-roberta", XLMRobertaConfig),
("pegasus", PegasusConfig),
("marian", MarianConfig),
("mbart", MBartConfig),
("megatron-bert", MegatronBertConfig),
("mpnet", MPNetConfig),
("bart", BartConfig),
("blenderbot", BlenderbotConfig),
("reformer", ReformerConfig),
("longformer", LongformerConfig),
("roberta", RobertaConfig),
("deberta-v2", DebertaV2Config),
("deberta", DebertaConfig),
("flaubert", FlaubertConfig),
("fsmt", FSMTConfig),
("squeezebert", SqueezeBertConfig),
("hubert", HubertConfig),
("bert", BertConfig),
("openai-gpt", OpenAIGPTConfig),
("gpt2", GPT2Config),
("transfo-xl", TransfoXLConfig),
("xlnet", XLNetConfig),
("xlm-prophetnet", XLMProphetNetConfig),
("prophetnet", ProphetNetConfig),
("xlm", XLMConfig),
("ctrl", CTRLConfig),
("electra", ElectraConfig),
("encoder-decoder", EncoderDecoderConfig),
("funnel", FunnelConfig),
("lxmert", LxmertConfig),
("dpr", DPRConfig),
("layoutlm", LayoutLMConfig),
("rag", RagConfig),
("tapas", TapasConfig),
# Add archive maps here
("beit", "BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("rembert", "REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("visual_bert", "VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("canine", "CANINE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("roformer", "ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("clip", "CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("bigbird_pegasus", "BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("deit", "DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("luke", "LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("detr", "DETR_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("gpt_neo", "GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("big_bird", "BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("megatron-bert", "MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("speech_to_text", "SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("vit", "VIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("wav2vec2", "WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("m2m_100", "M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("convbert", "CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("led", "LED_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("blenderbot-small", "BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("bert", "BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("bart", "BART_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("blenderbot", "BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("mbart", "MBART_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("openai-gpt", "OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("transfo-xl", "TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("gpt2", "GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("ctrl", "CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("xlnet", "XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("xlm", "XLM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("roberta", "ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("distilbert", "DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("albert", "ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("camembert", "CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("t5", "T5_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("xlm-roberta", "XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("flaubert", "FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("fsmt", "FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("electra", "ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("longformer", "LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("retribert", "RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("funnel", "FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("lxmert", "LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("layoutlm", "LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("dpr", "DPR_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("deberta", "DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("deberta-v2", "DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("squeezebert", "SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("xlm-prophetnet", "XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("prophetnet", "PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("mpnet", "MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("tapas", "TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("ibert", "IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("hubert", "HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
]
)
......@@ -290,14 +215,136 @@ MODEL_NAMES_MAPPING = OrderedDict(
("mpnet", "MPNet"),
("tapas", "TAPAS"),
("hubert", "Hubert"),
("barthez", "BARThez"),
("phobert", "PhoBERT"),
("cpm", "CPM"),
("bertweet", "Bertweet"),
("bert-japanese", "BertJapanese"),
("byt5", "ByT5"),
("mbart50", "mBART-50"),
]
)
SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict([("openai-gpt", "openai")])
def model_type_to_module_name(key):
"""Converts a config key to the corresponding module."""
# Special treatment
if key in SPECIAL_MODEL_TYPE_TO_MODULE_NAME:
return SPECIAL_MODEL_TYPE_TO_MODULE_NAME[key]
return key.replace("-", "_")
def config_class_to_model_type(config):
"""Converts a config class name to the corresponding model type"""
for key, cls in CONFIG_MAPPING_NAMES.items():
if cls == config:
return key
return None
class _LazyConfigMapping(OrderedDict):
"""
A dictionary that lazily load its values when they are requested.
"""
def __init__(self, mapping):
self._mapping = mapping
self._modules = {}
def __getitem__(self, key):
if key not in self._mapping:
raise KeyError(key)
value = self._mapping[key]
module_name = model_type_to_module_name(key)
if module_name not in self._modules:
self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models")
return getattr(self._modules[module_name], value)
def keys(self):
return self._mapping.keys()
def values(self):
return [self[k] for k in self._mapping.keys()]
def _get_class_name(model_class):
def items(self):
return [(k, self[k]) for k in self._mapping.keys()]
def __iter__(self):
return iter(self._mapping.keys())
def __contains__(self, item):
return item in self._mapping
CONFIG_MAPPING = _LazyConfigMapping(CONFIG_MAPPING_NAMES)
class _LazyLoadAllMappings(OrderedDict):
"""
A mapping that will load all pairs of key values at the first access (either by indexing, requestions keys, values,
etc.)
Args:
mapping: The mapping to load.
"""
def __init__(self, mapping):
self._mapping = mapping
self._initialized = False
self._data = {}
def _initialize(self):
if self._initialized:
return
warnings.warn(
"ALL_PRETRAINED_CONFIG_ARCHIVE_MAP is deprecated and will be removed in v5 of Transformers. "
"It does not contain all available model checkpoints, far from it. Checkout hf.co/models for that.",
FutureWarning,
)
for model_type, map_name in self._mapping.items():
module_name = model_type_to_module_name(model_type)
module = importlib.import_module(f".{module_name}", "transformers.models")
mapping = getattr(module, map_name)
self._data.update(mapping)
self._initialized = True
def __getitem__(self, key):
self._initialize()
return self._data[key]
def keys(self):
self._initialize()
return self._data.keys()
def values(self):
self._initialize()
return self._data.values()
def items(self):
self._initialize()
return self._data.keys()
def __iter__(self):
self._initialize()
return iter(self._data)
def __contains__(self, item):
self._initialize()
return item in self._data
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = _LazyLoadAllMappings(CONFIG_ARCHIVE_MAP_MAPPING_NAMES)
def _get_class_name(model_class: Union[str, List[str]]):
if isinstance(model_class, (list, tuple)):
return " or ".join([f":class:`~transformers.{c.__name__}`" for c in model_class])
return f":class:`~transformers.{model_class.__name__}`"
return " or ".join([f":class:`~transformers.{c}`" for c in model_class if c is not None])
return f":class:`~transformers.{model_class}`"
def _list_model_options(indent, config_to_class=None, use_model_types=True):
......@@ -306,23 +353,26 @@ def _list_model_options(indent, config_to_class=None, use_model_types=True):
if use_model_types:
if config_to_class is None:
model_type_to_name = {
model_type: f":class:`~transformers.{config.__name__}`"
for model_type, config in CONFIG_MAPPING.items()
model_type: f":class:`~transformers.{config}`" for model_type, config in CONFIG_MAPPING_NAMES.items()
}
else:
model_type_to_name = {
model_type: _get_class_name(config_to_class[config])
for model_type, config in CONFIG_MAPPING.items()
if config in config_to_class
model_type: _get_class_name(model_class)
for model_type, model_class in config_to_class.items()
if model_type in MODEL_NAMES_MAPPING
}
lines = [
f"{indent}- **{model_type}** -- {model_type_to_name[model_type]} ({MODEL_NAMES_MAPPING[model_type]} model)"
for model_type in sorted(model_type_to_name.keys())
]
else:
config_to_name = {config.__name__: _get_class_name(clas) for config, clas in config_to_class.items()}
config_to_name = {
CONFIG_MAPPING_NAMES[config]: _get_class_name(clas)
for config, clas in config_to_class.items()
if config in CONFIG_MAPPING_NAMES
}
config_to_model_name = {
config.__name__: MODEL_NAMES_MAPPING[model_type] for model_type, config in CONFIG_MAPPING.items()
config: MODEL_NAMES_MAPPING[model_type] for model_type, config in CONFIG_MAPPING_NAMES.items()
}
lines = [
f"{indent}- :class:`~transformers.{config_name}` configuration class: {config_to_name[config_name]} ({config_to_model_name[config_name]} model)"
......
......@@ -13,36 +13,43 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" AutoFeatureExtractor class. """
import importlib
import os
from collections import OrderedDict
from transformers import BeitFeatureExtractor, DeiTFeatureExtractor, Speech2TextFeatureExtractor, ViTFeatureExtractor
from ... import BeitConfig, DeiTConfig, PretrainedConfig, Speech2TextConfig, ViTConfig, Wav2Vec2Config
from ...feature_extraction_utils import FeatureExtractionMixin
# Build the list of all feature extractors
from ...configuration_utils import PretrainedConfig
from ...feature_extraction_utils import FeatureExtractionMixin
from ...file_utils import FEATURE_EXTRACTOR_NAME
from ..wav2vec2.feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor
from .configuration_auto import AutoConfig, replace_list_option_in_docstrings
from .auto_factory import _LazyAutoMapping
from .configuration_auto import (
CONFIG_MAPPING_NAMES,
AutoConfig,
config_class_to_model_type,
replace_list_option_in_docstrings,
)
FEATURE_EXTRACTOR_MAPPING = OrderedDict(
FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
[
(BeitConfig, BeitFeatureExtractor),
(DeiTConfig, DeiTFeatureExtractor),
(Speech2TextConfig, Speech2TextFeatureExtractor),
(ViTConfig, ViTFeatureExtractor),
(Wav2Vec2Config, Wav2Vec2FeatureExtractor),
("beit", "BeitFeatureExtractor"),
("deit", "DeiTFeatureExtractor"),
("speech_to_text", "Speech2TextFeatureExtractor"),
("vit", "ViTFeatureExtractor"),
("wav2vec2", "Wav2Vec2FeatureExtractor"),
]
)
FEATURE_EXTRACTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FEATURE_EXTRACTOR_MAPPING_NAMES)
def feature_extractor_class_from_name(class_name: str):
for c in FEATURE_EXTRACTOR_MAPPING.values():
if c is not None and c.__name__ == class_name:
return c
for module_name, extractors in FEATURE_EXTRACTOR_MAPPING_NAMES.items():
if class_name in extractors:
break
module = importlib.import_module(f".{module_name}", "transformers.models")
return getattr(module, class_name)
class AutoFeatureExtractor:
......@@ -60,7 +67,7 @@ class AutoFeatureExtractor:
)
@classmethod
@replace_list_option_in_docstrings(FEATURE_EXTRACTOR_MAPPING)
@replace_list_option_in_docstrings(FEATURE_EXTRACTOR_MAPPING_NAMES)
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
r"""
Instantiate one of the feature extractor classes of the library from a pretrained model vocabulary.
......@@ -142,7 +149,8 @@ class AutoFeatureExtractor:
kwargs["_from_auto"] = True
config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
if type(config) in FEATURE_EXTRACTOR_MAPPING.keys():
model_type = config_class_to_model_type(type(config).__name__)
if model_type is not None:
return FEATURE_EXTRACTOR_MAPPING[type(config)].from_dict(config_dict, **kwargs)
elif "feature_extractor_type" in config_dict:
feature_extractor_class = feature_extractor_class_from_name(config_dict["feature_extractor_type"])
......
......@@ -14,793 +14,454 @@
# limitations under the License.
""" Auto Model class. """
import warnings
from collections import OrderedDict
from ...utils import logging
# Add modeling imports here
from ..albert.modeling_albert import (
AlbertForMaskedLM,
AlbertForMultipleChoice,
AlbertForPreTraining,
AlbertForQuestionAnswering,
AlbertForSequenceClassification,
AlbertForTokenClassification,
AlbertModel,
)
from ..bart.modeling_bart import (
BartForCausalLM,
BartForConditionalGeneration,
BartForQuestionAnswering,
BartForSequenceClassification,
BartModel,
)
from ..beit.modeling_beit import BeitForImageClassification, BeitModel
from ..bert.modeling_bert import (
BertForMaskedLM,
BertForMultipleChoice,
BertForNextSentencePrediction,
BertForPreTraining,
BertForQuestionAnswering,
BertForSequenceClassification,
BertForTokenClassification,
BertLMHeadModel,
BertModel,
)
from ..bert_generation.modeling_bert_generation import BertGenerationDecoder, BertGenerationEncoder
from ..big_bird.modeling_big_bird import (
BigBirdForCausalLM,
BigBirdForMaskedLM,
BigBirdForMultipleChoice,
BigBirdForPreTraining,
BigBirdForQuestionAnswering,
BigBirdForSequenceClassification,
BigBirdForTokenClassification,
BigBirdModel,
)
from ..bigbird_pegasus.modeling_bigbird_pegasus import (
BigBirdPegasusForCausalLM,
BigBirdPegasusForConditionalGeneration,
BigBirdPegasusForQuestionAnswering,
BigBirdPegasusForSequenceClassification,
BigBirdPegasusModel,
)
from ..blenderbot.modeling_blenderbot import BlenderbotForCausalLM, BlenderbotForConditionalGeneration, BlenderbotModel
from ..blenderbot_small.modeling_blenderbot_small import (
BlenderbotSmallForCausalLM,
BlenderbotSmallForConditionalGeneration,
BlenderbotSmallModel,
)
from ..camembert.modeling_camembert import (
CamembertForCausalLM,
CamembertForMaskedLM,
CamembertForMultipleChoice,
CamembertForQuestionAnswering,
CamembertForSequenceClassification,
CamembertForTokenClassification,
CamembertModel,
)
from ..canine.modeling_canine import (
CanineForMultipleChoice,
CanineForQuestionAnswering,
CanineForSequenceClassification,
CanineForTokenClassification,
CanineModel,
)
from ..clip.modeling_clip import CLIPModel
from ..convbert.modeling_convbert import (
ConvBertForMaskedLM,
ConvBertForMultipleChoice,
ConvBertForQuestionAnswering,
ConvBertForSequenceClassification,
ConvBertForTokenClassification,
ConvBertModel,
)
from ..ctrl.modeling_ctrl import CTRLForSequenceClassification, CTRLLMHeadModel, CTRLModel
from ..deberta.modeling_deberta import (
DebertaForMaskedLM,
DebertaForQuestionAnswering,
DebertaForSequenceClassification,
DebertaForTokenClassification,
DebertaModel,
)
from ..deberta_v2.modeling_deberta_v2 import (
DebertaV2ForMaskedLM,
DebertaV2ForQuestionAnswering,
DebertaV2ForSequenceClassification,
DebertaV2ForTokenClassification,
DebertaV2Model,
)
from ..deit.modeling_deit import DeiTForImageClassification, DeiTForImageClassificationWithTeacher, DeiTModel
from ..detr.modeling_detr import DetrForObjectDetection, DetrModel
from ..distilbert.modeling_distilbert import (
DistilBertForMaskedLM,
DistilBertForMultipleChoice,
DistilBertForQuestionAnswering,
DistilBertForSequenceClassification,
DistilBertForTokenClassification,
DistilBertModel,
)
from ..dpr.modeling_dpr import DPRQuestionEncoder
from ..electra.modeling_electra import (
ElectraForMaskedLM,
ElectraForMultipleChoice,
ElectraForPreTraining,
ElectraForQuestionAnswering,
ElectraForSequenceClassification,
ElectraForTokenClassification,
ElectraModel,
)
from ..encoder_decoder.modeling_encoder_decoder import EncoderDecoderModel
from ..flaubert.modeling_flaubert import (
FlaubertForMultipleChoice,
FlaubertForQuestionAnsweringSimple,
FlaubertForSequenceClassification,
FlaubertForTokenClassification,
FlaubertModel,
FlaubertWithLMHeadModel,
)
from ..fsmt.modeling_fsmt import FSMTForConditionalGeneration, FSMTModel
from ..funnel.modeling_funnel import (
FunnelBaseModel,
FunnelForMaskedLM,
FunnelForMultipleChoice,
FunnelForPreTraining,
FunnelForQuestionAnswering,
FunnelForSequenceClassification,
FunnelForTokenClassification,
FunnelModel,
)
from ..gpt2.modeling_gpt2 import GPT2ForSequenceClassification, GPT2LMHeadModel, GPT2Model
from ..gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM, GPTNeoForSequenceClassification, GPTNeoModel
from ..hubert.modeling_hubert import HubertModel
from ..ibert.modeling_ibert import (
IBertForMaskedLM,
IBertForMultipleChoice,
IBertForQuestionAnswering,
IBertForSequenceClassification,
IBertForTokenClassification,
IBertModel,
)
from ..layoutlm.modeling_layoutlm import (
LayoutLMForMaskedLM,
LayoutLMForSequenceClassification,
LayoutLMForTokenClassification,
LayoutLMModel,
)
from ..led.modeling_led import (
LEDForConditionalGeneration,
LEDForQuestionAnswering,
LEDForSequenceClassification,
LEDModel,
)
from ..longformer.modeling_longformer import (
LongformerForMaskedLM,
LongformerForMultipleChoice,
LongformerForQuestionAnswering,
LongformerForSequenceClassification,
LongformerForTokenClassification,
LongformerModel,
)
from ..luke.modeling_luke import LukeModel
from ..lxmert.modeling_lxmert import LxmertForPreTraining, LxmertForQuestionAnswering, LxmertModel
from ..m2m_100.modeling_m2m_100 import M2M100ForConditionalGeneration, M2M100Model
from ..marian.modeling_marian import MarianForCausalLM, MarianModel, MarianMTModel
from ..mbart.modeling_mbart import (
MBartForCausalLM,
MBartForConditionalGeneration,
MBartForQuestionAnswering,
MBartForSequenceClassification,
MBartModel,
)
from ..megatron_bert.modeling_megatron_bert import (
MegatronBertForCausalLM,
MegatronBertForMaskedLM,
MegatronBertForMultipleChoice,
MegatronBertForNextSentencePrediction,
MegatronBertForPreTraining,
MegatronBertForQuestionAnswering,
MegatronBertForSequenceClassification,
MegatronBertForTokenClassification,
MegatronBertModel,
)
from ..mobilebert.modeling_mobilebert import (
MobileBertForMaskedLM,
MobileBertForMultipleChoice,
MobileBertForNextSentencePrediction,
MobileBertForPreTraining,
MobileBertForQuestionAnswering,
MobileBertForSequenceClassification,
MobileBertForTokenClassification,
MobileBertModel,
)
from ..mpnet.modeling_mpnet import (
MPNetForMaskedLM,
MPNetForMultipleChoice,
MPNetForQuestionAnswering,
MPNetForSequenceClassification,
MPNetForTokenClassification,
MPNetModel,
)
from ..mt5.modeling_mt5 import MT5ForConditionalGeneration, MT5Model
from ..openai.modeling_openai import OpenAIGPTForSequenceClassification, OpenAIGPTLMHeadModel, OpenAIGPTModel
from ..pegasus.modeling_pegasus import PegasusForCausalLM, PegasusForConditionalGeneration, PegasusModel
from ..prophetnet.modeling_prophetnet import ProphetNetForCausalLM, ProphetNetForConditionalGeneration, ProphetNetModel
from ..rag.modeling_rag import ( # noqa: F401 - need to import all RagModels to be in globals() function
RagModel,
RagSequenceForGeneration,
RagTokenForGeneration,
)
from ..reformer.modeling_reformer import (
ReformerForMaskedLM,
ReformerForQuestionAnswering,
ReformerForSequenceClassification,
ReformerModel,
ReformerModelWithLMHead,
)
from ..rembert.modeling_rembert import (
RemBertForCausalLM,
RemBertForMaskedLM,
RemBertForMultipleChoice,
RemBertForQuestionAnswering,
RemBertForSequenceClassification,
RemBertForTokenClassification,
RemBertModel,
)
from ..retribert.modeling_retribert import RetriBertModel
from ..roberta.modeling_roberta import (
RobertaForCausalLM,
RobertaForMaskedLM,
RobertaForMultipleChoice,
RobertaForQuestionAnswering,
RobertaForSequenceClassification,
RobertaForTokenClassification,
RobertaModel,
)
from ..roformer.modeling_roformer import (
RoFormerForCausalLM,
RoFormerForMaskedLM,
RoFormerForMultipleChoice,
RoFormerForQuestionAnswering,
RoFormerForSequenceClassification,
RoFormerForTokenClassification,
RoFormerModel,
)
from ..speech_to_text.modeling_speech_to_text import Speech2TextForConditionalGeneration, Speech2TextModel
from ..squeezebert.modeling_squeezebert import (
SqueezeBertForMaskedLM,
SqueezeBertForMultipleChoice,
SqueezeBertForQuestionAnswering,
SqueezeBertForSequenceClassification,
SqueezeBertForTokenClassification,
SqueezeBertModel,
)
from ..t5.modeling_t5 import T5ForConditionalGeneration, T5Model
from ..tapas.modeling_tapas import (
TapasForMaskedLM,
TapasForQuestionAnswering,
TapasForSequenceClassification,
TapasModel,
)
from ..transfo_xl.modeling_transfo_xl import TransfoXLForSequenceClassification, TransfoXLLMHeadModel, TransfoXLModel
from ..visual_bert.modeling_visual_bert import VisualBertForPreTraining, VisualBertModel
from ..vit.modeling_vit import ViTForImageClassification, ViTModel
from ..wav2vec2.modeling_wav2vec2 import Wav2Vec2ForMaskedLM, Wav2Vec2ForPreTraining, Wav2Vec2Model
from ..xlm.modeling_xlm import (
XLMForMultipleChoice,
XLMForQuestionAnsweringSimple,
XLMForSequenceClassification,
XLMForTokenClassification,
XLMModel,
XLMWithLMHeadModel,
)
from ..xlm_prophetnet.modeling_xlm_prophetnet import (
XLMProphetNetForCausalLM,
XLMProphetNetForConditionalGeneration,
XLMProphetNetModel,
)
from ..xlm_roberta.modeling_xlm_roberta import (
XLMRobertaForCausalLM,
XLMRobertaForMaskedLM,
XLMRobertaForMultipleChoice,
XLMRobertaForQuestionAnswering,
XLMRobertaForSequenceClassification,
XLMRobertaForTokenClassification,
XLMRobertaModel,
)
from ..xlnet.modeling_xlnet import (
XLNetForMultipleChoice,
XLNetForQuestionAnsweringSimple,
XLNetForSequenceClassification,
XLNetForTokenClassification,
XLNetLMHeadModel,
XLNetModel,
)
from .auto_factory import _BaseAutoModelClass, auto_class_update
from .configuration_auto import (
AlbertConfig,
BartConfig,
BeitConfig,
BertConfig,
BertGenerationConfig,
BigBirdConfig,
BigBirdPegasusConfig,
BlenderbotConfig,
BlenderbotSmallConfig,
CamembertConfig,
CanineConfig,
CLIPConfig,
ConvBertConfig,
CTRLConfig,
DebertaConfig,
DebertaV2Config,
DeiTConfig,
DetrConfig,
DistilBertConfig,
DPRConfig,
ElectraConfig,
EncoderDecoderConfig,
FlaubertConfig,
FSMTConfig,
FunnelConfig,
GPT2Config,
GPTNeoConfig,
HubertConfig,
IBertConfig,
LayoutLMConfig,
LEDConfig,
LongformerConfig,
LukeConfig,
LxmertConfig,
M2M100Config,
MarianConfig,
MBartConfig,
MegatronBertConfig,
MobileBertConfig,
MPNetConfig,
MT5Config,
OpenAIGPTConfig,
PegasusConfig,
ProphetNetConfig,
ReformerConfig,
RemBertConfig,
RetriBertConfig,
RobertaConfig,
RoFormerConfig,
Speech2TextConfig,
SqueezeBertConfig,
T5Config,
TapasConfig,
TransfoXLConfig,
VisualBertConfig,
ViTConfig,
Wav2Vec2Config,
XLMConfig,
XLMProphetNetConfig,
XLMRobertaConfig,
XLNetConfig,
)
from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
from .configuration_auto import CONFIG_MAPPING_NAMES
logger = logging.get_logger(__name__)
MODEL_MAPPING = OrderedDict(
MODEL_MAPPING_NAMES = OrderedDict(
[
# Base model mapping
(BeitConfig, BeitModel),
(RemBertConfig, RemBertModel),
(VisualBertConfig, VisualBertModel),
(CanineConfig, CanineModel),
(RoFormerConfig, RoFormerModel),
(CLIPConfig, CLIPModel),
(BigBirdPegasusConfig, BigBirdPegasusModel),
(DeiTConfig, DeiTModel),
(LukeConfig, LukeModel),
(DetrConfig, DetrModel),
(GPTNeoConfig, GPTNeoModel),
(BigBirdConfig, BigBirdModel),
(Speech2TextConfig, Speech2TextModel),
(ViTConfig, ViTModel),
(Wav2Vec2Config, Wav2Vec2Model),
(HubertConfig, HubertModel),
(M2M100Config, M2M100Model),
(ConvBertConfig, ConvBertModel),
(LEDConfig, LEDModel),
(BlenderbotSmallConfig, BlenderbotSmallModel),
(RetriBertConfig, RetriBertModel),
(MT5Config, MT5Model),
(T5Config, T5Model),
(PegasusConfig, PegasusModel),
(MarianConfig, MarianMTModel),
(MBartConfig, MBartModel),
(BlenderbotConfig, BlenderbotModel),
(DistilBertConfig, DistilBertModel),
(AlbertConfig, AlbertModel),
(CamembertConfig, CamembertModel),
(XLMRobertaConfig, XLMRobertaModel),
(BartConfig, BartModel),
(LongformerConfig, LongformerModel),
(RobertaConfig, RobertaModel),
(LayoutLMConfig, LayoutLMModel),
(SqueezeBertConfig, SqueezeBertModel),
(BertConfig, BertModel),
(OpenAIGPTConfig, OpenAIGPTModel),
(GPT2Config, GPT2Model),
(MegatronBertConfig, MegatronBertModel),
(MobileBertConfig, MobileBertModel),
(TransfoXLConfig, TransfoXLModel),
(XLNetConfig, XLNetModel),
(FlaubertConfig, FlaubertModel),
(FSMTConfig, FSMTModel),
(XLMConfig, XLMModel),
(CTRLConfig, CTRLModel),
(ElectraConfig, ElectraModel),
(ReformerConfig, ReformerModel),
(FunnelConfig, (FunnelModel, FunnelBaseModel)),
(LxmertConfig, LxmertModel),
(BertGenerationConfig, BertGenerationEncoder),
(DebertaConfig, DebertaModel),
(DebertaV2Config, DebertaV2Model),
(DPRConfig, DPRQuestionEncoder),
(XLMProphetNetConfig, XLMProphetNetModel),
(ProphetNetConfig, ProphetNetModel),
(MPNetConfig, MPNetModel),
(TapasConfig, TapasModel),
(MarianConfig, MarianModel),
(IBertConfig, IBertModel),
("beit", "BeitModel"),
("rembert", "RemBertModel"),
("visual_bert", "VisualBertModel"),
("canine", "CanineModel"),
("roformer", "RoFormerModel"),
("clip", "CLIPModel"),
("bigbird_pegasus", "BigBirdPegasusModel"),
("deit", "DeiTModel"),
("luke", "LukeModel"),
("detr", "DetrModel"),
("gpt_neo", "GPTNeoModel"),
("big_bird", "BigBirdModel"),
("speech_to_text", "Speech2TextModel"),
("vit", "ViTModel"),
("wav2vec2", "Wav2Vec2Model"),
("hubert", "HubertModel"),
("m2m_100", "M2M100Model"),
("convbert", "ConvBertModel"),
("led", "LEDModel"),
("blenderbot-small", "BlenderbotSmallModel"),
("retribert", "RetriBertModel"),
("mt5", "MT5Model"),
("t5", "T5Model"),
("pegasus", "PegasusModel"),
("marian", "MarianModel"),
("mbart", "MBartModel"),
("blenderbot", "BlenderbotModel"),
("distilbert", "DistilBertModel"),
("albert", "AlbertModel"),
("camembert", "CamembertModel"),
("xlm-roberta", "XLMRobertaModel"),
("bart", "BartModel"),
("longformer", "LongformerModel"),
("roberta", "RobertaModel"),
("layoutlm", "LayoutLMModel"),
("squeezebert", "SqueezeBertModel"),
("bert", "BertModel"),
("openai-gpt", "OpenAIGPTModel"),
("gpt2", "GPT2Model"),
("megatron-bert", "MegatronBertModel"),
("mobilebert", "MobileBertModel"),
("transfo-xl", "TransfoXLModel"),
("xlnet", "XLNetModel"),
("flaubert", "FlaubertModel"),
("fsmt", "FSMTModel"),
("xlm", "XLMModel"),
("ctrl", "CTRLModel"),
("electra", "ElectraModel"),
("reformer", "ReformerModel"),
("funnel", ("FunnelModel", "FunnelBaseModel")),
("lxmert", "LxmertModel"),
("bert-generation", "BertGenerationEncoder"),
("deberta", "DebertaModel"),
("deberta-v2", "DebertaV2Model"),
("dpr", "DPRQuestionEncoder"),
("xlm-prophetnet", "XLMProphetNetModel"),
("prophetnet", "ProphetNetModel"),
("mpnet", "MPNetModel"),
("tapas", "TapasModel"),
("ibert", "IBertModel"),
]
)
MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
[
# Model for pre-training mapping
(VisualBertConfig, VisualBertForPreTraining),
(LayoutLMConfig, LayoutLMForMaskedLM),
(RetriBertConfig, RetriBertModel),
(T5Config, T5ForConditionalGeneration),
(DistilBertConfig, DistilBertForMaskedLM),
(AlbertConfig, AlbertForPreTraining),
(CamembertConfig, CamembertForMaskedLM),
(XLMRobertaConfig, XLMRobertaForMaskedLM),
(BartConfig, BartForConditionalGeneration),
(FSMTConfig, FSMTForConditionalGeneration),
(LongformerConfig, LongformerForMaskedLM),
(RobertaConfig, RobertaForMaskedLM),
(SqueezeBertConfig, SqueezeBertForMaskedLM),
(BertConfig, BertForPreTraining),
(BigBirdConfig, BigBirdForPreTraining),
(OpenAIGPTConfig, OpenAIGPTLMHeadModel),
(GPT2Config, GPT2LMHeadModel),
(MegatronBertConfig, MegatronBertForPreTraining),
(MobileBertConfig, MobileBertForPreTraining),
(TransfoXLConfig, TransfoXLLMHeadModel),
(XLNetConfig, XLNetLMHeadModel),
(FlaubertConfig, FlaubertWithLMHeadModel),
(XLMConfig, XLMWithLMHeadModel),
(CTRLConfig, CTRLLMHeadModel),
(ElectraConfig, ElectraForPreTraining),
(LxmertConfig, LxmertForPreTraining),
(FunnelConfig, FunnelForPreTraining),
(MPNetConfig, MPNetForMaskedLM),
(TapasConfig, TapasForMaskedLM),
(IBertConfig, IBertForMaskedLM),
(DebertaConfig, DebertaForMaskedLM),
(DebertaV2Config, DebertaV2ForMaskedLM),
(Wav2Vec2Config, Wav2Vec2ForPreTraining),
("visual_bert", "VisualBertForPreTraining"),
("layoutlm", "LayoutLMForMaskedLM"),
("retribert", "RetriBertModel"),
("t5", "T5ForConditionalGeneration"),
("distilbert", "DistilBertForMaskedLM"),
("albert", "AlbertForPreTraining"),
("camembert", "CamembertForMaskedLM"),
("xlm-roberta", "XLMRobertaForMaskedLM"),
("bart", "BartForConditionalGeneration"),
("fsmt", "FSMTForConditionalGeneration"),
("longformer", "LongformerForMaskedLM"),
("roberta", "RobertaForMaskedLM"),
("squeezebert", "SqueezeBertForMaskedLM"),
("bert", "BertForPreTraining"),
("big_bird", "BigBirdForPreTraining"),
("openai-gpt", "OpenAIGPTLMHeadModel"),
("gpt2", "GPT2LMHeadModel"),
("megatron-bert", "MegatronBertForPreTraining"),
("mobilebert", "MobileBertForPreTraining"),
("transfo-xl", "TransfoXLLMHeadModel"),
("xlnet", "XLNetLMHeadModel"),
("flaubert", "FlaubertWithLMHeadModel"),
("xlm", "XLMWithLMHeadModel"),
("ctrl", "CTRLLMHeadModel"),
("electra", "ElectraForPreTraining"),
("lxmert", "LxmertForPreTraining"),
("funnel", "FunnelForPreTraining"),
("mpnet", "MPNetForMaskedLM"),
("tapas", "TapasForMaskedLM"),
("ibert", "IBertForMaskedLM"),
("deberta", "DebertaForMaskedLM"),
("deberta-v2", "DebertaV2ForMaskedLM"),
("wav2vec2", "Wav2Vec2ForPreTraining"),
]
)
MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
[
# Model with LM heads mapping
(RemBertConfig, RemBertForMaskedLM),
(RoFormerConfig, RoFormerForMaskedLM),
(BigBirdPegasusConfig, BigBirdPegasusForConditionalGeneration),
(GPTNeoConfig, GPTNeoForCausalLM),
(BigBirdConfig, BigBirdForMaskedLM),
(Speech2TextConfig, Speech2TextForConditionalGeneration),
(Wav2Vec2Config, Wav2Vec2ForMaskedLM),
(M2M100Config, M2M100ForConditionalGeneration),
(ConvBertConfig, ConvBertForMaskedLM),
(LEDConfig, LEDForConditionalGeneration),
(BlenderbotSmallConfig, BlenderbotSmallForConditionalGeneration),
(LayoutLMConfig, LayoutLMForMaskedLM),
(T5Config, T5ForConditionalGeneration),
(DistilBertConfig, DistilBertForMaskedLM),
(AlbertConfig, AlbertForMaskedLM),
(CamembertConfig, CamembertForMaskedLM),
(XLMRobertaConfig, XLMRobertaForMaskedLM),
(MarianConfig, MarianMTModel),
(FSMTConfig, FSMTForConditionalGeneration),
(BartConfig, BartForConditionalGeneration),
(LongformerConfig, LongformerForMaskedLM),
(RobertaConfig, RobertaForMaskedLM),
(SqueezeBertConfig, SqueezeBertForMaskedLM),
(BertConfig, BertForMaskedLM),
(OpenAIGPTConfig, OpenAIGPTLMHeadModel),
(GPT2Config, GPT2LMHeadModel),
(MegatronBertConfig, MegatronBertForMaskedLM),
(MobileBertConfig, MobileBertForMaskedLM),
(TransfoXLConfig, TransfoXLLMHeadModel),
(XLNetConfig, XLNetLMHeadModel),
(FlaubertConfig, FlaubertWithLMHeadModel),
(XLMConfig, XLMWithLMHeadModel),
(CTRLConfig, CTRLLMHeadModel),
(ElectraConfig, ElectraForMaskedLM),
(EncoderDecoderConfig, EncoderDecoderModel),
(ReformerConfig, ReformerModelWithLMHead),
(FunnelConfig, FunnelForMaskedLM),
(MPNetConfig, MPNetForMaskedLM),
(TapasConfig, TapasForMaskedLM),
(DebertaConfig, DebertaForMaskedLM),
(DebertaV2Config, DebertaV2ForMaskedLM),
(IBertConfig, IBertForMaskedLM),
(MegatronBertConfig, MegatronBertForCausalLM),
("rembert", "RemBertForMaskedLM"),
("roformer", "RoFormerForMaskedLM"),
("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"),
("gpt_neo", "GPTNeoForCausalLM"),
("big_bird", "BigBirdForMaskedLM"),
("speech_to_text", "Speech2TextForConditionalGeneration"),
("wav2vec2", "Wav2Vec2ForMaskedLM"),
("m2m_100", "M2M100ForConditionalGeneration"),
("convbert", "ConvBertForMaskedLM"),
("led", "LEDForConditionalGeneration"),
("blenderbot-small", "BlenderbotSmallForConditionalGeneration"),
("layoutlm", "LayoutLMForMaskedLM"),
("t5", "T5ForConditionalGeneration"),
("distilbert", "DistilBertForMaskedLM"),
("albert", "AlbertForMaskedLM"),
("camembert", "CamembertForMaskedLM"),
("xlm-roberta", "XLMRobertaForMaskedLM"),
("marian", "MarianMTModel"),
("fsmt", "FSMTForConditionalGeneration"),
("bart", "BartForConditionalGeneration"),
("longformer", "LongformerForMaskedLM"),
("roberta", "RobertaForMaskedLM"),
("squeezebert", "SqueezeBertForMaskedLM"),
("bert", "BertForMaskedLM"),
("openai-gpt", "OpenAIGPTLMHeadModel"),
("gpt2", "GPT2LMHeadModel"),
("megatron-bert", "MegatronBertForCausalLM"),
("mobilebert", "MobileBertForMaskedLM"),
("transfo-xl", "TransfoXLLMHeadModel"),
("xlnet", "XLNetLMHeadModel"),
("flaubert", "FlaubertWithLMHeadModel"),
("xlm", "XLMWithLMHeadModel"),
("ctrl", "CTRLLMHeadModel"),
("electra", "ElectraForMaskedLM"),
("encoder-decoder", "EncoderDecoderModel"),
("reformer", "ReformerModelWithLMHead"),
("funnel", "FunnelForMaskedLM"),
("mpnet", "MPNetForMaskedLM"),
("tapas", "TapasForMaskedLM"),
("deberta", "DebertaForMaskedLM"),
("deberta-v2", "DebertaV2ForMaskedLM"),
("ibert", "IBertForMaskedLM"),
]
)
MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict(
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Causal LM mapping
(RemBertConfig, RemBertForCausalLM),
(RoFormerConfig, RoFormerForCausalLM),
(BigBirdPegasusConfig, BigBirdPegasusForCausalLM),
(GPTNeoConfig, GPTNeoForCausalLM),
(BigBirdConfig, BigBirdForCausalLM),
(CamembertConfig, CamembertForCausalLM),
(XLMRobertaConfig, XLMRobertaForCausalLM),
(RobertaConfig, RobertaForCausalLM),
(BertConfig, BertLMHeadModel),
(OpenAIGPTConfig, OpenAIGPTLMHeadModel),
(GPT2Config, GPT2LMHeadModel),
(TransfoXLConfig, TransfoXLLMHeadModel),
(XLNetConfig, XLNetLMHeadModel),
(
XLMConfig,
XLMWithLMHeadModel,
), # XLM can be MLM and CLM => model should be split similar to BERT; leave here for now
(CTRLConfig, CTRLLMHeadModel),
(ReformerConfig, ReformerModelWithLMHead),
(BertGenerationConfig, BertGenerationDecoder),
(XLMProphetNetConfig, XLMProphetNetForCausalLM),
(ProphetNetConfig, ProphetNetForCausalLM),
(BartConfig, BartForCausalLM),
(MBartConfig, MBartForCausalLM),
(PegasusConfig, PegasusForCausalLM),
(MarianConfig, MarianForCausalLM),
(BlenderbotConfig, BlenderbotForCausalLM),
(BlenderbotSmallConfig, BlenderbotSmallForCausalLM),
(MegatronBertConfig, MegatronBertForCausalLM),
("rembert", "RemBertForCausalLM"),
("roformer", "RoFormerForCausalLM"),
("bigbird_pegasus", "BigBirdPegasusForCausalLM"),
("gpt_neo", "GPTNeoForCausalLM"),
("big_bird", "BigBirdForCausalLM"),
("camembert", "CamembertForCausalLM"),
("xlm-roberta", "XLMRobertaForCausalLM"),
("roberta", "RobertaForCausalLM"),
("bert", "BertLMHeadModel"),
("openai-gpt", "OpenAIGPTLMHeadModel"),
("gpt2", "GPT2LMHeadModel"),
("transfo-xl", "TransfoXLLMHeadModel"),
("xlnet", "XLNetLMHeadModel"),
("xlm", "XLMWithLMHeadModel"),
("ctrl", "CTRLLMHeadModel"),
("reformer", "ReformerModelWithLMHead"),
("bert-generation", "BertGenerationDecoder"),
("xlm-prophetnet", "XLMProphetNetForCausalLM"),
("prophetnet", "ProphetNetForCausalLM"),
("bart", "BartForCausalLM"),
("mbart", "MBartForCausalLM"),
("pegasus", "PegasusForCausalLM"),
("marian", "MarianForCausalLM"),
("blenderbot", "BlenderbotForCausalLM"),
("blenderbot-small", "BlenderbotSmallForCausalLM"),
("megatron-bert", "MegatronBertForCausalLM"),
]
)
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = OrderedDict(
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Image Classification mapping
(ViTConfig, ViTForImageClassification),
(DeiTConfig, (DeiTForImageClassification, DeiTForImageClassificationWithTeacher)),
(BeitConfig, BeitForImageClassification),
("vit", "ViTForImageClassification"),
("deit", ("DeiTForImageClassification", "DeiTForImageClassificationWithTeacher")),
("beit", "BeitForImageClassification"),
]
)
MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Masked LM mapping
(RemBertConfig, RemBertForMaskedLM),
(RoFormerConfig, RoFormerForMaskedLM),
(BigBirdConfig, BigBirdForMaskedLM),
(Wav2Vec2Config, Wav2Vec2ForMaskedLM),
(ConvBertConfig, ConvBertForMaskedLM),
(LayoutLMConfig, LayoutLMForMaskedLM),
(DistilBertConfig, DistilBertForMaskedLM),
(AlbertConfig, AlbertForMaskedLM),
(BartConfig, BartForConditionalGeneration),
(MBartConfig, MBartForConditionalGeneration),
(CamembertConfig, CamembertForMaskedLM),
(XLMRobertaConfig, XLMRobertaForMaskedLM),
(LongformerConfig, LongformerForMaskedLM),
(RobertaConfig, RobertaForMaskedLM),
(SqueezeBertConfig, SqueezeBertForMaskedLM),
(BertConfig, BertForMaskedLM),
(MegatronBertConfig, MegatronBertForMaskedLM),
(MobileBertConfig, MobileBertForMaskedLM),
(FlaubertConfig, FlaubertWithLMHeadModel),
(XLMConfig, XLMWithLMHeadModel),
(ElectraConfig, ElectraForMaskedLM),
(ReformerConfig, ReformerForMaskedLM),
(FunnelConfig, FunnelForMaskedLM),
(MPNetConfig, MPNetForMaskedLM),
(TapasConfig, TapasForMaskedLM),
(DebertaConfig, DebertaForMaskedLM),
(DebertaV2Config, DebertaV2ForMaskedLM),
(IBertConfig, IBertForMaskedLM),
("rembert", "RemBertForMaskedLM"),
("roformer", "RoFormerForMaskedLM"),
("big_bird", "BigBirdForMaskedLM"),
("wav2vec2", "Wav2Vec2ForMaskedLM"),
("convbert", "ConvBertForMaskedLM"),
("layoutlm", "LayoutLMForMaskedLM"),
("distilbert", "DistilBertForMaskedLM"),
("albert", "AlbertForMaskedLM"),
("bart", "BartForConditionalGeneration"),
("mbart", "MBartForConditionalGeneration"),
("camembert", "CamembertForMaskedLM"),
("xlm-roberta", "XLMRobertaForMaskedLM"),
("longformer", "LongformerForMaskedLM"),
("roberta", "RobertaForMaskedLM"),
("squeezebert", "SqueezeBertForMaskedLM"),
("bert", "BertForMaskedLM"),
("megatron-bert", "MegatronBertForMaskedLM"),
("mobilebert", "MobileBertForMaskedLM"),
("flaubert", "FlaubertWithLMHeadModel"),
("xlm", "XLMWithLMHeadModel"),
("electra", "ElectraForMaskedLM"),
("reformer", "ReformerForMaskedLM"),
("funnel", "FunnelForMaskedLM"),
("mpnet", "MPNetForMaskedLM"),
("tapas", "TapasForMaskedLM"),
("deberta", "DebertaForMaskedLM"),
("deberta-v2", "DebertaV2ForMaskedLM"),
("ibert", "IBertForMaskedLM"),
]
)
MODEL_FOR_OBJECT_DETECTION_MAPPING = OrderedDict(
MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
[
# Model for Object Detection mapping
(DetrConfig, DetrForObjectDetection),
("detr", "DetrForObjectDetection"),
]
)
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Seq2Seq Causal LM mapping
(BigBirdPegasusConfig, BigBirdPegasusForConditionalGeneration),
(M2M100Config, M2M100ForConditionalGeneration),
(LEDConfig, LEDForConditionalGeneration),
(BlenderbotSmallConfig, BlenderbotSmallForConditionalGeneration),
(MT5Config, MT5ForConditionalGeneration),
(T5Config, T5ForConditionalGeneration),
(PegasusConfig, PegasusForConditionalGeneration),
(MarianConfig, MarianMTModel),
(MBartConfig, MBartForConditionalGeneration),
(BlenderbotConfig, BlenderbotForConditionalGeneration),
(BartConfig, BartForConditionalGeneration),
(FSMTConfig, FSMTForConditionalGeneration),
(EncoderDecoderConfig, EncoderDecoderModel),
(XLMProphetNetConfig, XLMProphetNetForConditionalGeneration),
(ProphetNetConfig, ProphetNetForConditionalGeneration),
("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"),
("m2m_100", "M2M100ForConditionalGeneration"),
("led", "LEDForConditionalGeneration"),
("blenderbot-small", "BlenderbotSmallForConditionalGeneration"),
("mt5", "MT5ForConditionalGeneration"),
("t5", "T5ForConditionalGeneration"),
("pegasus", "PegasusForConditionalGeneration"),
("marian", "MarianMTModel"),
("mbart", "MBartForConditionalGeneration"),
("blenderbot", "BlenderbotForConditionalGeneration"),
("bart", "BartForConditionalGeneration"),
("fsmt", "FSMTForConditionalGeneration"),
("encoder-decoder", "EncoderDecoderModel"),
("xlm-prophetnet", "XLMProphetNetForConditionalGeneration"),
("prophetnet", "ProphetNetForConditionalGeneration"),
]
)
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Sequence Classification mapping
(RemBertConfig, RemBertForSequenceClassification),
(CanineConfig, CanineForSequenceClassification),
(RoFormerConfig, RoFormerForSequenceClassification),
(BigBirdPegasusConfig, BigBirdPegasusForSequenceClassification),
(BigBirdConfig, BigBirdForSequenceClassification),
(ConvBertConfig, ConvBertForSequenceClassification),
(LEDConfig, LEDForSequenceClassification),
(DistilBertConfig, DistilBertForSequenceClassification),
(AlbertConfig, AlbertForSequenceClassification),
(CamembertConfig, CamembertForSequenceClassification),
(XLMRobertaConfig, XLMRobertaForSequenceClassification),
(MBartConfig, MBartForSequenceClassification),
(BartConfig, BartForSequenceClassification),
(LongformerConfig, LongformerForSequenceClassification),
(RobertaConfig, RobertaForSequenceClassification),
(SqueezeBertConfig, SqueezeBertForSequenceClassification),
(LayoutLMConfig, LayoutLMForSequenceClassification),
(BertConfig, BertForSequenceClassification),
(XLNetConfig, XLNetForSequenceClassification),
(MegatronBertConfig, MegatronBertForSequenceClassification),
(MobileBertConfig, MobileBertForSequenceClassification),
(FlaubertConfig, FlaubertForSequenceClassification),
(XLMConfig, XLMForSequenceClassification),
(ElectraConfig, ElectraForSequenceClassification),
(FunnelConfig, FunnelForSequenceClassification),
(DebertaConfig, DebertaForSequenceClassification),
(DebertaV2Config, DebertaV2ForSequenceClassification),
(GPT2Config, GPT2ForSequenceClassification),
(GPTNeoConfig, GPTNeoForSequenceClassification),
(OpenAIGPTConfig, OpenAIGPTForSequenceClassification),
(ReformerConfig, ReformerForSequenceClassification),
(CTRLConfig, CTRLForSequenceClassification),
(TransfoXLConfig, TransfoXLForSequenceClassification),
(MPNetConfig, MPNetForSequenceClassification),
(TapasConfig, TapasForSequenceClassification),
(IBertConfig, IBertForSequenceClassification),
("rembert", "RemBertForSequenceClassification"),
("canine", "CanineForSequenceClassification"),
("roformer", "RoFormerForSequenceClassification"),
("bigbird_pegasus", "BigBirdPegasusForSequenceClassification"),
("big_bird", "BigBirdForSequenceClassification"),
("convbert", "ConvBertForSequenceClassification"),
("led", "LEDForSequenceClassification"),
("distilbert", "DistilBertForSequenceClassification"),
("albert", "AlbertForSequenceClassification"),
("camembert", "CamembertForSequenceClassification"),
("xlm-roberta", "XLMRobertaForSequenceClassification"),
("mbart", "MBartForSequenceClassification"),
("bart", "BartForSequenceClassification"),
("longformer", "LongformerForSequenceClassification"),
("roberta", "RobertaForSequenceClassification"),
("squeezebert", "SqueezeBertForSequenceClassification"),
("layoutlm", "LayoutLMForSequenceClassification"),
("bert", "BertForSequenceClassification"),
("xlnet", "XLNetForSequenceClassification"),
("megatron-bert", "MegatronBertForSequenceClassification"),
("mobilebert", "MobileBertForSequenceClassification"),
("flaubert", "FlaubertForSequenceClassification"),
("xlm", "XLMForSequenceClassification"),
("electra", "ElectraForSequenceClassification"),
("funnel", "FunnelForSequenceClassification"),
("deberta", "DebertaForSequenceClassification"),
("deberta-v2", "DebertaV2ForSequenceClassification"),
("gpt2", "GPT2ForSequenceClassification"),
("gpt_neo", "GPTNeoForSequenceClassification"),
("openai-gpt", "OpenAIGPTForSequenceClassification"),
("reformer", "ReformerForSequenceClassification"),
("ctrl", "CTRLForSequenceClassification"),
("transfo-xl", "TransfoXLForSequenceClassification"),
("mpnet", "MPNetForSequenceClassification"),
("tapas", "TapasForSequenceClassification"),
("ibert", "IBertForSequenceClassification"),
]
)
MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
# Model for Question Answering mapping
(RemBertConfig, RemBertForQuestionAnswering),
(CanineConfig, CanineForQuestionAnswering),
(RoFormerConfig, RoFormerForQuestionAnswering),
(BigBirdPegasusConfig, BigBirdPegasusForQuestionAnswering),
(BigBirdConfig, BigBirdForQuestionAnswering),
(ConvBertConfig, ConvBertForQuestionAnswering),
(LEDConfig, LEDForQuestionAnswering),
(DistilBertConfig, DistilBertForQuestionAnswering),
(AlbertConfig, AlbertForQuestionAnswering),
(CamembertConfig, CamembertForQuestionAnswering),
(BartConfig, BartForQuestionAnswering),
(MBartConfig, MBartForQuestionAnswering),
(LongformerConfig, LongformerForQuestionAnswering),
(XLMRobertaConfig, XLMRobertaForQuestionAnswering),
(RobertaConfig, RobertaForQuestionAnswering),
(SqueezeBertConfig, SqueezeBertForQuestionAnswering),
(BertConfig, BertForQuestionAnswering),
(XLNetConfig, XLNetForQuestionAnsweringSimple),
(FlaubertConfig, FlaubertForQuestionAnsweringSimple),
(MegatronBertConfig, MegatronBertForQuestionAnswering),
(MobileBertConfig, MobileBertForQuestionAnswering),
(XLMConfig, XLMForQuestionAnsweringSimple),
(ElectraConfig, ElectraForQuestionAnswering),
(ReformerConfig, ReformerForQuestionAnswering),
(FunnelConfig, FunnelForQuestionAnswering),
(LxmertConfig, LxmertForQuestionAnswering),
(MPNetConfig, MPNetForQuestionAnswering),
(DebertaConfig, DebertaForQuestionAnswering),
(DebertaV2Config, DebertaV2ForQuestionAnswering),
(IBertConfig, IBertForQuestionAnswering),
("rembert", "RemBertForQuestionAnswering"),
("canine", "CanineForQuestionAnswering"),
("roformer", "RoFormerForQuestionAnswering"),
("bigbird_pegasus", "BigBirdPegasusForQuestionAnswering"),
("big_bird", "BigBirdForQuestionAnswering"),
("convbert", "ConvBertForQuestionAnswering"),
("led", "LEDForQuestionAnswering"),
("distilbert", "DistilBertForQuestionAnswering"),
("albert", "AlbertForQuestionAnswering"),
("camembert", "CamembertForQuestionAnswering"),
("bart", "BartForQuestionAnswering"),
("mbart", "MBartForQuestionAnswering"),
("longformer", "LongformerForQuestionAnswering"),
("xlm-roberta", "XLMRobertaForQuestionAnswering"),
("roberta", "RobertaForQuestionAnswering"),
("squeezebert", "SqueezeBertForQuestionAnswering"),
("bert", "BertForQuestionAnswering"),
("xlnet", "XLNetForQuestionAnsweringSimple"),
("flaubert", "FlaubertForQuestionAnsweringSimple"),
("megatron-bert", "MegatronBertForQuestionAnswering"),
("mobilebert", "MobileBertForQuestionAnswering"),
("xlm", "XLMForQuestionAnsweringSimple"),
("electra", "ElectraForQuestionAnswering"),
("reformer", "ReformerForQuestionAnswering"),
("funnel", "FunnelForQuestionAnswering"),
("lxmert", "LxmertForQuestionAnswering"),
("mpnet", "MPNetForQuestionAnswering"),
("deberta", "DebertaForQuestionAnswering"),
("deberta-v2", "DebertaV2ForQuestionAnswering"),
("ibert", "IBertForQuestionAnswering"),
]
)
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = OrderedDict(
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
# Model for Table Question Answering mapping
(TapasConfig, TapasForQuestionAnswering),
("tapas", "TapasForQuestionAnswering"),
]
)
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Token Classification mapping
(RemBertConfig, RemBertForTokenClassification),
(CanineConfig, CanineForTokenClassification),
(RoFormerConfig, RoFormerForTokenClassification),
(BigBirdConfig, BigBirdForTokenClassification),
(ConvBertConfig, ConvBertForTokenClassification),
(LayoutLMConfig, LayoutLMForTokenClassification),
(DistilBertConfig, DistilBertForTokenClassification),
(CamembertConfig, CamembertForTokenClassification),
(FlaubertConfig, FlaubertForTokenClassification),
(XLMConfig, XLMForTokenClassification),
(XLMRobertaConfig, XLMRobertaForTokenClassification),
(LongformerConfig, LongformerForTokenClassification),
(RobertaConfig, RobertaForTokenClassification),
(SqueezeBertConfig, SqueezeBertForTokenClassification),
(BertConfig, BertForTokenClassification),
(MegatronBertConfig, MegatronBertForTokenClassification),
(MobileBertConfig, MobileBertForTokenClassification),
(XLNetConfig, XLNetForTokenClassification),
(AlbertConfig, AlbertForTokenClassification),
(ElectraConfig, ElectraForTokenClassification),
(FlaubertConfig, FlaubertForTokenClassification),
(FunnelConfig, FunnelForTokenClassification),
(MPNetConfig, MPNetForTokenClassification),
(DebertaConfig, DebertaForTokenClassification),
(DebertaV2Config, DebertaV2ForTokenClassification),
(IBertConfig, IBertForTokenClassification),
("rembert", "RemBertForTokenClassification"),
("canine", "CanineForTokenClassification"),
("roformer", "RoFormerForTokenClassification"),
("big_bird", "BigBirdForTokenClassification"),
("convbert", "ConvBertForTokenClassification"),
("layoutlm", "LayoutLMForTokenClassification"),
("distilbert", "DistilBertForTokenClassification"),
("camembert", "CamembertForTokenClassification"),
("flaubert", "FlaubertForTokenClassification"),
("xlm", "XLMForTokenClassification"),
("xlm-roberta", "XLMRobertaForTokenClassification"),
("longformer", "LongformerForTokenClassification"),
("roberta", "RobertaForTokenClassification"),
("squeezebert", "SqueezeBertForTokenClassification"),
("bert", "BertForTokenClassification"),
("megatron-bert", "MegatronBertForTokenClassification"),
("mobilebert", "MobileBertForTokenClassification"),
("xlnet", "XLNetForTokenClassification"),
("albert", "AlbertForTokenClassification"),
("electra", "ElectraForTokenClassification"),
("funnel", "FunnelForTokenClassification"),
("mpnet", "MPNetForTokenClassification"),
("deberta", "DebertaForTokenClassification"),
("deberta-v2", "DebertaV2ForTokenClassification"),
("ibert", "IBertForTokenClassification"),
]
)
MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
[
# Model for Multiple Choice mapping
(RemBertConfig, RemBertForMultipleChoice),
(CanineConfig, CanineForMultipleChoice),
(RoFormerConfig, RoFormerForMultipleChoice),
(BigBirdConfig, BigBirdForMultipleChoice),
(ConvBertConfig, ConvBertForMultipleChoice),
(CamembertConfig, CamembertForMultipleChoice),
(ElectraConfig, ElectraForMultipleChoice),
(XLMRobertaConfig, XLMRobertaForMultipleChoice),
(LongformerConfig, LongformerForMultipleChoice),
(RobertaConfig, RobertaForMultipleChoice),
(SqueezeBertConfig, SqueezeBertForMultipleChoice),
(BertConfig, BertForMultipleChoice),
(DistilBertConfig, DistilBertForMultipleChoice),
(MegatronBertConfig, MegatronBertForMultipleChoice),
(MobileBertConfig, MobileBertForMultipleChoice),
(XLNetConfig, XLNetForMultipleChoice),
(AlbertConfig, AlbertForMultipleChoice),
(XLMConfig, XLMForMultipleChoice),
(FlaubertConfig, FlaubertForMultipleChoice),
(FunnelConfig, FunnelForMultipleChoice),
(MPNetConfig, MPNetForMultipleChoice),
(IBertConfig, IBertForMultipleChoice),
("rembert", "RemBertForMultipleChoice"),
("canine", "CanineForMultipleChoice"),
("roformer", "RoFormerForMultipleChoice"),
("big_bird", "BigBirdForMultipleChoice"),
("convbert", "ConvBertForMultipleChoice"),
("camembert", "CamembertForMultipleChoice"),
("electra", "ElectraForMultipleChoice"),
("xlm-roberta", "XLMRobertaForMultipleChoice"),
("longformer", "LongformerForMultipleChoice"),
("roberta", "RobertaForMultipleChoice"),
("squeezebert", "SqueezeBertForMultipleChoice"),
("bert", "BertForMultipleChoice"),
("distilbert", "DistilBertForMultipleChoice"),
("megatron-bert", "MegatronBertForMultipleChoice"),
("mobilebert", "MobileBertForMultipleChoice"),
("xlnet", "XLNetForMultipleChoice"),
("albert", "AlbertForMultipleChoice"),
("xlm", "XLMForMultipleChoice"),
("flaubert", "FlaubertForMultipleChoice"),
("funnel", "FunnelForMultipleChoice"),
("mpnet", "MPNetForMultipleChoice"),
("ibert", "IBertForMultipleChoice"),
]
)
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict(
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
[
(BertConfig, BertForNextSentencePrediction),
(MegatronBertConfig, MegatronBertForNextSentencePrediction),
(MobileBertConfig, MobileBertForNextSentencePrediction),
("bert", "BertForNextSentencePrediction"),
("megatron-bert", "MegatronBertForNextSentencePrediction"),
("mobilebert", "MobileBertForNextSentencePrediction"),
]
)
MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES)
MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES)
MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
)
MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES)
MODEL_FOR_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES)
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
)
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
)
MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
)
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES
)
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
)
MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES)
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
)
class AutoModel(_BaseAutoModelClass):
_model_mapping = MODEL_MAPPING
......
......@@ -18,207 +18,163 @@
from collections import OrderedDict
from ...utils import logging
from ..bart.modeling_flax_bart import (
FlaxBartForConditionalGeneration,
FlaxBartForQuestionAnswering,
FlaxBartForSequenceClassification,
FlaxBartModel,
)
from ..bert.modeling_flax_bert import (
FlaxBertForMaskedLM,
FlaxBertForMultipleChoice,
FlaxBertForNextSentencePrediction,
FlaxBertForPreTraining,
FlaxBertForQuestionAnswering,
FlaxBertForSequenceClassification,
FlaxBertForTokenClassification,
FlaxBertModel,
)
from ..big_bird.modeling_flax_big_bird import (
FlaxBigBirdForMaskedLM,
FlaxBigBirdForMultipleChoice,
FlaxBigBirdForPreTraining,
FlaxBigBirdForQuestionAnswering,
FlaxBigBirdForSequenceClassification,
FlaxBigBirdForTokenClassification,
FlaxBigBirdModel,
)
from ..clip.modeling_flax_clip import FlaxCLIPModel
from ..electra.modeling_flax_electra import (
FlaxElectraForMaskedLM,
FlaxElectraForMultipleChoice,
FlaxElectraForPreTraining,
FlaxElectraForQuestionAnswering,
FlaxElectraForSequenceClassification,
FlaxElectraForTokenClassification,
FlaxElectraModel,
)
from ..gpt2.modeling_flax_gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model
from ..gpt_neo.modeling_flax_gpt_neo import FlaxGPTNeoForCausalLM, FlaxGPTNeoModel
from ..marian.modeling_flax_marian import FlaxMarianModel, FlaxMarianMTModel
from ..mbart.modeling_flax_mbart import (
FlaxMBartForConditionalGeneration,
FlaxMBartForQuestionAnswering,
FlaxMBartForSequenceClassification,
FlaxMBartModel,
)
from ..mt5.modeling_flax_mt5 import FlaxMT5ForConditionalGeneration, FlaxMT5Model
from ..roberta.modeling_flax_roberta import (
FlaxRobertaForMaskedLM,
FlaxRobertaForMultipleChoice,
FlaxRobertaForQuestionAnswering,
FlaxRobertaForSequenceClassification,
FlaxRobertaForTokenClassification,
FlaxRobertaModel,
)
from ..t5.modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model
from ..vit.modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel
from ..wav2vec2.modeling_flax_wav2vec2 import FlaxWav2Vec2ForPreTraining, FlaxWav2Vec2Model
from .auto_factory import _BaseAutoModelClass, auto_class_update
from .configuration_auto import (
BartConfig,
BertConfig,
BigBirdConfig,
CLIPConfig,
ElectraConfig,
GPT2Config,
GPTNeoConfig,
MarianConfig,
MBartConfig,
MT5Config,
RobertaConfig,
T5Config,
ViTConfig,
Wav2Vec2Config,
)
from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
from .configuration_auto import CONFIG_MAPPING_NAMES
logger = logging.get_logger(__name__)
FLAX_MODEL_MAPPING = OrderedDict(
FLAX_MODEL_MAPPING_NAMES = OrderedDict(
[
# Base model mapping
(RobertaConfig, FlaxRobertaModel),
(BertConfig, FlaxBertModel),
(BigBirdConfig, FlaxBigBirdModel),
(BartConfig, FlaxBartModel),
(GPT2Config, FlaxGPT2Model),
(GPTNeoConfig, FlaxGPTNeoModel),
(ElectraConfig, FlaxElectraModel),
(CLIPConfig, FlaxCLIPModel),
(ViTConfig, FlaxViTModel),
(MBartConfig, FlaxMBartModel),
(T5Config, FlaxT5Model),
(MT5Config, FlaxMT5Model),
(Wav2Vec2Config, FlaxWav2Vec2Model),
(MarianConfig, FlaxMarianModel),
("roberta", "FlaxRobertaModel"),
("bert", "FlaxBertModel"),
("big_bird", "FlaxBigBirdModel"),
("bart", "FlaxBartModel"),
("gpt2", "FlaxGPT2Model"),
("gpt_neo", "FlaxGPTNeoModel"),
("electra", "FlaxElectraModel"),
("clip", "FlaxCLIPModel"),
("vit", "FlaxViTModel"),
("mbart", "FlaxMBartModel"),
("t5", "FlaxT5Model"),
("mt5", "FlaxMT5Model"),
("wav2vec2", "FlaxWav2Vec2Model"),
("marian", "FlaxMarianModel"),
]
)
FLAX_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
[
# Model for pre-training mapping
(RobertaConfig, FlaxRobertaForMaskedLM),
(BertConfig, FlaxBertForPreTraining),
(BigBirdConfig, FlaxBigBirdForPreTraining),
(BartConfig, FlaxBartForConditionalGeneration),
(ElectraConfig, FlaxElectraForPreTraining),
(MBartConfig, FlaxMBartForConditionalGeneration),
(T5Config, FlaxT5ForConditionalGeneration),
(MT5Config, FlaxMT5ForConditionalGeneration),
(Wav2Vec2Config, FlaxWav2Vec2ForPreTraining),
("roberta", "FlaxRobertaForMaskedLM"),
("bert", "FlaxBertForPreTraining"),
("big_bird", "FlaxBigBirdForPreTraining"),
("bart", "FlaxBartForConditionalGeneration"),
("electra", "FlaxElectraForPreTraining"),
("mbart", "FlaxMBartForConditionalGeneration"),
("t5", "FlaxT5ForConditionalGeneration"),
("mt5", "FlaxMT5ForConditionalGeneration"),
("wav2vec2", "FlaxWav2Vec2ForPreTraining"),
]
)
FLAX_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Masked LM mapping
(RobertaConfig, FlaxRobertaForMaskedLM),
(BertConfig, FlaxBertForMaskedLM),
(BigBirdConfig, FlaxBigBirdForMaskedLM),
(BartConfig, FlaxBartForConditionalGeneration),
(ElectraConfig, FlaxElectraForMaskedLM),
(MBartConfig, FlaxMBartForConditionalGeneration),
("roberta", "FlaxRobertaForMaskedLM"),
("bert", "FlaxBertForMaskedLM"),
("big_bird", "FlaxBigBirdForMaskedLM"),
("bart", "FlaxBartForConditionalGeneration"),
("electra", "FlaxElectraForMaskedLM"),
("mbart", "FlaxMBartForConditionalGeneration"),
]
)
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Seq2Seq Causal LM mapping
(BartConfig, FlaxBartForConditionalGeneration),
(T5Config, FlaxT5ForConditionalGeneration),
(MT5Config, FlaxMT5ForConditionalGeneration),
(MarianConfig, FlaxMarianMTModel),
("bart", "FlaxBartForConditionalGeneration"),
("t5", "FlaxT5ForConditionalGeneration"),
("mt5", "FlaxMT5ForConditionalGeneration"),
("marian", "FlaxMarianMTModel"),
]
)
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = OrderedDict(
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Image-classsification
(ViTConfig, FlaxViTForImageClassification),
("vit", "FlaxViTForImageClassification"),
]
)
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict(
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Causal LM mapping
(GPT2Config, FlaxGPT2LMHeadModel),
(GPTNeoConfig, FlaxGPTNeoForCausalLM),
("gpt2", "FlaxGPT2LMHeadModel"),
("gpt_neo", "FlaxGPTNeoForCausalLM"),
]
)
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Sequence Classification mapping
(RobertaConfig, FlaxRobertaForSequenceClassification),
(BertConfig, FlaxBertForSequenceClassification),
(BigBirdConfig, FlaxBigBirdForSequenceClassification),
(BartConfig, FlaxBartForSequenceClassification),
(ElectraConfig, FlaxElectraForSequenceClassification),
(MBartConfig, FlaxMBartForSequenceClassification),
("roberta", "FlaxRobertaForSequenceClassification"),
("bert", "FlaxBertForSequenceClassification"),
("big_bird", "FlaxBigBirdForSequenceClassification"),
("bart", "FlaxBartForSequenceClassification"),
("electra", "FlaxElectraForSequenceClassification"),
("mbart", "FlaxMBartForSequenceClassification"),
]
)
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
# Model for Question Answering mapping
(RobertaConfig, FlaxRobertaForQuestionAnswering),
(BertConfig, FlaxBertForQuestionAnswering),
(BigBirdConfig, FlaxBigBirdForQuestionAnswering),
(BartConfig, FlaxBartForQuestionAnswering),
(ElectraConfig, FlaxElectraForQuestionAnswering),
(MBartConfig, FlaxMBartForQuestionAnswering),
("roberta", "FlaxRobertaForQuestionAnswering"),
("bert", "FlaxBertForQuestionAnswering"),
("big_bird", "FlaxBigBirdForQuestionAnswering"),
("bart", "FlaxBartForQuestionAnswering"),
("electra", "FlaxElectraForQuestionAnswering"),
("mbart", "FlaxMBartForQuestionAnswering"),
]
)
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Token Classification mapping
(RobertaConfig, FlaxRobertaForTokenClassification),
(BertConfig, FlaxBertForTokenClassification),
(BigBirdConfig, FlaxBigBirdForTokenClassification),
(ElectraConfig, FlaxElectraForTokenClassification),
("roberta", "FlaxRobertaForTokenClassification"),
("bert", "FlaxBertForTokenClassification"),
("big_bird", "FlaxBigBirdForTokenClassification"),
("electra", "FlaxElectraForTokenClassification"),
]
)
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
[
# Model for Multiple Choice mapping
(RobertaConfig, FlaxRobertaForMultipleChoice),
(BertConfig, FlaxBertForMultipleChoice),
(BigBirdConfig, FlaxBigBirdForMultipleChoice),
(ElectraConfig, FlaxElectraForMultipleChoice),
("roberta", "FlaxRobertaForMultipleChoice"),
("bert", "FlaxBertForMultipleChoice"),
("big_bird", "FlaxBigBirdForMultipleChoice"),
("electra", "FlaxElectraForMultipleChoice"),
]
)
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict(
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
[
(BertConfig, FlaxBertForNextSentencePrediction),
("bert", "FlaxBertForNextSentencePrediction"),
]
)
FLAX_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_MAPPING_NAMES)
FLAX_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
FLAX_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES)
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
)
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
)
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
)
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
)
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
)
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES
)
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
)
class FlaxAutoModel(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_MAPPING
......
......@@ -19,492 +19,298 @@ import warnings
from collections import OrderedDict
from ...utils import logging
# Add modeling imports here
from ..albert.modeling_tf_albert import (
TFAlbertForMaskedLM,
TFAlbertForMultipleChoice,
TFAlbertForPreTraining,
TFAlbertForQuestionAnswering,
TFAlbertForSequenceClassification,
TFAlbertForTokenClassification,
TFAlbertModel,
)
from ..bart.modeling_tf_bart import TFBartForConditionalGeneration, TFBartModel
from ..bert.modeling_tf_bert import (
TFBertForMaskedLM,
TFBertForMultipleChoice,
TFBertForNextSentencePrediction,
TFBertForPreTraining,
TFBertForQuestionAnswering,
TFBertForSequenceClassification,
TFBertForTokenClassification,
TFBertLMHeadModel,
TFBertModel,
)
from ..blenderbot.modeling_tf_blenderbot import TFBlenderbotForConditionalGeneration, TFBlenderbotModel
from ..blenderbot_small.modeling_tf_blenderbot_small import (
TFBlenderbotSmallForConditionalGeneration,
TFBlenderbotSmallModel,
)
from ..camembert.modeling_tf_camembert import (
TFCamembertForMaskedLM,
TFCamembertForMultipleChoice,
TFCamembertForQuestionAnswering,
TFCamembertForSequenceClassification,
TFCamembertForTokenClassification,
TFCamembertModel,
)
from ..convbert.modeling_tf_convbert import (
TFConvBertForMaskedLM,
TFConvBertForMultipleChoice,
TFConvBertForQuestionAnswering,
TFConvBertForSequenceClassification,
TFConvBertForTokenClassification,
TFConvBertModel,
)
from ..ctrl.modeling_tf_ctrl import TFCTRLForSequenceClassification, TFCTRLLMHeadModel, TFCTRLModel
from ..distilbert.modeling_tf_distilbert import (
TFDistilBertForMaskedLM,
TFDistilBertForMultipleChoice,
TFDistilBertForQuestionAnswering,
TFDistilBertForSequenceClassification,
TFDistilBertForTokenClassification,
TFDistilBertModel,
)
from ..dpr.modeling_tf_dpr import TFDPRQuestionEncoder
from ..electra.modeling_tf_electra import (
TFElectraForMaskedLM,
TFElectraForMultipleChoice,
TFElectraForPreTraining,
TFElectraForQuestionAnswering,
TFElectraForSequenceClassification,
TFElectraForTokenClassification,
TFElectraModel,
)
from ..flaubert.modeling_tf_flaubert import (
TFFlaubertForMultipleChoice,
TFFlaubertForQuestionAnsweringSimple,
TFFlaubertForSequenceClassification,
TFFlaubertForTokenClassification,
TFFlaubertModel,
TFFlaubertWithLMHeadModel,
)
from ..funnel.modeling_tf_funnel import (
TFFunnelBaseModel,
TFFunnelForMaskedLM,
TFFunnelForMultipleChoice,
TFFunnelForPreTraining,
TFFunnelForQuestionAnswering,
TFFunnelForSequenceClassification,
TFFunnelForTokenClassification,
TFFunnelModel,
)
from ..gpt2.modeling_tf_gpt2 import TFGPT2ForSequenceClassification, TFGPT2LMHeadModel, TFGPT2Model
from ..hubert.modeling_tf_hubert import TFHubertModel
from ..layoutlm.modeling_tf_layoutlm import (
TFLayoutLMForMaskedLM,
TFLayoutLMForSequenceClassification,
TFLayoutLMForTokenClassification,
TFLayoutLMModel,
)
from ..led.modeling_tf_led import TFLEDForConditionalGeneration, TFLEDModel
from ..longformer.modeling_tf_longformer import (
TFLongformerForMaskedLM,
TFLongformerForMultipleChoice,
TFLongformerForQuestionAnswering,
TFLongformerForSequenceClassification,
TFLongformerForTokenClassification,
TFLongformerModel,
)
from ..lxmert.modeling_tf_lxmert import TFLxmertForPreTraining, TFLxmertModel
from ..marian.modeling_tf_marian import TFMarianModel, TFMarianMTModel
from ..mbart.modeling_tf_mbart import TFMBartForConditionalGeneration, TFMBartModel
from ..mobilebert.modeling_tf_mobilebert import (
TFMobileBertForMaskedLM,
TFMobileBertForMultipleChoice,
TFMobileBertForNextSentencePrediction,
TFMobileBertForPreTraining,
TFMobileBertForQuestionAnswering,
TFMobileBertForSequenceClassification,
TFMobileBertForTokenClassification,
TFMobileBertModel,
)
from ..mpnet.modeling_tf_mpnet import (
TFMPNetForMaskedLM,
TFMPNetForMultipleChoice,
TFMPNetForQuestionAnswering,
TFMPNetForSequenceClassification,
TFMPNetForTokenClassification,
TFMPNetModel,
)
from ..mt5.modeling_tf_mt5 import TFMT5ForConditionalGeneration, TFMT5Model
from ..openai.modeling_tf_openai import TFOpenAIGPTForSequenceClassification, TFOpenAIGPTLMHeadModel, TFOpenAIGPTModel
from ..pegasus.modeling_tf_pegasus import TFPegasusForConditionalGeneration, TFPegasusModel
from ..rembert.modeling_tf_rembert import (
TFRemBertForCausalLM,
TFRemBertForMaskedLM,
TFRemBertForMultipleChoice,
TFRemBertForQuestionAnswering,
TFRemBertForSequenceClassification,
TFRemBertForTokenClassification,
TFRemBertModel,
)
from ..roberta.modeling_tf_roberta import (
TFRobertaForMaskedLM,
TFRobertaForMultipleChoice,
TFRobertaForQuestionAnswering,
TFRobertaForSequenceClassification,
TFRobertaForTokenClassification,
TFRobertaModel,
)
from ..roformer.modeling_tf_roformer import (
TFRoFormerForCausalLM,
TFRoFormerForMaskedLM,
TFRoFormerForMultipleChoice,
TFRoFormerForQuestionAnswering,
TFRoFormerForSequenceClassification,
TFRoFormerForTokenClassification,
TFRoFormerModel,
)
from ..t5.modeling_tf_t5 import TFT5ForConditionalGeneration, TFT5Model
from ..transfo_xl.modeling_tf_transfo_xl import (
TFTransfoXLForSequenceClassification,
TFTransfoXLLMHeadModel,
TFTransfoXLModel,
)
from ..wav2vec2.modeling_tf_wav2vec2 import TFWav2Vec2Model
from ..xlm.modeling_tf_xlm import (
TFXLMForMultipleChoice,
TFXLMForQuestionAnsweringSimple,
TFXLMForSequenceClassification,
TFXLMForTokenClassification,
TFXLMModel,
TFXLMWithLMHeadModel,
)
from ..xlm_roberta.modeling_tf_xlm_roberta import (
TFXLMRobertaForMaskedLM,
TFXLMRobertaForMultipleChoice,
TFXLMRobertaForQuestionAnswering,
TFXLMRobertaForSequenceClassification,
TFXLMRobertaForTokenClassification,
TFXLMRobertaModel,
)
from ..xlnet.modeling_tf_xlnet import (
TFXLNetForMultipleChoice,
TFXLNetForQuestionAnsweringSimple,
TFXLNetForSequenceClassification,
TFXLNetForTokenClassification,
TFXLNetLMHeadModel,
TFXLNetModel,
)
from .auto_factory import _BaseAutoModelClass, auto_class_update
from .configuration_auto import (
AlbertConfig,
BartConfig,
BertConfig,
BlenderbotConfig,
BlenderbotSmallConfig,
CamembertConfig,
ConvBertConfig,
CTRLConfig,
DistilBertConfig,
DPRConfig,
ElectraConfig,
FlaubertConfig,
FunnelConfig,
GPT2Config,
HubertConfig,
LayoutLMConfig,
LEDConfig,
LongformerConfig,
LxmertConfig,
MarianConfig,
MBartConfig,
MobileBertConfig,
MPNetConfig,
MT5Config,
OpenAIGPTConfig,
PegasusConfig,
RemBertConfig,
RobertaConfig,
RoFormerConfig,
T5Config,
TransfoXLConfig,
Wav2Vec2Config,
XLMConfig,
XLMRobertaConfig,
XLNetConfig,
)
from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
from .configuration_auto import CONFIG_MAPPING_NAMES
logger = logging.get_logger(__name__)
TF_MODEL_MAPPING = OrderedDict(
TF_MODEL_MAPPING_NAMES = OrderedDict(
[
# Base model mapping
(RemBertConfig, TFRemBertModel),
(RoFormerConfig, TFRoFormerModel),
(ConvBertConfig, TFConvBertModel),
(LEDConfig, TFLEDModel),
(LxmertConfig, TFLxmertModel),
(MT5Config, TFMT5Model),
(T5Config, TFT5Model),
(DistilBertConfig, TFDistilBertModel),
(AlbertConfig, TFAlbertModel),
(BartConfig, TFBartModel),
(CamembertConfig, TFCamembertModel),
(XLMRobertaConfig, TFXLMRobertaModel),
(LongformerConfig, TFLongformerModel),
(RobertaConfig, TFRobertaModel),
(LayoutLMConfig, TFLayoutLMModel),
(BertConfig, TFBertModel),
(OpenAIGPTConfig, TFOpenAIGPTModel),
(GPT2Config, TFGPT2Model),
(MobileBertConfig, TFMobileBertModel),
(TransfoXLConfig, TFTransfoXLModel),
(XLNetConfig, TFXLNetModel),
(FlaubertConfig, TFFlaubertModel),
(XLMConfig, TFXLMModel),
(CTRLConfig, TFCTRLModel),
(ElectraConfig, TFElectraModel),
(FunnelConfig, (TFFunnelModel, TFFunnelBaseModel)),
(DPRConfig, TFDPRQuestionEncoder),
(MPNetConfig, TFMPNetModel),
(BartConfig, TFBartModel),
(MBartConfig, TFMBartModel),
(MarianConfig, TFMarianModel),
(PegasusConfig, TFPegasusModel),
(BlenderbotConfig, TFBlenderbotModel),
(BlenderbotSmallConfig, TFBlenderbotSmallModel),
(Wav2Vec2Config, TFWav2Vec2Model),
(HubertConfig, TFHubertModel),
("rembert", "TFRemBertModel"),
("roformer", "TFRoFormerModel"),
("convbert", "TFConvBertModel"),
("led", "TFLEDModel"),
("lxmert", "TFLxmertModel"),
("mt5", "TFMT5Model"),
("t5", "TFT5Model"),
("distilbert", "TFDistilBertModel"),
("albert", "TFAlbertModel"),
("bart", "TFBartModel"),
("camembert", "TFCamembertModel"),
("xlm-roberta", "TFXLMRobertaModel"),
("longformer", "TFLongformerModel"),
("roberta", "TFRobertaModel"),
("layoutlm", "TFLayoutLMModel"),
("bert", "TFBertModel"),
("openai-gpt", "TFOpenAIGPTModel"),
("gpt2", "TFGPT2Model"),
("mobilebert", "TFMobileBertModel"),
("transfo-xl", "TFTransfoXLModel"),
("xlnet", "TFXLNetModel"),
("flaubert", "TFFlaubertModel"),
("xlm", "TFXLMModel"),
("ctrl", "TFCTRLModel"),
("electra", "TFElectraModel"),
("funnel", ("TFFunnelModel", "TFFunnelBaseModel")),
("dpr", "TFDPRQuestionEncoder"),
("mpnet", "TFMPNetModel"),
("mbart", "TFMBartModel"),
("marian", "TFMarianModel"),
("pegasus", "TFPegasusModel"),
("blenderbot", "TFBlenderbotModel"),
("blenderbot-small", "TFBlenderbotSmallModel"),
("wav2vec2", "TFWav2Vec2Model"),
("hubert", "TFHubertModel"),
]
)
TF_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
[
# Model for pre-training mapping
(LxmertConfig, TFLxmertForPreTraining),
(T5Config, TFT5ForConditionalGeneration),
(DistilBertConfig, TFDistilBertForMaskedLM),
(AlbertConfig, TFAlbertForPreTraining),
(BartConfig, TFBartForConditionalGeneration),
(CamembertConfig, TFCamembertForMaskedLM),
(XLMRobertaConfig, TFXLMRobertaForMaskedLM),
(RobertaConfig, TFRobertaForMaskedLM),
(LayoutLMConfig, TFLayoutLMForMaskedLM),
(BertConfig, TFBertForPreTraining),
(OpenAIGPTConfig, TFOpenAIGPTLMHeadModel),
(GPT2Config, TFGPT2LMHeadModel),
(MobileBertConfig, TFMobileBertForPreTraining),
(TransfoXLConfig, TFTransfoXLLMHeadModel),
(XLNetConfig, TFXLNetLMHeadModel),
(FlaubertConfig, TFFlaubertWithLMHeadModel),
(XLMConfig, TFXLMWithLMHeadModel),
(CTRLConfig, TFCTRLLMHeadModel),
(ElectraConfig, TFElectraForPreTraining),
(FunnelConfig, TFFunnelForPreTraining),
(MPNetConfig, TFMPNetForMaskedLM),
("lxmert", "TFLxmertForPreTraining"),
("t5", "TFT5ForConditionalGeneration"),
("distilbert", "TFDistilBertForMaskedLM"),
("albert", "TFAlbertForPreTraining"),
("bart", "TFBartForConditionalGeneration"),
("camembert", "TFCamembertForMaskedLM"),
("xlm-roberta", "TFXLMRobertaForMaskedLM"),
("roberta", "TFRobertaForMaskedLM"),
("layoutlm", "TFLayoutLMForMaskedLM"),
("bert", "TFBertForPreTraining"),
("openai-gpt", "TFOpenAIGPTLMHeadModel"),
("gpt2", "TFGPT2LMHeadModel"),
("mobilebert", "TFMobileBertForPreTraining"),
("transfo-xl", "TFTransfoXLLMHeadModel"),
("xlnet", "TFXLNetLMHeadModel"),
("flaubert", "TFFlaubertWithLMHeadModel"),
("xlm", "TFXLMWithLMHeadModel"),
("ctrl", "TFCTRLLMHeadModel"),
("electra", "TFElectraForPreTraining"),
("funnel", "TFFunnelForPreTraining"),
("mpnet", "TFMPNetForMaskedLM"),
]
)
TF_MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
[
# Model with LM heads mapping
(RemBertConfig, TFRemBertForMaskedLM),
(RoFormerConfig, TFRoFormerForMaskedLM),
(ConvBertConfig, TFConvBertForMaskedLM),
(LEDConfig, TFLEDForConditionalGeneration),
(T5Config, TFT5ForConditionalGeneration),
(DistilBertConfig, TFDistilBertForMaskedLM),
(AlbertConfig, TFAlbertForMaskedLM),
(MarianConfig, TFMarianMTModel),
(BartConfig, TFBartForConditionalGeneration),
(CamembertConfig, TFCamembertForMaskedLM),
(XLMRobertaConfig, TFXLMRobertaForMaskedLM),
(LongformerConfig, TFLongformerForMaskedLM),
(RobertaConfig, TFRobertaForMaskedLM),
(LayoutLMConfig, TFLayoutLMForMaskedLM),
(BertConfig, TFBertForMaskedLM),
(OpenAIGPTConfig, TFOpenAIGPTLMHeadModel),
(GPT2Config, TFGPT2LMHeadModel),
(MobileBertConfig, TFMobileBertForMaskedLM),
(TransfoXLConfig, TFTransfoXLLMHeadModel),
(XLNetConfig, TFXLNetLMHeadModel),
(FlaubertConfig, TFFlaubertWithLMHeadModel),
(XLMConfig, TFXLMWithLMHeadModel),
(CTRLConfig, TFCTRLLMHeadModel),
(ElectraConfig, TFElectraForMaskedLM),
(FunnelConfig, TFFunnelForMaskedLM),
(MPNetConfig, TFMPNetForMaskedLM),
("rembert", "TFRemBertForMaskedLM"),
("roformer", "TFRoFormerForMaskedLM"),
("convbert", "TFConvBertForMaskedLM"),
("led", "TFLEDForConditionalGeneration"),
("t5", "TFT5ForConditionalGeneration"),
("distilbert", "TFDistilBertForMaskedLM"),
("albert", "TFAlbertForMaskedLM"),
("marian", "TFMarianMTModel"),
("bart", "TFBartForConditionalGeneration"),
("camembert", "TFCamembertForMaskedLM"),
("xlm-roberta", "TFXLMRobertaForMaskedLM"),
("longformer", "TFLongformerForMaskedLM"),
("roberta", "TFRobertaForMaskedLM"),
("layoutlm", "TFLayoutLMForMaskedLM"),
("bert", "TFBertForMaskedLM"),
("openai-gpt", "TFOpenAIGPTLMHeadModel"),
("gpt2", "TFGPT2LMHeadModel"),
("mobilebert", "TFMobileBertForMaskedLM"),
("transfo-xl", "TFTransfoXLLMHeadModel"),
("xlnet", "TFXLNetLMHeadModel"),
("flaubert", "TFFlaubertWithLMHeadModel"),
("xlm", "TFXLMWithLMHeadModel"),
("ctrl", "TFCTRLLMHeadModel"),
("electra", "TFElectraForMaskedLM"),
("funnel", "TFFunnelForMaskedLM"),
("mpnet", "TFMPNetForMaskedLM"),
]
)
TF_MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict(
TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Causal LM mapping
(RemBertConfig, TFRemBertForCausalLM),
(RoFormerConfig, TFRoFormerForCausalLM),
(BertConfig, TFBertLMHeadModel),
(OpenAIGPTConfig, TFOpenAIGPTLMHeadModel),
(GPT2Config, TFGPT2LMHeadModel),
(TransfoXLConfig, TFTransfoXLLMHeadModel),
(XLNetConfig, TFXLNetLMHeadModel),
(
XLMConfig,
TFXLMWithLMHeadModel,
), # XLM can be MLM and CLM => model should be split similar to BERT; leave here for now
(CTRLConfig, TFCTRLLMHeadModel),
("rembert", "TFRemBertForCausalLM"),
("roformer", "TFRoFormerForCausalLM"),
("bert", "TFBertLMHeadModel"),
("openai-gpt", "TFOpenAIGPTLMHeadModel"),
("gpt2", "TFGPT2LMHeadModel"),
("transfo-xl", "TFTransfoXLLMHeadModel"),
("xlnet", "TFXLNetLMHeadModel"),
("xlm", "TFXLMWithLMHeadModel"),
("ctrl", "TFCTRLLMHeadModel"),
]
)
TF_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Masked LM mapping
(RemBertConfig, TFRemBertForMaskedLM),
(RoFormerConfig, TFRoFormerForMaskedLM),
(ConvBertConfig, TFConvBertForMaskedLM),
(DistilBertConfig, TFDistilBertForMaskedLM),
(AlbertConfig, TFAlbertForMaskedLM),
(CamembertConfig, TFCamembertForMaskedLM),
(XLMRobertaConfig, TFXLMRobertaForMaskedLM),
(LongformerConfig, TFLongformerForMaskedLM),
(RobertaConfig, TFRobertaForMaskedLM),
(LayoutLMConfig, TFLayoutLMForMaskedLM),
(BertConfig, TFBertForMaskedLM),
(MobileBertConfig, TFMobileBertForMaskedLM),
(FlaubertConfig, TFFlaubertWithLMHeadModel),
(XLMConfig, TFXLMWithLMHeadModel),
(ElectraConfig, TFElectraForMaskedLM),
(FunnelConfig, TFFunnelForMaskedLM),
(MPNetConfig, TFMPNetForMaskedLM),
("rembert", "TFRemBertForMaskedLM"),
("roformer", "TFRoFormerForMaskedLM"),
("convbert", "TFConvBertForMaskedLM"),
("distilbert", "TFDistilBertForMaskedLM"),
("albert", "TFAlbertForMaskedLM"),
("camembert", "TFCamembertForMaskedLM"),
("xlm-roberta", "TFXLMRobertaForMaskedLM"),
("longformer", "TFLongformerForMaskedLM"),
("roberta", "TFRobertaForMaskedLM"),
("layoutlm", "TFLayoutLMForMaskedLM"),
("bert", "TFBertForMaskedLM"),
("mobilebert", "TFMobileBertForMaskedLM"),
("flaubert", "TFFlaubertWithLMHeadModel"),
("xlm", "TFXLMWithLMHeadModel"),
("electra", "TFElectraForMaskedLM"),
("funnel", "TFFunnelForMaskedLM"),
("mpnet", "TFMPNetForMaskedLM"),
]
)
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Seq2Seq Causal LM mapping
(LEDConfig, TFLEDForConditionalGeneration),
(MT5Config, TFMT5ForConditionalGeneration),
(T5Config, TFT5ForConditionalGeneration),
(MarianConfig, TFMarianMTModel),
(MBartConfig, TFMBartForConditionalGeneration),
(PegasusConfig, TFPegasusForConditionalGeneration),
(BlenderbotConfig, TFBlenderbotForConditionalGeneration),
(BlenderbotSmallConfig, TFBlenderbotSmallForConditionalGeneration),
(BartConfig, TFBartForConditionalGeneration),
("led", "TFLEDForConditionalGeneration"),
("mt5", "TFMT5ForConditionalGeneration"),
("t5", "TFT5ForConditionalGeneration"),
("marian", "TFMarianMTModel"),
("mbart", "TFMBartForConditionalGeneration"),
("pegasus", "TFPegasusForConditionalGeneration"),
("blenderbot", "TFBlenderbotForConditionalGeneration"),
("blenderbot-small", "TFBlenderbotSmallForConditionalGeneration"),
("bart", "TFBartForConditionalGeneration"),
]
)
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Sequence Classification mapping
(RemBertConfig, TFRemBertForSequenceClassification),
(RoFormerConfig, TFRoFormerForSequenceClassification),
(ConvBertConfig, TFConvBertForSequenceClassification),
(DistilBertConfig, TFDistilBertForSequenceClassification),
(AlbertConfig, TFAlbertForSequenceClassification),
(CamembertConfig, TFCamembertForSequenceClassification),
(XLMRobertaConfig, TFXLMRobertaForSequenceClassification),
(LongformerConfig, TFLongformerForSequenceClassification),
(RobertaConfig, TFRobertaForSequenceClassification),
(LayoutLMConfig, TFLayoutLMForSequenceClassification),
(BertConfig, TFBertForSequenceClassification),
(XLNetConfig, TFXLNetForSequenceClassification),
(MobileBertConfig, TFMobileBertForSequenceClassification),
(FlaubertConfig, TFFlaubertForSequenceClassification),
(XLMConfig, TFXLMForSequenceClassification),
(ElectraConfig, TFElectraForSequenceClassification),
(FunnelConfig, TFFunnelForSequenceClassification),
(GPT2Config, TFGPT2ForSequenceClassification),
(MPNetConfig, TFMPNetForSequenceClassification),
(OpenAIGPTConfig, TFOpenAIGPTForSequenceClassification),
(TransfoXLConfig, TFTransfoXLForSequenceClassification),
(CTRLConfig, TFCTRLForSequenceClassification),
("rembert", "TFRemBertForSequenceClassification"),
("roformer", "TFRoFormerForSequenceClassification"),
("convbert", "TFConvBertForSequenceClassification"),
("distilbert", "TFDistilBertForSequenceClassification"),
("albert", "TFAlbertForSequenceClassification"),
("camembert", "TFCamembertForSequenceClassification"),
("xlm-roberta", "TFXLMRobertaForSequenceClassification"),
("longformer", "TFLongformerForSequenceClassification"),
("roberta", "TFRobertaForSequenceClassification"),
("layoutlm", "TFLayoutLMForSequenceClassification"),
("bert", "TFBertForSequenceClassification"),
("xlnet", "TFXLNetForSequenceClassification"),
("mobilebert", "TFMobileBertForSequenceClassification"),
("flaubert", "TFFlaubertForSequenceClassification"),
("xlm", "TFXLMForSequenceClassification"),
("electra", "TFElectraForSequenceClassification"),
("funnel", "TFFunnelForSequenceClassification"),
("gpt2", "TFGPT2ForSequenceClassification"),
("mpnet", "TFMPNetForSequenceClassification"),
("openai-gpt", "TFOpenAIGPTForSequenceClassification"),
("transfo-xl", "TFTransfoXLForSequenceClassification"),
("ctrl", "TFCTRLForSequenceClassification"),
]
)
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
# Model for Question Answering mapping
(RemBertConfig, TFRemBertForQuestionAnswering),
(RoFormerConfig, TFRoFormerForQuestionAnswering),
(ConvBertConfig, TFConvBertForQuestionAnswering),
(DistilBertConfig, TFDistilBertForQuestionAnswering),
(AlbertConfig, TFAlbertForQuestionAnswering),
(CamembertConfig, TFCamembertForQuestionAnswering),
(XLMRobertaConfig, TFXLMRobertaForQuestionAnswering),
(LongformerConfig, TFLongformerForQuestionAnswering),
(RobertaConfig, TFRobertaForQuestionAnswering),
(BertConfig, TFBertForQuestionAnswering),
(XLNetConfig, TFXLNetForQuestionAnsweringSimple),
(MobileBertConfig, TFMobileBertForQuestionAnswering),
(FlaubertConfig, TFFlaubertForQuestionAnsweringSimple),
(XLMConfig, TFXLMForQuestionAnsweringSimple),
(ElectraConfig, TFElectraForQuestionAnswering),
(FunnelConfig, TFFunnelForQuestionAnswering),
(MPNetConfig, TFMPNetForQuestionAnswering),
("rembert", "TFRemBertForQuestionAnswering"),
("roformer", "TFRoFormerForQuestionAnswering"),
("convbert", "TFConvBertForQuestionAnswering"),
("distilbert", "TFDistilBertForQuestionAnswering"),
("albert", "TFAlbertForQuestionAnswering"),
("camembert", "TFCamembertForQuestionAnswering"),
("xlm-roberta", "TFXLMRobertaForQuestionAnswering"),
("longformer", "TFLongformerForQuestionAnswering"),
("roberta", "TFRobertaForQuestionAnswering"),
("bert", "TFBertForQuestionAnswering"),
("xlnet", "TFXLNetForQuestionAnsweringSimple"),
("mobilebert", "TFMobileBertForQuestionAnswering"),
("flaubert", "TFFlaubertForQuestionAnsweringSimple"),
("xlm", "TFXLMForQuestionAnsweringSimple"),
("electra", "TFElectraForQuestionAnswering"),
("funnel", "TFFunnelForQuestionAnswering"),
("mpnet", "TFMPNetForQuestionAnswering"),
]
)
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Token Classification mapping
(RemBertConfig, TFRemBertForTokenClassification),
(RoFormerConfig, TFRoFormerForTokenClassification),
(ConvBertConfig, TFConvBertForTokenClassification),
(DistilBertConfig, TFDistilBertForTokenClassification),
(AlbertConfig, TFAlbertForTokenClassification),
(CamembertConfig, TFCamembertForTokenClassification),
(FlaubertConfig, TFFlaubertForTokenClassification),
(XLMConfig, TFXLMForTokenClassification),
(XLMRobertaConfig, TFXLMRobertaForTokenClassification),
(LongformerConfig, TFLongformerForTokenClassification),
(RobertaConfig, TFRobertaForTokenClassification),
(LayoutLMConfig, TFLayoutLMForTokenClassification),
(BertConfig, TFBertForTokenClassification),
(MobileBertConfig, TFMobileBertForTokenClassification),
(XLNetConfig, TFXLNetForTokenClassification),
(ElectraConfig, TFElectraForTokenClassification),
(FunnelConfig, TFFunnelForTokenClassification),
(MPNetConfig, TFMPNetForTokenClassification),
("rembert", "TFRemBertForTokenClassification"),
("roformer", "TFRoFormerForTokenClassification"),
("convbert", "TFConvBertForTokenClassification"),
("distilbert", "TFDistilBertForTokenClassification"),
("albert", "TFAlbertForTokenClassification"),
("camembert", "TFCamembertForTokenClassification"),
("flaubert", "TFFlaubertForTokenClassification"),
("xlm", "TFXLMForTokenClassification"),
("xlm-roberta", "TFXLMRobertaForTokenClassification"),
("longformer", "TFLongformerForTokenClassification"),
("roberta", "TFRobertaForTokenClassification"),
("layoutlm", "TFLayoutLMForTokenClassification"),
("bert", "TFBertForTokenClassification"),
("mobilebert", "TFMobileBertForTokenClassification"),
("xlnet", "TFXLNetForTokenClassification"),
("electra", "TFElectraForTokenClassification"),
("funnel", "TFFunnelForTokenClassification"),
("mpnet", "TFMPNetForTokenClassification"),
]
)
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
[
# Model for Multiple Choice mapping
(RemBertConfig, TFRemBertForMultipleChoice),
(RoFormerConfig, TFRoFormerForMultipleChoice),
(ConvBertConfig, TFConvBertForMultipleChoice),
(CamembertConfig, TFCamembertForMultipleChoice),
(XLMConfig, TFXLMForMultipleChoice),
(XLMRobertaConfig, TFXLMRobertaForMultipleChoice),
(LongformerConfig, TFLongformerForMultipleChoice),
(RobertaConfig, TFRobertaForMultipleChoice),
(BertConfig, TFBertForMultipleChoice),
(DistilBertConfig, TFDistilBertForMultipleChoice),
(MobileBertConfig, TFMobileBertForMultipleChoice),
(XLNetConfig, TFXLNetForMultipleChoice),
(FlaubertConfig, TFFlaubertForMultipleChoice),
(AlbertConfig, TFAlbertForMultipleChoice),
(ElectraConfig, TFElectraForMultipleChoice),
(FunnelConfig, TFFunnelForMultipleChoice),
(MPNetConfig, TFMPNetForMultipleChoice),
("rembert", "TFRemBertForMultipleChoice"),
("roformer", "TFRoFormerForMultipleChoice"),
("convbert", "TFConvBertForMultipleChoice"),
("camembert", "TFCamembertForMultipleChoice"),
("xlm", "TFXLMForMultipleChoice"),
("xlm-roberta", "TFXLMRobertaForMultipleChoice"),
("longformer", "TFLongformerForMultipleChoice"),
("roberta", "TFRobertaForMultipleChoice"),
("bert", "TFBertForMultipleChoice"),
("distilbert", "TFDistilBertForMultipleChoice"),
("mobilebert", "TFMobileBertForMultipleChoice"),
("xlnet", "TFXLNetForMultipleChoice"),
("flaubert", "TFFlaubertForMultipleChoice"),
("albert", "TFAlbertForMultipleChoice"),
("electra", "TFElectraForMultipleChoice"),
("funnel", "TFFunnelForMultipleChoice"),
("mpnet", "TFMPNetForMultipleChoice"),
]
)
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict(
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
[
(BertConfig, TFBertForNextSentencePrediction),
(MobileBertConfig, TFMobileBertForNextSentencePrediction),
("bert", "TFBertForNextSentencePrediction"),
("mobilebert", "TFMobileBertForNextSentencePrediction"),
]
)
TF_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_MAPPING_NAMES)
TF_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
TF_MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES)
TF_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
TF_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES)
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
)
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
)
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
)
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
)
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES
)
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
)
class TFAutoModel(_BaseAutoModelClass):
_model_mapping = TF_MODEL_MAPPING
......
......@@ -14,12 +14,12 @@
# limitations under the License.
""" Auto Tokenizer class. """
import importlib
import json
import os
from collections import OrderedDict
from typing import Dict, Optional, Union
from ... import GPTNeoConfig
from ...configuration_utils import PretrainedConfig
from ...file_utils import (
cached_path,
......@@ -29,315 +29,183 @@ from ...file_utils import (
is_tokenizers_available,
)
from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE
from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import logging
from ..bart.tokenization_bart import BartTokenizer
from ..bert.tokenization_bert import BertTokenizer
from ..bert_japanese.tokenization_bert_japanese import BertJapaneseTokenizer
from ..bertweet.tokenization_bertweet import BertweetTokenizer
from ..blenderbot.tokenization_blenderbot import BlenderbotTokenizer
from ..blenderbot_small.tokenization_blenderbot_small import BlenderbotSmallTokenizer
from ..byt5.tokenization_byt5 import ByT5Tokenizer
from ..canine.tokenization_canine import CanineTokenizer
from ..convbert.tokenization_convbert import ConvBertTokenizer
from ..ctrl.tokenization_ctrl import CTRLTokenizer
from ..deberta.tokenization_deberta import DebertaTokenizer
from ..distilbert.tokenization_distilbert import DistilBertTokenizer
from ..dpr.tokenization_dpr import DPRQuestionEncoderTokenizer
from ..electra.tokenization_electra import ElectraTokenizer
from ..flaubert.tokenization_flaubert import FlaubertTokenizer
from ..fsmt.tokenization_fsmt import FSMTTokenizer
from ..funnel.tokenization_funnel import FunnelTokenizer
from ..gpt2.tokenization_gpt2 import GPT2Tokenizer
from ..herbert.tokenization_herbert import HerbertTokenizer
from ..layoutlm.tokenization_layoutlm import LayoutLMTokenizer
from ..led.tokenization_led import LEDTokenizer
from ..longformer.tokenization_longformer import LongformerTokenizer
from ..luke.tokenization_luke import LukeTokenizer
from ..lxmert.tokenization_lxmert import LxmertTokenizer
from ..mobilebert.tokenization_mobilebert import MobileBertTokenizer
from ..mpnet.tokenization_mpnet import MPNetTokenizer
from ..openai.tokenization_openai import OpenAIGPTTokenizer
from ..phobert.tokenization_phobert import PhobertTokenizer
from ..prophetnet.tokenization_prophetnet import ProphetNetTokenizer
from ..rag.tokenization_rag import RagTokenizer
from ..retribert.tokenization_retribert import RetriBertTokenizer
from ..roberta.tokenization_roberta import RobertaTokenizer
from ..roformer.tokenization_roformer import RoFormerTokenizer
from ..squeezebert.tokenization_squeezebert import SqueezeBertTokenizer
from ..tapas.tokenization_tapas import TapasTokenizer
from ..transfo_xl.tokenization_transfo_xl import TransfoXLTokenizer
from ..wav2vec2.tokenization_wav2vec2 import Wav2Vec2CTCTokenizer
from ..xlm.tokenization_xlm import XLMTokenizer
from ..encoder_decoder import EncoderDecoderConfig
from .auto_factory import _LazyAutoMapping
from .configuration_auto import (
AlbertConfig,
CONFIG_MAPPING_NAMES,
AutoConfig,
BartConfig,
BertConfig,
BertGenerationConfig,
BigBirdConfig,
BigBirdPegasusConfig,
BlenderbotConfig,
BlenderbotSmallConfig,
CamembertConfig,
CanineConfig,
ConvBertConfig,
CTRLConfig,
DebertaConfig,
DebertaV2Config,
DistilBertConfig,
DPRConfig,
ElectraConfig,
EncoderDecoderConfig,
FlaubertConfig,
FSMTConfig,
FunnelConfig,
GPT2Config,
HubertConfig,
IBertConfig,
LayoutLMConfig,
LEDConfig,
LongformerConfig,
LukeConfig,
LxmertConfig,
M2M100Config,
MarianConfig,
MBartConfig,
MobileBertConfig,
MPNetConfig,
MT5Config,
OpenAIGPTConfig,
PegasusConfig,
ProphetNetConfig,
RagConfig,
ReformerConfig,
RetriBertConfig,
RobertaConfig,
RoFormerConfig,
Speech2TextConfig,
SqueezeBertConfig,
T5Config,
TapasConfig,
TransfoXLConfig,
Wav2Vec2Config,
XLMConfig,
XLMProphetNetConfig,
XLMRobertaConfig,
XLNetConfig,
config_class_to_model_type,
replace_list_option_in_docstrings,
)
if is_sentencepiece_available():
from ..albert.tokenization_albert import AlbertTokenizer
from ..barthez.tokenization_barthez import BarthezTokenizer
from ..bert_generation.tokenization_bert_generation import BertGenerationTokenizer
from ..big_bird.tokenization_big_bird import BigBirdTokenizer
from ..camembert.tokenization_camembert import CamembertTokenizer
from ..cpm.tokenization_cpm import CpmTokenizer
from ..deberta_v2.tokenization_deberta_v2 import DebertaV2Tokenizer
from ..m2m_100 import M2M100Tokenizer
from ..marian.tokenization_marian import MarianTokenizer
from ..mbart.tokenization_mbart import MBartTokenizer
from ..mbart.tokenization_mbart50 import MBart50Tokenizer
from ..mt5 import MT5Tokenizer
from ..pegasus.tokenization_pegasus import PegasusTokenizer
from ..reformer.tokenization_reformer import ReformerTokenizer
from ..speech_to_text import Speech2TextTokenizer
from ..t5.tokenization_t5 import T5Tokenizer
from ..xlm_prophetnet.tokenization_xlm_prophetnet import XLMProphetNetTokenizer
from ..xlm_roberta.tokenization_xlm_roberta import XLMRobertaTokenizer
from ..xlnet.tokenization_xlnet import XLNetTokenizer
else:
AlbertTokenizer = None
BarthezTokenizer = None
BertGenerationTokenizer = None
BigBirdTokenizer = None
CamembertTokenizer = None
CpmTokenizer = None
DebertaV2Tokenizer = None
MarianTokenizer = None
MBartTokenizer = None
MBart50Tokenizer = None
MT5Tokenizer = None
PegasusTokenizer = None
ReformerTokenizer = None
T5Tokenizer = None
XLMRobertaTokenizer = None
XLNetTokenizer = None
XLMProphetNetTokenizer = None
M2M100Tokenizer = None
Speech2TextTokenizer = None
if is_tokenizers_available():
from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ..albert.tokenization_albert_fast import AlbertTokenizerFast
from ..bart.tokenization_bart_fast import BartTokenizerFast
from ..barthez.tokenization_barthez_fast import BarthezTokenizerFast
from ..bert.tokenization_bert_fast import BertTokenizerFast
from ..big_bird.tokenization_big_bird_fast import BigBirdTokenizerFast
from ..camembert.tokenization_camembert_fast import CamembertTokenizerFast
from ..convbert.tokenization_convbert_fast import ConvBertTokenizerFast
from ..cpm.tokenization_cpm_fast import CpmTokenizerFast
from ..deberta.tokenization_deberta_fast import DebertaTokenizerFast
from ..distilbert.tokenization_distilbert_fast import DistilBertTokenizerFast
from ..dpr.tokenization_dpr_fast import DPRQuestionEncoderTokenizerFast
from ..electra.tokenization_electra_fast import ElectraTokenizerFast
from ..funnel.tokenization_funnel_fast import FunnelTokenizerFast
from ..gpt2.tokenization_gpt2_fast import GPT2TokenizerFast
from ..herbert.tokenization_herbert_fast import HerbertTokenizerFast
from ..layoutlm.tokenization_layoutlm_fast import LayoutLMTokenizerFast
from ..led.tokenization_led_fast import LEDTokenizerFast
from ..longformer.tokenization_longformer_fast import LongformerTokenizerFast
from ..lxmert.tokenization_lxmert_fast import LxmertTokenizerFast
from ..mbart.tokenization_mbart50_fast import MBart50TokenizerFast
from ..mbart.tokenization_mbart_fast import MBartTokenizerFast
from ..mobilebert.tokenization_mobilebert_fast import MobileBertTokenizerFast
from ..mpnet.tokenization_mpnet_fast import MPNetTokenizerFast
from ..mt5 import MT5TokenizerFast
from ..openai.tokenization_openai_fast import OpenAIGPTTokenizerFast
from ..pegasus.tokenization_pegasus_fast import PegasusTokenizerFast
from ..reformer.tokenization_reformer_fast import ReformerTokenizerFast
from ..retribert.tokenization_retribert_fast import RetriBertTokenizerFast
from ..roberta.tokenization_roberta_fast import RobertaTokenizerFast
from ..roformer.tokenization_roformer_fast import RoFormerTokenizerFast
from ..squeezebert.tokenization_squeezebert_fast import SqueezeBertTokenizerFast
from ..t5.tokenization_t5_fast import T5TokenizerFast
from ..xlm_roberta.tokenization_xlm_roberta_fast import XLMRobertaTokenizerFast
from ..xlnet.tokenization_xlnet_fast import XLNetTokenizerFast
else:
AlbertTokenizerFast = None
BartTokenizerFast = None
BarthezTokenizerFast = None
BertTokenizerFast = None
BigBirdTokenizerFast = None
CamembertTokenizerFast = None
ConvBertTokenizerFast = None
CpmTokenizerFast = None
DebertaTokenizerFast = None
DistilBertTokenizerFast = None
DPRQuestionEncoderTokenizerFast = None
ElectraTokenizerFast = None
FunnelTokenizerFast = None
GPT2TokenizerFast = None
HerbertTokenizerFast = None
LayoutLMTokenizerFast = None
LEDTokenizerFast = None
LongformerTokenizerFast = None
LxmertTokenizerFast = None
MBartTokenizerFast = None
MBart50TokenizerFast = None
MobileBertTokenizerFast = None
MPNetTokenizerFast = None
MT5TokenizerFast = None
OpenAIGPTTokenizerFast = None
PegasusTokenizerFast = None
ReformerTokenizerFast = None
RetriBertTokenizerFast = None
RobertaTokenizerFast = None
RoFormerTokenizerFast = None
SqueezeBertTokenizerFast = None
T5TokenizerFast = None
XLMRobertaTokenizerFast = None
XLNetTokenizerFast = None
PreTrainedTokenizerFast = None
logger = logging.get_logger(__name__)
TOKENIZER_MAPPING = OrderedDict(
TOKENIZER_MAPPING_NAMES = OrderedDict(
[
(RetriBertConfig, (RetriBertTokenizer, RetriBertTokenizerFast)),
(RoFormerConfig, (RoFormerTokenizer, RoFormerTokenizerFast)),
(T5Config, (T5Tokenizer, T5TokenizerFast)),
(MT5Config, (MT5Tokenizer, MT5TokenizerFast)),
(MobileBertConfig, (MobileBertTokenizer, MobileBertTokenizerFast)),
(DistilBertConfig, (DistilBertTokenizer, DistilBertTokenizerFast)),
(AlbertConfig, (AlbertTokenizer, AlbertTokenizerFast)),
(CamembertConfig, (CamembertTokenizer, CamembertTokenizerFast)),
(PegasusConfig, (PegasusTokenizer, PegasusTokenizerFast)),
(MBartConfig, (MBartTokenizer, MBartTokenizerFast)),
(XLMRobertaConfig, (XLMRobertaTokenizer, XLMRobertaTokenizerFast)),
(MarianConfig, (MarianTokenizer, None)),
(BlenderbotSmallConfig, (BlenderbotSmallTokenizer, None)),
(BlenderbotConfig, (BlenderbotTokenizer, None)),
(BartConfig, (BartTokenizer, BartTokenizerFast)),
(LongformerConfig, (LongformerTokenizer, LongformerTokenizerFast)),
(RobertaConfig, (RobertaTokenizer, RobertaTokenizerFast)),
(ReformerConfig, (ReformerTokenizer, ReformerTokenizerFast)),
(ElectraConfig, (ElectraTokenizer, ElectraTokenizerFast)),
(FunnelConfig, (FunnelTokenizer, FunnelTokenizerFast)),
(LxmertConfig, (LxmertTokenizer, LxmertTokenizerFast)),
(LayoutLMConfig, (LayoutLMTokenizer, LayoutLMTokenizerFast)),
(DPRConfig, (DPRQuestionEncoderTokenizer, DPRQuestionEncoderTokenizerFast)),
(SqueezeBertConfig, (SqueezeBertTokenizer, SqueezeBertTokenizerFast)),
(BertConfig, (BertTokenizer, BertTokenizerFast)),
(OpenAIGPTConfig, (OpenAIGPTTokenizer, OpenAIGPTTokenizerFast)),
(GPT2Config, (GPT2Tokenizer, GPT2TokenizerFast)),
(TransfoXLConfig, (TransfoXLTokenizer, None)),
(XLNetConfig, (XLNetTokenizer, XLNetTokenizerFast)),
(FlaubertConfig, (FlaubertTokenizer, None)),
(XLMConfig, (XLMTokenizer, None)),
(CTRLConfig, (CTRLTokenizer, None)),
(FSMTConfig, (FSMTTokenizer, None)),
(BertGenerationConfig, (BertGenerationTokenizer, None)),
(DebertaConfig, (DebertaTokenizer, DebertaTokenizerFast)),
(DebertaV2Config, (DebertaV2Tokenizer, None)),
(RagConfig, (RagTokenizer, None)),
(XLMProphetNetConfig, (XLMProphetNetTokenizer, None)),
(Speech2TextConfig, (Speech2TextTokenizer, None)),
(M2M100Config, (M2M100Tokenizer, None)),
(ProphetNetConfig, (ProphetNetTokenizer, None)),
(MPNetConfig, (MPNetTokenizer, MPNetTokenizerFast)),
(TapasConfig, (TapasTokenizer, None)),
(LEDConfig, (LEDTokenizer, LEDTokenizerFast)),
(ConvBertConfig, (ConvBertTokenizer, ConvBertTokenizerFast)),
(BigBirdConfig, (BigBirdTokenizer, BigBirdTokenizerFast)),
(IBertConfig, (RobertaTokenizer, RobertaTokenizerFast)),
(Wav2Vec2Config, (Wav2Vec2CTCTokenizer, None)),
(HubertConfig, (Wav2Vec2CTCTokenizer, None)),
(GPTNeoConfig, (GPT2Tokenizer, GPT2TokenizerFast)),
(LukeConfig, (LukeTokenizer, None)),
(BigBirdPegasusConfig, (PegasusTokenizer, PegasusTokenizerFast)),
(CanineConfig, (CanineTokenizer, None)),
("retribert", ("RetriBertTokenizer", "RetriBertTokenizerFast" if is_tokenizers_available() else None)),
("roformer", ("RoFormerTokenizer", "RoFormerTokenizerFast" if is_tokenizers_available() else None)),
(
"t5",
(
"T5Tokenizer" if is_sentencepiece_available() else None,
"T5TokenizerFast" if is_tokenizers_available() else None,
),
),
(
"mt5",
(
"MT5Tokenizer" if is_sentencepiece_available() else None,
"MT5TokenizerFast" if is_tokenizers_available() else None,
),
),
("mobilebert", ("MobileBertTokenizer", "MobileBertTokenizerFast" if is_tokenizers_available() else None)),
("distilbert", ("DistilBertTokenizer", "DistilBertTokenizerFast" if is_tokenizers_available() else None)),
(
"albert",
(
"AlbertTokenizer" if is_sentencepiece_available() else None,
"AlbertTokenizerFast" if is_tokenizers_available() else None,
),
),
(
"camembert",
(
"CamembertTokenizer" if is_sentencepiece_available() else None,
"CamembertTokenizerFast" if is_tokenizers_available() else None,
),
),
(
"pegasus",
(
"PegasusTokenizer" if is_sentencepiece_available() else None,
"PegasusTokenizerFast" if is_tokenizers_available() else None,
),
),
(
"mbart",
(
"MBartTokenizer" if is_sentencepiece_available() else None,
"MBartTokenizerFast" if is_tokenizers_available() else None,
),
),
(
"xlm-roberta",
(
"XLMRobertaTokenizer" if is_sentencepiece_available() else None,
"XLMRobertaTokenizerFast" if is_tokenizers_available() else None,
),
),
("marian", ("MarianTokenizer" if is_sentencepiece_available() else None, None)),
("blenderbot-small", ("BlenderbotSmallTokenizer", None)),
("blenderbot", ("BlenderbotTokenizer", None)),
("bart", ("BartTokenizer", "BartTokenizerFast")),
("longformer", ("LongformerTokenizer", "LongformerTokenizerFast" if is_tokenizers_available() else None)),
("roberta", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
(
"reformer",
(
"ReformerTokenizer" if is_sentencepiece_available() else None,
"ReformerTokenizerFast" if is_tokenizers_available() else None,
),
),
("electra", ("ElectraTokenizer", "ElectraTokenizerFast" if is_tokenizers_available() else None)),
("funnel", ("FunnelTokenizer", "FunnelTokenizerFast" if is_tokenizers_available() else None)),
("lxmert", ("LxmertTokenizer", "LxmertTokenizerFast" if is_tokenizers_available() else None)),
("layoutlm", ("LayoutLMTokenizer", "LayoutLMTokenizerFast" if is_tokenizers_available() else None)),
(
"dpr",
("DPRQuestionEncoderTokenizer", "DPRQuestionEncoderTokenizerFast" if is_tokenizers_available() else None),
),
("squeezebert", ("SqueezeBertTokenizer", "SqueezeBertTokenizerFast" if is_tokenizers_available() else None)),
("bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("openai-gpt", ("OpenAIGPTTokenizer", "OpenAIGPTTokenizerFast" if is_tokenizers_available() else None)),
("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("transfo-xl", ("TransfoXLTokenizer", None)),
(
"xlnet",
(
"XLNetTokenizer" if is_sentencepiece_available() else None,
"XLNetTokenizerFast" if is_tokenizers_available() else None,
),
),
("flaubert", ("FlaubertTokenizer", None)),
("xlm", ("XLMTokenizer", None)),
("ctrl", ("CTRLTokenizer", None)),
("fsmt", ("FSMTTokenizer", None)),
("bert-generation", ("BertGenerationTokenizer" if is_sentencepiece_available() else None, None)),
("deberta", ("DebertaTokenizer", "DebertaTokenizerFast" if is_tokenizers_available() else None)),
("deberta-v2", ("DebertaV2Tokenizer" if is_sentencepiece_available() else None, None)),
("rag", ("RagTokenizer", None)),
("xlm-prophetnet", ("XLMProphetNetTokenizer" if is_sentencepiece_available() else None, None)),
("speech_to_text", ("Speech2TextTokenizer" if is_sentencepiece_available() else None, None)),
("m2m_100", ("M2M100Tokenizer" if is_sentencepiece_available() else None, None)),
("prophetnet", ("ProphetNetTokenizer", None)),
("mpnet", ("MPNetTokenizer", "MPNetTokenizerFast" if is_tokenizers_available() else None)),
("tapas", ("TapasTokenizer", None)),
("led", ("LEDTokenizer", "LEDTokenizerFast" if is_tokenizers_available() else None)),
("convbert", ("ConvBertTokenizer", "ConvBertTokenizerFast" if is_tokenizers_available() else None)),
(
"big_bird",
(
"BigBirdTokenizer" if is_sentencepiece_available() else None,
"BigBirdTokenizerFast" if is_tokenizers_available() else None,
),
),
("ibert", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
("wav2vec2", ("Wav2Vec2CTCTokenizer", None)),
("hubert", ("Wav2Vec2CTCTokenizer", None)),
("gpt_neo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("luke", ("LukeTokenizer", None)),
("bigbird_pegasus", ("PegasusTokenizer", "PegasusTokenizerFast" if is_tokenizers_available() else None)),
("canine", ("CanineTokenizer", None)),
("bertweet", ("BertweetTokenizer", None)),
("bert-japanese", ("BertJapaneseTokenizer", None)),
("byt5", ("ByT5Tokenizer", None)),
(
"cpm",
(
"CpmTokenizer" if is_sentencepiece_available() else None,
"CpmTokenizerFast" if is_tokenizers_available() else None,
),
),
("herbert", ("HerbertTokenizer", "HerbertTokenizerFast" if is_tokenizers_available() else None)),
("phobert", ("PhobertTokenizer", None)),
(
"barthez",
(
"BarthezTokenizer" if is_sentencepiece_available() else None,
"BarthezTokenizerFast" if is_tokenizers_available() else None,
),
),
(
"mbart50",
(
"MBart50Tokenizer" if is_sentencepiece_available() else None,
"MBart50TokenizerFast" if is_tokenizers_available() else None,
),
),
]
)
# For tokenizers which are not directly mapped from a config
NO_CONFIG_TOKENIZER = [
BertJapaneseTokenizer,
BertweetTokenizer,
ByT5Tokenizer,
CpmTokenizer,
CpmTokenizerFast,
HerbertTokenizer,
HerbertTokenizerFast,
PhobertTokenizer,
BarthezTokenizer,
BarthezTokenizerFast,
MBart50Tokenizer,
MBart50TokenizerFast,
PreTrainedTokenizerFast,
]
SLOW_TOKENIZER_MAPPING = {
k: (v[0] if v[0] is not None else v[1])
for k, v in TOKENIZER_MAPPING.items()
if (v[0] is not None or v[1] is not None)
}
TOKENIZER_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TOKENIZER_MAPPING_NAMES)
CONFIG_TO_TYPE = {v: k for k, v in CONFIG_MAPPING_NAMES.items()}
def tokenizer_class_from_name(class_name: str):
all_tokenizer_classes = (
[v[0] for v in TOKENIZER_MAPPING.values() if v[0] is not None]
+ [v[1] for v in TOKENIZER_MAPPING.values() if v[1] is not None]
+ [v for v in NO_CONFIG_TOKENIZER if v is not None]
)
for c in all_tokenizer_classes:
if c.__name__ == class_name:
return c
if class_name == "PreTrainedTokenizerFast":
return PreTrainedTokenizerFast
for module_name, tokenizers in TOKENIZER_MAPPING_NAMES.items():
if class_name in tokenizers:
break
module = importlib.import_module(f".{module_name}", "transformers.models")
return getattr(module, class_name)
def get_tokenizer_config(
......@@ -454,7 +322,7 @@ class AutoTokenizer:
)
@classmethod
@replace_list_option_in_docstrings(SLOW_TOKENIZER_MAPPING)
@replace_list_option_in_docstrings(TOKENIZER_MAPPING_NAMES)
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
r"""
Instantiate one of the tokenizer classes of the library from a pretrained model vocabulary.
......@@ -565,7 +433,8 @@ class AutoTokenizer:
)
config = config.encoder
if type(config) in TOKENIZER_MAPPING.keys():
model_type = config_class_to_model_type(type(config).__name__)
if model_type is not None:
tokenizer_class_py, tokenizer_class_fast = TOKENIZER_MAPPING[type(config)]
if tokenizer_class_fast and (use_fast or tokenizer_class_py is None):
return tokenizer_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
......
......@@ -33,10 +33,8 @@ _import_structure = {
if is_sentencepiece_available():
_import_structure["tokenization_mbart"] = ["MBartTokenizer"]
_import_structure["tokenization_mbart50"] = ["MBart50Tokenizer"]
if is_tokenizers_available():
_import_structure["tokenization_mbart50_fast"] = ["MBart50TokenizerFast"]
_import_structure["tokenization_mbart_fast"] = ["MBartTokenizerFast"]
if is_torch_available():
......@@ -72,10 +70,8 @@ if TYPE_CHECKING:
if is_sentencepiece_available():
from .tokenization_mbart import MBartTokenizer
from .tokenization_mbart50 import MBart50Tokenizer
if is_tokenizers_available():
from .tokenization_mbart50_fast import MBart50TokenizerFast
from .tokenization_mbart_fast import MBartTokenizerFast
if is_torch_available():
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...file_utils import _LazyModule, is_sentencepiece_available, is_tokenizers_available
_import_structure = {}
if is_sentencepiece_available():
_import_structure["tokenization_mbart50"] = ["MBart50Tokenizer"]
if is_tokenizers_available():
_import_structure["tokenization_mbart50_fast"] = ["MBart50TokenizerFast"]
if TYPE_CHECKING:
if is_sentencepiece_available():
from .tokenization_mbart50 import MBart50Tokenizer
if is_tokenizers_available():
from .tokenization_mbart50_fast import MBart50TokenizerFast
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
......@@ -1781,25 +1781,22 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
if config_tokenizer_class is None:
# Third attempt. If we have not yet found the original type of the tokenizer,
# we are loading we see if we can infer it from the type of the configuration file
from .models.auto.configuration_auto import CONFIG_MAPPING # tests_ignore
from .models.auto.tokenization_auto import TOKENIZER_MAPPING # tests_ignore
from .models.auto.tokenization_auto import TOKENIZER_MAPPING_NAMES # tests_ignore
if hasattr(config, "model_type"):
config_class = CONFIG_MAPPING.get(config.model_type)
model_type = config.model_type
else:
# Fallback: use pattern matching on the string.
config_class = None
for pattern, config_class_tmp in CONFIG_MAPPING.items():
model_type = None
for pattern in TOKENIZER_MAPPING_NAMES.keys():
if pattern in str(pretrained_model_name_or_path):
config_class = config_class_tmp
model_type = pattern
break
if config_class in TOKENIZER_MAPPING.keys():
config_tokenizer_class, config_tokenizer_class_fast = TOKENIZER_MAPPING[config_class]
if config_tokenizer_class is not None:
config_tokenizer_class = config_tokenizer_class.__name__
else:
config_tokenizer_class = config_tokenizer_class_fast.__name__
if model_type is not None:
config_tokenizer_class, config_tokenizer_class_fast = TOKENIZER_MAPPING_NAMES[model_type]
if config_tokenizer_class is None:
config_tokenizer_class = config_tokenizer_class_fast
if config_tokenizer_class is not None:
if cls.__name__.replace("Fast", "") != config_tokenizer_class.replace("Fast", ""):
......
......@@ -74,6 +74,7 @@ from .file_utils import (
)
from .modelcard import TrainingSummary
from .modeling_utils import PreTrainedModel, unwrap_model
from .models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
from .optimization import Adafactor, AdamW, get_scheduler
from .tokenization_utils_base import PreTrainedTokenizerBase
from .trainer_callback import (
......@@ -125,7 +126,6 @@ from .trainer_utils import (
)
from .training_args import ParallelMode, TrainingArguments
from .utils import logging
from .utils.modeling_auto_mapping import MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
_is_torch_generator_available = False
......
......@@ -191,7 +191,7 @@ class LxmertTokenizerFast:
requires_backends(cls, ["tokenizers"])
class MBart50TokenizerFast:
class MBartTokenizerFast:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tokenizers"])
......@@ -200,7 +200,7 @@ class MBart50TokenizerFast:
requires_backends(cls, ["tokenizers"])
class MBartTokenizerFast:
class MBart50TokenizerFast:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tokenizers"])
......
# THIS FILE HAS BEEN AUTOGENERATED. To update:
# 1. modify: models/auto/modeling_auto.py
# 2. run: python utils/class_mapping_update.py
from collections import OrderedDict
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
("RemBertConfig", "RemBertForQuestionAnswering"),
("CanineConfig", "CanineForQuestionAnswering"),
("RoFormerConfig", "RoFormerForQuestionAnswering"),
("BigBirdPegasusConfig", "BigBirdPegasusForQuestionAnswering"),
("BigBirdConfig", "BigBirdForQuestionAnswering"),
("ConvBertConfig", "ConvBertForQuestionAnswering"),
("LEDConfig", "LEDForQuestionAnswering"),
("DistilBertConfig", "DistilBertForQuestionAnswering"),
("AlbertConfig", "AlbertForQuestionAnswering"),
("CamembertConfig", "CamembertForQuestionAnswering"),
("BartConfig", "BartForQuestionAnswering"),
("MBartConfig", "MBartForQuestionAnswering"),
("LongformerConfig", "LongformerForQuestionAnswering"),
("XLMRobertaConfig", "XLMRobertaForQuestionAnswering"),
("RobertaConfig", "RobertaForQuestionAnswering"),
("SqueezeBertConfig", "SqueezeBertForQuestionAnswering"),
("BertConfig", "BertForQuestionAnswering"),
("XLNetConfig", "XLNetForQuestionAnsweringSimple"),
("FlaubertConfig", "FlaubertForQuestionAnsweringSimple"),
("MegatronBertConfig", "MegatronBertForQuestionAnswering"),
("MobileBertConfig", "MobileBertForQuestionAnswering"),
("XLMConfig", "XLMForQuestionAnsweringSimple"),
("ElectraConfig", "ElectraForQuestionAnswering"),
("ReformerConfig", "ReformerForQuestionAnswering"),
("FunnelConfig", "FunnelForQuestionAnswering"),
("LxmertConfig", "LxmertForQuestionAnswering"),
("MPNetConfig", "MPNetForQuestionAnswering"),
("DebertaConfig", "DebertaForQuestionAnswering"),
("DebertaV2Config", "DebertaV2ForQuestionAnswering"),
("IBertConfig", "IBertForQuestionAnswering"),
]
)
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
("RemBertConfig", "RemBertForCausalLM"),
("RoFormerConfig", "RoFormerForCausalLM"),
("BigBirdPegasusConfig", "BigBirdPegasusForCausalLM"),
("GPTNeoConfig", "GPTNeoForCausalLM"),
("BigBirdConfig", "BigBirdForCausalLM"),
("CamembertConfig", "CamembertForCausalLM"),
("XLMRobertaConfig", "XLMRobertaForCausalLM"),
("RobertaConfig", "RobertaForCausalLM"),
("BertConfig", "BertLMHeadModel"),
("OpenAIGPTConfig", "OpenAIGPTLMHeadModel"),
("GPT2Config", "GPT2LMHeadModel"),
("TransfoXLConfig", "TransfoXLLMHeadModel"),
("XLNetConfig", "XLNetLMHeadModel"),
("XLMConfig", "XLMWithLMHeadModel"),
("CTRLConfig", "CTRLLMHeadModel"),
("ReformerConfig", "ReformerModelWithLMHead"),
("BertGenerationConfig", "BertGenerationDecoder"),
("XLMProphetNetConfig", "XLMProphetNetForCausalLM"),
("ProphetNetConfig", "ProphetNetForCausalLM"),
("BartConfig", "BartForCausalLM"),
("MBartConfig", "MBartForCausalLM"),
("PegasusConfig", "PegasusForCausalLM"),
("MarianConfig", "MarianForCausalLM"),
("BlenderbotConfig", "BlenderbotForCausalLM"),
("BlenderbotSmallConfig", "BlenderbotSmallForCausalLM"),
("MegatronBertConfig", "MegatronBertForCausalLM"),
]
)
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
("ViTConfig", "ViTForImageClassification"),
("DeiTConfig", "('DeiTForImageClassification', 'DeiTForImageClassificationWithTeacher')"),
("BeitConfig", "BeitForImageClassification"),
]
)
MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
[
("RemBertConfig", "RemBertForMaskedLM"),
("RoFormerConfig", "RoFormerForMaskedLM"),
("BigBirdConfig", "BigBirdForMaskedLM"),
("Wav2Vec2Config", "Wav2Vec2ForMaskedLM"),
("ConvBertConfig", "ConvBertForMaskedLM"),
("LayoutLMConfig", "LayoutLMForMaskedLM"),
("DistilBertConfig", "DistilBertForMaskedLM"),
("AlbertConfig", "AlbertForMaskedLM"),
("BartConfig", "BartForConditionalGeneration"),
("MBartConfig", "MBartForConditionalGeneration"),
("CamembertConfig", "CamembertForMaskedLM"),
("XLMRobertaConfig", "XLMRobertaForMaskedLM"),
("LongformerConfig", "LongformerForMaskedLM"),
("RobertaConfig", "RobertaForMaskedLM"),
("SqueezeBertConfig", "SqueezeBertForMaskedLM"),
("BertConfig", "BertForMaskedLM"),
("MegatronBertConfig", "MegatronBertForMaskedLM"),
("MobileBertConfig", "MobileBertForMaskedLM"),
("FlaubertConfig", "FlaubertWithLMHeadModel"),
("XLMConfig", "XLMWithLMHeadModel"),
("ElectraConfig", "ElectraForMaskedLM"),
("ReformerConfig", "ReformerForMaskedLM"),
("FunnelConfig", "FunnelForMaskedLM"),
("MPNetConfig", "MPNetForMaskedLM"),
("TapasConfig", "TapasForMaskedLM"),
("DebertaConfig", "DebertaForMaskedLM"),
("DebertaV2Config", "DebertaV2ForMaskedLM"),
("IBertConfig", "IBertForMaskedLM"),
]
)
MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
[
("RemBertConfig", "RemBertForMultipleChoice"),
("CanineConfig", "CanineForMultipleChoice"),
("RoFormerConfig", "RoFormerForMultipleChoice"),
("BigBirdConfig", "BigBirdForMultipleChoice"),
("ConvBertConfig", "ConvBertForMultipleChoice"),
("CamembertConfig", "CamembertForMultipleChoice"),
("ElectraConfig", "ElectraForMultipleChoice"),
("XLMRobertaConfig", "XLMRobertaForMultipleChoice"),
("LongformerConfig", "LongformerForMultipleChoice"),
("RobertaConfig", "RobertaForMultipleChoice"),
("SqueezeBertConfig", "SqueezeBertForMultipleChoice"),
("BertConfig", "BertForMultipleChoice"),
("DistilBertConfig", "DistilBertForMultipleChoice"),
("MegatronBertConfig", "MegatronBertForMultipleChoice"),
("MobileBertConfig", "MobileBertForMultipleChoice"),
("XLNetConfig", "XLNetForMultipleChoice"),
("AlbertConfig", "AlbertForMultipleChoice"),
("XLMConfig", "XLMForMultipleChoice"),
("FlaubertConfig", "FlaubertForMultipleChoice"),
("FunnelConfig", "FunnelForMultipleChoice"),
("MPNetConfig", "MPNetForMultipleChoice"),
("IBertConfig", "IBertForMultipleChoice"),
]
)
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
[
("BertConfig", "BertForNextSentencePrediction"),
("MegatronBertConfig", "MegatronBertForNextSentencePrediction"),
("MobileBertConfig", "MobileBertForNextSentencePrediction"),
]
)
MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
[
("DetrConfig", "DetrForObjectDetection"),
]
)
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
("BigBirdPegasusConfig", "BigBirdPegasusForConditionalGeneration"),
("M2M100Config", "M2M100ForConditionalGeneration"),
("LEDConfig", "LEDForConditionalGeneration"),
("BlenderbotSmallConfig", "BlenderbotSmallForConditionalGeneration"),
("MT5Config", "MT5ForConditionalGeneration"),
("T5Config", "T5ForConditionalGeneration"),
("PegasusConfig", "PegasusForConditionalGeneration"),
("MarianConfig", "MarianMTModel"),
("MBartConfig", "MBartForConditionalGeneration"),
("BlenderbotConfig", "BlenderbotForConditionalGeneration"),
("BartConfig", "BartForConditionalGeneration"),
("FSMTConfig", "FSMTForConditionalGeneration"),
("EncoderDecoderConfig", "EncoderDecoderModel"),
("XLMProphetNetConfig", "XLMProphetNetForConditionalGeneration"),
("ProphetNetConfig", "ProphetNetForConditionalGeneration"),
]
)
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
("RemBertConfig", "RemBertForSequenceClassification"),
("CanineConfig", "CanineForSequenceClassification"),
("RoFormerConfig", "RoFormerForSequenceClassification"),
("BigBirdPegasusConfig", "BigBirdPegasusForSequenceClassification"),
("BigBirdConfig", "BigBirdForSequenceClassification"),
("ConvBertConfig", "ConvBertForSequenceClassification"),
("LEDConfig", "LEDForSequenceClassification"),
("DistilBertConfig", "DistilBertForSequenceClassification"),
("AlbertConfig", "AlbertForSequenceClassification"),
("CamembertConfig", "CamembertForSequenceClassification"),
("XLMRobertaConfig", "XLMRobertaForSequenceClassification"),
("MBartConfig", "MBartForSequenceClassification"),
("BartConfig", "BartForSequenceClassification"),
("LongformerConfig", "LongformerForSequenceClassification"),
("RobertaConfig", "RobertaForSequenceClassification"),
("SqueezeBertConfig", "SqueezeBertForSequenceClassification"),
("LayoutLMConfig", "LayoutLMForSequenceClassification"),
("BertConfig", "BertForSequenceClassification"),
("XLNetConfig", "XLNetForSequenceClassification"),
("MegatronBertConfig", "MegatronBertForSequenceClassification"),
("MobileBertConfig", "MobileBertForSequenceClassification"),
("FlaubertConfig", "FlaubertForSequenceClassification"),
("XLMConfig", "XLMForSequenceClassification"),
("ElectraConfig", "ElectraForSequenceClassification"),
("FunnelConfig", "FunnelForSequenceClassification"),
("DebertaConfig", "DebertaForSequenceClassification"),
("DebertaV2Config", "DebertaV2ForSequenceClassification"),
("GPT2Config", "GPT2ForSequenceClassification"),
("GPTNeoConfig", "GPTNeoForSequenceClassification"),
("OpenAIGPTConfig", "OpenAIGPTForSequenceClassification"),
("ReformerConfig", "ReformerForSequenceClassification"),
("CTRLConfig", "CTRLForSequenceClassification"),
("TransfoXLConfig", "TransfoXLForSequenceClassification"),
("MPNetConfig", "MPNetForSequenceClassification"),
("TapasConfig", "TapasForSequenceClassification"),
("IBertConfig", "IBertForSequenceClassification"),
]
)
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
("TapasConfig", "TapasForQuestionAnswering"),
]
)
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
("RemBertConfig", "RemBertForTokenClassification"),
("CanineConfig", "CanineForTokenClassification"),
("RoFormerConfig", "RoFormerForTokenClassification"),
("BigBirdConfig", "BigBirdForTokenClassification"),
("ConvBertConfig", "ConvBertForTokenClassification"),
("LayoutLMConfig", "LayoutLMForTokenClassification"),
("DistilBertConfig", "DistilBertForTokenClassification"),
("CamembertConfig", "CamembertForTokenClassification"),
("FlaubertConfig", "FlaubertForTokenClassification"),
("XLMConfig", "XLMForTokenClassification"),
("XLMRobertaConfig", "XLMRobertaForTokenClassification"),
("LongformerConfig", "LongformerForTokenClassification"),
("RobertaConfig", "RobertaForTokenClassification"),
("SqueezeBertConfig", "SqueezeBertForTokenClassification"),
("BertConfig", "BertForTokenClassification"),
("MegatronBertConfig", "MegatronBertForTokenClassification"),
("MobileBertConfig", "MobileBertForTokenClassification"),
("XLNetConfig", "XLNetForTokenClassification"),
("AlbertConfig", "AlbertForTokenClassification"),
("ElectraConfig", "ElectraForTokenClassification"),
("FunnelConfig", "FunnelForTokenClassification"),
("MPNetConfig", "MPNetForTokenClassification"),
("DebertaConfig", "DebertaForTokenClassification"),
("DebertaV2Config", "DebertaV2ForTokenClassification"),
("IBertConfig", "IBertForTokenClassification"),
]
)
MODEL_MAPPING_NAMES = OrderedDict(
[
("BeitConfig", "BeitModel"),
("RemBertConfig", "RemBertModel"),
("VisualBertConfig", "VisualBertModel"),
("CanineConfig", "CanineModel"),
("RoFormerConfig", "RoFormerModel"),
("CLIPConfig", "CLIPModel"),
("BigBirdPegasusConfig", "BigBirdPegasusModel"),
("DeiTConfig", "DeiTModel"),
("LukeConfig", "LukeModel"),
("DetrConfig", "DetrModel"),
("GPTNeoConfig", "GPTNeoModel"),
("BigBirdConfig", "BigBirdModel"),
("Speech2TextConfig", "Speech2TextModel"),
("ViTConfig", "ViTModel"),
("Wav2Vec2Config", "Wav2Vec2Model"),
("HubertConfig", "HubertModel"),
("M2M100Config", "M2M100Model"),
("ConvBertConfig", "ConvBertModel"),
("LEDConfig", "LEDModel"),
("BlenderbotSmallConfig", "BlenderbotSmallModel"),
("RetriBertConfig", "RetriBertModel"),
("MT5Config", "MT5Model"),
("T5Config", "T5Model"),
("PegasusConfig", "PegasusModel"),
("MarianConfig", "MarianModel"),
("MBartConfig", "MBartModel"),
("BlenderbotConfig", "BlenderbotModel"),
("DistilBertConfig", "DistilBertModel"),
("AlbertConfig", "AlbertModel"),
("CamembertConfig", "CamembertModel"),
("XLMRobertaConfig", "XLMRobertaModel"),
("BartConfig", "BartModel"),
("LongformerConfig", "LongformerModel"),
("RobertaConfig", "RobertaModel"),
("LayoutLMConfig", "LayoutLMModel"),
("SqueezeBertConfig", "SqueezeBertModel"),
("BertConfig", "BertModel"),
("OpenAIGPTConfig", "OpenAIGPTModel"),
("GPT2Config", "GPT2Model"),
("MegatronBertConfig", "MegatronBertModel"),
("MobileBertConfig", "MobileBertModel"),
("TransfoXLConfig", "TransfoXLModel"),
("XLNetConfig", "XLNetModel"),
("FlaubertConfig", "FlaubertModel"),
("FSMTConfig", "FSMTModel"),
("XLMConfig", "XLMModel"),
("CTRLConfig", "CTRLModel"),
("ElectraConfig", "ElectraModel"),
("ReformerConfig", "ReformerModel"),
("FunnelConfig", "('FunnelModel', 'FunnelBaseModel')"),
("LxmertConfig", "LxmertModel"),
("BertGenerationConfig", "BertGenerationEncoder"),
("DebertaConfig", "DebertaModel"),
("DebertaV2Config", "DebertaV2Model"),
("DPRConfig", "DPRQuestionEncoder"),
("XLMProphetNetConfig", "XLMProphetNetModel"),
("ProphetNetConfig", "ProphetNetModel"),
("MPNetConfig", "MPNetModel"),
("TapasConfig", "TapasModel"),
("IBertConfig", "IBertModel"),
]
)
MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
[
("RemBertConfig", "RemBertForMaskedLM"),
("RoFormerConfig", "RoFormerForMaskedLM"),
("BigBirdPegasusConfig", "BigBirdPegasusForConditionalGeneration"),
("GPTNeoConfig", "GPTNeoForCausalLM"),
("BigBirdConfig", "BigBirdForMaskedLM"),
("Speech2TextConfig", "Speech2TextForConditionalGeneration"),
("Wav2Vec2Config", "Wav2Vec2ForMaskedLM"),
("M2M100Config", "M2M100ForConditionalGeneration"),
("ConvBertConfig", "ConvBertForMaskedLM"),
("LEDConfig", "LEDForConditionalGeneration"),
("BlenderbotSmallConfig", "BlenderbotSmallForConditionalGeneration"),
("LayoutLMConfig", "LayoutLMForMaskedLM"),
("T5Config", "T5ForConditionalGeneration"),
("DistilBertConfig", "DistilBertForMaskedLM"),
("AlbertConfig", "AlbertForMaskedLM"),
("CamembertConfig", "CamembertForMaskedLM"),
("XLMRobertaConfig", "XLMRobertaForMaskedLM"),
("MarianConfig", "MarianMTModel"),
("FSMTConfig", "FSMTForConditionalGeneration"),
("BartConfig", "BartForConditionalGeneration"),
("LongformerConfig", "LongformerForMaskedLM"),
("RobertaConfig", "RobertaForMaskedLM"),
("SqueezeBertConfig", "SqueezeBertForMaskedLM"),
("BertConfig", "BertForMaskedLM"),
("OpenAIGPTConfig", "OpenAIGPTLMHeadModel"),
("GPT2Config", "GPT2LMHeadModel"),
("MegatronBertConfig", "MegatronBertForCausalLM"),
("MobileBertConfig", "MobileBertForMaskedLM"),
("TransfoXLConfig", "TransfoXLLMHeadModel"),
("XLNetConfig", "XLNetLMHeadModel"),
("FlaubertConfig", "FlaubertWithLMHeadModel"),
("XLMConfig", "XLMWithLMHeadModel"),
("CTRLConfig", "CTRLLMHeadModel"),
("ElectraConfig", "ElectraForMaskedLM"),
("EncoderDecoderConfig", "EncoderDecoderModel"),
("ReformerConfig", "ReformerModelWithLMHead"),
("FunnelConfig", "FunnelForMaskedLM"),
("MPNetConfig", "MPNetForMaskedLM"),
("TapasConfig", "TapasForMaskedLM"),
("DebertaConfig", "DebertaForMaskedLM"),
("DebertaV2Config", "DebertaV2ForMaskedLM"),
("IBertConfig", "IBertForMaskedLM"),
]
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment