modeling_auto.py 44.7 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Sylvain Gugger's avatar
Sylvain Gugger committed
15
""" Auto Model class."""
thomwolf's avatar
thomwolf committed
16

17
import warnings
Julien Chaumond's avatar
Julien Chaumond committed
18
from collections import OrderedDict
thomwolf's avatar
thomwolf committed
19

Sylvain Gugger's avatar
Sylvain Gugger committed
20
from ...utils import logging
21
22
from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
from .configuration_auto import CONFIG_MAPPING_NAMES
thomwolf's avatar
thomwolf committed
23

thomwolf's avatar
thomwolf committed
24

Lysandre Debut's avatar
Lysandre Debut committed
25
logger = logging.get_logger(__name__)
thomwolf's avatar
thomwolf committed
26
27


28
MODEL_MAPPING_NAMES = OrderedDict(
Julien Chaumond's avatar
Julien Chaumond committed
29
    [
30
        # Base model mapping
31
32
        ("albert", "AlbertModel"),
        ("bart", "BartModel"),
33
        ("beit", "BeitModel"),
34
35
36
37
38
39
        ("bert", "BertModel"),
        ("bert-generation", "BertGenerationEncoder"),
        ("big_bird", "BigBirdModel"),
        ("bigbird_pegasus", "BigBirdPegasusModel"),
        ("blenderbot", "BlenderbotModel"),
        ("blenderbot-small", "BlenderbotSmallModel"),
Younes Belkada's avatar
Younes Belkada committed
40
        ("bloom", "BloomModel"),
41
        ("camembert", "CamembertModel"),
42
43
        ("canine", "CanineModel"),
        ("clip", "CLIPModel"),
NielsRogge's avatar
NielsRogge committed
44
        ("clipseg", "CLIPSegModel"),
rooa's avatar
rooa committed
45
        ("codegen", "CodeGenModel"),
46
        ("conditional_detr", "ConditionalDetrModel"),
47
48
49
        ("convbert", "ConvBertModel"),
        ("convnext", "ConvNextModel"),
        ("ctrl", "CTRLModel"),
NielsRogge's avatar
NielsRogge committed
50
        ("cvt", "CvtModel"),
51
52
53
54
55
56
57
        ("data2vec-audio", "Data2VecAudioModel"),
        ("data2vec-text", "Data2VecTextModel"),
        ("data2vec-vision", "Data2VecVisionModel"),
        ("deberta", "DebertaModel"),
        ("deberta-v2", "DebertaV2Model"),
        ("decision_transformer", "DecisionTransformerModel"),
        ("decision_transformer_gpt2", "DecisionTransformerGPT2Model"),
NielsRogge's avatar
NielsRogge committed
58
        ("deformable_detr", "DeformableDetrModel"),
59
60
        ("deit", "DeiTModel"),
        ("detr", "DetrModel"),
61
        ("distilbert", "DistilBertModel"),
NielsRogge's avatar
NielsRogge committed
62
        ("donut-swin", "DonutSwinModel"),
63
64
65
        ("dpr", "DPRQuestionEncoder"),
        ("dpt", "DPTModel"),
        ("electra", "ElectraModel"),
66
        ("ernie", "ErnieModel"),
67
        ("esm", "EsmModel"),
68
69
70
71
72
73
74
        ("flaubert", "FlaubertModel"),
        ("flava", "FlavaModel"),
        ("fnet", "FNetModel"),
        ("fsmt", "FSMTModel"),
        ("funnel", ("FunnelModel", "FunnelBaseModel")),
        ("glpn", "GLPNModel"),
        ("gpt2", "GPT2Model"),
75
        ("gpt_neo", "GPTNeoModel"),
76
        ("gpt_neox", "GPTNeoXModel"),
77
        ("gpt_neox_japanese", "GPTNeoXJapaneseModel"),
78
        ("gptj", "GPTJModel"),
79
        ("groupvit", "GroupViTModel"),
80
        ("hubert", "HubertModel"),
81
82
83
84
        ("ibert", "IBertModel"),
        ("imagegpt", "ImageGPTModel"),
        ("layoutlm", "LayoutLMModel"),
        ("layoutlmv2", "LayoutLMv2Model"),
NielsRogge's avatar
NielsRogge committed
85
        ("layoutlmv3", "LayoutLMv3Model"),
86
        ("led", "LEDModel"),
87
        ("levit", "LevitModel"),
NielsRogge's avatar
NielsRogge committed
88
        ("lilt", "LiltModel"),
89
        ("longformer", "LongformerModel"),
Daniel Stancl's avatar
Daniel Stancl committed
90
        ("longt5", "LongT5Model"),
91
92
93
        ("luke", "LukeModel"),
        ("lxmert", "LxmertModel"),
        ("m2m_100", "M2M100Model"),
94
        ("marian", "MarianModel"),
NielsRogge's avatar
NielsRogge committed
95
        ("markuplm", "MarkupLMModel"),
96
        ("maskformer", "MaskFormerModel"),
97
        ("mbart", "MBartModel"),
Chan Woo Kim's avatar
Chan Woo Kim committed
98
        ("mctct", "MCTCTModel"),
99
100
        ("megatron-bert", "MegatronBertModel"),
        ("mobilebert", "MobileBertModel"),
101
        ("mobilevit", "MobileViTModel"),
102
103
        ("mpnet", "MPNetModel"),
        ("mt5", "MT5Model"),
StevenTang1998's avatar
StevenTang1998 committed
104
        ("mvp", "MvpModel"),
105
        ("nezha", "NezhaModel"),
Lysandre Debut's avatar
Lysandre Debut committed
106
        ("nllb", "M2M100Model"),
107
108
        ("nystromformer", "NystromformerModel"),
        ("openai-gpt", "OpenAIGPTModel"),
Younes Belkada's avatar
Younes Belkada committed
109
        ("opt", "OPTModel"),
110
        ("owlvit", "OwlViTModel"),
111
        ("pegasus", "PegasusModel"),
Jason Phang's avatar
Jason Phang committed
112
        ("pegasus_x", "PegasusXModel"),
113
114
115
116
117
118
119
120
121
122
        ("perceiver", "PerceiverModel"),
        ("plbart", "PLBartModel"),
        ("poolformer", "PoolFormerModel"),
        ("prophetnet", "ProphetNetModel"),
        ("qdqbert", "QDQBertModel"),
        ("reformer", "ReformerModel"),
        ("regnet", "RegNetModel"),
        ("rembert", "RemBertModel"),
        ("resnet", "ResNetModel"),
        ("retribert", "RetriBertModel"),
123
        ("roberta", "RobertaModel"),
124
125
126
127
128
129
        ("roformer", "RoFormerModel"),
        ("segformer", "SegformerModel"),
        ("sew", "SEWModel"),
        ("sew-d", "SEWDModel"),
        ("speech_to_text", "Speech2TextModel"),
        ("splinter", "SplinterModel"),
130
        ("squeezebert", "SqueezeBertModel"),
131
        ("swin", "SwinModel"),
132
        ("swinv2", "Swinv2Model"),
133
        ("t5", "T5Model"),
134
        ("table-transformer", "TableTransformerModel"),
135
        ("tapas", "TapasModel"),
136
        ("time_series_transformer", "TimeSeriesTransformerModel"),
Carl's avatar
Carl committed
137
        ("trajectory_transformer", "TrajectoryTransformerModel"),
138
        ("transfo-xl", "TransfoXLModel"),
139
140
141
        ("unispeech", "UniSpeechModel"),
        ("unispeech-sat", "UniSpeechSatModel"),
        ("van", "VanModel"),
NielsRogge's avatar
NielsRogge committed
142
        ("videomae", "VideoMAEModel"),
143
144
145
146
147
        ("vilt", "ViltModel"),
        ("vision-text-dual-encoder", "VisionTextDualEncoderModel"),
        ("visual_bert", "VisualBertModel"),
        ("vit", "ViTModel"),
        ("vit_mae", "ViTMAEModel"),
148
        ("vit_msn", "ViTMSNModel"),
149
        ("wav2vec2", "Wav2Vec2Model"),
150
        ("wav2vec2-conformer", "Wav2Vec2ConformerModel"),
151
        ("wavlm", "WavLMModel"),
152
        ("whisper", "WhisperModel"),
NielsRogge's avatar
NielsRogge committed
153
        ("xclip", "XCLIPModel"),
154
        ("xglm", "XGLMModel"),
155
156
        ("xlm", "XLMModel"),
        ("xlm-prophetnet", "XLMProphetNetModel"),
157
158
159
160
161
        ("xlm-roberta", "XLMRobertaModel"),
        ("xlm-roberta-xl", "XLMRobertaXLModel"),
        ("xlnet", "XLNetModel"),
        ("yolos", "YolosModel"),
        ("yoso", "YosoModel"),
Julien Chaumond's avatar
Julien Chaumond committed
162
163
164
    ]
)

165
MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
thomwolf's avatar
thomwolf committed
166
    [
167
        # Model for pre-training mapping
168
169
170
171
        ("albert", "AlbertForPreTraining"),
        ("bart", "BartForConditionalGeneration"),
        ("bert", "BertForPreTraining"),
        ("big_bird", "BigBirdForPreTraining"),
Younes Belkada's avatar
Younes Belkada committed
172
        ("bloom", "BloomForCausalLM"),
173
        ("camembert", "CamembertForMaskedLM"),
174
        ("ctrl", "CTRLLMHeadModel"),
175
176
177
178
        ("data2vec-text", "Data2VecTextForMaskedLM"),
        ("deberta", "DebertaForMaskedLM"),
        ("deberta-v2", "DebertaV2ForMaskedLM"),
        ("distilbert", "DistilBertForMaskedLM"),
179
        ("electra", "ElectraForPreTraining"),
180
        ("ernie", "ErnieForPreTraining"),
181
182
183
184
        ("flaubert", "FlaubertWithLMHeadModel"),
        ("flava", "FlavaForPreTraining"),
        ("fnet", "FNetForPreTraining"),
        ("fsmt", "FSMTForConditionalGeneration"),
185
        ("funnel", "FunnelForPreTraining"),
186
187
188
189
        ("gpt2", "GPT2LMHeadModel"),
        ("ibert", "IBertForMaskedLM"),
        ("layoutlm", "LayoutLMForMaskedLM"),
        ("longformer", "LongformerForMaskedLM"),
190
        ("luke", "LukeForMaskedLM"),
191
192
193
        ("lxmert", "LxmertForPreTraining"),
        ("megatron-bert", "MegatronBertForPreTraining"),
        ("mobilebert", "MobileBertForPreTraining"),
194
        ("mpnet", "MPNetForMaskedLM"),
StevenTang1998's avatar
StevenTang1998 committed
195
        ("mvp", "MvpForConditionalGeneration"),
196
        ("nezha", "NezhaForPreTraining"),
197
198
199
        ("openai-gpt", "OpenAIGPTLMHeadModel"),
        ("retribert", "RetriBertModel"),
        ("roberta", "RobertaForMaskedLM"),
200
        ("splinter", "SplinterForPreTraining"),
201
202
        ("squeezebert", "SqueezeBertForMaskedLM"),
        ("t5", "T5ForConditionalGeneration"),
203
        ("tapas", "TapasForMaskedLM"),
204
        ("transfo-xl", "TransfoXLLMHeadModel"),
205
        ("unispeech", "UniSpeechForPreTraining"),
206
        ("unispeech-sat", "UniSpeechSatForPreTraining"),
NielsRogge's avatar
NielsRogge committed
207
        ("videomae", "VideoMAEForPreTraining"),
208
209
210
        ("visual_bert", "VisualBertForPreTraining"),
        ("vit_mae", "ViTMAEForPreTraining"),
        ("wav2vec2", "Wav2Vec2ForPreTraining"),
211
        ("wav2vec2-conformer", "Wav2Vec2ConformerForPreTraining"),
212
213
214
215
        ("xlm", "XLMWithLMHeadModel"),
        ("xlm-roberta", "XLMRobertaForMaskedLM"),
        ("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"),
        ("xlnet", "XLNetLMHeadModel"),
thomwolf's avatar
thomwolf committed
216
217
218
    ]
)

219
MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
220
    [
221
        # Model with LM heads mapping
222
223
224
        ("albert", "AlbertForMaskedLM"),
        ("bart", "BartForConditionalGeneration"),
        ("bert", "BertForMaskedLM"),
225
226
227
        ("big_bird", "BigBirdForMaskedLM"),
        ("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"),
        ("blenderbot-small", "BlenderbotSmallForConditionalGeneration"),
Younes Belkada's avatar
Younes Belkada committed
228
        ("bloom", "BloomForCausalLM"),
229
        ("camembert", "CamembertForMaskedLM"),
rooa's avatar
rooa committed
230
        ("codegen", "CodeGenForCausalLM"),
231
        ("convbert", "ConvBertForMaskedLM"),
232
        ("ctrl", "CTRLLMHeadModel"),
233
234
235
236
        ("data2vec-text", "Data2VecTextForMaskedLM"),
        ("deberta", "DebertaForMaskedLM"),
        ("deberta-v2", "DebertaV2ForMaskedLM"),
        ("distilbert", "DistilBertForMaskedLM"),
237
238
        ("electra", "ElectraForMaskedLM"),
        ("encoder-decoder", "EncoderDecoderModel"),
239
        ("ernie", "ErnieForMaskedLM"),
240
        ("esm", "EsmForMaskedLM"),
241
242
243
        ("flaubert", "FlaubertWithLMHeadModel"),
        ("fnet", "FNetForMaskedLM"),
        ("fsmt", "FSMTForConditionalGeneration"),
244
        ("funnel", "FunnelForMaskedLM"),
245
246
        ("gpt2", "GPT2LMHeadModel"),
        ("gpt_neo", "GPTNeoForCausalLM"),
247
        ("gpt_neox", "GPTNeoXForCausalLM"),
248
        ("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"),
249
250
251
252
253
        ("gptj", "GPTJForCausalLM"),
        ("ibert", "IBertForMaskedLM"),
        ("layoutlm", "LayoutLMForMaskedLM"),
        ("led", "LEDForConditionalGeneration"),
        ("longformer", "LongformerForMaskedLM"),
Daniel Stancl's avatar
Daniel Stancl committed
254
        ("longt5", "LongT5ForConditionalGeneration"),
255
        ("luke", "LukeForMaskedLM"),
256
257
258
259
        ("m2m_100", "M2M100ForConditionalGeneration"),
        ("marian", "MarianMTModel"),
        ("megatron-bert", "MegatronBertForCausalLM"),
        ("mobilebert", "MobileBertForMaskedLM"),
260
        ("mpnet", "MPNetForMaskedLM"),
StevenTang1998's avatar
StevenTang1998 committed
261
        ("mvp", "MvpForConditionalGeneration"),
262
        ("nezha", "NezhaForMaskedLM"),
Lysandre Debut's avatar
Lysandre Debut committed
263
        ("nllb", "M2M100ForConditionalGeneration"),
264
265
        ("nystromformer", "NystromformerForMaskedLM"),
        ("openai-gpt", "OpenAIGPTLMHeadModel"),
Jason Phang's avatar
Jason Phang committed
266
        ("pegasus_x", "PegasusXForConditionalGeneration"),
267
268
269
270
271
272
273
274
275
        ("plbart", "PLBartForConditionalGeneration"),
        ("qdqbert", "QDQBertForMaskedLM"),
        ("reformer", "ReformerModelWithLMHead"),
        ("rembert", "RemBertForMaskedLM"),
        ("roberta", "RobertaForMaskedLM"),
        ("roformer", "RoFormerForMaskedLM"),
        ("speech_to_text", "Speech2TextForConditionalGeneration"),
        ("squeezebert", "SqueezeBertForMaskedLM"),
        ("t5", "T5ForConditionalGeneration"),
276
        ("tapas", "TapasForMaskedLM"),
277
278
        ("transfo-xl", "TransfoXLLMHeadModel"),
        ("wav2vec2", "Wav2Vec2ForMaskedLM"),
279
        ("whisper", "WhisperForConditionalGeneration"),
280
281
282
283
284
        ("xlm", "XLMWithLMHeadModel"),
        ("xlm-roberta", "XLMRobertaForMaskedLM"),
        ("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"),
        ("xlnet", "XLNetLMHeadModel"),
        ("yoso", "YosoForMaskedLM"),
285
286
287
    ]
)

288
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
289
    [
290
        # Model for Causal LM mapping
291
292
293
294
295
296
297
        ("bart", "BartForCausalLM"),
        ("bert", "BertLMHeadModel"),
        ("bert-generation", "BertGenerationDecoder"),
        ("big_bird", "BigBirdForCausalLM"),
        ("bigbird_pegasus", "BigBirdPegasusForCausalLM"),
        ("blenderbot", "BlenderbotForCausalLM"),
        ("blenderbot-small", "BlenderbotSmallForCausalLM"),
Younes Belkada's avatar
Younes Belkada committed
298
        ("bloom", "BloomForCausalLM"),
299
        ("camembert", "CamembertForCausalLM"),
rooa's avatar
rooa committed
300
        ("codegen", "CodeGenForCausalLM"),
301
302
303
        ("ctrl", "CTRLLMHeadModel"),
        ("data2vec-text", "Data2VecTextForCausalLM"),
        ("electra", "ElectraForCausalLM"),
304
        ("ernie", "ErnieForCausalLM"),
305
306
        ("gpt2", "GPT2LMHeadModel"),
        ("gpt_neo", "GPTNeoForCausalLM"),
307
        ("gpt_neox", "GPTNeoXForCausalLM"),
308
        ("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"),
309
310
311
312
        ("gptj", "GPTJForCausalLM"),
        ("marian", "MarianForCausalLM"),
        ("mbart", "MBartForCausalLM"),
        ("megatron-bert", "MegatronBertForCausalLM"),
StevenTang1998's avatar
StevenTang1998 committed
313
        ("mvp", "MvpForCausalLM"),
314
315
316
        ("openai-gpt", "OpenAIGPTLMHeadModel"),
        ("opt", "OPTForCausalLM"),
        ("pegasus", "PegasusForCausalLM"),
Gunjan Chhablani's avatar
Gunjan Chhablani committed
317
        ("plbart", "PLBartForCausalLM"),
318
        ("prophetnet", "ProphetNetForCausalLM"),
319
        ("qdqbert", "QDQBertLMHeadModel"),
320
        ("reformer", "ReformerModelWithLMHead"),
321
322
        ("rembert", "RemBertForCausalLM"),
        ("roberta", "RobertaForCausalLM"),
323
324
        ("roformer", "RoFormerForCausalLM"),
        ("speech_to_text_2", "Speech2Text2ForCausalLM"),
325
        ("transfo-xl", "TransfoXLLMHeadModel"),
326
327
        ("trocr", "TrOCRForCausalLM"),
        ("xglm", "XGLMForCausalLM"),
328
329
        ("xlm", "XLMWithLMHeadModel"),
        ("xlm-prophetnet", "XLMProphetNetForCausalLM"),
330
331
332
        ("xlm-roberta", "XLMRobertaForCausalLM"),
        ("xlm-roberta-xl", "XLMRobertaXLForCausalLM"),
        ("xlnet", "XLNetLMHeadModel"),
333
334
335
    ]
)

NielsRogge's avatar
NielsRogge committed
336
337
338
339
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES = OrderedDict(
    [
        ("deit", "DeiTForMaskedImageModeling"),
        ("swin", "SwinForMaskedImageModeling"),
340
        ("swinv2", "Swinv2ForMaskedImageModeling"),
341
        ("vit", "ViTForMaskedImageModeling"),
NielsRogge's avatar
NielsRogge committed
342
343
344
345
    ]
)


NielsRogge's avatar
NielsRogge committed
346
347
348
349
350
351
352
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES = OrderedDict(
    # Model for Causal Image Modeling mapping
    [
        ("imagegpt", "ImageGPTForCausalImageModeling"),
    ]
)

353
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
354
355
    [
        # Model for Image Classification mapping
356
        ("beit", "BeitForImageClassification"),
357
        ("convnext", "ConvNextForImageClassification"),
NielsRogge's avatar
NielsRogge committed
358
        ("cvt", "CvtForImageClassification"),
359
        ("data2vec-vision", "Data2VecVisionForImageClassification"),
360
        ("deit", ("DeiTForImageClassification", "DeiTForImageClassificationWithTeacher")),
NielsRogge's avatar
NielsRogge committed
361
        ("imagegpt", "ImageGPTForImageClassification"),
362
        ("levit", ("LevitForImageClassification", "LevitForImageClassificationWithTeacher")),
363
        ("mobilevit", "MobileViTForImageClassification"),
NielsRogge's avatar
NielsRogge committed
364
365
366
367
368
369
370
371
        (
            "perceiver",
            (
                "PerceiverForImageClassificationLearned",
                "PerceiverForImageClassificationFourier",
                "PerceiverForImageClassificationConvProcessing",
            ),
        ),
372
373
374
375
        ("poolformer", "PoolFormerForImageClassification"),
        ("regnet", "RegNetForImageClassification"),
        ("resnet", "ResNetForImageClassification"),
        ("segformer", "SegformerForImageClassification"),
novice's avatar
novice committed
376
        ("swin", "SwinForImageClassification"),
377
        ("swinv2", "Swinv2ForImageClassification"),
378
        ("van", "VanForImageClassification"),
379
        ("vit", "ViTForImageClassification"),
380
        ("vit_msn", "ViTMSNForImageClassification"),
381
382
383
    ]
)

384
385
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = OrderedDict(
    [
386
        # Do not add new models here, this class will be deprecated in the future.
387
388
389
390
391
        # Model for Image Segmentation mapping
        ("detr", "DetrForSegmentation"),
    ]
)

392
393
394
395
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict(
    [
        # Model for Semantic Segmentation mapping
        ("beit", "BeitForSemanticSegmentation"),
396
        ("data2vec-vision", "Data2VecVisionForSemanticSegmentation"),
NielsRogge's avatar
NielsRogge committed
397
        ("dpt", "DPTForSemanticSegmentation"),
398
        ("mobilevit", "MobileViTForSemanticSegmentation"),
399
        ("segformer", "SegformerForSemanticSegmentation"),
400
401
402
    ]
)

403
404
405
406
407
408
409
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES = OrderedDict(
    [
        # Model for Instance Segmentation mapping
        ("maskformer", "MaskFormerForInstanceSegmentation"),
    ]
)

NielsRogge's avatar
NielsRogge committed
410
411
412
413
414
415
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
    [
        ("videomae", "VideoMAEForVideoClassification"),
    ]
)

416
417
418
419
420
421
MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
    [
        ("vision-encoder-decoder", "VisionEncoderDecoderModel"),
    ]
)

422
MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
423
    [
424
        # Model for Masked LM mapping
425
426
        ("albert", "AlbertForMaskedLM"),
        ("bart", "BartForConditionalGeneration"),
427
428
        ("bert", "BertForMaskedLM"),
        ("big_bird", "BigBirdForMaskedLM"),
429
        ("camembert", "CamembertForMaskedLM"),
430
        ("convbert", "ConvBertForMaskedLM"),
431
        ("data2vec-text", "Data2VecTextForMaskedLM"),
432
433
434
        ("deberta", "DebertaForMaskedLM"),
        ("deberta-v2", "DebertaV2ForMaskedLM"),
        ("distilbert", "DistilBertForMaskedLM"),
435
        ("electra", "ElectraForMaskedLM"),
436
        ("ernie", "ErnieForMaskedLM"),
437
438
        ("flaubert", "FlaubertWithLMHeadModel"),
        ("fnet", "FNetForMaskedLM"),
439
        ("funnel", "FunnelForMaskedLM"),
440
441
442
        ("ibert", "IBertForMaskedLM"),
        ("layoutlm", "LayoutLMForMaskedLM"),
        ("longformer", "LongformerForMaskedLM"),
Ryokan RI's avatar
Ryokan RI committed
443
        ("luke", "LukeForMaskedLM"),
444
445
446
        ("mbart", "MBartForConditionalGeneration"),
        ("megatron-bert", "MegatronBertForMaskedLM"),
        ("mobilebert", "MobileBertForMaskedLM"),
447
        ("mpnet", "MPNetForMaskedLM"),
StevenTang1998's avatar
StevenTang1998 committed
448
        ("mvp", "MvpForConditionalGeneration"),
449
        ("nezha", "NezhaForMaskedLM"),
450
451
452
453
454
455
456
457
        ("nystromformer", "NystromformerForMaskedLM"),
        ("perceiver", "PerceiverForMaskedLM"),
        ("qdqbert", "QDQBertForMaskedLM"),
        ("reformer", "ReformerForMaskedLM"),
        ("rembert", "RemBertForMaskedLM"),
        ("roberta", "RobertaForMaskedLM"),
        ("roformer", "RoFormerForMaskedLM"),
        ("squeezebert", "SqueezeBertForMaskedLM"),
458
        ("tapas", "TapasForMaskedLM"),
459
460
461
462
463
        ("wav2vec2", "Wav2Vec2ForMaskedLM"),
        ("xlm", "XLMWithLMHeadModel"),
        ("xlm-roberta", "XLMRobertaForMaskedLM"),
        ("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"),
        ("yoso", "YosoForMaskedLM"),
464
465
466
    ]
)

467
MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
NielsRogge's avatar
NielsRogge committed
468
469
    [
        # Model for Object Detection mapping
470
        ("conditional_detr", "ConditionalDetrForObjectDetection"),
NielsRogge's avatar
NielsRogge committed
471
        ("deformable_detr", "DeformableDetrForObjectDetection"),
472
        ("detr", "DetrForObjectDetection"),
473
        ("table-transformer", "TableTransformerForObjectDetection"),
474
        ("yolos", "YolosForObjectDetection"),
NielsRogge's avatar
NielsRogge committed
475
476
477
    ]
)

478
479
480
481
482
483
484
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
    [
        # Model for Zero Shot Object Detection mapping
        ("owlvit", "OwlViTForObjectDetection")
    ]
)

485
486
487
488
489
490
491
MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES = OrderedDict(
    [
        # Model for depth estimation mapping
        ("dpt", "DPTForDepthEstimation"),
        ("glpn", "GLPNForDepthEstimation"),
    ]
)
492
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
493
    [
494
        # Model for Seq2Seq Causal LM mapping
495
        ("bart", "BartForConditionalGeneration"),
496
        ("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"),
497
        ("blenderbot", "BlenderbotForConditionalGeneration"),
498
        ("blenderbot-small", "BlenderbotSmallForConditionalGeneration"),
499
500
501
        ("encoder-decoder", "EncoderDecoderModel"),
        ("fsmt", "FSMTForConditionalGeneration"),
        ("led", "LEDForConditionalGeneration"),
Daniel Stancl's avatar
Daniel Stancl committed
502
        ("longt5", "LongT5ForConditionalGeneration"),
503
        ("m2m_100", "M2M100ForConditionalGeneration"),
504
505
        ("marian", "MarianMTModel"),
        ("mbart", "MBartForConditionalGeneration"),
506
        ("mt5", "MT5ForConditionalGeneration"),
StevenTang1998's avatar
StevenTang1998 committed
507
        ("mvp", "MvpForConditionalGeneration"),
Lysandre Debut's avatar
Lysandre Debut committed
508
        ("nllb", "M2M100ForConditionalGeneration"),
509
        ("pegasus", "PegasusForConditionalGeneration"),
Jason Phang's avatar
Jason Phang committed
510
        ("pegasus_x", "PegasusXForConditionalGeneration"),
511
        ("plbart", "PLBartForConditionalGeneration"),
512
        ("prophetnet", "ProphetNetForConditionalGeneration"),
513
514
        ("t5", "T5ForConditionalGeneration"),
        ("xlm-prophetnet", "XLMProphetNetForConditionalGeneration"),
515
516
517
    ]
)

518
519
520
521
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
    [
        ("speech-encoder-decoder", "SpeechEncoderDecoderModel"),
        ("speech_to_text", "Speech2TextForConditionalGeneration"),
522
        ("whisper", "WhisperForConditionalGeneration"),
523
524
525
    ]
)

526
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
527
    [
528
        # Model for Sequence Classification mapping
529
530
531
        ("albert", "AlbertForSequenceClassification"),
        ("bart", "BartForSequenceClassification"),
        ("bert", "BertForSequenceClassification"),
532
533
        ("big_bird", "BigBirdForSequenceClassification"),
        ("bigbird_pegasus", "BigBirdPegasusForSequenceClassification"),
534
        ("bloom", "BloomForSequenceClassification"),
535
536
537
538
539
        ("camembert", "CamembertForSequenceClassification"),
        ("canine", "CanineForSequenceClassification"),
        ("convbert", "ConvBertForSequenceClassification"),
        ("ctrl", "CTRLForSequenceClassification"),
        ("data2vec-text", "Data2VecTextForSequenceClassification"),
540
541
        ("deberta", "DebertaForSequenceClassification"),
        ("deberta-v2", "DebertaV2ForSequenceClassification"),
542
543
        ("distilbert", "DistilBertForSequenceClassification"),
        ("electra", "ElectraForSequenceClassification"),
544
        ("ernie", "ErnieForSequenceClassification"),
545
        ("esm", "EsmForSequenceClassification"),
546
547
548
        ("flaubert", "FlaubertForSequenceClassification"),
        ("fnet", "FNetForSequenceClassification"),
        ("funnel", "FunnelForSequenceClassification"),
549
550
        ("gpt2", "GPT2ForSequenceClassification"),
        ("gpt_neo", "GPTNeoForSequenceClassification"),
551
552
553
554
        ("gptj", "GPTJForSequenceClassification"),
        ("ibert", "IBertForSequenceClassification"),
        ("layoutlm", "LayoutLMForSequenceClassification"),
        ("layoutlmv2", "LayoutLMv2ForSequenceClassification"),
NielsRogge's avatar
NielsRogge committed
555
        ("layoutlmv3", "LayoutLMv3ForSequenceClassification"),
556
        ("led", "LEDForSequenceClassification"),
NielsRogge's avatar
NielsRogge committed
557
        ("lilt", "LiltForSequenceClassification"),
558
        ("longformer", "LongformerForSequenceClassification"),
559
        ("luke", "LukeForSequenceClassification"),
NielsRogge's avatar
NielsRogge committed
560
        ("markuplm", "MarkupLMForSequenceClassification"),
561
562
563
564
        ("mbart", "MBartForSequenceClassification"),
        ("megatron-bert", "MegatronBertForSequenceClassification"),
        ("mobilebert", "MobileBertForSequenceClassification"),
        ("mpnet", "MPNetForSequenceClassification"),
StevenTang1998's avatar
StevenTang1998 committed
565
        ("mvp", "MvpForSequenceClassification"),
566
        ("nezha", "NezhaForSequenceClassification"),
567
        ("nystromformer", "NystromformerForSequenceClassification"),
568
        ("openai-gpt", "OpenAIGPTForSequenceClassification"),
569
        ("opt", "OPTForSequenceClassification"),
570
571
572
        ("perceiver", "PerceiverForSequenceClassification"),
        ("plbart", "PLBartForSequenceClassification"),
        ("qdqbert", "QDQBertForSequenceClassification"),
573
        ("reformer", "ReformerForSequenceClassification"),
574
575
576
577
        ("rembert", "RemBertForSequenceClassification"),
        ("roberta", "RobertaForSequenceClassification"),
        ("roformer", "RoFormerForSequenceClassification"),
        ("squeezebert", "SqueezeBertForSequenceClassification"),
578
        ("tapas", "TapasForSequenceClassification"),
579
580
581
582
583
584
        ("transfo-xl", "TransfoXLForSequenceClassification"),
        ("xlm", "XLMForSequenceClassification"),
        ("xlm-roberta", "XLMRobertaForSequenceClassification"),
        ("xlm-roberta-xl", "XLMRobertaXLForSequenceClassification"),
        ("xlnet", "XLNetForSequenceClassification"),
        ("yoso", "YosoForSequenceClassification"),
585
586
587
    ]
)

588
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
589
    [
590
        # Model for Question Answering mapping
591
592
593
594
595
        ("albert", "AlbertForQuestionAnswering"),
        ("bart", "BartForQuestionAnswering"),
        ("bert", "BertForQuestionAnswering"),
        ("big_bird", "BigBirdForQuestionAnswering"),
        ("bigbird_pegasus", "BigBirdPegasusForQuestionAnswering"),
596
        ("bloom", "BloomForQuestionAnswering"),
597
598
599
600
601
602
603
604
        ("camembert", "CamembertForQuestionAnswering"),
        ("canine", "CanineForQuestionAnswering"),
        ("convbert", "ConvBertForQuestionAnswering"),
        ("data2vec-text", "Data2VecTextForQuestionAnswering"),
        ("deberta", "DebertaForQuestionAnswering"),
        ("deberta-v2", "DebertaV2ForQuestionAnswering"),
        ("distilbert", "DistilBertForQuestionAnswering"),
        ("electra", "ElectraForQuestionAnswering"),
605
        ("ernie", "ErnieForQuestionAnswering"),
606
        ("flaubert", "FlaubertForQuestionAnsweringSimple"),
Gunjan Chhablani's avatar
Gunjan Chhablani committed
607
        ("fnet", "FNetForQuestionAnswering"),
608
        ("funnel", "FunnelForQuestionAnswering"),
609
        ("gptj", "GPTJForQuestionAnswering"),
610
        ("ibert", "IBertForQuestionAnswering"),
611
        ("layoutlmv2", "LayoutLMv2ForQuestionAnswering"),
NielsRogge's avatar
NielsRogge committed
612
        ("layoutlmv3", "LayoutLMv3ForQuestionAnswering"),
613
        ("led", "LEDForQuestionAnswering"),
NielsRogge's avatar
NielsRogge committed
614
        ("lilt", "LiltForQuestionAnswering"),
615
        ("longformer", "LongformerForQuestionAnswering"),
616
        ("luke", "LukeForQuestionAnswering"),
617
        ("lxmert", "LxmertForQuestionAnswering"),
NielsRogge's avatar
NielsRogge committed
618
        ("markuplm", "MarkupLMForQuestionAnswering"),
619
        ("mbart", "MBartForQuestionAnswering"),
620
621
622
        ("megatron-bert", "MegatronBertForQuestionAnswering"),
        ("mobilebert", "MobileBertForQuestionAnswering"),
        ("mpnet", "MPNetForQuestionAnswering"),
StevenTang1998's avatar
StevenTang1998 committed
623
        ("mvp", "MvpForQuestionAnswering"),
624
        ("nezha", "NezhaForQuestionAnswering"),
625
        ("nystromformer", "NystromformerForQuestionAnswering"),
626
        ("opt", "OPTForQuestionAnswering"),
627
628
629
630
631
        ("qdqbert", "QDQBertForQuestionAnswering"),
        ("reformer", "ReformerForQuestionAnswering"),
        ("rembert", "RemBertForQuestionAnswering"),
        ("roberta", "RobertaForQuestionAnswering"),
        ("roformer", "RoFormerForQuestionAnswering"),
Ori Ram's avatar
Ori Ram committed
632
        ("splinter", "SplinterForQuestionAnswering"),
633
634
635
636
637
638
        ("squeezebert", "SqueezeBertForQuestionAnswering"),
        ("xlm", "XLMForQuestionAnsweringSimple"),
        ("xlm-roberta", "XLMRobertaForQuestionAnswering"),
        ("xlm-roberta-xl", "XLMRobertaXLForQuestionAnswering"),
        ("xlnet", "XLNetForQuestionAnsweringSimple"),
        ("yoso", "YosoForQuestionAnswering"),
639
640
641
    ]
)

642
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
643
644
    [
        # Model for Table Question Answering mapping
645
        ("tapas", "TapasForQuestionAnswering"),
646
647
648
    ]
)

649
650
651
652
653
654
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
    [
        ("vilt", "ViltForQuestionAnswering"),
    ]
)

655
656
657
658
659
660
661
662
MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
    [
        ("layoutlm", "LayoutLMForQuestionAnswering"),
        ("layoutlmv2", "LayoutLMv2ForQuestionAnswering"),
        ("layoutlmv3", "LayoutLMv3ForQuestionAnswering"),
    ]
)

663
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
Julien Chaumond's avatar
Julien Chaumond committed
664
    [
665
        # Model for Token Classification mapping
666
667
        ("albert", "AlbertForTokenClassification"),
        ("bert", "BertForTokenClassification"),
668
        ("big_bird", "BigBirdForTokenClassification"),
669
        ("bloom", "BloomForTokenClassification"),
670
671
        ("camembert", "CamembertForTokenClassification"),
        ("canine", "CanineForTokenClassification"),
672
        ("convbert", "ConvBertForTokenClassification"),
673
674
675
        ("data2vec-text", "Data2VecTextForTokenClassification"),
        ("deberta", "DebertaForTokenClassification"),
        ("deberta-v2", "DebertaV2ForTokenClassification"),
676
        ("distilbert", "DistilBertForTokenClassification"),
677
        ("electra", "ElectraForTokenClassification"),
678
        ("ernie", "ErnieForTokenClassification"),
679
        ("esm", "EsmForTokenClassification"),
680
        ("flaubert", "FlaubertForTokenClassification"),
681
682
683
684
685
686
        ("fnet", "FNetForTokenClassification"),
        ("funnel", "FunnelForTokenClassification"),
        ("gpt2", "GPT2ForTokenClassification"),
        ("ibert", "IBertForTokenClassification"),
        ("layoutlm", "LayoutLMForTokenClassification"),
        ("layoutlmv2", "LayoutLMv2ForTokenClassification"),
NielsRogge's avatar
NielsRogge committed
687
        ("layoutlmv3", "LayoutLMv3ForTokenClassification"),
NielsRogge's avatar
NielsRogge committed
688
        ("lilt", "LiltForTokenClassification"),
689
        ("longformer", "LongformerForTokenClassification"),
690
        ("luke", "LukeForTokenClassification"),
NielsRogge's avatar
NielsRogge committed
691
        ("markuplm", "MarkupLMForTokenClassification"),
692
693
694
        ("megatron-bert", "MegatronBertForTokenClassification"),
        ("mobilebert", "MobileBertForTokenClassification"),
        ("mpnet", "MPNetForTokenClassification"),
695
        ("nezha", "NezhaForTokenClassification"),
696
697
698
699
700
701
702
703
704
705
706
        ("nystromformer", "NystromformerForTokenClassification"),
        ("qdqbert", "QDQBertForTokenClassification"),
        ("rembert", "RemBertForTokenClassification"),
        ("roberta", "RobertaForTokenClassification"),
        ("roformer", "RoFormerForTokenClassification"),
        ("squeezebert", "SqueezeBertForTokenClassification"),
        ("xlm", "XLMForTokenClassification"),
        ("xlm-roberta", "XLMRobertaForTokenClassification"),
        ("xlm-roberta-xl", "XLMRobertaXLForTokenClassification"),
        ("xlnet", "XLNetForTokenClassification"),
        ("yoso", "YosoForTokenClassification"),
Julien Chaumond's avatar
Julien Chaumond committed
707
708
709
    ]
)

710
MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
Julien Chaumond's avatar
Julien Chaumond committed
711
    [
712
        # Model for Multiple Choice mapping
713
714
        ("albert", "AlbertForMultipleChoice"),
        ("bert", "BertForMultipleChoice"),
715
716
        ("big_bird", "BigBirdForMultipleChoice"),
        ("camembert", "CamembertForMultipleChoice"),
717
718
        ("canine", "CanineForMultipleChoice"),
        ("convbert", "ConvBertForMultipleChoice"),
719
        ("data2vec-text", "Data2VecTextForMultipleChoice"),
720
        ("deberta-v2", "DebertaV2ForMultipleChoice"),
721
        ("distilbert", "DistilBertForMultipleChoice"),
722
        ("electra", "ElectraForMultipleChoice"),
723
        ("ernie", "ErnieForMultipleChoice"),
724
        ("flaubert", "FlaubertForMultipleChoice"),
725
        ("fnet", "FNetForMultipleChoice"),
726
727
        ("funnel", "FunnelForMultipleChoice"),
        ("ibert", "IBertForMultipleChoice"),
728
        ("longformer", "LongformerForMultipleChoice"),
729
        ("luke", "LukeForMultipleChoice"),
730
731
732
        ("megatron-bert", "MegatronBertForMultipleChoice"),
        ("mobilebert", "MobileBertForMultipleChoice"),
        ("mpnet", "MPNetForMultipleChoice"),
733
        ("nezha", "NezhaForMultipleChoice"),
734
735
736
737
738
739
740
741
742
743
744
        ("nystromformer", "NystromformerForMultipleChoice"),
        ("qdqbert", "QDQBertForMultipleChoice"),
        ("rembert", "RemBertForMultipleChoice"),
        ("roberta", "RobertaForMultipleChoice"),
        ("roformer", "RoFormerForMultipleChoice"),
        ("squeezebert", "SqueezeBertForMultipleChoice"),
        ("xlm", "XLMForMultipleChoice"),
        ("xlm-roberta", "XLMRobertaForMultipleChoice"),
        ("xlm-roberta-xl", "XLMRobertaXLForMultipleChoice"),
        ("xlnet", "XLNetForMultipleChoice"),
        ("yoso", "YosoForMultipleChoice"),
Julien Chaumond's avatar
Julien Chaumond committed
745
746
747
    ]
)

748
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
749
    [
750
        ("bert", "BertForNextSentencePrediction"),
751
        ("ernie", "ErnieForNextSentencePrediction"),
Gunjan Chhablani's avatar
Gunjan Chhablani committed
752
        ("fnet", "FNetForNextSentencePrediction"),
753
754
        ("megatron-bert", "MegatronBertForNextSentencePrediction"),
        ("mobilebert", "MobileBertForNextSentencePrediction"),
755
        ("nezha", "NezhaForNextSentencePrediction"),
756
        ("qdqbert", "QDQBertForNextSentencePrediction"),
757
758
759
    ]
)

760
761
762
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
    [
        # Model for Audio Classification mapping
763
        ("data2vec-audio", "Data2VecAudioForSequenceClassification"),
764
        ("hubert", "HubertForSequenceClassification"),
765
766
        ("sew", "SEWForSequenceClassification"),
        ("sew-d", "SEWDForSequenceClassification"),
767
768
769
        ("unispeech", "UniSpeechForSequenceClassification"),
        ("unispeech-sat", "UniSpeechSatForSequenceClassification"),
        ("wav2vec2", "Wav2Vec2ForSequenceClassification"),
770
        ("wav2vec2-conformer", "Wav2Vec2ConformerForSequenceClassification"),
Patrick von Platen's avatar
Patrick von Platen committed
771
        ("wavlm", "WavLMForSequenceClassification"),
772
773
774
    ]
)

775
776
777
MODEL_FOR_CTC_MAPPING_NAMES = OrderedDict(
    [
        # Model for Connectionist temporal classification (CTC) mapping
778
        ("data2vec-audio", "Data2VecAudioForCTC"),
779
        ("hubert", "HubertForCTC"),
Chan Woo Kim's avatar
Chan Woo Kim committed
780
        ("mctct", "MCTCTForCTC"),
781
782
        ("sew", "SEWForCTC"),
        ("sew-d", "SEWDForCTC"),
783
784
785
        ("unispeech", "UniSpeechForCTC"),
        ("unispeech-sat", "UniSpeechSatForCTC"),
        ("wav2vec2", "Wav2Vec2ForCTC"),
786
        ("wav2vec2-conformer", "Wav2Vec2ConformerForCTC"),
Patrick von Platen's avatar
Patrick von Platen committed
787
        ("wavlm", "WavLMForCTC"),
788
789
790
    ]
)

791
792
793
MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
    [
        # Model for Audio Classification mapping
794
        ("data2vec-audio", "Data2VecAudioForAudioFrameClassification"),
795
        ("unispeech-sat", "UniSpeechSatForAudioFrameClassification"),
796
        ("wav2vec2", "Wav2Vec2ForAudioFrameClassification"),
797
        ("wav2vec2-conformer", "Wav2Vec2ConformerForAudioFrameClassification"),
798
        ("wavlm", "WavLMForAudioFrameClassification"),
799
800
801
802
803
804
    ]
)

MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = OrderedDict(
    [
        # Model for Audio Classification mapping
805
        ("data2vec-audio", "Data2VecAudioForXVector"),
806
        ("unispeech-sat", "UniSpeechSatForXVector"),
807
        ("wav2vec2", "Wav2Vec2ForXVector"),
808
        ("wav2vec2-conformer", "Wav2Vec2ConformerForXVector"),
809
        ("wavlm", "WavLMForXVector"),
810
811
812
    ]
)

813
814
815
816
_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
    [
        # Model for Zero Shot Image Classification mapping
        ("clip", "CLIPModel"),
NielsRogge's avatar
NielsRogge committed
817
        ("clipseg", "CLIPSegModel"),
818
819
820
    ]
)

821
822
823
824
MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES)
MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES)
MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
NielsRogge's avatar
NielsRogge committed
825
826
827
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES
)
828
829
830
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
)
831
832
833
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES
)
834
835
836
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES
)
837
838
839
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES
)
NielsRogge's avatar
NielsRogge committed
840
841
842
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES
)
843
MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
844
845
846
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES
)
847
848
849
MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES
)
850
MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES)
NielsRogge's avatar
NielsRogge committed
851
852
853
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES
)
854
MODEL_FOR_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES)
855
856
857
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES
)
858
MODEL_FOR_DEPTH_ESTIMATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES)
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
)
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
)
MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
)
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES
)
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
)
MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES)
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
)
878
879
880
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
)
881
882
MODEL_FOR_CTC_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CTC_MAPPING_NAMES)
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES)
883
884
885
886
MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES
)
MODEL_FOR_AUDIO_XVECTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES)
887

888

Sylvain Gugger's avatar
Sylvain Gugger committed
889
890
891
892
893
894
895
896
897
898
899
900
class AutoModel(_BaseAutoModelClass):
    _model_mapping = MODEL_MAPPING


AutoModel = auto_class_update(AutoModel)


class AutoModelForPreTraining(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_PRETRAINING_MAPPING


AutoModelForPreTraining = auto_class_update(AutoModelForPreTraining, head_doc="pretraining")
901

thomwolf's avatar
thomwolf committed
902

903
# Private on purpose, the public class will add the deprecation warnings.
Sylvain Gugger's avatar
Sylvain Gugger committed
904
905
class _AutoModelWithLMHead(_BaseAutoModelClass):
    _model_mapping = MODEL_WITH_LM_HEAD_MAPPING
thomwolf's avatar
thomwolf committed
906
907


Sylvain Gugger's avatar
Sylvain Gugger committed
908
_AutoModelWithLMHead = auto_class_update(_AutoModelWithLMHead, head_doc="language modeling")
thomwolf's avatar
thomwolf committed
909
910


Sylvain Gugger's avatar
Sylvain Gugger committed
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
class AutoModelForCausalLM(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING


AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling")


class AutoModelForMaskedLM(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_MASKED_LM_MAPPING


AutoModelForMaskedLM = auto_class_update(AutoModelForMaskedLM, head_doc="masked language modeling")


class AutoModelForSeq2SeqLM(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING


AutoModelForSeq2SeqLM = auto_class_update(
    AutoModelForSeq2SeqLM, head_doc="sequence-to-sequence language modeling", checkpoint_for_example="t5-base"
931
)
thomwolf's avatar
thomwolf committed
932

Sylvain Gugger's avatar
Sylvain Gugger committed
933
934
935
936
937
938
939

class AutoModelForSequenceClassification(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING


AutoModelForSequenceClassification = auto_class_update(
    AutoModelForSequenceClassification, head_doc="sequence classification"
940
)
thomwolf's avatar
thomwolf committed
941

Sylvain Gugger's avatar
Sylvain Gugger committed
942
943
944
945
946
947
948
949
950
951
952
953
954
955

class AutoModelForQuestionAnswering(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING


AutoModelForQuestionAnswering = auto_class_update(AutoModelForQuestionAnswering, head_doc="question answering")


class AutoModelForTableQuestionAnswering(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING


AutoModelForTableQuestionAnswering = auto_class_update(
    AutoModelForTableQuestionAnswering,
956
957
958
    head_doc="table question answering",
    checkpoint_for_example="google/tapas-base-finetuned-wtq",
)
thomwolf's avatar
thomwolf committed
959
960


961
962
963
964
965
966
967
968
969
970
971
class AutoModelForVisualQuestionAnswering(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING


AutoModelForVisualQuestionAnswering = auto_class_update(
    AutoModelForVisualQuestionAnswering,
    head_doc="visual question answering",
    checkpoint_for_example="dandelin/vilt-b32-finetuned-vqa",
)


972
973
974
975
976
977
978
class AutoModelForDocumentQuestionAnswering(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING


AutoModelForDocumentQuestionAnswering = auto_class_update(
    AutoModelForDocumentQuestionAnswering,
    head_doc="document question answering",
979
    checkpoint_for_example='impira/layoutlm-document-qa", revision="52e01b3',
980
981
982
)


Sylvain Gugger's avatar
Sylvain Gugger committed
983
984
class AutoModelForTokenClassification(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
985

986

Sylvain Gugger's avatar
Sylvain Gugger committed
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
AutoModelForTokenClassification = auto_class_update(AutoModelForTokenClassification, head_doc="token classification")


class AutoModelForMultipleChoice(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_MULTIPLE_CHOICE_MAPPING


AutoModelForMultipleChoice = auto_class_update(AutoModelForMultipleChoice, head_doc="multiple choice")


class AutoModelForNextSentencePrediction(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING


AutoModelForNextSentencePrediction = auto_class_update(
    AutoModelForNextSentencePrediction, head_doc="next sentence prediction"
1003
)
1004

1005

Sylvain Gugger's avatar
Sylvain Gugger committed
1006
1007
1008
1009
1010
1011
1012
class AutoModelForImageClassification(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING


AutoModelForImageClassification = auto_class_update(AutoModelForImageClassification, head_doc="image classification")


1013
1014
1015
1016
1017
1018
1019
class AutoModelForImageSegmentation(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_IMAGE_SEGMENTATION_MAPPING


AutoModelForImageSegmentation = auto_class_update(AutoModelForImageSegmentation, head_doc="image segmentation")


1020
1021
1022
1023
1024
1025
1026
1027
1028
class AutoModelForSemanticSegmentation(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING


AutoModelForSemanticSegmentation = auto_class_update(
    AutoModelForSemanticSegmentation, head_doc="semantic segmentation"
)


1029
1030
1031
1032
1033
1034
1035
1036
1037
class AutoModelForInstanceSegmentation(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING


AutoModelForInstanceSegmentation = auto_class_update(
    AutoModelForInstanceSegmentation, head_doc="instance segmentation"
)


1038
1039
1040
1041
1042
1043
1044
class AutoModelForObjectDetection(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING


AutoModelForObjectDetection = auto_class_update(AutoModelForObjectDetection, head_doc="object detection")


1045
1046
1047
1048
1049
1050
1051
1052
1053
class AutoModelForZeroShotObjectDetection(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING


AutoModelForZeroShotObjectDetection = auto_class_update(
    AutoModelForZeroShotObjectDetection, head_doc="zero-shot object detection"
)


1054
1055
1056
1057
1058
1059
1060
class AutoModelForDepthEstimation(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_DEPTH_ESTIMATION_MAPPING


AutoModelForDepthEstimation = auto_class_update(AutoModelForDepthEstimation, head_doc="depth estimation")


NielsRogge's avatar
NielsRogge committed
1061
1062
1063
1064
1065
1066
1067
class AutoModelForVideoClassification(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING


AutoModelForVideoClassification = auto_class_update(AutoModelForVideoClassification, head_doc="video classification")


1068
1069
1070
1071
1072
1073
1074
class AutoModelForVision2Seq(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_VISION_2_SEQ_MAPPING


AutoModelForVision2Seq = auto_class_update(AutoModelForVision2Seq, head_doc="vision-to-text modeling")


1075
1076
1077
1078
1079
1080
1081
class AutoModelForAudioClassification(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING


AutoModelForAudioClassification = auto_class_update(AutoModelForAudioClassification, head_doc="audio classification")


1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
class AutoModelForCTC(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_CTC_MAPPING


AutoModelForCTC = auto_class_update(AutoModelForCTC, head_doc="connectionist temporal classification")


class AutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING


AutoModelForSpeechSeq2Seq = auto_class_update(
Joao Gante's avatar
Joao Gante committed
1094
    AutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling"
1095
1096
1097
)


1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
class AutoModelForAudioFrameClassification(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING


AutoModelForAudioFrameClassification = auto_class_update(
    AutoModelForAudioFrameClassification, head_doc="audio frame (token) classification"
)


class AutoModelForAudioXVector(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_AUDIO_XVECTOR_MAPPING


AutoModelForAudioXVector = auto_class_update(AutoModelForAudioXVector, head_doc="audio retrieval via x-vector")


NielsRogge's avatar
NielsRogge committed
1114
1115
1116
1117
1118
1119
1120
class AutoModelForMaskedImageModeling(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING


AutoModelForMaskedImageModeling = auto_class_update(AutoModelForMaskedImageModeling, head_doc="masked image modeling")


1121
class AutoModelWithLMHead(_AutoModelWithLMHead):
1122
1123
    @classmethod
    def from_config(cls, config):
1124
        warnings.warn(
1125
1126
1127
            "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use "
            "`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and "
            "`AutoModelForSeq2SeqLM` for encoder-decoder models.",
1128
1129
            FutureWarning,
        )
1130
        return super().from_config(config)
1131
1132
1133

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
1134
        warnings.warn(
1135
1136
1137
            "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use "
            "`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and "
            "`AutoModelForSeq2SeqLM` for encoder-decoder models.",
1138
1139
            FutureWarning,
        )
1140
        return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)