modeling_auto.py 42.1 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Sylvain Gugger's avatar
Sylvain Gugger committed
15
""" Auto Model class."""
thomwolf's avatar
thomwolf committed
16

17
import warnings
Julien Chaumond's avatar
Julien Chaumond committed
18
from collections import OrderedDict
thomwolf's avatar
thomwolf committed
19

Sylvain Gugger's avatar
Sylvain Gugger committed
20
from ...utils import logging
21
22
from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
from .configuration_auto import CONFIG_MAPPING_NAMES
thomwolf's avatar
thomwolf committed
23

thomwolf's avatar
thomwolf committed
24

Lysandre Debut's avatar
Lysandre Debut committed
25
logger = logging.get_logger(__name__)
thomwolf's avatar
thomwolf committed
26
27


28
MODEL_MAPPING_NAMES = OrderedDict(
Julien Chaumond's avatar
Julien Chaumond committed
29
    [
30
        # Base model mapping
31
32
        ("albert", "AlbertModel"),
        ("bart", "BartModel"),
33
        ("beit", "BeitModel"),
34
35
36
37
38
39
        ("bert", "BertModel"),
        ("bert-generation", "BertGenerationEncoder"),
        ("big_bird", "BigBirdModel"),
        ("bigbird_pegasus", "BigBirdPegasusModel"),
        ("blenderbot", "BlenderbotModel"),
        ("blenderbot-small", "BlenderbotSmallModel"),
Younes Belkada's avatar
Younes Belkada committed
40
        ("bloom", "BloomModel"),
41
        ("camembert", "CamembertModel"),
42
43
        ("canine", "CanineModel"),
        ("clip", "CLIPModel"),
rooa's avatar
rooa committed
44
        ("codegen", "CodeGenModel"),
45
46
47
        ("convbert", "ConvBertModel"),
        ("convnext", "ConvNextModel"),
        ("ctrl", "CTRLModel"),
NielsRogge's avatar
NielsRogge committed
48
        ("cvt", "CvtModel"),
49
50
51
52
53
54
55
        ("data2vec-audio", "Data2VecAudioModel"),
        ("data2vec-text", "Data2VecTextModel"),
        ("data2vec-vision", "Data2VecVisionModel"),
        ("deberta", "DebertaModel"),
        ("deberta-v2", "DebertaV2Model"),
        ("decision_transformer", "DecisionTransformerModel"),
        ("decision_transformer_gpt2", "DecisionTransformerGPT2Model"),
NielsRogge's avatar
NielsRogge committed
56
        ("deformable_detr", "DeformableDetrModel"),
57
58
        ("deit", "DeiTModel"),
        ("detr", "DetrModel"),
59
        ("distilbert", "DistilBertModel"),
NielsRogge's avatar
NielsRogge committed
60
        ("donut-swin", "DonutSwinModel"),
61
62
63
        ("dpr", "DPRQuestionEncoder"),
        ("dpt", "DPTModel"),
        ("electra", "ElectraModel"),
64
        ("ernie", "ErnieModel"),
65
66
67
68
69
70
71
        ("flaubert", "FlaubertModel"),
        ("flava", "FlavaModel"),
        ("fnet", "FNetModel"),
        ("fsmt", "FSMTModel"),
        ("funnel", ("FunnelModel", "FunnelBaseModel")),
        ("glpn", "GLPNModel"),
        ("gpt2", "GPT2Model"),
72
        ("gpt_neo", "GPTNeoModel"),
73
        ("gpt_neox", "GPTNeoXModel"),
74
        ("gpt_neox_japanese", "GPTNeoXJapaneseModel"),
75
        ("gptj", "GPTJModel"),
76
        ("groupvit", "GroupViTModel"),
77
        ("hubert", "HubertModel"),
78
79
80
81
        ("ibert", "IBertModel"),
        ("imagegpt", "ImageGPTModel"),
        ("layoutlm", "LayoutLMModel"),
        ("layoutlmv2", "LayoutLMv2Model"),
NielsRogge's avatar
NielsRogge committed
82
        ("layoutlmv3", "LayoutLMv3Model"),
83
        ("led", "LEDModel"),
84
        ("levit", "LevitModel"),
85
        ("longformer", "LongformerModel"),
Daniel Stancl's avatar
Daniel Stancl committed
86
        ("longt5", "LongT5Model"),
87
88
89
        ("luke", "LukeModel"),
        ("lxmert", "LxmertModel"),
        ("m2m_100", "M2M100Model"),
90
        ("marian", "MarianModel"),
91
        ("maskformer", "MaskFormerModel"),
92
        ("mbart", "MBartModel"),
Chan Woo Kim's avatar
Chan Woo Kim committed
93
        ("mctct", "MCTCTModel"),
94
95
        ("megatron-bert", "MegatronBertModel"),
        ("mobilebert", "MobileBertModel"),
96
        ("mobilevit", "MobileViTModel"),
97
98
        ("mpnet", "MPNetModel"),
        ("mt5", "MT5Model"),
StevenTang1998's avatar
StevenTang1998 committed
99
        ("mvp", "MvpModel"),
100
        ("nezha", "NezhaModel"),
Lysandre Debut's avatar
Lysandre Debut committed
101
        ("nllb", "M2M100Model"),
102
103
        ("nystromformer", "NystromformerModel"),
        ("openai-gpt", "OpenAIGPTModel"),
Younes Belkada's avatar
Younes Belkada committed
104
        ("opt", "OPTModel"),
105
        ("owlvit", "OwlViTModel"),
106
        ("pegasus", "PegasusModel"),
Jason Phang's avatar
Jason Phang committed
107
        ("pegasus_x", "PegasusXModel"),
108
109
110
111
112
113
114
115
116
117
        ("perceiver", "PerceiverModel"),
        ("plbart", "PLBartModel"),
        ("poolformer", "PoolFormerModel"),
        ("prophetnet", "ProphetNetModel"),
        ("qdqbert", "QDQBertModel"),
        ("reformer", "ReformerModel"),
        ("regnet", "RegNetModel"),
        ("rembert", "RemBertModel"),
        ("resnet", "ResNetModel"),
        ("retribert", "RetriBertModel"),
118
        ("roberta", "RobertaModel"),
119
120
121
122
123
124
        ("roformer", "RoFormerModel"),
        ("segformer", "SegformerModel"),
        ("sew", "SEWModel"),
        ("sew-d", "SEWDModel"),
        ("speech_to_text", "Speech2TextModel"),
        ("splinter", "SplinterModel"),
125
        ("squeezebert", "SqueezeBertModel"),
126
        ("swin", "SwinModel"),
127
        ("swinv2", "Swinv2Model"),
128
129
        ("t5", "T5Model"),
        ("tapas", "TapasModel"),
Carl's avatar
Carl committed
130
        ("trajectory_transformer", "TrajectoryTransformerModel"),
131
        ("transfo-xl", "TransfoXLModel"),
132
133
134
        ("unispeech", "UniSpeechModel"),
        ("unispeech-sat", "UniSpeechSatModel"),
        ("van", "VanModel"),
NielsRogge's avatar
NielsRogge committed
135
        ("videomae", "VideoMAEModel"),
136
137
138
139
140
141
        ("vilt", "ViltModel"),
        ("vision-text-dual-encoder", "VisionTextDualEncoderModel"),
        ("visual_bert", "VisualBertModel"),
        ("vit", "ViTModel"),
        ("vit_mae", "ViTMAEModel"),
        ("wav2vec2", "Wav2Vec2Model"),
142
        ("wav2vec2-conformer", "Wav2Vec2ConformerModel"),
143
        ("wavlm", "WavLMModel"),
NielsRogge's avatar
NielsRogge committed
144
        ("xclip", "XCLIPModel"),
145
        ("xglm", "XGLMModel"),
146
147
        ("xlm", "XLMModel"),
        ("xlm-prophetnet", "XLMProphetNetModel"),
148
149
150
151
152
        ("xlm-roberta", "XLMRobertaModel"),
        ("xlm-roberta-xl", "XLMRobertaXLModel"),
        ("xlnet", "XLNetModel"),
        ("yolos", "YolosModel"),
        ("yoso", "YosoModel"),
Julien Chaumond's avatar
Julien Chaumond committed
153
154
155
    ]
)

156
MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
thomwolf's avatar
thomwolf committed
157
    [
158
        # Model for pre-training mapping
159
160
161
162
        ("albert", "AlbertForPreTraining"),
        ("bart", "BartForConditionalGeneration"),
        ("bert", "BertForPreTraining"),
        ("big_bird", "BigBirdForPreTraining"),
Younes Belkada's avatar
Younes Belkada committed
163
        ("bloom", "BloomForCausalLM"),
164
        ("camembert", "CamembertForMaskedLM"),
165
        ("ctrl", "CTRLLMHeadModel"),
166
167
168
169
        ("data2vec-text", "Data2VecTextForMaskedLM"),
        ("deberta", "DebertaForMaskedLM"),
        ("deberta-v2", "DebertaV2ForMaskedLM"),
        ("distilbert", "DistilBertForMaskedLM"),
170
        ("electra", "ElectraForPreTraining"),
171
        ("ernie", "ErnieForPreTraining"),
172
173
174
175
        ("flaubert", "FlaubertWithLMHeadModel"),
        ("flava", "FlavaForPreTraining"),
        ("fnet", "FNetForPreTraining"),
        ("fsmt", "FSMTForConditionalGeneration"),
176
        ("funnel", "FunnelForPreTraining"),
177
178
179
180
        ("gpt2", "GPT2LMHeadModel"),
        ("ibert", "IBertForMaskedLM"),
        ("layoutlm", "LayoutLMForMaskedLM"),
        ("longformer", "LongformerForMaskedLM"),
181
        ("luke", "LukeForMaskedLM"),
182
183
184
        ("lxmert", "LxmertForPreTraining"),
        ("megatron-bert", "MegatronBertForPreTraining"),
        ("mobilebert", "MobileBertForPreTraining"),
185
        ("mpnet", "MPNetForMaskedLM"),
StevenTang1998's avatar
StevenTang1998 committed
186
        ("mvp", "MvpForConditionalGeneration"),
187
        ("nezha", "NezhaForPreTraining"),
188
189
190
        ("openai-gpt", "OpenAIGPTLMHeadModel"),
        ("retribert", "RetriBertModel"),
        ("roberta", "RobertaForMaskedLM"),
191
        ("splinter", "SplinterForPreTraining"),
192
193
        ("squeezebert", "SqueezeBertForMaskedLM"),
        ("t5", "T5ForConditionalGeneration"),
194
        ("tapas", "TapasForMaskedLM"),
195
        ("transfo-xl", "TransfoXLLMHeadModel"),
196
        ("unispeech", "UniSpeechForPreTraining"),
197
        ("unispeech-sat", "UniSpeechSatForPreTraining"),
NielsRogge's avatar
NielsRogge committed
198
        ("videomae", "VideoMAEForPreTraining"),
199
200
201
        ("visual_bert", "VisualBertForPreTraining"),
        ("vit_mae", "ViTMAEForPreTraining"),
        ("wav2vec2", "Wav2Vec2ForPreTraining"),
202
        ("wav2vec2-conformer", "Wav2Vec2ConformerForPreTraining"),
203
204
205
206
        ("xlm", "XLMWithLMHeadModel"),
        ("xlm-roberta", "XLMRobertaForMaskedLM"),
        ("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"),
        ("xlnet", "XLNetLMHeadModel"),
thomwolf's avatar
thomwolf committed
207
208
209
    ]
)

210
MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
211
    [
212
        # Model with LM heads mapping
213
214
215
        ("albert", "AlbertForMaskedLM"),
        ("bart", "BartForConditionalGeneration"),
        ("bert", "BertForMaskedLM"),
216
217
218
        ("big_bird", "BigBirdForMaskedLM"),
        ("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"),
        ("blenderbot-small", "BlenderbotSmallForConditionalGeneration"),
Younes Belkada's avatar
Younes Belkada committed
219
        ("bloom", "BloomForCausalLM"),
220
        ("camembert", "CamembertForMaskedLM"),
rooa's avatar
rooa committed
221
        ("codegen", "CodeGenForCausalLM"),
222
        ("convbert", "ConvBertForMaskedLM"),
223
        ("ctrl", "CTRLLMHeadModel"),
224
225
226
227
        ("data2vec-text", "Data2VecTextForMaskedLM"),
        ("deberta", "DebertaForMaskedLM"),
        ("deberta-v2", "DebertaV2ForMaskedLM"),
        ("distilbert", "DistilBertForMaskedLM"),
228
229
        ("electra", "ElectraForMaskedLM"),
        ("encoder-decoder", "EncoderDecoderModel"),
230
        ("ernie", "ErnieForMaskedLM"),
231
232
233
        ("flaubert", "FlaubertWithLMHeadModel"),
        ("fnet", "FNetForMaskedLM"),
        ("fsmt", "FSMTForConditionalGeneration"),
234
        ("funnel", "FunnelForMaskedLM"),
235
236
        ("gpt2", "GPT2LMHeadModel"),
        ("gpt_neo", "GPTNeoForCausalLM"),
237
        ("gpt_neox", "GPTNeoXForCausalLM"),
238
        ("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"),
239
240
241
242
243
        ("gptj", "GPTJForCausalLM"),
        ("ibert", "IBertForMaskedLM"),
        ("layoutlm", "LayoutLMForMaskedLM"),
        ("led", "LEDForConditionalGeneration"),
        ("longformer", "LongformerForMaskedLM"),
Daniel Stancl's avatar
Daniel Stancl committed
244
        ("longt5", "LongT5ForConditionalGeneration"),
245
        ("luke", "LukeForMaskedLM"),
246
247
248
249
        ("m2m_100", "M2M100ForConditionalGeneration"),
        ("marian", "MarianMTModel"),
        ("megatron-bert", "MegatronBertForCausalLM"),
        ("mobilebert", "MobileBertForMaskedLM"),
250
        ("mpnet", "MPNetForMaskedLM"),
StevenTang1998's avatar
StevenTang1998 committed
251
        ("mvp", "MvpForConditionalGeneration"),
252
        ("nezha", "NezhaForMaskedLM"),
Lysandre Debut's avatar
Lysandre Debut committed
253
        ("nllb", "M2M100ForConditionalGeneration"),
254
255
        ("nystromformer", "NystromformerForMaskedLM"),
        ("openai-gpt", "OpenAIGPTLMHeadModel"),
Jason Phang's avatar
Jason Phang committed
256
        ("pegasus_x", "PegasusXForConditionalGeneration"),
257
258
259
260
261
262
263
264
265
        ("plbart", "PLBartForConditionalGeneration"),
        ("qdqbert", "QDQBertForMaskedLM"),
        ("reformer", "ReformerModelWithLMHead"),
        ("rembert", "RemBertForMaskedLM"),
        ("roberta", "RobertaForMaskedLM"),
        ("roformer", "RoFormerForMaskedLM"),
        ("speech_to_text", "Speech2TextForConditionalGeneration"),
        ("squeezebert", "SqueezeBertForMaskedLM"),
        ("t5", "T5ForConditionalGeneration"),
266
        ("tapas", "TapasForMaskedLM"),
267
268
269
270
271
272
273
        ("transfo-xl", "TransfoXLLMHeadModel"),
        ("wav2vec2", "Wav2Vec2ForMaskedLM"),
        ("xlm", "XLMWithLMHeadModel"),
        ("xlm-roberta", "XLMRobertaForMaskedLM"),
        ("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"),
        ("xlnet", "XLNetLMHeadModel"),
        ("yoso", "YosoForMaskedLM"),
274
275
276
    ]
)

277
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
278
    [
279
        # Model for Causal LM mapping
280
281
282
283
284
285
286
        ("bart", "BartForCausalLM"),
        ("bert", "BertLMHeadModel"),
        ("bert-generation", "BertGenerationDecoder"),
        ("big_bird", "BigBirdForCausalLM"),
        ("bigbird_pegasus", "BigBirdPegasusForCausalLM"),
        ("blenderbot", "BlenderbotForCausalLM"),
        ("blenderbot-small", "BlenderbotSmallForCausalLM"),
Younes Belkada's avatar
Younes Belkada committed
287
        ("bloom", "BloomForCausalLM"),
288
        ("camembert", "CamembertForCausalLM"),
rooa's avatar
rooa committed
289
        ("codegen", "CodeGenForCausalLM"),
290
291
292
        ("ctrl", "CTRLLMHeadModel"),
        ("data2vec-text", "Data2VecTextForCausalLM"),
        ("electra", "ElectraForCausalLM"),
293
        ("ernie", "ErnieForCausalLM"),
294
295
        ("gpt2", "GPT2LMHeadModel"),
        ("gpt_neo", "GPTNeoForCausalLM"),
296
        ("gpt_neox", "GPTNeoXForCausalLM"),
297
        ("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"),
298
299
300
301
        ("gptj", "GPTJForCausalLM"),
        ("marian", "MarianForCausalLM"),
        ("mbart", "MBartForCausalLM"),
        ("megatron-bert", "MegatronBertForCausalLM"),
StevenTang1998's avatar
StevenTang1998 committed
302
        ("mvp", "MvpForCausalLM"),
303
304
305
        ("openai-gpt", "OpenAIGPTLMHeadModel"),
        ("opt", "OPTForCausalLM"),
        ("pegasus", "PegasusForCausalLM"),
Gunjan Chhablani's avatar
Gunjan Chhablani committed
306
        ("plbart", "PLBartForCausalLM"),
307
        ("prophetnet", "ProphetNetForCausalLM"),
308
        ("qdqbert", "QDQBertLMHeadModel"),
309
        ("reformer", "ReformerModelWithLMHead"),
310
311
        ("rembert", "RemBertForCausalLM"),
        ("roberta", "RobertaForCausalLM"),
312
313
        ("roformer", "RoFormerForCausalLM"),
        ("speech_to_text_2", "Speech2Text2ForCausalLM"),
314
        ("transfo-xl", "TransfoXLLMHeadModel"),
315
316
        ("trocr", "TrOCRForCausalLM"),
        ("xglm", "XGLMForCausalLM"),
317
318
        ("xlm", "XLMWithLMHeadModel"),
        ("xlm-prophetnet", "XLMProphetNetForCausalLM"),
319
320
321
        ("xlm-roberta", "XLMRobertaForCausalLM"),
        ("xlm-roberta-xl", "XLMRobertaXLForCausalLM"),
        ("xlnet", "XLNetLMHeadModel"),
322
323
324
    ]
)

NielsRogge's avatar
NielsRogge committed
325
326
327
328
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES = OrderedDict(
    [
        ("deit", "DeiTForMaskedImageModeling"),
        ("swin", "SwinForMaskedImageModeling"),
329
        ("swinv2", "Swinv2ForMaskedImageModeling"),
330
        ("vit", "ViTForMaskedImageModeling"),
NielsRogge's avatar
NielsRogge committed
331
332
333
334
    ]
)


NielsRogge's avatar
NielsRogge committed
335
336
337
338
339
340
341
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES = OrderedDict(
    # Model for Causal Image Modeling mapping
    [
        ("imagegpt", "ImageGPTForCausalImageModeling"),
    ]
)

342
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
343
344
    [
        # Model for Image Classification mapping
345
        ("beit", "BeitForImageClassification"),
346
        ("convnext", "ConvNextForImageClassification"),
NielsRogge's avatar
NielsRogge committed
347
        ("cvt", "CvtForImageClassification"),
348
        ("data2vec-vision", "Data2VecVisionForImageClassification"),
349
        ("deit", ("DeiTForImageClassification", "DeiTForImageClassificationWithTeacher")),
NielsRogge's avatar
NielsRogge committed
350
        ("imagegpt", "ImageGPTForImageClassification"),
351
        ("levit", ("LevitForImageClassification", "LevitForImageClassificationWithTeacher")),
352
        ("mobilevit", "MobileViTForImageClassification"),
NielsRogge's avatar
NielsRogge committed
353
354
355
356
357
358
359
360
        (
            "perceiver",
            (
                "PerceiverForImageClassificationLearned",
                "PerceiverForImageClassificationFourier",
                "PerceiverForImageClassificationConvProcessing",
            ),
        ),
361
362
363
364
        ("poolformer", "PoolFormerForImageClassification"),
        ("regnet", "RegNetForImageClassification"),
        ("resnet", "ResNetForImageClassification"),
        ("segformer", "SegformerForImageClassification"),
novice's avatar
novice committed
365
        ("swin", "SwinForImageClassification"),
366
        ("swinv2", "Swinv2ForImageClassification"),
367
        ("van", "VanForImageClassification"),
368
        ("vit", "ViTForImageClassification"),
369
370
371
    ]
)

372
373
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = OrderedDict(
    [
374
        # Do not add new models here, this class will be deprecated in the future.
375
376
377
378
379
        # Model for Image Segmentation mapping
        ("detr", "DetrForSegmentation"),
    ]
)

380
381
382
383
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict(
    [
        # Model for Semantic Segmentation mapping
        ("beit", "BeitForSemanticSegmentation"),
384
        ("data2vec-vision", "Data2VecVisionForSemanticSegmentation"),
NielsRogge's avatar
NielsRogge committed
385
        ("dpt", "DPTForSemanticSegmentation"),
386
        ("mobilevit", "MobileViTForSemanticSegmentation"),
387
        ("segformer", "SegformerForSemanticSegmentation"),
388
389
390
    ]
)

391
392
393
394
395
396
397
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES = OrderedDict(
    [
        # Model for Instance Segmentation mapping
        ("maskformer", "MaskFormerForInstanceSegmentation"),
    ]
)

NielsRogge's avatar
NielsRogge committed
398
399
400
401
402
403
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
    [
        ("videomae", "VideoMAEForVideoClassification"),
    ]
)

404
405
406
407
408
409
MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
    [
        ("vision-encoder-decoder", "VisionEncoderDecoderModel"),
    ]
)

410
MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
411
    [
412
        # Model for Masked LM mapping
413
414
        ("albert", "AlbertForMaskedLM"),
        ("bart", "BartForConditionalGeneration"),
415
416
        ("bert", "BertForMaskedLM"),
        ("big_bird", "BigBirdForMaskedLM"),
417
        ("camembert", "CamembertForMaskedLM"),
418
        ("convbert", "ConvBertForMaskedLM"),
419
        ("data2vec-text", "Data2VecTextForMaskedLM"),
420
421
422
        ("deberta", "DebertaForMaskedLM"),
        ("deberta-v2", "DebertaV2ForMaskedLM"),
        ("distilbert", "DistilBertForMaskedLM"),
423
        ("electra", "ElectraForMaskedLM"),
424
        ("ernie", "ErnieForMaskedLM"),
425
426
        ("flaubert", "FlaubertWithLMHeadModel"),
        ("fnet", "FNetForMaskedLM"),
427
        ("funnel", "FunnelForMaskedLM"),
428
429
430
        ("ibert", "IBertForMaskedLM"),
        ("layoutlm", "LayoutLMForMaskedLM"),
        ("longformer", "LongformerForMaskedLM"),
Ryokan RI's avatar
Ryokan RI committed
431
        ("luke", "LukeForMaskedLM"),
432
433
434
        ("mbart", "MBartForConditionalGeneration"),
        ("megatron-bert", "MegatronBertForMaskedLM"),
        ("mobilebert", "MobileBertForMaskedLM"),
435
        ("mpnet", "MPNetForMaskedLM"),
StevenTang1998's avatar
StevenTang1998 committed
436
        ("mvp", "MvpForConditionalGeneration"),
437
        ("nezha", "NezhaForMaskedLM"),
438
439
440
441
442
443
444
445
        ("nystromformer", "NystromformerForMaskedLM"),
        ("perceiver", "PerceiverForMaskedLM"),
        ("qdqbert", "QDQBertForMaskedLM"),
        ("reformer", "ReformerForMaskedLM"),
        ("rembert", "RemBertForMaskedLM"),
        ("roberta", "RobertaForMaskedLM"),
        ("roformer", "RoFormerForMaskedLM"),
        ("squeezebert", "SqueezeBertForMaskedLM"),
446
        ("tapas", "TapasForMaskedLM"),
447
448
449
450
451
        ("wav2vec2", "Wav2Vec2ForMaskedLM"),
        ("xlm", "XLMWithLMHeadModel"),
        ("xlm-roberta", "XLMRobertaForMaskedLM"),
        ("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"),
        ("yoso", "YosoForMaskedLM"),
452
453
454
    ]
)

455
MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
NielsRogge's avatar
NielsRogge committed
456
457
    [
        # Model for Object Detection mapping
NielsRogge's avatar
NielsRogge committed
458
        ("deformable_detr", "DeformableDetrForObjectDetection"),
459
        ("detr", "DetrForObjectDetection"),
460
        ("yolos", "YolosForObjectDetection"),
NielsRogge's avatar
NielsRogge committed
461
462
463
    ]
)

464
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
465
    [
466
        # Model for Seq2Seq Causal LM mapping
467
        ("bart", "BartForConditionalGeneration"),
468
        ("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"),
469
        ("blenderbot", "BlenderbotForConditionalGeneration"),
470
        ("blenderbot-small", "BlenderbotSmallForConditionalGeneration"),
471
472
473
        ("encoder-decoder", "EncoderDecoderModel"),
        ("fsmt", "FSMTForConditionalGeneration"),
        ("led", "LEDForConditionalGeneration"),
Daniel Stancl's avatar
Daniel Stancl committed
474
        ("longt5", "LongT5ForConditionalGeneration"),
475
        ("m2m_100", "M2M100ForConditionalGeneration"),
476
477
        ("marian", "MarianMTModel"),
        ("mbart", "MBartForConditionalGeneration"),
478
        ("mt5", "MT5ForConditionalGeneration"),
StevenTang1998's avatar
StevenTang1998 committed
479
        ("mvp", "MvpForConditionalGeneration"),
Lysandre Debut's avatar
Lysandre Debut committed
480
        ("nllb", "M2M100ForConditionalGeneration"),
481
        ("pegasus", "PegasusForConditionalGeneration"),
Jason Phang's avatar
Jason Phang committed
482
        ("pegasus_x", "PegasusXForConditionalGeneration"),
483
        ("plbart", "PLBartForConditionalGeneration"),
484
        ("prophetnet", "ProphetNetForConditionalGeneration"),
485
486
        ("t5", "T5ForConditionalGeneration"),
        ("xlm-prophetnet", "XLMProphetNetForConditionalGeneration"),
487
488
489
    ]
)

490
491
492
493
494
495
496
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
    [
        ("speech-encoder-decoder", "SpeechEncoderDecoderModel"),
        ("speech_to_text", "Speech2TextForConditionalGeneration"),
    ]
)

497
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
498
    [
499
        # Model for Sequence Classification mapping
500
501
502
        ("albert", "AlbertForSequenceClassification"),
        ("bart", "BartForSequenceClassification"),
        ("bert", "BertForSequenceClassification"),
503
504
        ("big_bird", "BigBirdForSequenceClassification"),
        ("bigbird_pegasus", "BigBirdPegasusForSequenceClassification"),
505
        ("bloom", "BloomForSequenceClassification"),
506
507
508
509
510
        ("camembert", "CamembertForSequenceClassification"),
        ("canine", "CanineForSequenceClassification"),
        ("convbert", "ConvBertForSequenceClassification"),
        ("ctrl", "CTRLForSequenceClassification"),
        ("data2vec-text", "Data2VecTextForSequenceClassification"),
511
512
        ("deberta", "DebertaForSequenceClassification"),
        ("deberta-v2", "DebertaV2ForSequenceClassification"),
513
514
        ("distilbert", "DistilBertForSequenceClassification"),
        ("electra", "ElectraForSequenceClassification"),
515
        ("ernie", "ErnieForSequenceClassification"),
516
517
518
        ("flaubert", "FlaubertForSequenceClassification"),
        ("fnet", "FNetForSequenceClassification"),
        ("funnel", "FunnelForSequenceClassification"),
519
520
        ("gpt2", "GPT2ForSequenceClassification"),
        ("gpt_neo", "GPTNeoForSequenceClassification"),
521
522
523
524
        ("gptj", "GPTJForSequenceClassification"),
        ("ibert", "IBertForSequenceClassification"),
        ("layoutlm", "LayoutLMForSequenceClassification"),
        ("layoutlmv2", "LayoutLMv2ForSequenceClassification"),
NielsRogge's avatar
NielsRogge committed
525
        ("layoutlmv3", "LayoutLMv3ForSequenceClassification"),
526
527
        ("led", "LEDForSequenceClassification"),
        ("longformer", "LongformerForSequenceClassification"),
528
        ("luke", "LukeForSequenceClassification"),
529
530
531
532
        ("mbart", "MBartForSequenceClassification"),
        ("megatron-bert", "MegatronBertForSequenceClassification"),
        ("mobilebert", "MobileBertForSequenceClassification"),
        ("mpnet", "MPNetForSequenceClassification"),
StevenTang1998's avatar
StevenTang1998 committed
533
        ("mvp", "MvpForSequenceClassification"),
534
        ("nezha", "NezhaForSequenceClassification"),
535
        ("nystromformer", "NystromformerForSequenceClassification"),
536
        ("openai-gpt", "OpenAIGPTForSequenceClassification"),
537
        ("opt", "OPTForSequenceClassification"),
538
539
540
        ("perceiver", "PerceiverForSequenceClassification"),
        ("plbart", "PLBartForSequenceClassification"),
        ("qdqbert", "QDQBertForSequenceClassification"),
541
        ("reformer", "ReformerForSequenceClassification"),
542
543
544
545
        ("rembert", "RemBertForSequenceClassification"),
        ("roberta", "RobertaForSequenceClassification"),
        ("roformer", "RoFormerForSequenceClassification"),
        ("squeezebert", "SqueezeBertForSequenceClassification"),
546
        ("tapas", "TapasForSequenceClassification"),
547
548
549
550
551
552
        ("transfo-xl", "TransfoXLForSequenceClassification"),
        ("xlm", "XLMForSequenceClassification"),
        ("xlm-roberta", "XLMRobertaForSequenceClassification"),
        ("xlm-roberta-xl", "XLMRobertaXLForSequenceClassification"),
        ("xlnet", "XLNetForSequenceClassification"),
        ("yoso", "YosoForSequenceClassification"),
553
554
555
    ]
)

556
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
557
    [
558
        # Model for Question Answering mapping
559
560
561
562
563
564
565
566
567
568
569
570
571
        ("albert", "AlbertForQuestionAnswering"),
        ("bart", "BartForQuestionAnswering"),
        ("bert", "BertForQuestionAnswering"),
        ("big_bird", "BigBirdForQuestionAnswering"),
        ("bigbird_pegasus", "BigBirdPegasusForQuestionAnswering"),
        ("camembert", "CamembertForQuestionAnswering"),
        ("canine", "CanineForQuestionAnswering"),
        ("convbert", "ConvBertForQuestionAnswering"),
        ("data2vec-text", "Data2VecTextForQuestionAnswering"),
        ("deberta", "DebertaForQuestionAnswering"),
        ("deberta-v2", "DebertaV2ForQuestionAnswering"),
        ("distilbert", "DistilBertForQuestionAnswering"),
        ("electra", "ElectraForQuestionAnswering"),
572
        ("ernie", "ErnieForQuestionAnswering"),
573
        ("flaubert", "FlaubertForQuestionAnsweringSimple"),
Gunjan Chhablani's avatar
Gunjan Chhablani committed
574
        ("fnet", "FNetForQuestionAnswering"),
575
        ("funnel", "FunnelForQuestionAnswering"),
576
        ("gptj", "GPTJForQuestionAnswering"),
577
        ("ibert", "IBertForQuestionAnswering"),
578
        ("layoutlmv2", "LayoutLMv2ForQuestionAnswering"),
NielsRogge's avatar
NielsRogge committed
579
        ("layoutlmv3", "LayoutLMv3ForQuestionAnswering"),
580
581
        ("led", "LEDForQuestionAnswering"),
        ("longformer", "LongformerForQuestionAnswering"),
582
        ("luke", "LukeForQuestionAnswering"),
583
584
        ("lxmert", "LxmertForQuestionAnswering"),
        ("mbart", "MBartForQuestionAnswering"),
585
586
587
        ("megatron-bert", "MegatronBertForQuestionAnswering"),
        ("mobilebert", "MobileBertForQuestionAnswering"),
        ("mpnet", "MPNetForQuestionAnswering"),
StevenTang1998's avatar
StevenTang1998 committed
588
        ("mvp", "MvpForQuestionAnswering"),
589
        ("nezha", "NezhaForQuestionAnswering"),
590
591
592
593
594
595
        ("nystromformer", "NystromformerForQuestionAnswering"),
        ("qdqbert", "QDQBertForQuestionAnswering"),
        ("reformer", "ReformerForQuestionAnswering"),
        ("rembert", "RemBertForQuestionAnswering"),
        ("roberta", "RobertaForQuestionAnswering"),
        ("roformer", "RoFormerForQuestionAnswering"),
Ori Ram's avatar
Ori Ram committed
596
        ("splinter", "SplinterForQuestionAnswering"),
597
598
599
600
601
602
        ("squeezebert", "SqueezeBertForQuestionAnswering"),
        ("xlm", "XLMForQuestionAnsweringSimple"),
        ("xlm-roberta", "XLMRobertaForQuestionAnswering"),
        ("xlm-roberta-xl", "XLMRobertaXLForQuestionAnswering"),
        ("xlnet", "XLNetForQuestionAnsweringSimple"),
        ("yoso", "YosoForQuestionAnswering"),
603
604
605
    ]
)

606
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
607
608
    [
        # Model for Table Question Answering mapping
609
        ("tapas", "TapasForQuestionAnswering"),
610
611
612
    ]
)

613
614
615
616
617
618
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
    [
        ("vilt", "ViltForQuestionAnswering"),
    ]
)

619
620
621
622
623
624
625
626
MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
    [
        ("layoutlm", "LayoutLMForQuestionAnswering"),
        ("layoutlmv2", "LayoutLMv2ForQuestionAnswering"),
        ("layoutlmv3", "LayoutLMv3ForQuestionAnswering"),
    ]
)

627
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
Julien Chaumond's avatar
Julien Chaumond committed
628
    [
629
        # Model for Token Classification mapping
630
631
        ("albert", "AlbertForTokenClassification"),
        ("bert", "BertForTokenClassification"),
632
        ("big_bird", "BigBirdForTokenClassification"),
633
        ("bloom", "BloomForTokenClassification"),
634
635
        ("camembert", "CamembertForTokenClassification"),
        ("canine", "CanineForTokenClassification"),
636
        ("convbert", "ConvBertForTokenClassification"),
637
638
639
        ("data2vec-text", "Data2VecTextForTokenClassification"),
        ("deberta", "DebertaForTokenClassification"),
        ("deberta-v2", "DebertaV2ForTokenClassification"),
640
        ("distilbert", "DistilBertForTokenClassification"),
641
        ("electra", "ElectraForTokenClassification"),
642
        ("ernie", "ErnieForTokenClassification"),
643
        ("flaubert", "FlaubertForTokenClassification"),
644
645
646
647
648
649
        ("fnet", "FNetForTokenClassification"),
        ("funnel", "FunnelForTokenClassification"),
        ("gpt2", "GPT2ForTokenClassification"),
        ("ibert", "IBertForTokenClassification"),
        ("layoutlm", "LayoutLMForTokenClassification"),
        ("layoutlmv2", "LayoutLMv2ForTokenClassification"),
NielsRogge's avatar
NielsRogge committed
650
        ("layoutlmv3", "LayoutLMv3ForTokenClassification"),
651
        ("longformer", "LongformerForTokenClassification"),
652
        ("luke", "LukeForTokenClassification"),
653
654
655
        ("megatron-bert", "MegatronBertForTokenClassification"),
        ("mobilebert", "MobileBertForTokenClassification"),
        ("mpnet", "MPNetForTokenClassification"),
656
        ("nezha", "NezhaForTokenClassification"),
657
658
659
660
661
662
663
664
665
666
667
        ("nystromformer", "NystromformerForTokenClassification"),
        ("qdqbert", "QDQBertForTokenClassification"),
        ("rembert", "RemBertForTokenClassification"),
        ("roberta", "RobertaForTokenClassification"),
        ("roformer", "RoFormerForTokenClassification"),
        ("squeezebert", "SqueezeBertForTokenClassification"),
        ("xlm", "XLMForTokenClassification"),
        ("xlm-roberta", "XLMRobertaForTokenClassification"),
        ("xlm-roberta-xl", "XLMRobertaXLForTokenClassification"),
        ("xlnet", "XLNetForTokenClassification"),
        ("yoso", "YosoForTokenClassification"),
Julien Chaumond's avatar
Julien Chaumond committed
668
669
670
    ]
)

671
MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
Julien Chaumond's avatar
Julien Chaumond committed
672
    [
673
        # Model for Multiple Choice mapping
674
675
        ("albert", "AlbertForMultipleChoice"),
        ("bert", "BertForMultipleChoice"),
676
677
        ("big_bird", "BigBirdForMultipleChoice"),
        ("camembert", "CamembertForMultipleChoice"),
678
679
        ("canine", "CanineForMultipleChoice"),
        ("convbert", "ConvBertForMultipleChoice"),
680
        ("data2vec-text", "Data2VecTextForMultipleChoice"),
681
        ("deberta-v2", "DebertaV2ForMultipleChoice"),
682
        ("distilbert", "DistilBertForMultipleChoice"),
683
        ("electra", "ElectraForMultipleChoice"),
684
        ("ernie", "ErnieForMultipleChoice"),
685
        ("flaubert", "FlaubertForMultipleChoice"),
686
        ("fnet", "FNetForMultipleChoice"),
687
688
        ("funnel", "FunnelForMultipleChoice"),
        ("ibert", "IBertForMultipleChoice"),
689
        ("longformer", "LongformerForMultipleChoice"),
690
        ("luke", "LukeForMultipleChoice"),
691
692
693
        ("megatron-bert", "MegatronBertForMultipleChoice"),
        ("mobilebert", "MobileBertForMultipleChoice"),
        ("mpnet", "MPNetForMultipleChoice"),
694
        ("nezha", "NezhaForMultipleChoice"),
695
696
697
698
699
700
701
702
703
704
705
        ("nystromformer", "NystromformerForMultipleChoice"),
        ("qdqbert", "QDQBertForMultipleChoice"),
        ("rembert", "RemBertForMultipleChoice"),
        ("roberta", "RobertaForMultipleChoice"),
        ("roformer", "RoFormerForMultipleChoice"),
        ("squeezebert", "SqueezeBertForMultipleChoice"),
        ("xlm", "XLMForMultipleChoice"),
        ("xlm-roberta", "XLMRobertaForMultipleChoice"),
        ("xlm-roberta-xl", "XLMRobertaXLForMultipleChoice"),
        ("xlnet", "XLNetForMultipleChoice"),
        ("yoso", "YosoForMultipleChoice"),
Julien Chaumond's avatar
Julien Chaumond committed
706
707
708
    ]
)

709
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
710
    [
711
        ("bert", "BertForNextSentencePrediction"),
712
        ("ernie", "ErnieForNextSentencePrediction"),
Gunjan Chhablani's avatar
Gunjan Chhablani committed
713
        ("fnet", "FNetForNextSentencePrediction"),
714
715
        ("megatron-bert", "MegatronBertForNextSentencePrediction"),
        ("mobilebert", "MobileBertForNextSentencePrediction"),
716
        ("nezha", "NezhaForNextSentencePrediction"),
717
        ("qdqbert", "QDQBertForNextSentencePrediction"),
718
719
720
    ]
)

721
722
723
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
    [
        # Model for Audio Classification mapping
724
        ("data2vec-audio", "Data2VecAudioForSequenceClassification"),
725
        ("hubert", "HubertForSequenceClassification"),
726
727
        ("sew", "SEWForSequenceClassification"),
        ("sew-d", "SEWDForSequenceClassification"),
728
729
730
        ("unispeech", "UniSpeechForSequenceClassification"),
        ("unispeech-sat", "UniSpeechSatForSequenceClassification"),
        ("wav2vec2", "Wav2Vec2ForSequenceClassification"),
731
        ("wav2vec2-conformer", "Wav2Vec2ConformerForSequenceClassification"),
Patrick von Platen's avatar
Patrick von Platen committed
732
        ("wavlm", "WavLMForSequenceClassification"),
733
734
735
    ]
)

736
737
738
MODEL_FOR_CTC_MAPPING_NAMES = OrderedDict(
    [
        # Model for Connectionist temporal classification (CTC) mapping
739
        ("data2vec-audio", "Data2VecAudioForCTC"),
740
        ("hubert", "HubertForCTC"),
Chan Woo Kim's avatar
Chan Woo Kim committed
741
        ("mctct", "MCTCTForCTC"),
742
743
        ("sew", "SEWForCTC"),
        ("sew-d", "SEWDForCTC"),
744
745
746
        ("unispeech", "UniSpeechForCTC"),
        ("unispeech-sat", "UniSpeechSatForCTC"),
        ("wav2vec2", "Wav2Vec2ForCTC"),
747
        ("wav2vec2-conformer", "Wav2Vec2ConformerForCTC"),
Patrick von Platen's avatar
Patrick von Platen committed
748
        ("wavlm", "WavLMForCTC"),
749
750
751
    ]
)

752
753
754
MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
    [
        # Model for Audio Classification mapping
755
        ("data2vec-audio", "Data2VecAudioForAudioFrameClassification"),
756
        ("unispeech-sat", "UniSpeechSatForAudioFrameClassification"),
757
        ("wav2vec2", "Wav2Vec2ForAudioFrameClassification"),
758
        ("wav2vec2-conformer", "Wav2Vec2ConformerForAudioFrameClassification"),
759
        ("wavlm", "WavLMForAudioFrameClassification"),
760
761
762
763
764
765
    ]
)

MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = OrderedDict(
    [
        # Model for Audio Classification mapping
766
        ("data2vec-audio", "Data2VecAudioForXVector"),
767
        ("unispeech-sat", "UniSpeechSatForXVector"),
768
        ("wav2vec2", "Wav2Vec2ForXVector"),
769
        ("wav2vec2-conformer", "Wav2Vec2ConformerForXVector"),
770
        ("wavlm", "WavLMForXVector"),
771
772
773
    ]
)

774
775
776
777
MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES)
MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES)
MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
NielsRogge's avatar
NielsRogge committed
778
779
780
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES
)
781
782
783
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
)
784
785
786
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES
)
787
788
789
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES
)
790
791
792
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES
)
NielsRogge's avatar
NielsRogge committed
793
794
795
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES
)
796
MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
797
798
799
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES
)
800
801
802
MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES
)
803
MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES)
NielsRogge's avatar
NielsRogge committed
804
805
806
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES
)
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
MODEL_FOR_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES)
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
)
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
)
MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
)
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES
)
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
)
MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES)
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
)
827
828
829
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
)
830
831
MODEL_FOR_CTC_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CTC_MAPPING_NAMES)
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES)
832
833
834
835
MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES
)
MODEL_FOR_AUDIO_XVECTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES)
836

837

Sylvain Gugger's avatar
Sylvain Gugger committed
838
839
840
841
842
843
844
845
846
847
848
849
class AutoModel(_BaseAutoModelClass):
    _model_mapping = MODEL_MAPPING


AutoModel = auto_class_update(AutoModel)


class AutoModelForPreTraining(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_PRETRAINING_MAPPING


AutoModelForPreTraining = auto_class_update(AutoModelForPreTraining, head_doc="pretraining")
850

thomwolf's avatar
thomwolf committed
851

852
# Private on purpose, the public class will add the deprecation warnings.
Sylvain Gugger's avatar
Sylvain Gugger committed
853
854
class _AutoModelWithLMHead(_BaseAutoModelClass):
    _model_mapping = MODEL_WITH_LM_HEAD_MAPPING
thomwolf's avatar
thomwolf committed
855
856


Sylvain Gugger's avatar
Sylvain Gugger committed
857
_AutoModelWithLMHead = auto_class_update(_AutoModelWithLMHead, head_doc="language modeling")
thomwolf's avatar
thomwolf committed
858
859


Sylvain Gugger's avatar
Sylvain Gugger committed
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
class AutoModelForCausalLM(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING


AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling")


class AutoModelForMaskedLM(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_MASKED_LM_MAPPING


AutoModelForMaskedLM = auto_class_update(AutoModelForMaskedLM, head_doc="masked language modeling")


class AutoModelForSeq2SeqLM(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING


AutoModelForSeq2SeqLM = auto_class_update(
    AutoModelForSeq2SeqLM, head_doc="sequence-to-sequence language modeling", checkpoint_for_example="t5-base"
880
)
thomwolf's avatar
thomwolf committed
881

Sylvain Gugger's avatar
Sylvain Gugger committed
882
883
884
885
886
887
888

class AutoModelForSequenceClassification(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING


AutoModelForSequenceClassification = auto_class_update(
    AutoModelForSequenceClassification, head_doc="sequence classification"
889
)
thomwolf's avatar
thomwolf committed
890

Sylvain Gugger's avatar
Sylvain Gugger committed
891
892
893
894
895
896
897
898
899
900
901
902
903
904

class AutoModelForQuestionAnswering(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING


AutoModelForQuestionAnswering = auto_class_update(AutoModelForQuestionAnswering, head_doc="question answering")


class AutoModelForTableQuestionAnswering(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING


AutoModelForTableQuestionAnswering = auto_class_update(
    AutoModelForTableQuestionAnswering,
905
906
907
    head_doc="table question answering",
    checkpoint_for_example="google/tapas-base-finetuned-wtq",
)
thomwolf's avatar
thomwolf committed
908
909


910
911
912
913
914
915
916
917
918
919
920
class AutoModelForVisualQuestionAnswering(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING


AutoModelForVisualQuestionAnswering = auto_class_update(
    AutoModelForVisualQuestionAnswering,
    head_doc="visual question answering",
    checkpoint_for_example="dandelin/vilt-b32-finetuned-vqa",
)


921
922
923
924
925
926
927
class AutoModelForDocumentQuestionAnswering(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING


AutoModelForDocumentQuestionAnswering = auto_class_update(
    AutoModelForDocumentQuestionAnswering,
    head_doc="document question answering",
928
    checkpoint_for_example='impira/layoutlm-document-qa", revision="52e01b3',
929
930
931
)


Sylvain Gugger's avatar
Sylvain Gugger committed
932
933
class AutoModelForTokenClassification(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
934

935

Sylvain Gugger's avatar
Sylvain Gugger committed
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
AutoModelForTokenClassification = auto_class_update(AutoModelForTokenClassification, head_doc="token classification")


class AutoModelForMultipleChoice(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_MULTIPLE_CHOICE_MAPPING


AutoModelForMultipleChoice = auto_class_update(AutoModelForMultipleChoice, head_doc="multiple choice")


class AutoModelForNextSentencePrediction(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING


AutoModelForNextSentencePrediction = auto_class_update(
    AutoModelForNextSentencePrediction, head_doc="next sentence prediction"
952
)
953

954

Sylvain Gugger's avatar
Sylvain Gugger committed
955
956
957
958
959
960
961
class AutoModelForImageClassification(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING


AutoModelForImageClassification = auto_class_update(AutoModelForImageClassification, head_doc="image classification")


962
963
964
965
966
967
968
class AutoModelForImageSegmentation(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_IMAGE_SEGMENTATION_MAPPING


AutoModelForImageSegmentation = auto_class_update(AutoModelForImageSegmentation, head_doc="image segmentation")


969
970
971
972
973
974
975
976
977
class AutoModelForSemanticSegmentation(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING


AutoModelForSemanticSegmentation = auto_class_update(
    AutoModelForSemanticSegmentation, head_doc="semantic segmentation"
)


978
979
980
981
982
983
984
985
986
class AutoModelForInstanceSegmentation(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING


AutoModelForInstanceSegmentation = auto_class_update(
    AutoModelForInstanceSegmentation, head_doc="instance segmentation"
)


987
988
989
990
991
992
993
class AutoModelForObjectDetection(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING


AutoModelForObjectDetection = auto_class_update(AutoModelForObjectDetection, head_doc="object detection")


NielsRogge's avatar
NielsRogge committed
994
995
996
997
998
999
1000
class AutoModelForVideoClassification(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING


AutoModelForVideoClassification = auto_class_update(AutoModelForVideoClassification, head_doc="video classification")


1001
1002
1003
1004
1005
1006
1007
class AutoModelForVision2Seq(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_VISION_2_SEQ_MAPPING


AutoModelForVision2Seq = auto_class_update(AutoModelForVision2Seq, head_doc="vision-to-text modeling")


1008
1009
1010
1011
1012
1013
1014
class AutoModelForAudioClassification(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING


AutoModelForAudioClassification = auto_class_update(AutoModelForAudioClassification, head_doc="audio classification")


1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
class AutoModelForCTC(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_CTC_MAPPING


AutoModelForCTC = auto_class_update(AutoModelForCTC, head_doc="connectionist temporal classification")


class AutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING


AutoModelForSpeechSeq2Seq = auto_class_update(
Joao Gante's avatar
Joao Gante committed
1027
    AutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling"
1028
1029
1030
)


1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
class AutoModelForAudioFrameClassification(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING


AutoModelForAudioFrameClassification = auto_class_update(
    AutoModelForAudioFrameClassification, head_doc="audio frame (token) classification"
)


class AutoModelForAudioXVector(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_AUDIO_XVECTOR_MAPPING


AutoModelForAudioXVector = auto_class_update(AutoModelForAudioXVector, head_doc="audio retrieval via x-vector")


NielsRogge's avatar
NielsRogge committed
1047
1048
1049
1050
1051
1052
1053
class AutoModelForMaskedImageModeling(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING


AutoModelForMaskedImageModeling = auto_class_update(AutoModelForMaskedImageModeling, head_doc="masked image modeling")


1054
class AutoModelWithLMHead(_AutoModelWithLMHead):
1055
1056
    @classmethod
    def from_config(cls, config):
1057
        warnings.warn(
1058
1059
1060
            "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use "
            "`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and "
            "`AutoModelForSeq2SeqLM` for encoder-decoder models.",
1061
1062
            FutureWarning,
        )
1063
        return super().from_config(config)
1064
1065
1066

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
1067
        warnings.warn(
1068
1069
1070
            "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use "
            "`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and "
            "`AutoModelForSeq2SeqLM` for encoder-decoder models.",
1071
1072
            FutureWarning,
        )
1073
        return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)