modeling_auto.py 46.2 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Sylvain Gugger's avatar
Sylvain Gugger committed
15
""" Auto Model class."""
thomwolf's avatar
thomwolf committed
16

17
import warnings
Julien Chaumond's avatar
Julien Chaumond committed
18
from collections import OrderedDict
thomwolf's avatar
thomwolf committed
19

Sylvain Gugger's avatar
Sylvain Gugger committed
20
from ...utils import logging
21
22
from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
from .configuration_auto import CONFIG_MAPPING_NAMES
thomwolf's avatar
thomwolf committed
23

thomwolf's avatar
thomwolf committed
24

Lysandre Debut's avatar
Lysandre Debut committed
25
logger = logging.get_logger(__name__)
thomwolf's avatar
thomwolf committed
26
27


28
MODEL_MAPPING_NAMES = OrderedDict(
Julien Chaumond's avatar
Julien Chaumond committed
29
    [
30
        # Base model mapping
31
32
        ("albert", "AlbertModel"),
        ("bart", "BartModel"),
33
        ("beit", "BeitModel"),
34
35
36
37
38
39
        ("bert", "BertModel"),
        ("bert-generation", "BertGenerationEncoder"),
        ("big_bird", "BigBirdModel"),
        ("bigbird_pegasus", "BigBirdPegasusModel"),
        ("blenderbot", "BlenderbotModel"),
        ("blenderbot-small", "BlenderbotSmallModel"),
Younes Belkada's avatar
Younes Belkada committed
40
        ("bloom", "BloomModel"),
41
        ("camembert", "CamembertModel"),
42
43
        ("canine", "CanineModel"),
        ("clip", "CLIPModel"),
NielsRogge's avatar
NielsRogge committed
44
        ("clipseg", "CLIPSegModel"),
rooa's avatar
rooa committed
45
        ("codegen", "CodeGenModel"),
46
        ("conditional_detr", "ConditionalDetrModel"),
47
48
49
        ("convbert", "ConvBertModel"),
        ("convnext", "ConvNextModel"),
        ("ctrl", "CTRLModel"),
NielsRogge's avatar
NielsRogge committed
50
        ("cvt", "CvtModel"),
51
52
53
54
55
56
57
        ("data2vec-audio", "Data2VecAudioModel"),
        ("data2vec-text", "Data2VecTextModel"),
        ("data2vec-vision", "Data2VecVisionModel"),
        ("deberta", "DebertaModel"),
        ("deberta-v2", "DebertaV2Model"),
        ("decision_transformer", "DecisionTransformerModel"),
        ("decision_transformer_gpt2", "DecisionTransformerGPT2Model"),
NielsRogge's avatar
NielsRogge committed
58
        ("deformable_detr", "DeformableDetrModel"),
59
60
        ("deit", "DeiTModel"),
        ("detr", "DetrModel"),
61
        ("dinat", "DinatModel"),
62
        ("distilbert", "DistilBertModel"),
NielsRogge's avatar
NielsRogge committed
63
        ("donut-swin", "DonutSwinModel"),
64
65
66
        ("dpr", "DPRQuestionEncoder"),
        ("dpt", "DPTModel"),
        ("electra", "ElectraModel"),
67
        ("ernie", "ErnieModel"),
68
        ("esm", "EsmModel"),
69
70
71
72
73
74
75
        ("flaubert", "FlaubertModel"),
        ("flava", "FlavaModel"),
        ("fnet", "FNetModel"),
        ("fsmt", "FSMTModel"),
        ("funnel", ("FunnelModel", "FunnelBaseModel")),
        ("glpn", "GLPNModel"),
        ("gpt2", "GPT2Model"),
76
        ("gpt_neo", "GPTNeoModel"),
77
        ("gpt_neox", "GPTNeoXModel"),
78
        ("gpt_neox_japanese", "GPTNeoXJapaneseModel"),
79
        ("gptj", "GPTJModel"),
80
        ("groupvit", "GroupViTModel"),
81
        ("hubert", "HubertModel"),
82
83
        ("ibert", "IBertModel"),
        ("imagegpt", "ImageGPTModel"),
84
        ("jukebox", "JukeboxModel"),
85
86
        ("layoutlm", "LayoutLMModel"),
        ("layoutlmv2", "LayoutLMv2Model"),
NielsRogge's avatar
NielsRogge committed
87
        ("layoutlmv3", "LayoutLMv3Model"),
88
        ("led", "LEDModel"),
89
        ("levit", "LevitModel"),
NielsRogge's avatar
NielsRogge committed
90
        ("lilt", "LiltModel"),
91
        ("longformer", "LongformerModel"),
Daniel Stancl's avatar
Daniel Stancl committed
92
        ("longt5", "LongT5Model"),
93
94
95
        ("luke", "LukeModel"),
        ("lxmert", "LxmertModel"),
        ("m2m_100", "M2M100Model"),
96
        ("marian", "MarianModel"),
NielsRogge's avatar
NielsRogge committed
97
        ("markuplm", "MarkupLMModel"),
98
        ("maskformer", "MaskFormerModel"),
99
        ("mbart", "MBartModel"),
Chan Woo Kim's avatar
Chan Woo Kim committed
100
        ("mctct", "MCTCTModel"),
101
102
        ("megatron-bert", "MegatronBertModel"),
        ("mobilebert", "MobileBertModel"),
103
        ("mobilenet_v1", "MobileNetV1Model"),
104
        ("mobilenet_v2", "MobileNetV2Model"),
105
        ("mobilevit", "MobileViTModel"),
106
107
        ("mpnet", "MPNetModel"),
        ("mt5", "MT5Model"),
StevenTang1998's avatar
StevenTang1998 committed
108
        ("mvp", "MvpModel"),
109
        ("nat", "NatModel"),
110
        ("nezha", "NezhaModel"),
Lysandre Debut's avatar
Lysandre Debut committed
111
        ("nllb", "M2M100Model"),
112
113
        ("nystromformer", "NystromformerModel"),
        ("openai-gpt", "OpenAIGPTModel"),
Younes Belkada's avatar
Younes Belkada committed
114
        ("opt", "OPTModel"),
115
        ("owlvit", "OwlViTModel"),
116
        ("pegasus", "PegasusModel"),
Jason Phang's avatar
Jason Phang committed
117
        ("pegasus_x", "PegasusXModel"),
118
119
120
121
122
123
124
125
126
127
        ("perceiver", "PerceiverModel"),
        ("plbart", "PLBartModel"),
        ("poolformer", "PoolFormerModel"),
        ("prophetnet", "ProphetNetModel"),
        ("qdqbert", "QDQBertModel"),
        ("reformer", "ReformerModel"),
        ("regnet", "RegNetModel"),
        ("rembert", "RemBertModel"),
        ("resnet", "ResNetModel"),
        ("retribert", "RetriBertModel"),
128
        ("roberta", "RobertaModel"),
Weiwe Shi's avatar
Weiwe Shi committed
129
        ("roc_bert", "RoCBertModel"),
130
131
132
133
134
135
        ("roformer", "RoFormerModel"),
        ("segformer", "SegformerModel"),
        ("sew", "SEWModel"),
        ("sew-d", "SEWDModel"),
        ("speech_to_text", "Speech2TextModel"),
        ("splinter", "SplinterModel"),
136
        ("squeezebert", "SqueezeBertModel"),
137
        ("swin", "SwinModel"),
138
        ("swinv2", "Swinv2Model"),
139
        ("switch_transformers", "SwitchTransformersModel"),
140
        ("t5", "T5Model"),
141
        ("table-transformer", "TableTransformerModel"),
142
        ("tapas", "TapasModel"),
143
        ("time_series_transformer", "TimeSeriesTransformerModel"),
Carl's avatar
Carl committed
144
        ("trajectory_transformer", "TrajectoryTransformerModel"),
145
        ("transfo-xl", "TransfoXLModel"),
146
147
148
        ("unispeech", "UniSpeechModel"),
        ("unispeech-sat", "UniSpeechSatModel"),
        ("van", "VanModel"),
NielsRogge's avatar
NielsRogge committed
149
        ("videomae", "VideoMAEModel"),
150
151
152
153
154
        ("vilt", "ViltModel"),
        ("vision-text-dual-encoder", "VisionTextDualEncoderModel"),
        ("visual_bert", "VisualBertModel"),
        ("vit", "ViTModel"),
        ("vit_mae", "ViTMAEModel"),
155
        ("vit_msn", "ViTMSNModel"),
156
        ("wav2vec2", "Wav2Vec2Model"),
157
        ("wav2vec2-conformer", "Wav2Vec2ConformerModel"),
158
        ("wavlm", "WavLMModel"),
159
        ("whisper", "WhisperModel"),
NielsRogge's avatar
NielsRogge committed
160
        ("xclip", "XCLIPModel"),
161
        ("xglm", "XGLMModel"),
162
163
        ("xlm", "XLMModel"),
        ("xlm-prophetnet", "XLMProphetNetModel"),
164
165
166
167
168
        ("xlm-roberta", "XLMRobertaModel"),
        ("xlm-roberta-xl", "XLMRobertaXLModel"),
        ("xlnet", "XLNetModel"),
        ("yolos", "YolosModel"),
        ("yoso", "YosoModel"),
Julien Chaumond's avatar
Julien Chaumond committed
169
170
171
    ]
)

172
MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
thomwolf's avatar
thomwolf committed
173
    [
174
        # Model for pre-training mapping
175
176
177
178
        ("albert", "AlbertForPreTraining"),
        ("bart", "BartForConditionalGeneration"),
        ("bert", "BertForPreTraining"),
        ("big_bird", "BigBirdForPreTraining"),
Younes Belkada's avatar
Younes Belkada committed
179
        ("bloom", "BloomForCausalLM"),
180
        ("camembert", "CamembertForMaskedLM"),
181
        ("ctrl", "CTRLLMHeadModel"),
182
183
184
185
        ("data2vec-text", "Data2VecTextForMaskedLM"),
        ("deberta", "DebertaForMaskedLM"),
        ("deberta-v2", "DebertaV2ForMaskedLM"),
        ("distilbert", "DistilBertForMaskedLM"),
186
        ("electra", "ElectraForPreTraining"),
187
        ("ernie", "ErnieForPreTraining"),
188
189
190
191
        ("flaubert", "FlaubertWithLMHeadModel"),
        ("flava", "FlavaForPreTraining"),
        ("fnet", "FNetForPreTraining"),
        ("fsmt", "FSMTForConditionalGeneration"),
192
        ("funnel", "FunnelForPreTraining"),
193
194
195
196
        ("gpt2", "GPT2LMHeadModel"),
        ("ibert", "IBertForMaskedLM"),
        ("layoutlm", "LayoutLMForMaskedLM"),
        ("longformer", "LongformerForMaskedLM"),
197
        ("luke", "LukeForMaskedLM"),
198
199
200
        ("lxmert", "LxmertForPreTraining"),
        ("megatron-bert", "MegatronBertForPreTraining"),
        ("mobilebert", "MobileBertForPreTraining"),
201
        ("mpnet", "MPNetForMaskedLM"),
StevenTang1998's avatar
StevenTang1998 committed
202
        ("mvp", "MvpForConditionalGeneration"),
203
        ("nezha", "NezhaForPreTraining"),
204
205
206
        ("openai-gpt", "OpenAIGPTLMHeadModel"),
        ("retribert", "RetriBertModel"),
        ("roberta", "RobertaForMaskedLM"),
Weiwe Shi's avatar
Weiwe Shi committed
207
        ("roc_bert", "RoCBertForPreTraining"),
208
        ("splinter", "SplinterForPreTraining"),
209
        ("squeezebert", "SqueezeBertForMaskedLM"),
210
        ("switch_transformers", "SwitchTransformersForConditionalGeneration"),
211
        ("t5", "T5ForConditionalGeneration"),
212
        ("tapas", "TapasForMaskedLM"),
213
        ("transfo-xl", "TransfoXLLMHeadModel"),
214
        ("unispeech", "UniSpeechForPreTraining"),
215
        ("unispeech-sat", "UniSpeechSatForPreTraining"),
NielsRogge's avatar
NielsRogge committed
216
        ("videomae", "VideoMAEForPreTraining"),
217
218
219
        ("visual_bert", "VisualBertForPreTraining"),
        ("vit_mae", "ViTMAEForPreTraining"),
        ("wav2vec2", "Wav2Vec2ForPreTraining"),
220
        ("wav2vec2-conformer", "Wav2Vec2ConformerForPreTraining"),
221
222
223
224
        ("xlm", "XLMWithLMHeadModel"),
        ("xlm-roberta", "XLMRobertaForMaskedLM"),
        ("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"),
        ("xlnet", "XLNetLMHeadModel"),
thomwolf's avatar
thomwolf committed
225
226
227
    ]
)

228
MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
229
    [
230
        # Model with LM heads mapping
231
232
233
        ("albert", "AlbertForMaskedLM"),
        ("bart", "BartForConditionalGeneration"),
        ("bert", "BertForMaskedLM"),
234
235
236
        ("big_bird", "BigBirdForMaskedLM"),
        ("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"),
        ("blenderbot-small", "BlenderbotSmallForConditionalGeneration"),
Younes Belkada's avatar
Younes Belkada committed
237
        ("bloom", "BloomForCausalLM"),
238
        ("camembert", "CamembertForMaskedLM"),
rooa's avatar
rooa committed
239
        ("codegen", "CodeGenForCausalLM"),
240
        ("convbert", "ConvBertForMaskedLM"),
241
        ("ctrl", "CTRLLMHeadModel"),
242
243
244
245
        ("data2vec-text", "Data2VecTextForMaskedLM"),
        ("deberta", "DebertaForMaskedLM"),
        ("deberta-v2", "DebertaV2ForMaskedLM"),
        ("distilbert", "DistilBertForMaskedLM"),
246
247
        ("electra", "ElectraForMaskedLM"),
        ("encoder-decoder", "EncoderDecoderModel"),
248
        ("ernie", "ErnieForMaskedLM"),
249
        ("esm", "EsmForMaskedLM"),
250
251
252
        ("flaubert", "FlaubertWithLMHeadModel"),
        ("fnet", "FNetForMaskedLM"),
        ("fsmt", "FSMTForConditionalGeneration"),
253
        ("funnel", "FunnelForMaskedLM"),
254
255
        ("gpt2", "GPT2LMHeadModel"),
        ("gpt_neo", "GPTNeoForCausalLM"),
256
        ("gpt_neox", "GPTNeoXForCausalLM"),
257
        ("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"),
258
259
260
261
262
        ("gptj", "GPTJForCausalLM"),
        ("ibert", "IBertForMaskedLM"),
        ("layoutlm", "LayoutLMForMaskedLM"),
        ("led", "LEDForConditionalGeneration"),
        ("longformer", "LongformerForMaskedLM"),
Daniel Stancl's avatar
Daniel Stancl committed
263
        ("longt5", "LongT5ForConditionalGeneration"),
264
        ("luke", "LukeForMaskedLM"),
265
266
267
268
        ("m2m_100", "M2M100ForConditionalGeneration"),
        ("marian", "MarianMTModel"),
        ("megatron-bert", "MegatronBertForCausalLM"),
        ("mobilebert", "MobileBertForMaskedLM"),
269
        ("mpnet", "MPNetForMaskedLM"),
StevenTang1998's avatar
StevenTang1998 committed
270
        ("mvp", "MvpForConditionalGeneration"),
271
        ("nezha", "NezhaForMaskedLM"),
Lysandre Debut's avatar
Lysandre Debut committed
272
        ("nllb", "M2M100ForConditionalGeneration"),
273
274
        ("nystromformer", "NystromformerForMaskedLM"),
        ("openai-gpt", "OpenAIGPTLMHeadModel"),
Jason Phang's avatar
Jason Phang committed
275
        ("pegasus_x", "PegasusXForConditionalGeneration"),
276
277
278
279
280
        ("plbart", "PLBartForConditionalGeneration"),
        ("qdqbert", "QDQBertForMaskedLM"),
        ("reformer", "ReformerModelWithLMHead"),
        ("rembert", "RemBertForMaskedLM"),
        ("roberta", "RobertaForMaskedLM"),
Weiwe Shi's avatar
Weiwe Shi committed
281
        ("roc_bert", "RoCBertForMaskedLM"),
282
283
284
        ("roformer", "RoFormerForMaskedLM"),
        ("speech_to_text", "Speech2TextForConditionalGeneration"),
        ("squeezebert", "SqueezeBertForMaskedLM"),
285
        ("switch_transformers", "SwitchTransformersForConditionalGeneration"),
286
        ("t5", "T5ForConditionalGeneration"),
287
        ("tapas", "TapasForMaskedLM"),
288
289
        ("transfo-xl", "TransfoXLLMHeadModel"),
        ("wav2vec2", "Wav2Vec2ForMaskedLM"),
290
        ("whisper", "WhisperForConditionalGeneration"),
291
292
293
294
295
        ("xlm", "XLMWithLMHeadModel"),
        ("xlm-roberta", "XLMRobertaForMaskedLM"),
        ("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"),
        ("xlnet", "XLNetLMHeadModel"),
        ("yoso", "YosoForMaskedLM"),
296
297
298
    ]
)

299
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
300
    [
301
        # Model for Causal LM mapping
302
303
304
305
306
307
308
        ("bart", "BartForCausalLM"),
        ("bert", "BertLMHeadModel"),
        ("bert-generation", "BertGenerationDecoder"),
        ("big_bird", "BigBirdForCausalLM"),
        ("bigbird_pegasus", "BigBirdPegasusForCausalLM"),
        ("blenderbot", "BlenderbotForCausalLM"),
        ("blenderbot-small", "BlenderbotSmallForCausalLM"),
Younes Belkada's avatar
Younes Belkada committed
309
        ("bloom", "BloomForCausalLM"),
310
        ("camembert", "CamembertForCausalLM"),
rooa's avatar
rooa committed
311
        ("codegen", "CodeGenForCausalLM"),
312
313
314
        ("ctrl", "CTRLLMHeadModel"),
        ("data2vec-text", "Data2VecTextForCausalLM"),
        ("electra", "ElectraForCausalLM"),
315
        ("ernie", "ErnieForCausalLM"),
316
317
        ("gpt2", "GPT2LMHeadModel"),
        ("gpt_neo", "GPTNeoForCausalLM"),
318
        ("gpt_neox", "GPTNeoXForCausalLM"),
319
        ("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"),
320
321
322
323
        ("gptj", "GPTJForCausalLM"),
        ("marian", "MarianForCausalLM"),
        ("mbart", "MBartForCausalLM"),
        ("megatron-bert", "MegatronBertForCausalLM"),
StevenTang1998's avatar
StevenTang1998 committed
324
        ("mvp", "MvpForCausalLM"),
325
326
327
        ("openai-gpt", "OpenAIGPTLMHeadModel"),
        ("opt", "OPTForCausalLM"),
        ("pegasus", "PegasusForCausalLM"),
Gunjan Chhablani's avatar
Gunjan Chhablani committed
328
        ("plbart", "PLBartForCausalLM"),
329
        ("prophetnet", "ProphetNetForCausalLM"),
330
        ("qdqbert", "QDQBertLMHeadModel"),
331
        ("reformer", "ReformerModelWithLMHead"),
332
333
        ("rembert", "RemBertForCausalLM"),
        ("roberta", "RobertaForCausalLM"),
Weiwe Shi's avatar
Weiwe Shi committed
334
        ("roc_bert", "RoCBertForCausalLM"),
335
336
        ("roformer", "RoFormerForCausalLM"),
        ("speech_to_text_2", "Speech2Text2ForCausalLM"),
337
        ("transfo-xl", "TransfoXLLMHeadModel"),
338
339
        ("trocr", "TrOCRForCausalLM"),
        ("xglm", "XGLMForCausalLM"),
340
341
        ("xlm", "XLMWithLMHeadModel"),
        ("xlm-prophetnet", "XLMProphetNetForCausalLM"),
342
343
344
        ("xlm-roberta", "XLMRobertaForCausalLM"),
        ("xlm-roberta-xl", "XLMRobertaXLForCausalLM"),
        ("xlnet", "XLNetLMHeadModel"),
345
346
347
    ]
)

NielsRogge's avatar
NielsRogge committed
348
349
350
351
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES = OrderedDict(
    [
        ("deit", "DeiTForMaskedImageModeling"),
        ("swin", "SwinForMaskedImageModeling"),
352
        ("swinv2", "Swinv2ForMaskedImageModeling"),
353
        ("vit", "ViTForMaskedImageModeling"),
NielsRogge's avatar
NielsRogge committed
354
355
356
357
    ]
)


NielsRogge's avatar
NielsRogge committed
358
359
360
361
362
363
364
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES = OrderedDict(
    # Model for Causal Image Modeling mapping
    [
        ("imagegpt", "ImageGPTForCausalImageModeling"),
    ]
)

365
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
366
367
    [
        # Model for Image Classification mapping
368
        ("beit", "BeitForImageClassification"),
369
        ("convnext", "ConvNextForImageClassification"),
NielsRogge's avatar
NielsRogge committed
370
        ("cvt", "CvtForImageClassification"),
371
        ("data2vec-vision", "Data2VecVisionForImageClassification"),
372
        ("deit", ("DeiTForImageClassification", "DeiTForImageClassificationWithTeacher")),
373
        ("dinat", "DinatForImageClassification"),
NielsRogge's avatar
NielsRogge committed
374
        ("imagegpt", "ImageGPTForImageClassification"),
375
        ("levit", ("LevitForImageClassification", "LevitForImageClassificationWithTeacher")),
376
        ("mobilenet_v1", "MobileNetV1ForImageClassification"),
377
        ("mobilenet_v2", "MobileNetV2ForImageClassification"),
378
        ("mobilevit", "MobileViTForImageClassification"),
379
        ("nat", "NatForImageClassification"),
NielsRogge's avatar
NielsRogge committed
380
381
382
383
384
385
386
387
        (
            "perceiver",
            (
                "PerceiverForImageClassificationLearned",
                "PerceiverForImageClassificationFourier",
                "PerceiverForImageClassificationConvProcessing",
            ),
        ),
388
389
390
391
        ("poolformer", "PoolFormerForImageClassification"),
        ("regnet", "RegNetForImageClassification"),
        ("resnet", "ResNetForImageClassification"),
        ("segformer", "SegformerForImageClassification"),
novice's avatar
novice committed
392
        ("swin", "SwinForImageClassification"),
393
        ("swinv2", "Swinv2ForImageClassification"),
394
        ("van", "VanForImageClassification"),
395
        ("vit", "ViTForImageClassification"),
396
        ("vit_msn", "ViTMSNForImageClassification"),
397
398
399
    ]
)

400
401
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = OrderedDict(
    [
402
        # Do not add new models here, this class will be deprecated in the future.
403
404
405
406
407
        # Model for Image Segmentation mapping
        ("detr", "DetrForSegmentation"),
    ]
)

408
409
410
411
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict(
    [
        # Model for Semantic Segmentation mapping
        ("beit", "BeitForSemanticSegmentation"),
412
        ("data2vec-vision", "Data2VecVisionForSemanticSegmentation"),
NielsRogge's avatar
NielsRogge committed
413
        ("dpt", "DPTForSemanticSegmentation"),
414
        ("mobilenet_v2", "MobileNetV2ForSemanticSegmentation"),
415
        ("mobilevit", "MobileViTForSemanticSegmentation"),
416
        ("segformer", "SegformerForSemanticSegmentation"),
417
418
419
    ]
)

420
421
422
423
424
425
426
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES = OrderedDict(
    [
        # Model for Instance Segmentation mapping
        ("maskformer", "MaskFormerForInstanceSegmentation"),
    ]
)

NielsRogge's avatar
NielsRogge committed
427
428
429
430
431
432
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
    [
        ("videomae", "VideoMAEForVideoClassification"),
    ]
)

433
434
435
436
437
438
MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
    [
        ("vision-encoder-decoder", "VisionEncoderDecoderModel"),
    ]
)

439
MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
440
    [
441
        # Model for Masked LM mapping
442
443
        ("albert", "AlbertForMaskedLM"),
        ("bart", "BartForConditionalGeneration"),
444
445
        ("bert", "BertForMaskedLM"),
        ("big_bird", "BigBirdForMaskedLM"),
446
        ("camembert", "CamembertForMaskedLM"),
447
        ("convbert", "ConvBertForMaskedLM"),
448
        ("data2vec-text", "Data2VecTextForMaskedLM"),
449
450
451
        ("deberta", "DebertaForMaskedLM"),
        ("deberta-v2", "DebertaV2ForMaskedLM"),
        ("distilbert", "DistilBertForMaskedLM"),
452
        ("electra", "ElectraForMaskedLM"),
453
        ("ernie", "ErnieForMaskedLM"),
Matt's avatar
Matt committed
454
        ("esm", "EsmForMaskedLM"),
455
456
        ("flaubert", "FlaubertWithLMHeadModel"),
        ("fnet", "FNetForMaskedLM"),
457
        ("funnel", "FunnelForMaskedLM"),
458
459
460
        ("ibert", "IBertForMaskedLM"),
        ("layoutlm", "LayoutLMForMaskedLM"),
        ("longformer", "LongformerForMaskedLM"),
Ryokan RI's avatar
Ryokan RI committed
461
        ("luke", "LukeForMaskedLM"),
462
463
464
        ("mbart", "MBartForConditionalGeneration"),
        ("megatron-bert", "MegatronBertForMaskedLM"),
        ("mobilebert", "MobileBertForMaskedLM"),
465
        ("mpnet", "MPNetForMaskedLM"),
StevenTang1998's avatar
StevenTang1998 committed
466
        ("mvp", "MvpForConditionalGeneration"),
467
        ("nezha", "NezhaForMaskedLM"),
468
469
470
471
472
473
        ("nystromformer", "NystromformerForMaskedLM"),
        ("perceiver", "PerceiverForMaskedLM"),
        ("qdqbert", "QDQBertForMaskedLM"),
        ("reformer", "ReformerForMaskedLM"),
        ("rembert", "RemBertForMaskedLM"),
        ("roberta", "RobertaForMaskedLM"),
Weiwe Shi's avatar
Weiwe Shi committed
474
        ("roc_bert", "RoCBertForMaskedLM"),
475
476
        ("roformer", "RoFormerForMaskedLM"),
        ("squeezebert", "SqueezeBertForMaskedLM"),
477
        ("tapas", "TapasForMaskedLM"),
478
479
480
481
482
        ("wav2vec2", "Wav2Vec2ForMaskedLM"),
        ("xlm", "XLMWithLMHeadModel"),
        ("xlm-roberta", "XLMRobertaForMaskedLM"),
        ("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"),
        ("yoso", "YosoForMaskedLM"),
483
484
485
    ]
)

486
MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
NielsRogge's avatar
NielsRogge committed
487
488
    [
        # Model for Object Detection mapping
489
        ("conditional_detr", "ConditionalDetrForObjectDetection"),
NielsRogge's avatar
NielsRogge committed
490
        ("deformable_detr", "DeformableDetrForObjectDetection"),
491
        ("detr", "DetrForObjectDetection"),
492
        ("table-transformer", "TableTransformerForObjectDetection"),
493
        ("yolos", "YolosForObjectDetection"),
NielsRogge's avatar
NielsRogge committed
494
495
496
    ]
)

497
498
499
500
501
502
503
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
    [
        # Model for Zero Shot Object Detection mapping
        ("owlvit", "OwlViTForObjectDetection")
    ]
)

504
505
506
507
508
509
510
MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES = OrderedDict(
    [
        # Model for depth estimation mapping
        ("dpt", "DPTForDepthEstimation"),
        ("glpn", "GLPNForDepthEstimation"),
    ]
)
511
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
512
    [
513
        # Model for Seq2Seq Causal LM mapping
514
        ("bart", "BartForConditionalGeneration"),
515
        ("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"),
516
        ("blenderbot", "BlenderbotForConditionalGeneration"),
517
        ("blenderbot-small", "BlenderbotSmallForConditionalGeneration"),
518
519
520
        ("encoder-decoder", "EncoderDecoderModel"),
        ("fsmt", "FSMTForConditionalGeneration"),
        ("led", "LEDForConditionalGeneration"),
Daniel Stancl's avatar
Daniel Stancl committed
521
        ("longt5", "LongT5ForConditionalGeneration"),
522
        ("m2m_100", "M2M100ForConditionalGeneration"),
523
524
        ("marian", "MarianMTModel"),
        ("mbart", "MBartForConditionalGeneration"),
525
        ("mt5", "MT5ForConditionalGeneration"),
StevenTang1998's avatar
StevenTang1998 committed
526
        ("mvp", "MvpForConditionalGeneration"),
Lysandre Debut's avatar
Lysandre Debut committed
527
        ("nllb", "M2M100ForConditionalGeneration"),
528
        ("pegasus", "PegasusForConditionalGeneration"),
Jason Phang's avatar
Jason Phang committed
529
        ("pegasus_x", "PegasusXForConditionalGeneration"),
530
        ("plbart", "PLBartForConditionalGeneration"),
531
        ("prophetnet", "ProphetNetForConditionalGeneration"),
532
        ("switch_transformers", "SwitchTransformersForConditionalGeneration"),
533
534
        ("t5", "T5ForConditionalGeneration"),
        ("xlm-prophetnet", "XLMProphetNetForConditionalGeneration"),
535
536
537
    ]
)

538
539
540
541
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
    [
        ("speech-encoder-decoder", "SpeechEncoderDecoderModel"),
        ("speech_to_text", "Speech2TextForConditionalGeneration"),
542
        ("whisper", "WhisperForConditionalGeneration"),
543
544
545
    ]
)

546
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
547
    [
548
        # Model for Sequence Classification mapping
549
550
551
        ("albert", "AlbertForSequenceClassification"),
        ("bart", "BartForSequenceClassification"),
        ("bert", "BertForSequenceClassification"),
552
553
        ("big_bird", "BigBirdForSequenceClassification"),
        ("bigbird_pegasus", "BigBirdPegasusForSequenceClassification"),
554
        ("bloom", "BloomForSequenceClassification"),
555
556
557
558
559
        ("camembert", "CamembertForSequenceClassification"),
        ("canine", "CanineForSequenceClassification"),
        ("convbert", "ConvBertForSequenceClassification"),
        ("ctrl", "CTRLForSequenceClassification"),
        ("data2vec-text", "Data2VecTextForSequenceClassification"),
560
561
        ("deberta", "DebertaForSequenceClassification"),
        ("deberta-v2", "DebertaV2ForSequenceClassification"),
562
563
        ("distilbert", "DistilBertForSequenceClassification"),
        ("electra", "ElectraForSequenceClassification"),
564
        ("ernie", "ErnieForSequenceClassification"),
565
        ("esm", "EsmForSequenceClassification"),
566
567
568
        ("flaubert", "FlaubertForSequenceClassification"),
        ("fnet", "FNetForSequenceClassification"),
        ("funnel", "FunnelForSequenceClassification"),
569
570
        ("gpt2", "GPT2ForSequenceClassification"),
        ("gpt_neo", "GPTNeoForSequenceClassification"),
571
572
573
574
        ("gptj", "GPTJForSequenceClassification"),
        ("ibert", "IBertForSequenceClassification"),
        ("layoutlm", "LayoutLMForSequenceClassification"),
        ("layoutlmv2", "LayoutLMv2ForSequenceClassification"),
NielsRogge's avatar
NielsRogge committed
575
        ("layoutlmv3", "LayoutLMv3ForSequenceClassification"),
576
        ("led", "LEDForSequenceClassification"),
NielsRogge's avatar
NielsRogge committed
577
        ("lilt", "LiltForSequenceClassification"),
578
        ("longformer", "LongformerForSequenceClassification"),
579
        ("luke", "LukeForSequenceClassification"),
NielsRogge's avatar
NielsRogge committed
580
        ("markuplm", "MarkupLMForSequenceClassification"),
581
582
583
584
        ("mbart", "MBartForSequenceClassification"),
        ("megatron-bert", "MegatronBertForSequenceClassification"),
        ("mobilebert", "MobileBertForSequenceClassification"),
        ("mpnet", "MPNetForSequenceClassification"),
StevenTang1998's avatar
StevenTang1998 committed
585
        ("mvp", "MvpForSequenceClassification"),
586
        ("nezha", "NezhaForSequenceClassification"),
587
        ("nystromformer", "NystromformerForSequenceClassification"),
588
        ("openai-gpt", "OpenAIGPTForSequenceClassification"),
589
        ("opt", "OPTForSequenceClassification"),
590
591
592
        ("perceiver", "PerceiverForSequenceClassification"),
        ("plbart", "PLBartForSequenceClassification"),
        ("qdqbert", "QDQBertForSequenceClassification"),
593
        ("reformer", "ReformerForSequenceClassification"),
594
595
        ("rembert", "RemBertForSequenceClassification"),
        ("roberta", "RobertaForSequenceClassification"),
Weiwe Shi's avatar
Weiwe Shi committed
596
        ("roc_bert", "RoCBertForSequenceClassification"),
597
598
        ("roformer", "RoFormerForSequenceClassification"),
        ("squeezebert", "SqueezeBertForSequenceClassification"),
599
        ("tapas", "TapasForSequenceClassification"),
600
601
602
603
604
605
        ("transfo-xl", "TransfoXLForSequenceClassification"),
        ("xlm", "XLMForSequenceClassification"),
        ("xlm-roberta", "XLMRobertaForSequenceClassification"),
        ("xlm-roberta-xl", "XLMRobertaXLForSequenceClassification"),
        ("xlnet", "XLNetForSequenceClassification"),
        ("yoso", "YosoForSequenceClassification"),
606
607
608
    ]
)

609
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
610
    [
611
        # Model for Question Answering mapping
612
613
614
615
616
        ("albert", "AlbertForQuestionAnswering"),
        ("bart", "BartForQuestionAnswering"),
        ("bert", "BertForQuestionAnswering"),
        ("big_bird", "BigBirdForQuestionAnswering"),
        ("bigbird_pegasus", "BigBirdPegasusForQuestionAnswering"),
617
        ("bloom", "BloomForQuestionAnswering"),
618
619
620
621
622
623
624
625
        ("camembert", "CamembertForQuestionAnswering"),
        ("canine", "CanineForQuestionAnswering"),
        ("convbert", "ConvBertForQuestionAnswering"),
        ("data2vec-text", "Data2VecTextForQuestionAnswering"),
        ("deberta", "DebertaForQuestionAnswering"),
        ("deberta-v2", "DebertaV2ForQuestionAnswering"),
        ("distilbert", "DistilBertForQuestionAnswering"),
        ("electra", "ElectraForQuestionAnswering"),
626
        ("ernie", "ErnieForQuestionAnswering"),
627
        ("flaubert", "FlaubertForQuestionAnsweringSimple"),
Gunjan Chhablani's avatar
Gunjan Chhablani committed
628
        ("fnet", "FNetForQuestionAnswering"),
629
        ("funnel", "FunnelForQuestionAnswering"),
630
        ("gptj", "GPTJForQuestionAnswering"),
631
        ("ibert", "IBertForQuestionAnswering"),
632
        ("layoutlmv2", "LayoutLMv2ForQuestionAnswering"),
NielsRogge's avatar
NielsRogge committed
633
        ("layoutlmv3", "LayoutLMv3ForQuestionAnswering"),
634
        ("led", "LEDForQuestionAnswering"),
NielsRogge's avatar
NielsRogge committed
635
        ("lilt", "LiltForQuestionAnswering"),
636
        ("longformer", "LongformerForQuestionAnswering"),
637
        ("luke", "LukeForQuestionAnswering"),
638
        ("lxmert", "LxmertForQuestionAnswering"),
NielsRogge's avatar
NielsRogge committed
639
        ("markuplm", "MarkupLMForQuestionAnswering"),
640
        ("mbart", "MBartForQuestionAnswering"),
641
642
643
        ("megatron-bert", "MegatronBertForQuestionAnswering"),
        ("mobilebert", "MobileBertForQuestionAnswering"),
        ("mpnet", "MPNetForQuestionAnswering"),
StevenTang1998's avatar
StevenTang1998 committed
644
        ("mvp", "MvpForQuestionAnswering"),
645
        ("nezha", "NezhaForQuestionAnswering"),
646
        ("nystromformer", "NystromformerForQuestionAnswering"),
647
        ("opt", "OPTForQuestionAnswering"),
648
649
650
651
        ("qdqbert", "QDQBertForQuestionAnswering"),
        ("reformer", "ReformerForQuestionAnswering"),
        ("rembert", "RemBertForQuestionAnswering"),
        ("roberta", "RobertaForQuestionAnswering"),
Weiwe Shi's avatar
Weiwe Shi committed
652
        ("roc_bert", "RoCBertForQuestionAnswering"),
653
        ("roformer", "RoFormerForQuestionAnswering"),
Ori Ram's avatar
Ori Ram committed
654
        ("splinter", "SplinterForQuestionAnswering"),
655
656
657
658
659
660
        ("squeezebert", "SqueezeBertForQuestionAnswering"),
        ("xlm", "XLMForQuestionAnsweringSimple"),
        ("xlm-roberta", "XLMRobertaForQuestionAnswering"),
        ("xlm-roberta-xl", "XLMRobertaXLForQuestionAnswering"),
        ("xlnet", "XLNetForQuestionAnsweringSimple"),
        ("yoso", "YosoForQuestionAnswering"),
661
662
663
    ]
)

664
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
665
666
    [
        # Model for Table Question Answering mapping
667
        ("tapas", "TapasForQuestionAnswering"),
668
669
670
    ]
)

671
672
673
674
675
676
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
    [
        ("vilt", "ViltForQuestionAnswering"),
    ]
)

677
678
679
680
681
682
683
684
MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
    [
        ("layoutlm", "LayoutLMForQuestionAnswering"),
        ("layoutlmv2", "LayoutLMv2ForQuestionAnswering"),
        ("layoutlmv3", "LayoutLMv3ForQuestionAnswering"),
    ]
)

685
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
Julien Chaumond's avatar
Julien Chaumond committed
686
    [
687
        # Model for Token Classification mapping
688
689
        ("albert", "AlbertForTokenClassification"),
        ("bert", "BertForTokenClassification"),
690
        ("big_bird", "BigBirdForTokenClassification"),
691
        ("bloom", "BloomForTokenClassification"),
692
693
        ("camembert", "CamembertForTokenClassification"),
        ("canine", "CanineForTokenClassification"),
694
        ("convbert", "ConvBertForTokenClassification"),
695
696
697
        ("data2vec-text", "Data2VecTextForTokenClassification"),
        ("deberta", "DebertaForTokenClassification"),
        ("deberta-v2", "DebertaV2ForTokenClassification"),
698
        ("distilbert", "DistilBertForTokenClassification"),
699
        ("electra", "ElectraForTokenClassification"),
700
        ("ernie", "ErnieForTokenClassification"),
701
        ("esm", "EsmForTokenClassification"),
702
        ("flaubert", "FlaubertForTokenClassification"),
703
704
705
706
707
708
        ("fnet", "FNetForTokenClassification"),
        ("funnel", "FunnelForTokenClassification"),
        ("gpt2", "GPT2ForTokenClassification"),
        ("ibert", "IBertForTokenClassification"),
        ("layoutlm", "LayoutLMForTokenClassification"),
        ("layoutlmv2", "LayoutLMv2ForTokenClassification"),
NielsRogge's avatar
NielsRogge committed
709
        ("layoutlmv3", "LayoutLMv3ForTokenClassification"),
NielsRogge's avatar
NielsRogge committed
710
        ("lilt", "LiltForTokenClassification"),
711
        ("longformer", "LongformerForTokenClassification"),
712
        ("luke", "LukeForTokenClassification"),
NielsRogge's avatar
NielsRogge committed
713
        ("markuplm", "MarkupLMForTokenClassification"),
714
715
716
        ("megatron-bert", "MegatronBertForTokenClassification"),
        ("mobilebert", "MobileBertForTokenClassification"),
        ("mpnet", "MPNetForTokenClassification"),
717
        ("nezha", "NezhaForTokenClassification"),
718
719
720
721
        ("nystromformer", "NystromformerForTokenClassification"),
        ("qdqbert", "QDQBertForTokenClassification"),
        ("rembert", "RemBertForTokenClassification"),
        ("roberta", "RobertaForTokenClassification"),
Weiwe Shi's avatar
Weiwe Shi committed
722
        ("roc_bert", "RoCBertForTokenClassification"),
723
724
725
726
727
728
729
        ("roformer", "RoFormerForTokenClassification"),
        ("squeezebert", "SqueezeBertForTokenClassification"),
        ("xlm", "XLMForTokenClassification"),
        ("xlm-roberta", "XLMRobertaForTokenClassification"),
        ("xlm-roberta-xl", "XLMRobertaXLForTokenClassification"),
        ("xlnet", "XLNetForTokenClassification"),
        ("yoso", "YosoForTokenClassification"),
Julien Chaumond's avatar
Julien Chaumond committed
730
731
732
    ]
)

733
MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
Julien Chaumond's avatar
Julien Chaumond committed
734
    [
735
        # Model for Multiple Choice mapping
736
737
        ("albert", "AlbertForMultipleChoice"),
        ("bert", "BertForMultipleChoice"),
738
739
        ("big_bird", "BigBirdForMultipleChoice"),
        ("camembert", "CamembertForMultipleChoice"),
740
741
        ("canine", "CanineForMultipleChoice"),
        ("convbert", "ConvBertForMultipleChoice"),
742
        ("data2vec-text", "Data2VecTextForMultipleChoice"),
743
        ("deberta-v2", "DebertaV2ForMultipleChoice"),
744
        ("distilbert", "DistilBertForMultipleChoice"),
745
        ("electra", "ElectraForMultipleChoice"),
746
        ("ernie", "ErnieForMultipleChoice"),
747
        ("flaubert", "FlaubertForMultipleChoice"),
748
        ("fnet", "FNetForMultipleChoice"),
749
750
        ("funnel", "FunnelForMultipleChoice"),
        ("ibert", "IBertForMultipleChoice"),
751
        ("longformer", "LongformerForMultipleChoice"),
752
        ("luke", "LukeForMultipleChoice"),
753
754
755
        ("megatron-bert", "MegatronBertForMultipleChoice"),
        ("mobilebert", "MobileBertForMultipleChoice"),
        ("mpnet", "MPNetForMultipleChoice"),
756
        ("nezha", "NezhaForMultipleChoice"),
757
758
759
760
        ("nystromformer", "NystromformerForMultipleChoice"),
        ("qdqbert", "QDQBertForMultipleChoice"),
        ("rembert", "RemBertForMultipleChoice"),
        ("roberta", "RobertaForMultipleChoice"),
Weiwe Shi's avatar
Weiwe Shi committed
761
        ("roc_bert", "RoCBertForMultipleChoice"),
762
763
764
765
766
767
768
        ("roformer", "RoFormerForMultipleChoice"),
        ("squeezebert", "SqueezeBertForMultipleChoice"),
        ("xlm", "XLMForMultipleChoice"),
        ("xlm-roberta", "XLMRobertaForMultipleChoice"),
        ("xlm-roberta-xl", "XLMRobertaXLForMultipleChoice"),
        ("xlnet", "XLNetForMultipleChoice"),
        ("yoso", "YosoForMultipleChoice"),
Julien Chaumond's avatar
Julien Chaumond committed
769
770
771
    ]
)

772
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
773
    [
774
        ("bert", "BertForNextSentencePrediction"),
775
        ("ernie", "ErnieForNextSentencePrediction"),
Gunjan Chhablani's avatar
Gunjan Chhablani committed
776
        ("fnet", "FNetForNextSentencePrediction"),
777
778
        ("megatron-bert", "MegatronBertForNextSentencePrediction"),
        ("mobilebert", "MobileBertForNextSentencePrediction"),
779
        ("nezha", "NezhaForNextSentencePrediction"),
780
        ("qdqbert", "QDQBertForNextSentencePrediction"),
781
782
783
    ]
)

784
785
786
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
    [
        # Model for Audio Classification mapping
787
        ("data2vec-audio", "Data2VecAudioForSequenceClassification"),
788
        ("hubert", "HubertForSequenceClassification"),
789
790
        ("sew", "SEWForSequenceClassification"),
        ("sew-d", "SEWDForSequenceClassification"),
791
792
793
        ("unispeech", "UniSpeechForSequenceClassification"),
        ("unispeech-sat", "UniSpeechSatForSequenceClassification"),
        ("wav2vec2", "Wav2Vec2ForSequenceClassification"),
794
        ("wav2vec2-conformer", "Wav2Vec2ConformerForSequenceClassification"),
Patrick von Platen's avatar
Patrick von Platen committed
795
        ("wavlm", "WavLMForSequenceClassification"),
796
797
798
    ]
)

799
800
801
MODEL_FOR_CTC_MAPPING_NAMES = OrderedDict(
    [
        # Model for Connectionist temporal classification (CTC) mapping
802
        ("data2vec-audio", "Data2VecAudioForCTC"),
803
        ("hubert", "HubertForCTC"),
Chan Woo Kim's avatar
Chan Woo Kim committed
804
        ("mctct", "MCTCTForCTC"),
805
806
        ("sew", "SEWForCTC"),
        ("sew-d", "SEWDForCTC"),
807
808
809
        ("unispeech", "UniSpeechForCTC"),
        ("unispeech-sat", "UniSpeechSatForCTC"),
        ("wav2vec2", "Wav2Vec2ForCTC"),
810
        ("wav2vec2-conformer", "Wav2Vec2ConformerForCTC"),
Patrick von Platen's avatar
Patrick von Platen committed
811
        ("wavlm", "WavLMForCTC"),
812
813
814
    ]
)

815
816
817
MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
    [
        # Model for Audio Classification mapping
818
        ("data2vec-audio", "Data2VecAudioForAudioFrameClassification"),
819
        ("unispeech-sat", "UniSpeechSatForAudioFrameClassification"),
820
        ("wav2vec2", "Wav2Vec2ForAudioFrameClassification"),
821
        ("wav2vec2-conformer", "Wav2Vec2ConformerForAudioFrameClassification"),
822
        ("wavlm", "WavLMForAudioFrameClassification"),
823
824
825
826
827
828
    ]
)

MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = OrderedDict(
    [
        # Model for Audio Classification mapping
829
        ("data2vec-audio", "Data2VecAudioForXVector"),
830
        ("unispeech-sat", "UniSpeechSatForXVector"),
831
        ("wav2vec2", "Wav2Vec2ForXVector"),
832
        ("wav2vec2-conformer", "Wav2Vec2ConformerForXVector"),
833
        ("wavlm", "WavLMForXVector"),
834
835
836
    ]
)

837
838
839
840
_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
    [
        # Model for Zero Shot Image Classification mapping
        ("clip", "CLIPModel"),
NielsRogge's avatar
NielsRogge committed
841
        ("clipseg", "CLIPSegModel"),
842
843
844
    ]
)

845
846
847
848
849
850
851
MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict(
    [
        # Backbone mapping
        ("resnet", "ResNetBackbone"),
    ]
)

852
853
854
855
MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES)
MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES)
MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
NielsRogge's avatar
NielsRogge committed
856
857
858
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES
)
859
860
861
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
)
862
863
864
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES
)
865
866
867
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES
)
868
869
870
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES
)
NielsRogge's avatar
NielsRogge committed
871
872
873
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES
)
874
MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
875
876
877
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES
)
878
879
880
MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES
)
881
MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES)
NielsRogge's avatar
NielsRogge committed
882
883
884
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES
)
885
MODEL_FOR_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES)
886
887
888
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES
)
889
MODEL_FOR_DEPTH_ESTIMATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES)
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
)
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
)
MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
)
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES
)
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
)
MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES)
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
)
909
910
911
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
)
912
913
MODEL_FOR_CTC_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CTC_MAPPING_NAMES)
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES)
914
915
916
917
MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES
)
MODEL_FOR_AUDIO_XVECTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES)
918

919
920
MODEL_FOR_BACKBONE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_BACKBONE_MAPPING_NAMES)

921

Sylvain Gugger's avatar
Sylvain Gugger committed
922
923
924
925
926
927
928
929
930
931
932
933
class AutoModel(_BaseAutoModelClass):
    _model_mapping = MODEL_MAPPING


AutoModel = auto_class_update(AutoModel)


class AutoModelForPreTraining(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_PRETRAINING_MAPPING


AutoModelForPreTraining = auto_class_update(AutoModelForPreTraining, head_doc="pretraining")
934

thomwolf's avatar
thomwolf committed
935

936
# Private on purpose, the public class will add the deprecation warnings.
Sylvain Gugger's avatar
Sylvain Gugger committed
937
938
class _AutoModelWithLMHead(_BaseAutoModelClass):
    _model_mapping = MODEL_WITH_LM_HEAD_MAPPING
thomwolf's avatar
thomwolf committed
939
940


Sylvain Gugger's avatar
Sylvain Gugger committed
941
_AutoModelWithLMHead = auto_class_update(_AutoModelWithLMHead, head_doc="language modeling")
thomwolf's avatar
thomwolf committed
942
943


Sylvain Gugger's avatar
Sylvain Gugger committed
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
class AutoModelForCausalLM(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING


AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling")


class AutoModelForMaskedLM(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_MASKED_LM_MAPPING


AutoModelForMaskedLM = auto_class_update(AutoModelForMaskedLM, head_doc="masked language modeling")


class AutoModelForSeq2SeqLM(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING


AutoModelForSeq2SeqLM = auto_class_update(
    AutoModelForSeq2SeqLM, head_doc="sequence-to-sequence language modeling", checkpoint_for_example="t5-base"
964
)
thomwolf's avatar
thomwolf committed
965

Sylvain Gugger's avatar
Sylvain Gugger committed
966
967
968
969
970
971
972

class AutoModelForSequenceClassification(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING


AutoModelForSequenceClassification = auto_class_update(
    AutoModelForSequenceClassification, head_doc="sequence classification"
973
)
thomwolf's avatar
thomwolf committed
974

Sylvain Gugger's avatar
Sylvain Gugger committed
975
976
977
978
979
980
981
982
983
984
985
986
987
988

class AutoModelForQuestionAnswering(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING


AutoModelForQuestionAnswering = auto_class_update(AutoModelForQuestionAnswering, head_doc="question answering")


class AutoModelForTableQuestionAnswering(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING


AutoModelForTableQuestionAnswering = auto_class_update(
    AutoModelForTableQuestionAnswering,
989
990
991
    head_doc="table question answering",
    checkpoint_for_example="google/tapas-base-finetuned-wtq",
)
thomwolf's avatar
thomwolf committed
992
993


994
995
996
997
998
999
1000
1001
1002
1003
1004
class AutoModelForVisualQuestionAnswering(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING


AutoModelForVisualQuestionAnswering = auto_class_update(
    AutoModelForVisualQuestionAnswering,
    head_doc="visual question answering",
    checkpoint_for_example="dandelin/vilt-b32-finetuned-vqa",
)


1005
1006
1007
1008
1009
1010
1011
class AutoModelForDocumentQuestionAnswering(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING


AutoModelForDocumentQuestionAnswering = auto_class_update(
    AutoModelForDocumentQuestionAnswering,
    head_doc="document question answering",
1012
    checkpoint_for_example='impira/layoutlm-document-qa", revision="52e01b3',
1013
1014
1015
)


Sylvain Gugger's avatar
Sylvain Gugger committed
1016
1017
class AutoModelForTokenClassification(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
1018

1019

Sylvain Gugger's avatar
Sylvain Gugger committed
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
AutoModelForTokenClassification = auto_class_update(AutoModelForTokenClassification, head_doc="token classification")


class AutoModelForMultipleChoice(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_MULTIPLE_CHOICE_MAPPING


AutoModelForMultipleChoice = auto_class_update(AutoModelForMultipleChoice, head_doc="multiple choice")


class AutoModelForNextSentencePrediction(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING


AutoModelForNextSentencePrediction = auto_class_update(
    AutoModelForNextSentencePrediction, head_doc="next sentence prediction"
1036
)
1037

1038

Sylvain Gugger's avatar
Sylvain Gugger committed
1039
1040
1041
1042
1043
1044
1045
class AutoModelForImageClassification(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING


AutoModelForImageClassification = auto_class_update(AutoModelForImageClassification, head_doc="image classification")


1046
1047
1048
1049
1050
1051
1052
class AutoModelForImageSegmentation(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_IMAGE_SEGMENTATION_MAPPING


AutoModelForImageSegmentation = auto_class_update(AutoModelForImageSegmentation, head_doc="image segmentation")


1053
1054
1055
1056
1057
1058
1059
1060
1061
class AutoModelForSemanticSegmentation(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING


AutoModelForSemanticSegmentation = auto_class_update(
    AutoModelForSemanticSegmentation, head_doc="semantic segmentation"
)


1062
1063
1064
1065
1066
1067
1068
1069
1070
class AutoModelForInstanceSegmentation(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING


AutoModelForInstanceSegmentation = auto_class_update(
    AutoModelForInstanceSegmentation, head_doc="instance segmentation"
)


1071
1072
1073
1074
1075
1076
1077
class AutoModelForObjectDetection(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING


AutoModelForObjectDetection = auto_class_update(AutoModelForObjectDetection, head_doc="object detection")


1078
1079
1080
1081
1082
1083
1084
1085
1086
class AutoModelForZeroShotObjectDetection(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING


AutoModelForZeroShotObjectDetection = auto_class_update(
    AutoModelForZeroShotObjectDetection, head_doc="zero-shot object detection"
)


1087
1088
1089
1090
1091
1092
1093
class AutoModelForDepthEstimation(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_DEPTH_ESTIMATION_MAPPING


AutoModelForDepthEstimation = auto_class_update(AutoModelForDepthEstimation, head_doc="depth estimation")


NielsRogge's avatar
NielsRogge committed
1094
1095
1096
1097
1098
1099
1100
class AutoModelForVideoClassification(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING


AutoModelForVideoClassification = auto_class_update(AutoModelForVideoClassification, head_doc="video classification")


1101
1102
1103
1104
1105
1106
1107
class AutoModelForVision2Seq(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_VISION_2_SEQ_MAPPING


AutoModelForVision2Seq = auto_class_update(AutoModelForVision2Seq, head_doc="vision-to-text modeling")


1108
1109
1110
1111
1112
1113
1114
class AutoModelForAudioClassification(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING


AutoModelForAudioClassification = auto_class_update(AutoModelForAudioClassification, head_doc="audio classification")


1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
class AutoModelForCTC(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_CTC_MAPPING


AutoModelForCTC = auto_class_update(AutoModelForCTC, head_doc="connectionist temporal classification")


class AutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING


AutoModelForSpeechSeq2Seq = auto_class_update(
Joao Gante's avatar
Joao Gante committed
1127
    AutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling"
1128
1129
1130
)


1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
class AutoModelForAudioFrameClassification(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING


AutoModelForAudioFrameClassification = auto_class_update(
    AutoModelForAudioFrameClassification, head_doc="audio frame (token) classification"
)


class AutoModelForAudioXVector(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_AUDIO_XVECTOR_MAPPING


1144
1145
1146
1147
class AutoBackbone(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_BACKBONE_MAPPING


1148
1149
1150
AutoModelForAudioXVector = auto_class_update(AutoModelForAudioXVector, head_doc="audio retrieval via x-vector")


NielsRogge's avatar
NielsRogge committed
1151
1152
1153
1154
1155
1156
1157
class AutoModelForMaskedImageModeling(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING


AutoModelForMaskedImageModeling = auto_class_update(AutoModelForMaskedImageModeling, head_doc="masked image modeling")


1158
class AutoModelWithLMHead(_AutoModelWithLMHead):
1159
1160
    @classmethod
    def from_config(cls, config):
1161
        warnings.warn(
1162
1163
1164
            "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use "
            "`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and "
            "`AutoModelForSeq2SeqLM` for encoder-decoder models.",
1165
1166
            FutureWarning,
        )
1167
        return super().from_config(config)
1168
1169
1170

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
1171
        warnings.warn(
1172
1173
1174
            "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use "
            "`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and "
            "`AutoModelForSeq2SeqLM` for encoder-decoder models.",
1175
1176
            FutureWarning,
        )
1177
        return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)