modeling_auto.py 46.4 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Sylvain Gugger's avatar
Sylvain Gugger committed
15
""" Auto Model class."""
thomwolf's avatar
thomwolf committed
16

17
import warnings
Julien Chaumond's avatar
Julien Chaumond committed
18
from collections import OrderedDict
thomwolf's avatar
thomwolf committed
19

Sylvain Gugger's avatar
Sylvain Gugger committed
20
from ...utils import logging
21
22
from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
from .configuration_auto import CONFIG_MAPPING_NAMES
thomwolf's avatar
thomwolf committed
23

thomwolf's avatar
thomwolf committed
24

Lysandre Debut's avatar
Lysandre Debut committed
25
logger = logging.get_logger(__name__)
thomwolf's avatar
thomwolf committed
26
27


28
MODEL_MAPPING_NAMES = OrderedDict(
Julien Chaumond's avatar
Julien Chaumond committed
29
    [
30
        # Base model mapping
31
        ("albert", "AlbertModel"),
32
        ("audio-spectrogram-transformer", "ASTModel"),
33
        ("bart", "BartModel"),
34
        ("beit", "BeitModel"),
35
36
37
38
39
40
        ("bert", "BertModel"),
        ("bert-generation", "BertGenerationEncoder"),
        ("big_bird", "BigBirdModel"),
        ("bigbird_pegasus", "BigBirdPegasusModel"),
        ("blenderbot", "BlenderbotModel"),
        ("blenderbot-small", "BlenderbotSmallModel"),
Younes Belkada's avatar
Younes Belkada committed
41
        ("bloom", "BloomModel"),
42
        ("camembert", "CamembertModel"),
43
44
        ("canine", "CanineModel"),
        ("clip", "CLIPModel"),
NielsRogge's avatar
NielsRogge committed
45
        ("clipseg", "CLIPSegModel"),
rooa's avatar
rooa committed
46
        ("codegen", "CodeGenModel"),
47
        ("conditional_detr", "ConditionalDetrModel"),
48
49
50
        ("convbert", "ConvBertModel"),
        ("convnext", "ConvNextModel"),
        ("ctrl", "CTRLModel"),
NielsRogge's avatar
NielsRogge committed
51
        ("cvt", "CvtModel"),
52
53
54
55
56
57
58
        ("data2vec-audio", "Data2VecAudioModel"),
        ("data2vec-text", "Data2VecTextModel"),
        ("data2vec-vision", "Data2VecVisionModel"),
        ("deberta", "DebertaModel"),
        ("deberta-v2", "DebertaV2Model"),
        ("decision_transformer", "DecisionTransformerModel"),
        ("decision_transformer_gpt2", "DecisionTransformerGPT2Model"),
NielsRogge's avatar
NielsRogge committed
59
        ("deformable_detr", "DeformableDetrModel"),
60
61
        ("deit", "DeiTModel"),
        ("detr", "DetrModel"),
62
        ("dinat", "DinatModel"),
63
        ("distilbert", "DistilBertModel"),
NielsRogge's avatar
NielsRogge committed
64
        ("donut-swin", "DonutSwinModel"),
65
66
67
        ("dpr", "DPRQuestionEncoder"),
        ("dpt", "DPTModel"),
        ("electra", "ElectraModel"),
68
        ("ernie", "ErnieModel"),
69
        ("esm", "EsmModel"),
70
71
72
73
74
75
76
        ("flaubert", "FlaubertModel"),
        ("flava", "FlavaModel"),
        ("fnet", "FNetModel"),
        ("fsmt", "FSMTModel"),
        ("funnel", ("FunnelModel", "FunnelBaseModel")),
        ("glpn", "GLPNModel"),
        ("gpt2", "GPT2Model"),
77
        ("gpt_neo", "GPTNeoModel"),
78
        ("gpt_neox", "GPTNeoXModel"),
79
        ("gpt_neox_japanese", "GPTNeoXJapaneseModel"),
80
        ("gptj", "GPTJModel"),
81
        ("groupvit", "GroupViTModel"),
82
        ("hubert", "HubertModel"),
83
84
        ("ibert", "IBertModel"),
        ("imagegpt", "ImageGPTModel"),
85
        ("jukebox", "JukeboxModel"),
86
87
        ("layoutlm", "LayoutLMModel"),
        ("layoutlmv2", "LayoutLMv2Model"),
NielsRogge's avatar
NielsRogge committed
88
        ("layoutlmv3", "LayoutLMv3Model"),
89
        ("led", "LEDModel"),
90
        ("levit", "LevitModel"),
NielsRogge's avatar
NielsRogge committed
91
        ("lilt", "LiltModel"),
92
        ("longformer", "LongformerModel"),
Daniel Stancl's avatar
Daniel Stancl committed
93
        ("longt5", "LongT5Model"),
94
95
96
        ("luke", "LukeModel"),
        ("lxmert", "LxmertModel"),
        ("m2m_100", "M2M100Model"),
97
        ("marian", "MarianModel"),
NielsRogge's avatar
NielsRogge committed
98
        ("markuplm", "MarkupLMModel"),
99
        ("maskformer", "MaskFormerModel"),
100
        ("maskformer-swin", "MaskFormerSwinModel"),
101
        ("mbart", "MBartModel"),
Chan Woo Kim's avatar
Chan Woo Kim committed
102
        ("mctct", "MCTCTModel"),
103
104
        ("megatron-bert", "MegatronBertModel"),
        ("mobilebert", "MobileBertModel"),
105
        ("mobilenet_v1", "MobileNetV1Model"),
106
        ("mobilenet_v2", "MobileNetV2Model"),
107
        ("mobilevit", "MobileViTModel"),
108
109
        ("mpnet", "MPNetModel"),
        ("mt5", "MT5Model"),
StevenTang1998's avatar
StevenTang1998 committed
110
        ("mvp", "MvpModel"),
111
        ("nat", "NatModel"),
112
        ("nezha", "NezhaModel"),
Lysandre Debut's avatar
Lysandre Debut committed
113
        ("nllb", "M2M100Model"),
114
115
        ("nystromformer", "NystromformerModel"),
        ("openai-gpt", "OpenAIGPTModel"),
Younes Belkada's avatar
Younes Belkada committed
116
        ("opt", "OPTModel"),
117
        ("owlvit", "OwlViTModel"),
118
        ("pegasus", "PegasusModel"),
Jason Phang's avatar
Jason Phang committed
119
        ("pegasus_x", "PegasusXModel"),
120
121
122
123
124
125
126
127
128
129
        ("perceiver", "PerceiverModel"),
        ("plbart", "PLBartModel"),
        ("poolformer", "PoolFormerModel"),
        ("prophetnet", "ProphetNetModel"),
        ("qdqbert", "QDQBertModel"),
        ("reformer", "ReformerModel"),
        ("regnet", "RegNetModel"),
        ("rembert", "RemBertModel"),
        ("resnet", "ResNetModel"),
        ("retribert", "RetriBertModel"),
130
        ("roberta", "RobertaModel"),
Weiwe Shi's avatar
Weiwe Shi committed
131
        ("roc_bert", "RoCBertModel"),
132
133
134
135
136
137
        ("roformer", "RoFormerModel"),
        ("segformer", "SegformerModel"),
        ("sew", "SEWModel"),
        ("sew-d", "SEWDModel"),
        ("speech_to_text", "Speech2TextModel"),
        ("splinter", "SplinterModel"),
138
        ("squeezebert", "SqueezeBertModel"),
139
        ("swin", "SwinModel"),
140
        ("swinv2", "Swinv2Model"),
141
        ("switch_transformers", "SwitchTransformersModel"),
142
        ("t5", "T5Model"),
143
        ("table-transformer", "TableTransformerModel"),
144
        ("tapas", "TapasModel"),
145
        ("time_series_transformer", "TimeSeriesTransformerModel"),
Carl's avatar
Carl committed
146
        ("trajectory_transformer", "TrajectoryTransformerModel"),
147
        ("transfo-xl", "TransfoXLModel"),
148
149
150
        ("unispeech", "UniSpeechModel"),
        ("unispeech-sat", "UniSpeechSatModel"),
        ("van", "VanModel"),
NielsRogge's avatar
NielsRogge committed
151
        ("videomae", "VideoMAEModel"),
152
153
154
155
156
        ("vilt", "ViltModel"),
        ("vision-text-dual-encoder", "VisionTextDualEncoderModel"),
        ("visual_bert", "VisualBertModel"),
        ("vit", "ViTModel"),
        ("vit_mae", "ViTMAEModel"),
157
        ("vit_msn", "ViTMSNModel"),
158
        ("wav2vec2", "Wav2Vec2Model"),
159
        ("wav2vec2-conformer", "Wav2Vec2ConformerModel"),
160
        ("wavlm", "WavLMModel"),
161
        ("whisper", "WhisperModel"),
NielsRogge's avatar
NielsRogge committed
162
        ("xclip", "XCLIPModel"),
163
        ("xglm", "XGLMModel"),
164
165
        ("xlm", "XLMModel"),
        ("xlm-prophetnet", "XLMProphetNetModel"),
166
167
168
169
170
        ("xlm-roberta", "XLMRobertaModel"),
        ("xlm-roberta-xl", "XLMRobertaXLModel"),
        ("xlnet", "XLNetModel"),
        ("yolos", "YolosModel"),
        ("yoso", "YosoModel"),
Julien Chaumond's avatar
Julien Chaumond committed
171
172
173
    ]
)

174
MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
thomwolf's avatar
thomwolf committed
175
    [
176
        # Model for pre-training mapping
177
178
179
180
        ("albert", "AlbertForPreTraining"),
        ("bart", "BartForConditionalGeneration"),
        ("bert", "BertForPreTraining"),
        ("big_bird", "BigBirdForPreTraining"),
Younes Belkada's avatar
Younes Belkada committed
181
        ("bloom", "BloomForCausalLM"),
182
        ("camembert", "CamembertForMaskedLM"),
183
        ("ctrl", "CTRLLMHeadModel"),
184
185
186
187
        ("data2vec-text", "Data2VecTextForMaskedLM"),
        ("deberta", "DebertaForMaskedLM"),
        ("deberta-v2", "DebertaV2ForMaskedLM"),
        ("distilbert", "DistilBertForMaskedLM"),
188
        ("electra", "ElectraForPreTraining"),
189
        ("ernie", "ErnieForPreTraining"),
190
191
192
193
        ("flaubert", "FlaubertWithLMHeadModel"),
        ("flava", "FlavaForPreTraining"),
        ("fnet", "FNetForPreTraining"),
        ("fsmt", "FSMTForConditionalGeneration"),
194
        ("funnel", "FunnelForPreTraining"),
195
196
197
198
        ("gpt2", "GPT2LMHeadModel"),
        ("ibert", "IBertForMaskedLM"),
        ("layoutlm", "LayoutLMForMaskedLM"),
        ("longformer", "LongformerForMaskedLM"),
199
        ("luke", "LukeForMaskedLM"),
200
201
202
        ("lxmert", "LxmertForPreTraining"),
        ("megatron-bert", "MegatronBertForPreTraining"),
        ("mobilebert", "MobileBertForPreTraining"),
203
        ("mpnet", "MPNetForMaskedLM"),
StevenTang1998's avatar
StevenTang1998 committed
204
        ("mvp", "MvpForConditionalGeneration"),
205
        ("nezha", "NezhaForPreTraining"),
206
207
208
        ("openai-gpt", "OpenAIGPTLMHeadModel"),
        ("retribert", "RetriBertModel"),
        ("roberta", "RobertaForMaskedLM"),
Weiwe Shi's avatar
Weiwe Shi committed
209
        ("roc_bert", "RoCBertForPreTraining"),
210
        ("splinter", "SplinterForPreTraining"),
211
        ("squeezebert", "SqueezeBertForMaskedLM"),
212
        ("switch_transformers", "SwitchTransformersForConditionalGeneration"),
213
        ("t5", "T5ForConditionalGeneration"),
214
        ("tapas", "TapasForMaskedLM"),
215
        ("transfo-xl", "TransfoXLLMHeadModel"),
216
        ("unispeech", "UniSpeechForPreTraining"),
217
        ("unispeech-sat", "UniSpeechSatForPreTraining"),
NielsRogge's avatar
NielsRogge committed
218
        ("videomae", "VideoMAEForPreTraining"),
219
220
221
        ("visual_bert", "VisualBertForPreTraining"),
        ("vit_mae", "ViTMAEForPreTraining"),
        ("wav2vec2", "Wav2Vec2ForPreTraining"),
222
        ("wav2vec2-conformer", "Wav2Vec2ConformerForPreTraining"),
223
224
225
226
        ("xlm", "XLMWithLMHeadModel"),
        ("xlm-roberta", "XLMRobertaForMaskedLM"),
        ("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"),
        ("xlnet", "XLNetLMHeadModel"),
thomwolf's avatar
thomwolf committed
227
228
229
    ]
)

230
MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
231
    [
232
        # Model with LM heads mapping
233
234
235
        ("albert", "AlbertForMaskedLM"),
        ("bart", "BartForConditionalGeneration"),
        ("bert", "BertForMaskedLM"),
236
237
238
        ("big_bird", "BigBirdForMaskedLM"),
        ("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"),
        ("blenderbot-small", "BlenderbotSmallForConditionalGeneration"),
Younes Belkada's avatar
Younes Belkada committed
239
        ("bloom", "BloomForCausalLM"),
240
        ("camembert", "CamembertForMaskedLM"),
rooa's avatar
rooa committed
241
        ("codegen", "CodeGenForCausalLM"),
242
        ("convbert", "ConvBertForMaskedLM"),
243
        ("ctrl", "CTRLLMHeadModel"),
244
245
246
247
        ("data2vec-text", "Data2VecTextForMaskedLM"),
        ("deberta", "DebertaForMaskedLM"),
        ("deberta-v2", "DebertaV2ForMaskedLM"),
        ("distilbert", "DistilBertForMaskedLM"),
248
249
        ("electra", "ElectraForMaskedLM"),
        ("encoder-decoder", "EncoderDecoderModel"),
250
        ("ernie", "ErnieForMaskedLM"),
251
        ("esm", "EsmForMaskedLM"),
252
253
254
        ("flaubert", "FlaubertWithLMHeadModel"),
        ("fnet", "FNetForMaskedLM"),
        ("fsmt", "FSMTForConditionalGeneration"),
255
        ("funnel", "FunnelForMaskedLM"),
256
257
        ("gpt2", "GPT2LMHeadModel"),
        ("gpt_neo", "GPTNeoForCausalLM"),
258
        ("gpt_neox", "GPTNeoXForCausalLM"),
259
        ("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"),
260
261
262
263
264
        ("gptj", "GPTJForCausalLM"),
        ("ibert", "IBertForMaskedLM"),
        ("layoutlm", "LayoutLMForMaskedLM"),
        ("led", "LEDForConditionalGeneration"),
        ("longformer", "LongformerForMaskedLM"),
Daniel Stancl's avatar
Daniel Stancl committed
265
        ("longt5", "LongT5ForConditionalGeneration"),
266
        ("luke", "LukeForMaskedLM"),
267
268
269
270
        ("m2m_100", "M2M100ForConditionalGeneration"),
        ("marian", "MarianMTModel"),
        ("megatron-bert", "MegatronBertForCausalLM"),
        ("mobilebert", "MobileBertForMaskedLM"),
271
        ("mpnet", "MPNetForMaskedLM"),
StevenTang1998's avatar
StevenTang1998 committed
272
        ("mvp", "MvpForConditionalGeneration"),
273
        ("nezha", "NezhaForMaskedLM"),
Lysandre Debut's avatar
Lysandre Debut committed
274
        ("nllb", "M2M100ForConditionalGeneration"),
275
276
        ("nystromformer", "NystromformerForMaskedLM"),
        ("openai-gpt", "OpenAIGPTLMHeadModel"),
Jason Phang's avatar
Jason Phang committed
277
        ("pegasus_x", "PegasusXForConditionalGeneration"),
278
279
280
281
282
        ("plbart", "PLBartForConditionalGeneration"),
        ("qdqbert", "QDQBertForMaskedLM"),
        ("reformer", "ReformerModelWithLMHead"),
        ("rembert", "RemBertForMaskedLM"),
        ("roberta", "RobertaForMaskedLM"),
Weiwe Shi's avatar
Weiwe Shi committed
283
        ("roc_bert", "RoCBertForMaskedLM"),
284
285
286
        ("roformer", "RoFormerForMaskedLM"),
        ("speech_to_text", "Speech2TextForConditionalGeneration"),
        ("squeezebert", "SqueezeBertForMaskedLM"),
287
        ("switch_transformers", "SwitchTransformersForConditionalGeneration"),
288
        ("t5", "T5ForConditionalGeneration"),
289
        ("tapas", "TapasForMaskedLM"),
290
291
        ("transfo-xl", "TransfoXLLMHeadModel"),
        ("wav2vec2", "Wav2Vec2ForMaskedLM"),
292
        ("whisper", "WhisperForConditionalGeneration"),
293
294
295
296
297
        ("xlm", "XLMWithLMHeadModel"),
        ("xlm-roberta", "XLMRobertaForMaskedLM"),
        ("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"),
        ("xlnet", "XLNetLMHeadModel"),
        ("yoso", "YosoForMaskedLM"),
298
299
300
    ]
)

301
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
302
    [
303
        # Model for Causal LM mapping
304
305
306
307
308
309
310
        ("bart", "BartForCausalLM"),
        ("bert", "BertLMHeadModel"),
        ("bert-generation", "BertGenerationDecoder"),
        ("big_bird", "BigBirdForCausalLM"),
        ("bigbird_pegasus", "BigBirdPegasusForCausalLM"),
        ("blenderbot", "BlenderbotForCausalLM"),
        ("blenderbot-small", "BlenderbotSmallForCausalLM"),
Younes Belkada's avatar
Younes Belkada committed
311
        ("bloom", "BloomForCausalLM"),
312
        ("camembert", "CamembertForCausalLM"),
rooa's avatar
rooa committed
313
        ("codegen", "CodeGenForCausalLM"),
314
315
316
        ("ctrl", "CTRLLMHeadModel"),
        ("data2vec-text", "Data2VecTextForCausalLM"),
        ("electra", "ElectraForCausalLM"),
317
        ("ernie", "ErnieForCausalLM"),
318
319
        ("gpt2", "GPT2LMHeadModel"),
        ("gpt_neo", "GPTNeoForCausalLM"),
320
        ("gpt_neox", "GPTNeoXForCausalLM"),
321
        ("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"),
322
323
324
325
        ("gptj", "GPTJForCausalLM"),
        ("marian", "MarianForCausalLM"),
        ("mbart", "MBartForCausalLM"),
        ("megatron-bert", "MegatronBertForCausalLM"),
StevenTang1998's avatar
StevenTang1998 committed
326
        ("mvp", "MvpForCausalLM"),
327
328
329
        ("openai-gpt", "OpenAIGPTLMHeadModel"),
        ("opt", "OPTForCausalLM"),
        ("pegasus", "PegasusForCausalLM"),
Gunjan Chhablani's avatar
Gunjan Chhablani committed
330
        ("plbart", "PLBartForCausalLM"),
331
        ("prophetnet", "ProphetNetForCausalLM"),
332
        ("qdqbert", "QDQBertLMHeadModel"),
333
        ("reformer", "ReformerModelWithLMHead"),
334
335
        ("rembert", "RemBertForCausalLM"),
        ("roberta", "RobertaForCausalLM"),
Weiwe Shi's avatar
Weiwe Shi committed
336
        ("roc_bert", "RoCBertForCausalLM"),
337
338
        ("roformer", "RoFormerForCausalLM"),
        ("speech_to_text_2", "Speech2Text2ForCausalLM"),
339
        ("transfo-xl", "TransfoXLLMHeadModel"),
340
341
        ("trocr", "TrOCRForCausalLM"),
        ("xglm", "XGLMForCausalLM"),
342
343
        ("xlm", "XLMWithLMHeadModel"),
        ("xlm-prophetnet", "XLMProphetNetForCausalLM"),
344
345
346
        ("xlm-roberta", "XLMRobertaForCausalLM"),
        ("xlm-roberta-xl", "XLMRobertaXLForCausalLM"),
        ("xlnet", "XLNetLMHeadModel"),
347
348
349
    ]
)

NielsRogge's avatar
NielsRogge committed
350
351
352
353
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES = OrderedDict(
    [
        ("deit", "DeiTForMaskedImageModeling"),
        ("swin", "SwinForMaskedImageModeling"),
354
        ("swinv2", "Swinv2ForMaskedImageModeling"),
355
        ("vit", "ViTForMaskedImageModeling"),
NielsRogge's avatar
NielsRogge committed
356
357
358
359
    ]
)


NielsRogge's avatar
NielsRogge committed
360
361
362
363
364
365
366
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES = OrderedDict(
    # Model for Causal Image Modeling mapping
    [
        ("imagegpt", "ImageGPTForCausalImageModeling"),
    ]
)

367
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
368
369
    [
        # Model for Image Classification mapping
370
        ("beit", "BeitForImageClassification"),
371
        ("convnext", "ConvNextForImageClassification"),
NielsRogge's avatar
NielsRogge committed
372
        ("cvt", "CvtForImageClassification"),
373
        ("data2vec-vision", "Data2VecVisionForImageClassification"),
374
        ("deit", ("DeiTForImageClassification", "DeiTForImageClassificationWithTeacher")),
375
        ("dinat", "DinatForImageClassification"),
NielsRogge's avatar
NielsRogge committed
376
        ("imagegpt", "ImageGPTForImageClassification"),
377
        ("levit", ("LevitForImageClassification", "LevitForImageClassificationWithTeacher")),
378
        ("mobilenet_v1", "MobileNetV1ForImageClassification"),
379
        ("mobilenet_v2", "MobileNetV2ForImageClassification"),
380
        ("mobilevit", "MobileViTForImageClassification"),
381
        ("nat", "NatForImageClassification"),
NielsRogge's avatar
NielsRogge committed
382
383
384
385
386
387
388
389
        (
            "perceiver",
            (
                "PerceiverForImageClassificationLearned",
                "PerceiverForImageClassificationFourier",
                "PerceiverForImageClassificationConvProcessing",
            ),
        ),
390
391
392
393
        ("poolformer", "PoolFormerForImageClassification"),
        ("regnet", "RegNetForImageClassification"),
        ("resnet", "ResNetForImageClassification"),
        ("segformer", "SegformerForImageClassification"),
novice's avatar
novice committed
394
        ("swin", "SwinForImageClassification"),
395
        ("swinv2", "Swinv2ForImageClassification"),
396
        ("van", "VanForImageClassification"),
397
        ("vit", "ViTForImageClassification"),
398
        ("vit_msn", "ViTMSNForImageClassification"),
399
400
401
    ]
)

402
403
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = OrderedDict(
    [
404
        # Do not add new models here, this class will be deprecated in the future.
405
406
407
408
409
        # Model for Image Segmentation mapping
        ("detr", "DetrForSegmentation"),
    ]
)

410
411
412
413
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict(
    [
        # Model for Semantic Segmentation mapping
        ("beit", "BeitForSemanticSegmentation"),
414
        ("data2vec-vision", "Data2VecVisionForSemanticSegmentation"),
NielsRogge's avatar
NielsRogge committed
415
        ("dpt", "DPTForSemanticSegmentation"),
416
        ("mobilenet_v2", "MobileNetV2ForSemanticSegmentation"),
417
        ("mobilevit", "MobileViTForSemanticSegmentation"),
418
        ("segformer", "SegformerForSemanticSegmentation"),
419
420
421
    ]
)

422
423
424
425
426
427
428
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES = OrderedDict(
    [
        # Model for Instance Segmentation mapping
        ("maskformer", "MaskFormerForInstanceSegmentation"),
    ]
)

NielsRogge's avatar
NielsRogge committed
429
430
431
432
433
434
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
    [
        ("videomae", "VideoMAEForVideoClassification"),
    ]
)

435
436
437
438
439
440
MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
    [
        ("vision-encoder-decoder", "VisionEncoderDecoderModel"),
    ]
)

441
MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
442
    [
443
        # Model for Masked LM mapping
444
445
        ("albert", "AlbertForMaskedLM"),
        ("bart", "BartForConditionalGeneration"),
446
447
        ("bert", "BertForMaskedLM"),
        ("big_bird", "BigBirdForMaskedLM"),
448
        ("camembert", "CamembertForMaskedLM"),
449
        ("convbert", "ConvBertForMaskedLM"),
450
        ("data2vec-text", "Data2VecTextForMaskedLM"),
451
452
453
        ("deberta", "DebertaForMaskedLM"),
        ("deberta-v2", "DebertaV2ForMaskedLM"),
        ("distilbert", "DistilBertForMaskedLM"),
454
        ("electra", "ElectraForMaskedLM"),
455
        ("ernie", "ErnieForMaskedLM"),
Matt's avatar
Matt committed
456
        ("esm", "EsmForMaskedLM"),
457
458
        ("flaubert", "FlaubertWithLMHeadModel"),
        ("fnet", "FNetForMaskedLM"),
459
        ("funnel", "FunnelForMaskedLM"),
460
461
462
        ("ibert", "IBertForMaskedLM"),
        ("layoutlm", "LayoutLMForMaskedLM"),
        ("longformer", "LongformerForMaskedLM"),
Ryokan RI's avatar
Ryokan RI committed
463
        ("luke", "LukeForMaskedLM"),
464
465
466
        ("mbart", "MBartForConditionalGeneration"),
        ("megatron-bert", "MegatronBertForMaskedLM"),
        ("mobilebert", "MobileBertForMaskedLM"),
467
        ("mpnet", "MPNetForMaskedLM"),
StevenTang1998's avatar
StevenTang1998 committed
468
        ("mvp", "MvpForConditionalGeneration"),
469
        ("nezha", "NezhaForMaskedLM"),
470
471
472
473
474
475
        ("nystromformer", "NystromformerForMaskedLM"),
        ("perceiver", "PerceiverForMaskedLM"),
        ("qdqbert", "QDQBertForMaskedLM"),
        ("reformer", "ReformerForMaskedLM"),
        ("rembert", "RemBertForMaskedLM"),
        ("roberta", "RobertaForMaskedLM"),
Weiwe Shi's avatar
Weiwe Shi committed
476
        ("roc_bert", "RoCBertForMaskedLM"),
477
478
        ("roformer", "RoFormerForMaskedLM"),
        ("squeezebert", "SqueezeBertForMaskedLM"),
479
        ("tapas", "TapasForMaskedLM"),
480
481
482
483
484
        ("wav2vec2", "Wav2Vec2ForMaskedLM"),
        ("xlm", "XLMWithLMHeadModel"),
        ("xlm-roberta", "XLMRobertaForMaskedLM"),
        ("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"),
        ("yoso", "YosoForMaskedLM"),
485
486
487
    ]
)

488
MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
NielsRogge's avatar
NielsRogge committed
489
490
    [
        # Model for Object Detection mapping
491
        ("conditional_detr", "ConditionalDetrForObjectDetection"),
NielsRogge's avatar
NielsRogge committed
492
        ("deformable_detr", "DeformableDetrForObjectDetection"),
493
        ("detr", "DetrForObjectDetection"),
494
        ("table-transformer", "TableTransformerForObjectDetection"),
495
        ("yolos", "YolosForObjectDetection"),
NielsRogge's avatar
NielsRogge committed
496
497
498
    ]
)

499
500
501
502
503
504
505
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
    [
        # Model for Zero Shot Object Detection mapping
        ("owlvit", "OwlViTForObjectDetection")
    ]
)

506
507
508
509
510
511
512
MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES = OrderedDict(
    [
        # Model for depth estimation mapping
        ("dpt", "DPTForDepthEstimation"),
        ("glpn", "GLPNForDepthEstimation"),
    ]
)
513
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
514
    [
515
        # Model for Seq2Seq Causal LM mapping
516
        ("bart", "BartForConditionalGeneration"),
517
        ("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"),
518
        ("blenderbot", "BlenderbotForConditionalGeneration"),
519
        ("blenderbot-small", "BlenderbotSmallForConditionalGeneration"),
520
521
522
        ("encoder-decoder", "EncoderDecoderModel"),
        ("fsmt", "FSMTForConditionalGeneration"),
        ("led", "LEDForConditionalGeneration"),
Daniel Stancl's avatar
Daniel Stancl committed
523
        ("longt5", "LongT5ForConditionalGeneration"),
524
        ("m2m_100", "M2M100ForConditionalGeneration"),
525
526
        ("marian", "MarianMTModel"),
        ("mbart", "MBartForConditionalGeneration"),
527
        ("mt5", "MT5ForConditionalGeneration"),
StevenTang1998's avatar
StevenTang1998 committed
528
        ("mvp", "MvpForConditionalGeneration"),
Lysandre Debut's avatar
Lysandre Debut committed
529
        ("nllb", "M2M100ForConditionalGeneration"),
530
        ("pegasus", "PegasusForConditionalGeneration"),
Jason Phang's avatar
Jason Phang committed
531
        ("pegasus_x", "PegasusXForConditionalGeneration"),
532
        ("plbart", "PLBartForConditionalGeneration"),
533
        ("prophetnet", "ProphetNetForConditionalGeneration"),
534
        ("switch_transformers", "SwitchTransformersForConditionalGeneration"),
535
536
        ("t5", "T5ForConditionalGeneration"),
        ("xlm-prophetnet", "XLMProphetNetForConditionalGeneration"),
537
538
539
    ]
)

540
541
542
543
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
    [
        ("speech-encoder-decoder", "SpeechEncoderDecoderModel"),
        ("speech_to_text", "Speech2TextForConditionalGeneration"),
544
        ("whisper", "WhisperForConditionalGeneration"),
545
546
547
    ]
)

548
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
549
    [
550
        # Model for Sequence Classification mapping
551
552
553
        ("albert", "AlbertForSequenceClassification"),
        ("bart", "BartForSequenceClassification"),
        ("bert", "BertForSequenceClassification"),
554
555
        ("big_bird", "BigBirdForSequenceClassification"),
        ("bigbird_pegasus", "BigBirdPegasusForSequenceClassification"),
556
        ("bloom", "BloomForSequenceClassification"),
557
558
559
560
561
        ("camembert", "CamembertForSequenceClassification"),
        ("canine", "CanineForSequenceClassification"),
        ("convbert", "ConvBertForSequenceClassification"),
        ("ctrl", "CTRLForSequenceClassification"),
        ("data2vec-text", "Data2VecTextForSequenceClassification"),
562
563
        ("deberta", "DebertaForSequenceClassification"),
        ("deberta-v2", "DebertaV2ForSequenceClassification"),
564
565
        ("distilbert", "DistilBertForSequenceClassification"),
        ("electra", "ElectraForSequenceClassification"),
566
        ("ernie", "ErnieForSequenceClassification"),
567
        ("esm", "EsmForSequenceClassification"),
568
569
570
        ("flaubert", "FlaubertForSequenceClassification"),
        ("fnet", "FNetForSequenceClassification"),
        ("funnel", "FunnelForSequenceClassification"),
571
572
        ("gpt2", "GPT2ForSequenceClassification"),
        ("gpt_neo", "GPTNeoForSequenceClassification"),
573
574
575
576
        ("gptj", "GPTJForSequenceClassification"),
        ("ibert", "IBertForSequenceClassification"),
        ("layoutlm", "LayoutLMForSequenceClassification"),
        ("layoutlmv2", "LayoutLMv2ForSequenceClassification"),
NielsRogge's avatar
NielsRogge committed
577
        ("layoutlmv3", "LayoutLMv3ForSequenceClassification"),
578
        ("led", "LEDForSequenceClassification"),
NielsRogge's avatar
NielsRogge committed
579
        ("lilt", "LiltForSequenceClassification"),
580
        ("longformer", "LongformerForSequenceClassification"),
581
        ("luke", "LukeForSequenceClassification"),
NielsRogge's avatar
NielsRogge committed
582
        ("markuplm", "MarkupLMForSequenceClassification"),
583
584
585
586
        ("mbart", "MBartForSequenceClassification"),
        ("megatron-bert", "MegatronBertForSequenceClassification"),
        ("mobilebert", "MobileBertForSequenceClassification"),
        ("mpnet", "MPNetForSequenceClassification"),
StevenTang1998's avatar
StevenTang1998 committed
587
        ("mvp", "MvpForSequenceClassification"),
588
        ("nezha", "NezhaForSequenceClassification"),
589
        ("nystromformer", "NystromformerForSequenceClassification"),
590
        ("openai-gpt", "OpenAIGPTForSequenceClassification"),
591
        ("opt", "OPTForSequenceClassification"),
592
593
594
        ("perceiver", "PerceiverForSequenceClassification"),
        ("plbart", "PLBartForSequenceClassification"),
        ("qdqbert", "QDQBertForSequenceClassification"),
595
        ("reformer", "ReformerForSequenceClassification"),
596
597
        ("rembert", "RemBertForSequenceClassification"),
        ("roberta", "RobertaForSequenceClassification"),
Weiwe Shi's avatar
Weiwe Shi committed
598
        ("roc_bert", "RoCBertForSequenceClassification"),
599
600
        ("roformer", "RoFormerForSequenceClassification"),
        ("squeezebert", "SqueezeBertForSequenceClassification"),
601
        ("tapas", "TapasForSequenceClassification"),
602
603
604
605
606
607
        ("transfo-xl", "TransfoXLForSequenceClassification"),
        ("xlm", "XLMForSequenceClassification"),
        ("xlm-roberta", "XLMRobertaForSequenceClassification"),
        ("xlm-roberta-xl", "XLMRobertaXLForSequenceClassification"),
        ("xlnet", "XLNetForSequenceClassification"),
        ("yoso", "YosoForSequenceClassification"),
608
609
610
    ]
)

611
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
612
    [
613
        # Model for Question Answering mapping
614
615
616
617
618
        ("albert", "AlbertForQuestionAnswering"),
        ("bart", "BartForQuestionAnswering"),
        ("bert", "BertForQuestionAnswering"),
        ("big_bird", "BigBirdForQuestionAnswering"),
        ("bigbird_pegasus", "BigBirdPegasusForQuestionAnswering"),
619
        ("bloom", "BloomForQuestionAnswering"),
620
621
622
623
624
625
626
627
        ("camembert", "CamembertForQuestionAnswering"),
        ("canine", "CanineForQuestionAnswering"),
        ("convbert", "ConvBertForQuestionAnswering"),
        ("data2vec-text", "Data2VecTextForQuestionAnswering"),
        ("deberta", "DebertaForQuestionAnswering"),
        ("deberta-v2", "DebertaV2ForQuestionAnswering"),
        ("distilbert", "DistilBertForQuestionAnswering"),
        ("electra", "ElectraForQuestionAnswering"),
628
        ("ernie", "ErnieForQuestionAnswering"),
629
        ("flaubert", "FlaubertForQuestionAnsweringSimple"),
Gunjan Chhablani's avatar
Gunjan Chhablani committed
630
        ("fnet", "FNetForQuestionAnswering"),
631
        ("funnel", "FunnelForQuestionAnswering"),
632
        ("gptj", "GPTJForQuestionAnswering"),
633
        ("ibert", "IBertForQuestionAnswering"),
634
        ("layoutlmv2", "LayoutLMv2ForQuestionAnswering"),
NielsRogge's avatar
NielsRogge committed
635
        ("layoutlmv3", "LayoutLMv3ForQuestionAnswering"),
636
        ("led", "LEDForQuestionAnswering"),
NielsRogge's avatar
NielsRogge committed
637
        ("lilt", "LiltForQuestionAnswering"),
638
        ("longformer", "LongformerForQuestionAnswering"),
639
        ("luke", "LukeForQuestionAnswering"),
640
        ("lxmert", "LxmertForQuestionAnswering"),
NielsRogge's avatar
NielsRogge committed
641
        ("markuplm", "MarkupLMForQuestionAnswering"),
642
        ("mbart", "MBartForQuestionAnswering"),
643
644
645
        ("megatron-bert", "MegatronBertForQuestionAnswering"),
        ("mobilebert", "MobileBertForQuestionAnswering"),
        ("mpnet", "MPNetForQuestionAnswering"),
StevenTang1998's avatar
StevenTang1998 committed
646
        ("mvp", "MvpForQuestionAnswering"),
647
        ("nezha", "NezhaForQuestionAnswering"),
648
        ("nystromformer", "NystromformerForQuestionAnswering"),
649
        ("opt", "OPTForQuestionAnswering"),
650
651
652
653
        ("qdqbert", "QDQBertForQuestionAnswering"),
        ("reformer", "ReformerForQuestionAnswering"),
        ("rembert", "RemBertForQuestionAnswering"),
        ("roberta", "RobertaForQuestionAnswering"),
Weiwe Shi's avatar
Weiwe Shi committed
654
        ("roc_bert", "RoCBertForQuestionAnswering"),
655
        ("roformer", "RoFormerForQuestionAnswering"),
Ori Ram's avatar
Ori Ram committed
656
        ("splinter", "SplinterForQuestionAnswering"),
657
658
659
660
661
662
        ("squeezebert", "SqueezeBertForQuestionAnswering"),
        ("xlm", "XLMForQuestionAnsweringSimple"),
        ("xlm-roberta", "XLMRobertaForQuestionAnswering"),
        ("xlm-roberta-xl", "XLMRobertaXLForQuestionAnswering"),
        ("xlnet", "XLNetForQuestionAnsweringSimple"),
        ("yoso", "YosoForQuestionAnswering"),
663
664
665
    ]
)

666
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
667
668
    [
        # Model for Table Question Answering mapping
669
        ("tapas", "TapasForQuestionAnswering"),
670
671
672
    ]
)

673
674
675
676
677
678
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
    [
        ("vilt", "ViltForQuestionAnswering"),
    ]
)

679
680
681
682
683
684
685
686
MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
    [
        ("layoutlm", "LayoutLMForQuestionAnswering"),
        ("layoutlmv2", "LayoutLMv2ForQuestionAnswering"),
        ("layoutlmv3", "LayoutLMv3ForQuestionAnswering"),
    ]
)

687
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
Julien Chaumond's avatar
Julien Chaumond committed
688
    [
689
        # Model for Token Classification mapping
690
691
        ("albert", "AlbertForTokenClassification"),
        ("bert", "BertForTokenClassification"),
692
        ("big_bird", "BigBirdForTokenClassification"),
693
        ("bloom", "BloomForTokenClassification"),
694
695
        ("camembert", "CamembertForTokenClassification"),
        ("canine", "CanineForTokenClassification"),
696
        ("convbert", "ConvBertForTokenClassification"),
697
698
699
        ("data2vec-text", "Data2VecTextForTokenClassification"),
        ("deberta", "DebertaForTokenClassification"),
        ("deberta-v2", "DebertaV2ForTokenClassification"),
700
        ("distilbert", "DistilBertForTokenClassification"),
701
        ("electra", "ElectraForTokenClassification"),
702
        ("ernie", "ErnieForTokenClassification"),
703
        ("esm", "EsmForTokenClassification"),
704
        ("flaubert", "FlaubertForTokenClassification"),
705
706
707
708
709
710
        ("fnet", "FNetForTokenClassification"),
        ("funnel", "FunnelForTokenClassification"),
        ("gpt2", "GPT2ForTokenClassification"),
        ("ibert", "IBertForTokenClassification"),
        ("layoutlm", "LayoutLMForTokenClassification"),
        ("layoutlmv2", "LayoutLMv2ForTokenClassification"),
NielsRogge's avatar
NielsRogge committed
711
        ("layoutlmv3", "LayoutLMv3ForTokenClassification"),
NielsRogge's avatar
NielsRogge committed
712
        ("lilt", "LiltForTokenClassification"),
713
        ("longformer", "LongformerForTokenClassification"),
714
        ("luke", "LukeForTokenClassification"),
NielsRogge's avatar
NielsRogge committed
715
        ("markuplm", "MarkupLMForTokenClassification"),
716
717
718
        ("megatron-bert", "MegatronBertForTokenClassification"),
        ("mobilebert", "MobileBertForTokenClassification"),
        ("mpnet", "MPNetForTokenClassification"),
719
        ("nezha", "NezhaForTokenClassification"),
720
721
722
723
        ("nystromformer", "NystromformerForTokenClassification"),
        ("qdqbert", "QDQBertForTokenClassification"),
        ("rembert", "RemBertForTokenClassification"),
        ("roberta", "RobertaForTokenClassification"),
Weiwe Shi's avatar
Weiwe Shi committed
724
        ("roc_bert", "RoCBertForTokenClassification"),
725
726
727
728
729
730
731
        ("roformer", "RoFormerForTokenClassification"),
        ("squeezebert", "SqueezeBertForTokenClassification"),
        ("xlm", "XLMForTokenClassification"),
        ("xlm-roberta", "XLMRobertaForTokenClassification"),
        ("xlm-roberta-xl", "XLMRobertaXLForTokenClassification"),
        ("xlnet", "XLNetForTokenClassification"),
        ("yoso", "YosoForTokenClassification"),
Julien Chaumond's avatar
Julien Chaumond committed
732
733
734
    ]
)

735
MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
Julien Chaumond's avatar
Julien Chaumond committed
736
    [
737
        # Model for Multiple Choice mapping
738
739
        ("albert", "AlbertForMultipleChoice"),
        ("bert", "BertForMultipleChoice"),
740
741
        ("big_bird", "BigBirdForMultipleChoice"),
        ("camembert", "CamembertForMultipleChoice"),
742
743
        ("canine", "CanineForMultipleChoice"),
        ("convbert", "ConvBertForMultipleChoice"),
744
        ("data2vec-text", "Data2VecTextForMultipleChoice"),
745
        ("deberta-v2", "DebertaV2ForMultipleChoice"),
746
        ("distilbert", "DistilBertForMultipleChoice"),
747
        ("electra", "ElectraForMultipleChoice"),
748
        ("ernie", "ErnieForMultipleChoice"),
749
        ("flaubert", "FlaubertForMultipleChoice"),
750
        ("fnet", "FNetForMultipleChoice"),
751
752
        ("funnel", "FunnelForMultipleChoice"),
        ("ibert", "IBertForMultipleChoice"),
753
        ("longformer", "LongformerForMultipleChoice"),
754
        ("luke", "LukeForMultipleChoice"),
755
756
757
        ("megatron-bert", "MegatronBertForMultipleChoice"),
        ("mobilebert", "MobileBertForMultipleChoice"),
        ("mpnet", "MPNetForMultipleChoice"),
758
        ("nezha", "NezhaForMultipleChoice"),
759
760
761
762
        ("nystromformer", "NystromformerForMultipleChoice"),
        ("qdqbert", "QDQBertForMultipleChoice"),
        ("rembert", "RemBertForMultipleChoice"),
        ("roberta", "RobertaForMultipleChoice"),
Weiwe Shi's avatar
Weiwe Shi committed
763
        ("roc_bert", "RoCBertForMultipleChoice"),
764
765
766
767
768
769
770
        ("roformer", "RoFormerForMultipleChoice"),
        ("squeezebert", "SqueezeBertForMultipleChoice"),
        ("xlm", "XLMForMultipleChoice"),
        ("xlm-roberta", "XLMRobertaForMultipleChoice"),
        ("xlm-roberta-xl", "XLMRobertaXLForMultipleChoice"),
        ("xlnet", "XLNetForMultipleChoice"),
        ("yoso", "YosoForMultipleChoice"),
Julien Chaumond's avatar
Julien Chaumond committed
771
772
773
    ]
)

774
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
775
    [
776
        ("bert", "BertForNextSentencePrediction"),
777
        ("ernie", "ErnieForNextSentencePrediction"),
Gunjan Chhablani's avatar
Gunjan Chhablani committed
778
        ("fnet", "FNetForNextSentencePrediction"),
779
780
        ("megatron-bert", "MegatronBertForNextSentencePrediction"),
        ("mobilebert", "MobileBertForNextSentencePrediction"),
781
        ("nezha", "NezhaForNextSentencePrediction"),
782
        ("qdqbert", "QDQBertForNextSentencePrediction"),
783
784
785
    ]
)

786
787
788
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
    [
        # Model for Audio Classification mapping
789
        ("audio-spectrogram-transformer", "ASTForAudioClassification"),
790
        ("data2vec-audio", "Data2VecAudioForSequenceClassification"),
791
        ("hubert", "HubertForSequenceClassification"),
792
793
        ("sew", "SEWForSequenceClassification"),
        ("sew-d", "SEWDForSequenceClassification"),
794
795
796
        ("unispeech", "UniSpeechForSequenceClassification"),
        ("unispeech-sat", "UniSpeechSatForSequenceClassification"),
        ("wav2vec2", "Wav2Vec2ForSequenceClassification"),
797
        ("wav2vec2-conformer", "Wav2Vec2ConformerForSequenceClassification"),
Patrick von Platen's avatar
Patrick von Platen committed
798
        ("wavlm", "WavLMForSequenceClassification"),
799
800
801
    ]
)

802
803
804
MODEL_FOR_CTC_MAPPING_NAMES = OrderedDict(
    [
        # Model for Connectionist temporal classification (CTC) mapping
805
        ("data2vec-audio", "Data2VecAudioForCTC"),
806
        ("hubert", "HubertForCTC"),
Chan Woo Kim's avatar
Chan Woo Kim committed
807
        ("mctct", "MCTCTForCTC"),
808
809
        ("sew", "SEWForCTC"),
        ("sew-d", "SEWDForCTC"),
810
811
812
        ("unispeech", "UniSpeechForCTC"),
        ("unispeech-sat", "UniSpeechSatForCTC"),
        ("wav2vec2", "Wav2Vec2ForCTC"),
813
        ("wav2vec2-conformer", "Wav2Vec2ConformerForCTC"),
Patrick von Platen's avatar
Patrick von Platen committed
814
        ("wavlm", "WavLMForCTC"),
815
816
817
    ]
)

818
819
820
MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
    [
        # Model for Audio Classification mapping
821
        ("data2vec-audio", "Data2VecAudioForAudioFrameClassification"),
822
        ("unispeech-sat", "UniSpeechSatForAudioFrameClassification"),
823
        ("wav2vec2", "Wav2Vec2ForAudioFrameClassification"),
824
        ("wav2vec2-conformer", "Wav2Vec2ConformerForAudioFrameClassification"),
825
        ("wavlm", "WavLMForAudioFrameClassification"),
826
827
828
829
830
831
    ]
)

MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = OrderedDict(
    [
        # Model for Audio Classification mapping
832
        ("data2vec-audio", "Data2VecAudioForXVector"),
833
        ("unispeech-sat", "UniSpeechSatForXVector"),
834
        ("wav2vec2", "Wav2Vec2ForXVector"),
835
        ("wav2vec2-conformer", "Wav2Vec2ConformerForXVector"),
836
        ("wavlm", "WavLMForXVector"),
837
838
839
    ]
)

840
841
842
843
_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
    [
        # Model for Zero Shot Image Classification mapping
        ("clip", "CLIPModel"),
NielsRogge's avatar
NielsRogge committed
844
        ("clipseg", "CLIPSegModel"),
845
846
847
    ]
)

848
849
850
MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict(
    [
        # Backbone mapping
851
        ("maskformer-swin", "MaskFormerSwinBackbone"),
852
853
854
855
        ("resnet", "ResNetBackbone"),
    ]
)

856
857
858
859
MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES)
MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES)
MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
NielsRogge's avatar
NielsRogge committed
860
861
862
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES
)
863
864
865
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
)
866
867
868
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES
)
869
870
871
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES
)
872
873
874
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES
)
NielsRogge's avatar
NielsRogge committed
875
876
877
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES
)
878
MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
879
880
881
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES
)
882
883
884
MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES
)
885
MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES)
NielsRogge's avatar
NielsRogge committed
886
887
888
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES
)
889
MODEL_FOR_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES)
890
891
892
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES
)
893
MODEL_FOR_DEPTH_ESTIMATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES)
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
)
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
)
MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
)
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES
)
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
)
MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES)
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
)
913
914
915
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
)
916
917
MODEL_FOR_CTC_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CTC_MAPPING_NAMES)
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES)
918
919
920
921
MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES
)
MODEL_FOR_AUDIO_XVECTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES)
922

923
924
MODEL_FOR_BACKBONE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_BACKBONE_MAPPING_NAMES)

925

Sylvain Gugger's avatar
Sylvain Gugger committed
926
927
928
929
930
931
932
933
934
935
936
937
class AutoModel(_BaseAutoModelClass):
    _model_mapping = MODEL_MAPPING


AutoModel = auto_class_update(AutoModel)


class AutoModelForPreTraining(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_PRETRAINING_MAPPING


AutoModelForPreTraining = auto_class_update(AutoModelForPreTraining, head_doc="pretraining")
938

thomwolf's avatar
thomwolf committed
939

940
# Private on purpose, the public class will add the deprecation warnings.
Sylvain Gugger's avatar
Sylvain Gugger committed
941
942
class _AutoModelWithLMHead(_BaseAutoModelClass):
    _model_mapping = MODEL_WITH_LM_HEAD_MAPPING
thomwolf's avatar
thomwolf committed
943
944


Sylvain Gugger's avatar
Sylvain Gugger committed
945
_AutoModelWithLMHead = auto_class_update(_AutoModelWithLMHead, head_doc="language modeling")
thomwolf's avatar
thomwolf committed
946
947


Sylvain Gugger's avatar
Sylvain Gugger committed
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
class AutoModelForCausalLM(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING


AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling")


class AutoModelForMaskedLM(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_MASKED_LM_MAPPING


AutoModelForMaskedLM = auto_class_update(AutoModelForMaskedLM, head_doc="masked language modeling")


class AutoModelForSeq2SeqLM(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING


AutoModelForSeq2SeqLM = auto_class_update(
    AutoModelForSeq2SeqLM, head_doc="sequence-to-sequence language modeling", checkpoint_for_example="t5-base"
968
)
thomwolf's avatar
thomwolf committed
969

Sylvain Gugger's avatar
Sylvain Gugger committed
970
971
972
973
974
975
976

class AutoModelForSequenceClassification(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING


AutoModelForSequenceClassification = auto_class_update(
    AutoModelForSequenceClassification, head_doc="sequence classification"
977
)
thomwolf's avatar
thomwolf committed
978

Sylvain Gugger's avatar
Sylvain Gugger committed
979
980
981
982
983
984
985
986
987
988
989
990
991
992

class AutoModelForQuestionAnswering(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING


AutoModelForQuestionAnswering = auto_class_update(AutoModelForQuestionAnswering, head_doc="question answering")


class AutoModelForTableQuestionAnswering(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING


AutoModelForTableQuestionAnswering = auto_class_update(
    AutoModelForTableQuestionAnswering,
993
994
995
    head_doc="table question answering",
    checkpoint_for_example="google/tapas-base-finetuned-wtq",
)
thomwolf's avatar
thomwolf committed
996
997


998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
class AutoModelForVisualQuestionAnswering(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING


AutoModelForVisualQuestionAnswering = auto_class_update(
    AutoModelForVisualQuestionAnswering,
    head_doc="visual question answering",
    checkpoint_for_example="dandelin/vilt-b32-finetuned-vqa",
)


1009
1010
1011
1012
1013
1014
1015
class AutoModelForDocumentQuestionAnswering(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING


AutoModelForDocumentQuestionAnswering = auto_class_update(
    AutoModelForDocumentQuestionAnswering,
    head_doc="document question answering",
1016
    checkpoint_for_example='impira/layoutlm-document-qa", revision="52e01b3',
1017
1018
1019
)


Sylvain Gugger's avatar
Sylvain Gugger committed
1020
1021
class AutoModelForTokenClassification(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
1022

1023

Sylvain Gugger's avatar
Sylvain Gugger committed
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
AutoModelForTokenClassification = auto_class_update(AutoModelForTokenClassification, head_doc="token classification")


class AutoModelForMultipleChoice(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_MULTIPLE_CHOICE_MAPPING


AutoModelForMultipleChoice = auto_class_update(AutoModelForMultipleChoice, head_doc="multiple choice")


class AutoModelForNextSentencePrediction(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING


AutoModelForNextSentencePrediction = auto_class_update(
    AutoModelForNextSentencePrediction, head_doc="next sentence prediction"
1040
)
1041

1042

Sylvain Gugger's avatar
Sylvain Gugger committed
1043
1044
1045
1046
1047
1048
1049
class AutoModelForImageClassification(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING


AutoModelForImageClassification = auto_class_update(AutoModelForImageClassification, head_doc="image classification")


1050
1051
1052
1053
1054
1055
1056
class AutoModelForImageSegmentation(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_IMAGE_SEGMENTATION_MAPPING


AutoModelForImageSegmentation = auto_class_update(AutoModelForImageSegmentation, head_doc="image segmentation")


1057
1058
1059
1060
1061
1062
1063
1064
1065
class AutoModelForSemanticSegmentation(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING


AutoModelForSemanticSegmentation = auto_class_update(
    AutoModelForSemanticSegmentation, head_doc="semantic segmentation"
)


1066
1067
1068
1069
1070
1071
1072
1073
1074
class AutoModelForInstanceSegmentation(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING


AutoModelForInstanceSegmentation = auto_class_update(
    AutoModelForInstanceSegmentation, head_doc="instance segmentation"
)


1075
1076
1077
1078
1079
1080
1081
class AutoModelForObjectDetection(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING


AutoModelForObjectDetection = auto_class_update(AutoModelForObjectDetection, head_doc="object detection")


1082
1083
1084
1085
1086
1087
1088
1089
1090
class AutoModelForZeroShotObjectDetection(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING


AutoModelForZeroShotObjectDetection = auto_class_update(
    AutoModelForZeroShotObjectDetection, head_doc="zero-shot object detection"
)


1091
1092
1093
1094
1095
1096
1097
class AutoModelForDepthEstimation(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_DEPTH_ESTIMATION_MAPPING


AutoModelForDepthEstimation = auto_class_update(AutoModelForDepthEstimation, head_doc="depth estimation")


NielsRogge's avatar
NielsRogge committed
1098
1099
1100
1101
1102
1103
1104
class AutoModelForVideoClassification(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING


AutoModelForVideoClassification = auto_class_update(AutoModelForVideoClassification, head_doc="video classification")


1105
1106
1107
1108
1109
1110
1111
class AutoModelForVision2Seq(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_VISION_2_SEQ_MAPPING


AutoModelForVision2Seq = auto_class_update(AutoModelForVision2Seq, head_doc="vision-to-text modeling")


1112
1113
1114
1115
1116
1117
1118
class AutoModelForAudioClassification(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING


AutoModelForAudioClassification = auto_class_update(AutoModelForAudioClassification, head_doc="audio classification")


1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
class AutoModelForCTC(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_CTC_MAPPING


AutoModelForCTC = auto_class_update(AutoModelForCTC, head_doc="connectionist temporal classification")


class AutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING


AutoModelForSpeechSeq2Seq = auto_class_update(
Joao Gante's avatar
Joao Gante committed
1131
    AutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling"
1132
1133
1134
)


1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
class AutoModelForAudioFrameClassification(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING


AutoModelForAudioFrameClassification = auto_class_update(
    AutoModelForAudioFrameClassification, head_doc="audio frame (token) classification"
)


class AutoModelForAudioXVector(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_AUDIO_XVECTOR_MAPPING


1148
1149
1150
1151
class AutoBackbone(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_BACKBONE_MAPPING


1152
1153
1154
AutoModelForAudioXVector = auto_class_update(AutoModelForAudioXVector, head_doc="audio retrieval via x-vector")


NielsRogge's avatar
NielsRogge committed
1155
1156
1157
1158
1159
1160
1161
class AutoModelForMaskedImageModeling(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING


AutoModelForMaskedImageModeling = auto_class_update(AutoModelForMaskedImageModeling, head_doc="masked image modeling")


1162
class AutoModelWithLMHead(_AutoModelWithLMHead):
1163
1164
    @classmethod
    def from_config(cls, config):
1165
        warnings.warn(
1166
1167
1168
            "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use "
            "`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and "
            "`AutoModelForSeq2SeqLM` for encoder-decoder models.",
1169
1170
            FutureWarning,
        )
1171
        return super().from_config(config)
1172
1173
1174

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
1175
        warnings.warn(
1176
1177
1178
            "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use "
            "`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and "
            "`AutoModelForSeq2SeqLM` for encoder-decoder models.",
1179
1180
            FutureWarning,
        )
1181
        return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)