modeling_auto.py 43.8 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Sylvain Gugger's avatar
Sylvain Gugger committed
15
""" Auto Model class."""
thomwolf's avatar
thomwolf committed
16

17
import warnings
Julien Chaumond's avatar
Julien Chaumond committed
18
from collections import OrderedDict
thomwolf's avatar
thomwolf committed
19

Sylvain Gugger's avatar
Sylvain Gugger committed
20
from ...utils import logging
21
22
from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
from .configuration_auto import CONFIG_MAPPING_NAMES
thomwolf's avatar
thomwolf committed
23

thomwolf's avatar
thomwolf committed
24

Lysandre Debut's avatar
Lysandre Debut committed
25
logger = logging.get_logger(__name__)
thomwolf's avatar
thomwolf committed
26
27


28
MODEL_MAPPING_NAMES = OrderedDict(
Julien Chaumond's avatar
Julien Chaumond committed
29
    [
30
        # Base model mapping
31
32
        ("albert", "AlbertModel"),
        ("bart", "BartModel"),
33
        ("beit", "BeitModel"),
34
35
36
37
38
39
        ("bert", "BertModel"),
        ("bert-generation", "BertGenerationEncoder"),
        ("big_bird", "BigBirdModel"),
        ("bigbird_pegasus", "BigBirdPegasusModel"),
        ("blenderbot", "BlenderbotModel"),
        ("blenderbot-small", "BlenderbotSmallModel"),
Younes Belkada's avatar
Younes Belkada committed
40
        ("bloom", "BloomModel"),
41
        ("camembert", "CamembertModel"),
42
43
        ("canine", "CanineModel"),
        ("clip", "CLIPModel"),
rooa's avatar
rooa committed
44
        ("codegen", "CodeGenModel"),
45
        ("conditional_detr", "ConditionalDetrModel"),
46
47
48
        ("convbert", "ConvBertModel"),
        ("convnext", "ConvNextModel"),
        ("ctrl", "CTRLModel"),
NielsRogge's avatar
NielsRogge committed
49
        ("cvt", "CvtModel"),
50
51
52
53
54
55
56
        ("data2vec-audio", "Data2VecAudioModel"),
        ("data2vec-text", "Data2VecTextModel"),
        ("data2vec-vision", "Data2VecVisionModel"),
        ("deberta", "DebertaModel"),
        ("deberta-v2", "DebertaV2Model"),
        ("decision_transformer", "DecisionTransformerModel"),
        ("decision_transformer_gpt2", "DecisionTransformerGPT2Model"),
NielsRogge's avatar
NielsRogge committed
57
        ("deformable_detr", "DeformableDetrModel"),
58
59
        ("deit", "DeiTModel"),
        ("detr", "DetrModel"),
60
        ("distilbert", "DistilBertModel"),
NielsRogge's avatar
NielsRogge committed
61
        ("donut-swin", "DonutSwinModel"),
62
63
64
        ("dpr", "DPRQuestionEncoder"),
        ("dpt", "DPTModel"),
        ("electra", "ElectraModel"),
65
        ("ernie", "ErnieModel"),
66
        ("esm", "EsmModel"),
67
68
69
70
71
72
73
        ("flaubert", "FlaubertModel"),
        ("flava", "FlavaModel"),
        ("fnet", "FNetModel"),
        ("fsmt", "FSMTModel"),
        ("funnel", ("FunnelModel", "FunnelBaseModel")),
        ("glpn", "GLPNModel"),
        ("gpt2", "GPT2Model"),
74
        ("gpt_neo", "GPTNeoModel"),
75
        ("gpt_neox", "GPTNeoXModel"),
76
        ("gpt_neox_japanese", "GPTNeoXJapaneseModel"),
77
        ("gptj", "GPTJModel"),
78
        ("groupvit", "GroupViTModel"),
79
        ("hubert", "HubertModel"),
80
81
82
83
        ("ibert", "IBertModel"),
        ("imagegpt", "ImageGPTModel"),
        ("layoutlm", "LayoutLMModel"),
        ("layoutlmv2", "LayoutLMv2Model"),
NielsRogge's avatar
NielsRogge committed
84
        ("layoutlmv3", "LayoutLMv3Model"),
85
        ("led", "LEDModel"),
86
        ("levit", "LevitModel"),
87
        ("longformer", "LongformerModel"),
Daniel Stancl's avatar
Daniel Stancl committed
88
        ("longt5", "LongT5Model"),
89
90
91
        ("luke", "LukeModel"),
        ("lxmert", "LxmertModel"),
        ("m2m_100", "M2M100Model"),
92
        ("marian", "MarianModel"),
NielsRogge's avatar
NielsRogge committed
93
        ("markuplm", "MarkupLMModel"),
94
        ("maskformer", "MaskFormerModel"),
95
        ("mbart", "MBartModel"),
Chan Woo Kim's avatar
Chan Woo Kim committed
96
        ("mctct", "MCTCTModel"),
97
98
        ("megatron-bert", "MegatronBertModel"),
        ("mobilebert", "MobileBertModel"),
99
        ("mobilevit", "MobileViTModel"),
100
101
        ("mpnet", "MPNetModel"),
        ("mt5", "MT5Model"),
StevenTang1998's avatar
StevenTang1998 committed
102
        ("mvp", "MvpModel"),
103
        ("nezha", "NezhaModel"),
Lysandre Debut's avatar
Lysandre Debut committed
104
        ("nllb", "M2M100Model"),
105
106
        ("nystromformer", "NystromformerModel"),
        ("openai-gpt", "OpenAIGPTModel"),
Younes Belkada's avatar
Younes Belkada committed
107
        ("opt", "OPTModel"),
108
        ("owlvit", "OwlViTModel"),
109
        ("pegasus", "PegasusModel"),
Jason Phang's avatar
Jason Phang committed
110
        ("pegasus_x", "PegasusXModel"),
111
112
113
114
115
116
117
118
119
120
        ("perceiver", "PerceiverModel"),
        ("plbart", "PLBartModel"),
        ("poolformer", "PoolFormerModel"),
        ("prophetnet", "ProphetNetModel"),
        ("qdqbert", "QDQBertModel"),
        ("reformer", "ReformerModel"),
        ("regnet", "RegNetModel"),
        ("rembert", "RemBertModel"),
        ("resnet", "ResNetModel"),
        ("retribert", "RetriBertModel"),
121
        ("roberta", "RobertaModel"),
122
123
124
125
126
127
        ("roformer", "RoFormerModel"),
        ("segformer", "SegformerModel"),
        ("sew", "SEWModel"),
        ("sew-d", "SEWDModel"),
        ("speech_to_text", "Speech2TextModel"),
        ("splinter", "SplinterModel"),
128
        ("squeezebert", "SqueezeBertModel"),
129
        ("swin", "SwinModel"),
130
        ("swinv2", "Swinv2Model"),
131
132
        ("t5", "T5Model"),
        ("tapas", "TapasModel"),
133
        ("time_series_transformer", "TimeSeriesTransformerModel"),
Carl's avatar
Carl committed
134
        ("trajectory_transformer", "TrajectoryTransformerModel"),
135
        ("transfo-xl", "TransfoXLModel"),
136
137
138
        ("unispeech", "UniSpeechModel"),
        ("unispeech-sat", "UniSpeechSatModel"),
        ("van", "VanModel"),
NielsRogge's avatar
NielsRogge committed
139
        ("videomae", "VideoMAEModel"),
140
141
142
143
144
        ("vilt", "ViltModel"),
        ("vision-text-dual-encoder", "VisionTextDualEncoderModel"),
        ("visual_bert", "VisualBertModel"),
        ("vit", "ViTModel"),
        ("vit_mae", "ViTMAEModel"),
145
        ("vit_msn", "ViTMSNModel"),
146
        ("wav2vec2", "Wav2Vec2Model"),
147
        ("wav2vec2-conformer", "Wav2Vec2ConformerModel"),
148
        ("wavlm", "WavLMModel"),
149
        ("whisper", "WhisperModel"),
NielsRogge's avatar
NielsRogge committed
150
        ("xclip", "XCLIPModel"),
151
        ("xglm", "XGLMModel"),
152
153
        ("xlm", "XLMModel"),
        ("xlm-prophetnet", "XLMProphetNetModel"),
154
155
156
157
158
        ("xlm-roberta", "XLMRobertaModel"),
        ("xlm-roberta-xl", "XLMRobertaXLModel"),
        ("xlnet", "XLNetModel"),
        ("yolos", "YolosModel"),
        ("yoso", "YosoModel"),
Julien Chaumond's avatar
Julien Chaumond committed
159
160
161
    ]
)

162
MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
thomwolf's avatar
thomwolf committed
163
    [
164
        # Model for pre-training mapping
165
166
167
168
        ("albert", "AlbertForPreTraining"),
        ("bart", "BartForConditionalGeneration"),
        ("bert", "BertForPreTraining"),
        ("big_bird", "BigBirdForPreTraining"),
Younes Belkada's avatar
Younes Belkada committed
169
        ("bloom", "BloomForCausalLM"),
170
        ("camembert", "CamembertForMaskedLM"),
171
        ("ctrl", "CTRLLMHeadModel"),
172
173
174
175
        ("data2vec-text", "Data2VecTextForMaskedLM"),
        ("deberta", "DebertaForMaskedLM"),
        ("deberta-v2", "DebertaV2ForMaskedLM"),
        ("distilbert", "DistilBertForMaskedLM"),
176
        ("electra", "ElectraForPreTraining"),
177
        ("ernie", "ErnieForPreTraining"),
178
179
180
181
        ("flaubert", "FlaubertWithLMHeadModel"),
        ("flava", "FlavaForPreTraining"),
        ("fnet", "FNetForPreTraining"),
        ("fsmt", "FSMTForConditionalGeneration"),
182
        ("funnel", "FunnelForPreTraining"),
183
184
185
186
        ("gpt2", "GPT2LMHeadModel"),
        ("ibert", "IBertForMaskedLM"),
        ("layoutlm", "LayoutLMForMaskedLM"),
        ("longformer", "LongformerForMaskedLM"),
187
        ("luke", "LukeForMaskedLM"),
188
189
190
        ("lxmert", "LxmertForPreTraining"),
        ("megatron-bert", "MegatronBertForPreTraining"),
        ("mobilebert", "MobileBertForPreTraining"),
191
        ("mpnet", "MPNetForMaskedLM"),
StevenTang1998's avatar
StevenTang1998 committed
192
        ("mvp", "MvpForConditionalGeneration"),
193
        ("nezha", "NezhaForPreTraining"),
194
195
196
        ("openai-gpt", "OpenAIGPTLMHeadModel"),
        ("retribert", "RetriBertModel"),
        ("roberta", "RobertaForMaskedLM"),
197
        ("splinter", "SplinterForPreTraining"),
198
199
        ("squeezebert", "SqueezeBertForMaskedLM"),
        ("t5", "T5ForConditionalGeneration"),
200
        ("tapas", "TapasForMaskedLM"),
201
        ("transfo-xl", "TransfoXLLMHeadModel"),
202
        ("unispeech", "UniSpeechForPreTraining"),
203
        ("unispeech-sat", "UniSpeechSatForPreTraining"),
NielsRogge's avatar
NielsRogge committed
204
        ("videomae", "VideoMAEForPreTraining"),
205
206
207
        ("visual_bert", "VisualBertForPreTraining"),
        ("vit_mae", "ViTMAEForPreTraining"),
        ("wav2vec2", "Wav2Vec2ForPreTraining"),
208
        ("wav2vec2-conformer", "Wav2Vec2ConformerForPreTraining"),
209
210
211
212
        ("xlm", "XLMWithLMHeadModel"),
        ("xlm-roberta", "XLMRobertaForMaskedLM"),
        ("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"),
        ("xlnet", "XLNetLMHeadModel"),
thomwolf's avatar
thomwolf committed
213
214
215
    ]
)

216
MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
217
    [
218
        # Model with LM heads mapping
219
220
221
        ("albert", "AlbertForMaskedLM"),
        ("bart", "BartForConditionalGeneration"),
        ("bert", "BertForMaskedLM"),
222
223
224
        ("big_bird", "BigBirdForMaskedLM"),
        ("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"),
        ("blenderbot-small", "BlenderbotSmallForConditionalGeneration"),
Younes Belkada's avatar
Younes Belkada committed
225
        ("bloom", "BloomForCausalLM"),
226
        ("camembert", "CamembertForMaskedLM"),
rooa's avatar
rooa committed
227
        ("codegen", "CodeGenForCausalLM"),
228
        ("convbert", "ConvBertForMaskedLM"),
229
        ("ctrl", "CTRLLMHeadModel"),
230
231
232
233
        ("data2vec-text", "Data2VecTextForMaskedLM"),
        ("deberta", "DebertaForMaskedLM"),
        ("deberta-v2", "DebertaV2ForMaskedLM"),
        ("distilbert", "DistilBertForMaskedLM"),
234
235
        ("electra", "ElectraForMaskedLM"),
        ("encoder-decoder", "EncoderDecoderModel"),
236
        ("ernie", "ErnieForMaskedLM"),
237
        ("esm", "EsmForMaskedLM"),
238
239
240
        ("flaubert", "FlaubertWithLMHeadModel"),
        ("fnet", "FNetForMaskedLM"),
        ("fsmt", "FSMTForConditionalGeneration"),
241
        ("funnel", "FunnelForMaskedLM"),
242
243
        ("gpt2", "GPT2LMHeadModel"),
        ("gpt_neo", "GPTNeoForCausalLM"),
244
        ("gpt_neox", "GPTNeoXForCausalLM"),
245
        ("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"),
246
247
248
249
250
        ("gptj", "GPTJForCausalLM"),
        ("ibert", "IBertForMaskedLM"),
        ("layoutlm", "LayoutLMForMaskedLM"),
        ("led", "LEDForConditionalGeneration"),
        ("longformer", "LongformerForMaskedLM"),
Daniel Stancl's avatar
Daniel Stancl committed
251
        ("longt5", "LongT5ForConditionalGeneration"),
252
        ("luke", "LukeForMaskedLM"),
253
254
        ("m2m_100", "M2M100ForConditionalGeneration"),
        ("marian", "MarianMTModel"),
NielsRogge's avatar
NielsRogge committed
255
        ("markuplm", "MarkupLMForMaskedLM"),
256
257
        ("megatron-bert", "MegatronBertForCausalLM"),
        ("mobilebert", "MobileBertForMaskedLM"),
258
        ("mpnet", "MPNetForMaskedLM"),
StevenTang1998's avatar
StevenTang1998 committed
259
        ("mvp", "MvpForConditionalGeneration"),
260
        ("nezha", "NezhaForMaskedLM"),
Lysandre Debut's avatar
Lysandre Debut committed
261
        ("nllb", "M2M100ForConditionalGeneration"),
262
263
        ("nystromformer", "NystromformerForMaskedLM"),
        ("openai-gpt", "OpenAIGPTLMHeadModel"),
Jason Phang's avatar
Jason Phang committed
264
        ("pegasus_x", "PegasusXForConditionalGeneration"),
265
266
267
268
269
270
271
272
273
        ("plbart", "PLBartForConditionalGeneration"),
        ("qdqbert", "QDQBertForMaskedLM"),
        ("reformer", "ReformerModelWithLMHead"),
        ("rembert", "RemBertForMaskedLM"),
        ("roberta", "RobertaForMaskedLM"),
        ("roformer", "RoFormerForMaskedLM"),
        ("speech_to_text", "Speech2TextForConditionalGeneration"),
        ("squeezebert", "SqueezeBertForMaskedLM"),
        ("t5", "T5ForConditionalGeneration"),
274
        ("tapas", "TapasForMaskedLM"),
275
276
        ("transfo-xl", "TransfoXLLMHeadModel"),
        ("wav2vec2", "Wav2Vec2ForMaskedLM"),
277
        ("whisper", "WhisperForConditionalGeneration"),
278
279
280
281
282
        ("xlm", "XLMWithLMHeadModel"),
        ("xlm-roberta", "XLMRobertaForMaskedLM"),
        ("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"),
        ("xlnet", "XLNetLMHeadModel"),
        ("yoso", "YosoForMaskedLM"),
283
284
285
    ]
)

286
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
287
    [
288
        # Model for Causal LM mapping
289
290
291
292
293
294
295
        ("bart", "BartForCausalLM"),
        ("bert", "BertLMHeadModel"),
        ("bert-generation", "BertGenerationDecoder"),
        ("big_bird", "BigBirdForCausalLM"),
        ("bigbird_pegasus", "BigBirdPegasusForCausalLM"),
        ("blenderbot", "BlenderbotForCausalLM"),
        ("blenderbot-small", "BlenderbotSmallForCausalLM"),
Younes Belkada's avatar
Younes Belkada committed
296
        ("bloom", "BloomForCausalLM"),
297
        ("camembert", "CamembertForCausalLM"),
rooa's avatar
rooa committed
298
        ("codegen", "CodeGenForCausalLM"),
299
300
301
        ("ctrl", "CTRLLMHeadModel"),
        ("data2vec-text", "Data2VecTextForCausalLM"),
        ("electra", "ElectraForCausalLM"),
302
        ("ernie", "ErnieForCausalLM"),
303
304
        ("gpt2", "GPT2LMHeadModel"),
        ("gpt_neo", "GPTNeoForCausalLM"),
305
        ("gpt_neox", "GPTNeoXForCausalLM"),
306
        ("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"),
307
308
309
310
        ("gptj", "GPTJForCausalLM"),
        ("marian", "MarianForCausalLM"),
        ("mbart", "MBartForCausalLM"),
        ("megatron-bert", "MegatronBertForCausalLM"),
StevenTang1998's avatar
StevenTang1998 committed
311
        ("mvp", "MvpForCausalLM"),
312
313
314
        ("openai-gpt", "OpenAIGPTLMHeadModel"),
        ("opt", "OPTForCausalLM"),
        ("pegasus", "PegasusForCausalLM"),
Gunjan Chhablani's avatar
Gunjan Chhablani committed
315
        ("plbart", "PLBartForCausalLM"),
316
        ("prophetnet", "ProphetNetForCausalLM"),
317
        ("qdqbert", "QDQBertLMHeadModel"),
318
        ("reformer", "ReformerModelWithLMHead"),
319
320
        ("rembert", "RemBertForCausalLM"),
        ("roberta", "RobertaForCausalLM"),
321
322
        ("roformer", "RoFormerForCausalLM"),
        ("speech_to_text_2", "Speech2Text2ForCausalLM"),
323
        ("transfo-xl", "TransfoXLLMHeadModel"),
324
325
        ("trocr", "TrOCRForCausalLM"),
        ("xglm", "XGLMForCausalLM"),
326
327
        ("xlm", "XLMWithLMHeadModel"),
        ("xlm-prophetnet", "XLMProphetNetForCausalLM"),
328
329
330
        ("xlm-roberta", "XLMRobertaForCausalLM"),
        ("xlm-roberta-xl", "XLMRobertaXLForCausalLM"),
        ("xlnet", "XLNetLMHeadModel"),
331
332
333
    ]
)

NielsRogge's avatar
NielsRogge committed
334
335
336
337
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES = OrderedDict(
    [
        ("deit", "DeiTForMaskedImageModeling"),
        ("swin", "SwinForMaskedImageModeling"),
338
        ("swinv2", "Swinv2ForMaskedImageModeling"),
339
        ("vit", "ViTForMaskedImageModeling"),
NielsRogge's avatar
NielsRogge committed
340
341
342
343
    ]
)


NielsRogge's avatar
NielsRogge committed
344
345
346
347
348
349
350
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES = OrderedDict(
    # Model for Causal Image Modeling mapping
    [
        ("imagegpt", "ImageGPTForCausalImageModeling"),
    ]
)

351
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
352
353
    [
        # Model for Image Classification mapping
354
        ("beit", "BeitForImageClassification"),
355
        ("convnext", "ConvNextForImageClassification"),
NielsRogge's avatar
NielsRogge committed
356
        ("cvt", "CvtForImageClassification"),
357
        ("data2vec-vision", "Data2VecVisionForImageClassification"),
358
        ("deit", ("DeiTForImageClassification", "DeiTForImageClassificationWithTeacher")),
NielsRogge's avatar
NielsRogge committed
359
        ("imagegpt", "ImageGPTForImageClassification"),
360
        ("levit", ("LevitForImageClassification", "LevitForImageClassificationWithTeacher")),
361
        ("mobilevit", "MobileViTForImageClassification"),
NielsRogge's avatar
NielsRogge committed
362
363
364
365
366
367
368
369
        (
            "perceiver",
            (
                "PerceiverForImageClassificationLearned",
                "PerceiverForImageClassificationFourier",
                "PerceiverForImageClassificationConvProcessing",
            ),
        ),
370
371
372
373
        ("poolformer", "PoolFormerForImageClassification"),
        ("regnet", "RegNetForImageClassification"),
        ("resnet", "ResNetForImageClassification"),
        ("segformer", "SegformerForImageClassification"),
novice's avatar
novice committed
374
        ("swin", "SwinForImageClassification"),
375
        ("swinv2", "Swinv2ForImageClassification"),
376
        ("van", "VanForImageClassification"),
377
        ("vit", "ViTForImageClassification"),
378
        ("vit_msn", "ViTMSNForImageClassification"),
379
380
381
    ]
)

382
383
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = OrderedDict(
    [
384
        # Do not add new models here, this class will be deprecated in the future.
385
386
387
388
389
        # Model for Image Segmentation mapping
        ("detr", "DetrForSegmentation"),
    ]
)

390
391
392
393
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict(
    [
        # Model for Semantic Segmentation mapping
        ("beit", "BeitForSemanticSegmentation"),
394
        ("data2vec-vision", "Data2VecVisionForSemanticSegmentation"),
NielsRogge's avatar
NielsRogge committed
395
        ("dpt", "DPTForSemanticSegmentation"),
396
        ("mobilevit", "MobileViTForSemanticSegmentation"),
397
        ("segformer", "SegformerForSemanticSegmentation"),
398
399
400
    ]
)

401
402
403
404
405
406
407
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES = OrderedDict(
    [
        # Model for Instance Segmentation mapping
        ("maskformer", "MaskFormerForInstanceSegmentation"),
    ]
)

NielsRogge's avatar
NielsRogge committed
408
409
410
411
412
413
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
    [
        ("videomae", "VideoMAEForVideoClassification"),
    ]
)

414
415
416
417
418
419
MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
    [
        ("vision-encoder-decoder", "VisionEncoderDecoderModel"),
    ]
)

420
MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
421
    [
422
        # Model for Masked LM mapping
423
424
        ("albert", "AlbertForMaskedLM"),
        ("bart", "BartForConditionalGeneration"),
425
426
        ("bert", "BertForMaskedLM"),
        ("big_bird", "BigBirdForMaskedLM"),
427
        ("camembert", "CamembertForMaskedLM"),
428
        ("convbert", "ConvBertForMaskedLM"),
429
        ("data2vec-text", "Data2VecTextForMaskedLM"),
430
431
432
        ("deberta", "DebertaForMaskedLM"),
        ("deberta-v2", "DebertaV2ForMaskedLM"),
        ("distilbert", "DistilBertForMaskedLM"),
433
        ("electra", "ElectraForMaskedLM"),
434
        ("ernie", "ErnieForMaskedLM"),
435
436
        ("flaubert", "FlaubertWithLMHeadModel"),
        ("fnet", "FNetForMaskedLM"),
437
        ("funnel", "FunnelForMaskedLM"),
438
439
440
        ("ibert", "IBertForMaskedLM"),
        ("layoutlm", "LayoutLMForMaskedLM"),
        ("longformer", "LongformerForMaskedLM"),
Ryokan RI's avatar
Ryokan RI committed
441
        ("luke", "LukeForMaskedLM"),
442
443
444
        ("mbart", "MBartForConditionalGeneration"),
        ("megatron-bert", "MegatronBertForMaskedLM"),
        ("mobilebert", "MobileBertForMaskedLM"),
445
        ("mpnet", "MPNetForMaskedLM"),
StevenTang1998's avatar
StevenTang1998 committed
446
        ("mvp", "MvpForConditionalGeneration"),
447
        ("nezha", "NezhaForMaskedLM"),
448
449
450
451
452
453
454
455
        ("nystromformer", "NystromformerForMaskedLM"),
        ("perceiver", "PerceiverForMaskedLM"),
        ("qdqbert", "QDQBertForMaskedLM"),
        ("reformer", "ReformerForMaskedLM"),
        ("rembert", "RemBertForMaskedLM"),
        ("roberta", "RobertaForMaskedLM"),
        ("roformer", "RoFormerForMaskedLM"),
        ("squeezebert", "SqueezeBertForMaskedLM"),
456
        ("tapas", "TapasForMaskedLM"),
457
458
459
460
461
        ("wav2vec2", "Wav2Vec2ForMaskedLM"),
        ("xlm", "XLMWithLMHeadModel"),
        ("xlm-roberta", "XLMRobertaForMaskedLM"),
        ("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"),
        ("yoso", "YosoForMaskedLM"),
462
463
464
    ]
)

465
MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
NielsRogge's avatar
NielsRogge committed
466
467
    [
        # Model for Object Detection mapping
468
        ("conditional_detr", "ConditionalDetrForObjectDetection"),
NielsRogge's avatar
NielsRogge committed
469
        ("deformable_detr", "DeformableDetrForObjectDetection"),
470
        ("detr", "DetrForObjectDetection"),
471
        ("yolos", "YolosForObjectDetection"),
NielsRogge's avatar
NielsRogge committed
472
473
474
    ]
)

475
476
477
478
479
480
481
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
    [
        # Model for Zero Shot Object Detection mapping
        ("owlvit", "OwlViTForObjectDetection")
    ]
)

482
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
483
    [
484
        # Model for Seq2Seq Causal LM mapping
485
        ("bart", "BartForConditionalGeneration"),
486
        ("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"),
487
        ("blenderbot", "BlenderbotForConditionalGeneration"),
488
        ("blenderbot-small", "BlenderbotSmallForConditionalGeneration"),
489
490
491
        ("encoder-decoder", "EncoderDecoderModel"),
        ("fsmt", "FSMTForConditionalGeneration"),
        ("led", "LEDForConditionalGeneration"),
Daniel Stancl's avatar
Daniel Stancl committed
492
        ("longt5", "LongT5ForConditionalGeneration"),
493
        ("m2m_100", "M2M100ForConditionalGeneration"),
494
495
        ("marian", "MarianMTModel"),
        ("mbart", "MBartForConditionalGeneration"),
496
        ("mt5", "MT5ForConditionalGeneration"),
StevenTang1998's avatar
StevenTang1998 committed
497
        ("mvp", "MvpForConditionalGeneration"),
Lysandre Debut's avatar
Lysandre Debut committed
498
        ("nllb", "M2M100ForConditionalGeneration"),
499
        ("pegasus", "PegasusForConditionalGeneration"),
Jason Phang's avatar
Jason Phang committed
500
        ("pegasus_x", "PegasusXForConditionalGeneration"),
501
        ("plbart", "PLBartForConditionalGeneration"),
502
        ("prophetnet", "ProphetNetForConditionalGeneration"),
503
504
        ("t5", "T5ForConditionalGeneration"),
        ("xlm-prophetnet", "XLMProphetNetForConditionalGeneration"),
505
506
507
    ]
)

508
509
510
511
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
    [
        ("speech-encoder-decoder", "SpeechEncoderDecoderModel"),
        ("speech_to_text", "Speech2TextForConditionalGeneration"),
512
        ("whisper", "WhisperForConditionalGeneration"),
513
514
515
    ]
)

516
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
517
    [
518
        # Model for Sequence Classification mapping
519
520
521
        ("albert", "AlbertForSequenceClassification"),
        ("bart", "BartForSequenceClassification"),
        ("bert", "BertForSequenceClassification"),
522
523
        ("big_bird", "BigBirdForSequenceClassification"),
        ("bigbird_pegasus", "BigBirdPegasusForSequenceClassification"),
524
        ("bloom", "BloomForSequenceClassification"),
525
526
527
528
529
        ("camembert", "CamembertForSequenceClassification"),
        ("canine", "CanineForSequenceClassification"),
        ("convbert", "ConvBertForSequenceClassification"),
        ("ctrl", "CTRLForSequenceClassification"),
        ("data2vec-text", "Data2VecTextForSequenceClassification"),
530
531
        ("deberta", "DebertaForSequenceClassification"),
        ("deberta-v2", "DebertaV2ForSequenceClassification"),
532
533
        ("distilbert", "DistilBertForSequenceClassification"),
        ("electra", "ElectraForSequenceClassification"),
534
        ("ernie", "ErnieForSequenceClassification"),
535
        ("esm", "EsmForSequenceClassification"),
536
537
538
        ("flaubert", "FlaubertForSequenceClassification"),
        ("fnet", "FNetForSequenceClassification"),
        ("funnel", "FunnelForSequenceClassification"),
539
540
        ("gpt2", "GPT2ForSequenceClassification"),
        ("gpt_neo", "GPTNeoForSequenceClassification"),
541
542
543
544
        ("gptj", "GPTJForSequenceClassification"),
        ("ibert", "IBertForSequenceClassification"),
        ("layoutlm", "LayoutLMForSequenceClassification"),
        ("layoutlmv2", "LayoutLMv2ForSequenceClassification"),
NielsRogge's avatar
NielsRogge committed
545
        ("layoutlmv3", "LayoutLMv3ForSequenceClassification"),
546
547
        ("led", "LEDForSequenceClassification"),
        ("longformer", "LongformerForSequenceClassification"),
548
        ("luke", "LukeForSequenceClassification"),
NielsRogge's avatar
NielsRogge committed
549
        ("markuplm", "MarkupLMForSequenceClassification"),
550
551
552
553
        ("mbart", "MBartForSequenceClassification"),
        ("megatron-bert", "MegatronBertForSequenceClassification"),
        ("mobilebert", "MobileBertForSequenceClassification"),
        ("mpnet", "MPNetForSequenceClassification"),
StevenTang1998's avatar
StevenTang1998 committed
554
        ("mvp", "MvpForSequenceClassification"),
555
        ("nezha", "NezhaForSequenceClassification"),
556
        ("nystromformer", "NystromformerForSequenceClassification"),
557
        ("openai-gpt", "OpenAIGPTForSequenceClassification"),
558
        ("opt", "OPTForSequenceClassification"),
559
560
561
        ("perceiver", "PerceiverForSequenceClassification"),
        ("plbart", "PLBartForSequenceClassification"),
        ("qdqbert", "QDQBertForSequenceClassification"),
562
        ("reformer", "ReformerForSequenceClassification"),
563
564
565
566
        ("rembert", "RemBertForSequenceClassification"),
        ("roberta", "RobertaForSequenceClassification"),
        ("roformer", "RoFormerForSequenceClassification"),
        ("squeezebert", "SqueezeBertForSequenceClassification"),
567
        ("tapas", "TapasForSequenceClassification"),
568
569
570
571
572
573
        ("transfo-xl", "TransfoXLForSequenceClassification"),
        ("xlm", "XLMForSequenceClassification"),
        ("xlm-roberta", "XLMRobertaForSequenceClassification"),
        ("xlm-roberta-xl", "XLMRobertaXLForSequenceClassification"),
        ("xlnet", "XLNetForSequenceClassification"),
        ("yoso", "YosoForSequenceClassification"),
574
575
576
    ]
)

577
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
578
    [
579
        # Model for Question Answering mapping
580
581
582
583
584
        ("albert", "AlbertForQuestionAnswering"),
        ("bart", "BartForQuestionAnswering"),
        ("bert", "BertForQuestionAnswering"),
        ("big_bird", "BigBirdForQuestionAnswering"),
        ("bigbird_pegasus", "BigBirdPegasusForQuestionAnswering"),
585
        ("bloom", "BloomForQuestionAnswering"),
586
587
588
589
590
591
592
593
        ("camembert", "CamembertForQuestionAnswering"),
        ("canine", "CanineForQuestionAnswering"),
        ("convbert", "ConvBertForQuestionAnswering"),
        ("data2vec-text", "Data2VecTextForQuestionAnswering"),
        ("deberta", "DebertaForQuestionAnswering"),
        ("deberta-v2", "DebertaV2ForQuestionAnswering"),
        ("distilbert", "DistilBertForQuestionAnswering"),
        ("electra", "ElectraForQuestionAnswering"),
594
        ("ernie", "ErnieForQuestionAnswering"),
595
        ("flaubert", "FlaubertForQuestionAnsweringSimple"),
Gunjan Chhablani's avatar
Gunjan Chhablani committed
596
        ("fnet", "FNetForQuestionAnswering"),
597
        ("funnel", "FunnelForQuestionAnswering"),
598
        ("gptj", "GPTJForQuestionAnswering"),
599
        ("ibert", "IBertForQuestionAnswering"),
600
        ("layoutlmv2", "LayoutLMv2ForQuestionAnswering"),
NielsRogge's avatar
NielsRogge committed
601
        ("layoutlmv3", "LayoutLMv3ForQuestionAnswering"),
602
603
        ("led", "LEDForQuestionAnswering"),
        ("longformer", "LongformerForQuestionAnswering"),
604
        ("luke", "LukeForQuestionAnswering"),
605
        ("lxmert", "LxmertForQuestionAnswering"),
NielsRogge's avatar
NielsRogge committed
606
        ("markuplm", "MarkupLMForQuestionAnswering"),
607
        ("mbart", "MBartForQuestionAnswering"),
608
609
610
        ("megatron-bert", "MegatronBertForQuestionAnswering"),
        ("mobilebert", "MobileBertForQuestionAnswering"),
        ("mpnet", "MPNetForQuestionAnswering"),
StevenTang1998's avatar
StevenTang1998 committed
611
        ("mvp", "MvpForQuestionAnswering"),
612
        ("nezha", "NezhaForQuestionAnswering"),
613
        ("nystromformer", "NystromformerForQuestionAnswering"),
614
        ("opt", "OPTForQuestionAnswering"),
615
616
617
618
619
        ("qdqbert", "QDQBertForQuestionAnswering"),
        ("reformer", "ReformerForQuestionAnswering"),
        ("rembert", "RemBertForQuestionAnswering"),
        ("roberta", "RobertaForQuestionAnswering"),
        ("roformer", "RoFormerForQuestionAnswering"),
Ori Ram's avatar
Ori Ram committed
620
        ("splinter", "SplinterForQuestionAnswering"),
621
622
623
624
625
626
        ("squeezebert", "SqueezeBertForQuestionAnswering"),
        ("xlm", "XLMForQuestionAnsweringSimple"),
        ("xlm-roberta", "XLMRobertaForQuestionAnswering"),
        ("xlm-roberta-xl", "XLMRobertaXLForQuestionAnswering"),
        ("xlnet", "XLNetForQuestionAnsweringSimple"),
        ("yoso", "YosoForQuestionAnswering"),
627
628
629
    ]
)

630
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
631
632
    [
        # Model for Table Question Answering mapping
633
        ("tapas", "TapasForQuestionAnswering"),
634
635
636
    ]
)

637
638
639
640
641
642
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
    [
        ("vilt", "ViltForQuestionAnswering"),
    ]
)

643
644
645
646
647
648
649
650
MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
    [
        ("layoutlm", "LayoutLMForQuestionAnswering"),
        ("layoutlmv2", "LayoutLMv2ForQuestionAnswering"),
        ("layoutlmv3", "LayoutLMv3ForQuestionAnswering"),
    ]
)

651
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
Julien Chaumond's avatar
Julien Chaumond committed
652
    [
653
        # Model for Token Classification mapping
654
655
        ("albert", "AlbertForTokenClassification"),
        ("bert", "BertForTokenClassification"),
656
        ("big_bird", "BigBirdForTokenClassification"),
657
        ("bloom", "BloomForTokenClassification"),
658
659
        ("camembert", "CamembertForTokenClassification"),
        ("canine", "CanineForTokenClassification"),
660
        ("convbert", "ConvBertForTokenClassification"),
661
662
663
        ("data2vec-text", "Data2VecTextForTokenClassification"),
        ("deberta", "DebertaForTokenClassification"),
        ("deberta-v2", "DebertaV2ForTokenClassification"),
664
        ("distilbert", "DistilBertForTokenClassification"),
665
        ("electra", "ElectraForTokenClassification"),
666
        ("ernie", "ErnieForTokenClassification"),
667
        ("esm", "EsmForTokenClassification"),
668
        ("flaubert", "FlaubertForTokenClassification"),
669
670
671
672
673
674
        ("fnet", "FNetForTokenClassification"),
        ("funnel", "FunnelForTokenClassification"),
        ("gpt2", "GPT2ForTokenClassification"),
        ("ibert", "IBertForTokenClassification"),
        ("layoutlm", "LayoutLMForTokenClassification"),
        ("layoutlmv2", "LayoutLMv2ForTokenClassification"),
NielsRogge's avatar
NielsRogge committed
675
        ("layoutlmv3", "LayoutLMv3ForTokenClassification"),
676
        ("longformer", "LongformerForTokenClassification"),
677
        ("luke", "LukeForTokenClassification"),
NielsRogge's avatar
NielsRogge committed
678
        ("markuplm", "MarkupLMForTokenClassification"),
679
680
681
        ("megatron-bert", "MegatronBertForTokenClassification"),
        ("mobilebert", "MobileBertForTokenClassification"),
        ("mpnet", "MPNetForTokenClassification"),
682
        ("nezha", "NezhaForTokenClassification"),
683
684
685
686
687
688
689
690
691
692
693
        ("nystromformer", "NystromformerForTokenClassification"),
        ("qdqbert", "QDQBertForTokenClassification"),
        ("rembert", "RemBertForTokenClassification"),
        ("roberta", "RobertaForTokenClassification"),
        ("roformer", "RoFormerForTokenClassification"),
        ("squeezebert", "SqueezeBertForTokenClassification"),
        ("xlm", "XLMForTokenClassification"),
        ("xlm-roberta", "XLMRobertaForTokenClassification"),
        ("xlm-roberta-xl", "XLMRobertaXLForTokenClassification"),
        ("xlnet", "XLNetForTokenClassification"),
        ("yoso", "YosoForTokenClassification"),
Julien Chaumond's avatar
Julien Chaumond committed
694
695
696
    ]
)

697
MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
Julien Chaumond's avatar
Julien Chaumond committed
698
    [
699
        # Model for Multiple Choice mapping
700
701
        ("albert", "AlbertForMultipleChoice"),
        ("bert", "BertForMultipleChoice"),
702
703
        ("big_bird", "BigBirdForMultipleChoice"),
        ("camembert", "CamembertForMultipleChoice"),
704
705
        ("canine", "CanineForMultipleChoice"),
        ("convbert", "ConvBertForMultipleChoice"),
706
        ("data2vec-text", "Data2VecTextForMultipleChoice"),
707
        ("deberta-v2", "DebertaV2ForMultipleChoice"),
708
        ("distilbert", "DistilBertForMultipleChoice"),
709
        ("electra", "ElectraForMultipleChoice"),
710
        ("ernie", "ErnieForMultipleChoice"),
711
        ("flaubert", "FlaubertForMultipleChoice"),
712
        ("fnet", "FNetForMultipleChoice"),
713
714
        ("funnel", "FunnelForMultipleChoice"),
        ("ibert", "IBertForMultipleChoice"),
715
        ("longformer", "LongformerForMultipleChoice"),
716
        ("luke", "LukeForMultipleChoice"),
717
718
719
        ("megatron-bert", "MegatronBertForMultipleChoice"),
        ("mobilebert", "MobileBertForMultipleChoice"),
        ("mpnet", "MPNetForMultipleChoice"),
720
        ("nezha", "NezhaForMultipleChoice"),
721
722
723
724
725
726
727
728
729
730
731
        ("nystromformer", "NystromformerForMultipleChoice"),
        ("qdqbert", "QDQBertForMultipleChoice"),
        ("rembert", "RemBertForMultipleChoice"),
        ("roberta", "RobertaForMultipleChoice"),
        ("roformer", "RoFormerForMultipleChoice"),
        ("squeezebert", "SqueezeBertForMultipleChoice"),
        ("xlm", "XLMForMultipleChoice"),
        ("xlm-roberta", "XLMRobertaForMultipleChoice"),
        ("xlm-roberta-xl", "XLMRobertaXLForMultipleChoice"),
        ("xlnet", "XLNetForMultipleChoice"),
        ("yoso", "YosoForMultipleChoice"),
Julien Chaumond's avatar
Julien Chaumond committed
732
733
734
    ]
)

735
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
736
    [
737
        ("bert", "BertForNextSentencePrediction"),
738
        ("ernie", "ErnieForNextSentencePrediction"),
Gunjan Chhablani's avatar
Gunjan Chhablani committed
739
        ("fnet", "FNetForNextSentencePrediction"),
740
741
        ("megatron-bert", "MegatronBertForNextSentencePrediction"),
        ("mobilebert", "MobileBertForNextSentencePrediction"),
742
        ("nezha", "NezhaForNextSentencePrediction"),
743
        ("qdqbert", "QDQBertForNextSentencePrediction"),
744
745
746
    ]
)

747
748
749
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
    [
        # Model for Audio Classification mapping
750
        ("data2vec-audio", "Data2VecAudioForSequenceClassification"),
751
        ("hubert", "HubertForSequenceClassification"),
752
753
        ("sew", "SEWForSequenceClassification"),
        ("sew-d", "SEWDForSequenceClassification"),
754
755
756
        ("unispeech", "UniSpeechForSequenceClassification"),
        ("unispeech-sat", "UniSpeechSatForSequenceClassification"),
        ("wav2vec2", "Wav2Vec2ForSequenceClassification"),
757
        ("wav2vec2-conformer", "Wav2Vec2ConformerForSequenceClassification"),
Patrick von Platen's avatar
Patrick von Platen committed
758
        ("wavlm", "WavLMForSequenceClassification"),
759
760
761
    ]
)

762
763
764
MODEL_FOR_CTC_MAPPING_NAMES = OrderedDict(
    [
        # Model for Connectionist temporal classification (CTC) mapping
765
        ("data2vec-audio", "Data2VecAudioForCTC"),
766
        ("hubert", "HubertForCTC"),
Chan Woo Kim's avatar
Chan Woo Kim committed
767
        ("mctct", "MCTCTForCTC"),
768
769
        ("sew", "SEWForCTC"),
        ("sew-d", "SEWDForCTC"),
770
771
772
        ("unispeech", "UniSpeechForCTC"),
        ("unispeech-sat", "UniSpeechSatForCTC"),
        ("wav2vec2", "Wav2Vec2ForCTC"),
773
        ("wav2vec2-conformer", "Wav2Vec2ConformerForCTC"),
Patrick von Platen's avatar
Patrick von Platen committed
774
        ("wavlm", "WavLMForCTC"),
775
776
777
    ]
)

778
779
780
MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
    [
        # Model for Audio Classification mapping
781
        ("data2vec-audio", "Data2VecAudioForAudioFrameClassification"),
782
        ("unispeech-sat", "UniSpeechSatForAudioFrameClassification"),
783
        ("wav2vec2", "Wav2Vec2ForAudioFrameClassification"),
784
        ("wav2vec2-conformer", "Wav2Vec2ConformerForAudioFrameClassification"),
785
        ("wavlm", "WavLMForAudioFrameClassification"),
786
787
788
789
790
791
    ]
)

MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = OrderedDict(
    [
        # Model for Audio Classification mapping
792
        ("data2vec-audio", "Data2VecAudioForXVector"),
793
        ("unispeech-sat", "UniSpeechSatForXVector"),
794
        ("wav2vec2", "Wav2Vec2ForXVector"),
795
        ("wav2vec2-conformer", "Wav2Vec2ConformerForXVector"),
796
        ("wavlm", "WavLMForXVector"),
797
798
799
    ]
)

800
801
802
803
804
805
806
_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
    [
        # Model for Zero Shot Image Classification mapping
        ("clip", "CLIPModel"),
    ]
)

807
808
809
810
MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES)
MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES)
MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
NielsRogge's avatar
NielsRogge committed
811
812
813
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES
)
814
815
816
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
)
817
818
819
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES
)
820
821
822
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES
)
823
824
825
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES
)
NielsRogge's avatar
NielsRogge committed
826
827
828
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES
)
829
MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
830
831
832
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES
)
833
834
835
MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES
)
836
MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES)
NielsRogge's avatar
NielsRogge committed
837
838
839
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES
)
840
MODEL_FOR_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES)
841
842
843
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES
)
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
)
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
)
MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
)
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES
)
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
)
MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES)
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
)
863
864
865
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
)
866
867
MODEL_FOR_CTC_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CTC_MAPPING_NAMES)
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES)
868
869
870
871
MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES
)
MODEL_FOR_AUDIO_XVECTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES)
872

873

Sylvain Gugger's avatar
Sylvain Gugger committed
874
875
876
877
878
879
880
881
882
883
884
885
class AutoModel(_BaseAutoModelClass):
    _model_mapping = MODEL_MAPPING


AutoModel = auto_class_update(AutoModel)


class AutoModelForPreTraining(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_PRETRAINING_MAPPING


AutoModelForPreTraining = auto_class_update(AutoModelForPreTraining, head_doc="pretraining")
886

thomwolf's avatar
thomwolf committed
887

888
# Private on purpose, the public class will add the deprecation warnings.
Sylvain Gugger's avatar
Sylvain Gugger committed
889
890
class _AutoModelWithLMHead(_BaseAutoModelClass):
    _model_mapping = MODEL_WITH_LM_HEAD_MAPPING
thomwolf's avatar
thomwolf committed
891
892


Sylvain Gugger's avatar
Sylvain Gugger committed
893
_AutoModelWithLMHead = auto_class_update(_AutoModelWithLMHead, head_doc="language modeling")
thomwolf's avatar
thomwolf committed
894
895


Sylvain Gugger's avatar
Sylvain Gugger committed
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
class AutoModelForCausalLM(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING


AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling")


class AutoModelForMaskedLM(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_MASKED_LM_MAPPING


AutoModelForMaskedLM = auto_class_update(AutoModelForMaskedLM, head_doc="masked language modeling")


class AutoModelForSeq2SeqLM(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING


AutoModelForSeq2SeqLM = auto_class_update(
    AutoModelForSeq2SeqLM, head_doc="sequence-to-sequence language modeling", checkpoint_for_example="t5-base"
916
)
thomwolf's avatar
thomwolf committed
917

Sylvain Gugger's avatar
Sylvain Gugger committed
918
919
920
921
922
923
924

class AutoModelForSequenceClassification(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING


AutoModelForSequenceClassification = auto_class_update(
    AutoModelForSequenceClassification, head_doc="sequence classification"
925
)
thomwolf's avatar
thomwolf committed
926

Sylvain Gugger's avatar
Sylvain Gugger committed
927
928
929
930
931
932
933
934
935
936
937
938
939
940

class AutoModelForQuestionAnswering(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING


AutoModelForQuestionAnswering = auto_class_update(AutoModelForQuestionAnswering, head_doc="question answering")


class AutoModelForTableQuestionAnswering(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING


AutoModelForTableQuestionAnswering = auto_class_update(
    AutoModelForTableQuestionAnswering,
941
942
943
    head_doc="table question answering",
    checkpoint_for_example="google/tapas-base-finetuned-wtq",
)
thomwolf's avatar
thomwolf committed
944
945


946
947
948
949
950
951
952
953
954
955
956
class AutoModelForVisualQuestionAnswering(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING


AutoModelForVisualQuestionAnswering = auto_class_update(
    AutoModelForVisualQuestionAnswering,
    head_doc="visual question answering",
    checkpoint_for_example="dandelin/vilt-b32-finetuned-vqa",
)


957
958
959
960
961
962
963
class AutoModelForDocumentQuestionAnswering(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING


AutoModelForDocumentQuestionAnswering = auto_class_update(
    AutoModelForDocumentQuestionAnswering,
    head_doc="document question answering",
964
    checkpoint_for_example='impira/layoutlm-document-qa", revision="52e01b3',
965
966
967
)


Sylvain Gugger's avatar
Sylvain Gugger committed
968
969
class AutoModelForTokenClassification(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
970

971

Sylvain Gugger's avatar
Sylvain Gugger committed
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
AutoModelForTokenClassification = auto_class_update(AutoModelForTokenClassification, head_doc="token classification")


class AutoModelForMultipleChoice(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_MULTIPLE_CHOICE_MAPPING


AutoModelForMultipleChoice = auto_class_update(AutoModelForMultipleChoice, head_doc="multiple choice")


class AutoModelForNextSentencePrediction(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING


AutoModelForNextSentencePrediction = auto_class_update(
    AutoModelForNextSentencePrediction, head_doc="next sentence prediction"
988
)
989

990

Sylvain Gugger's avatar
Sylvain Gugger committed
991
992
993
994
995
996
997
class AutoModelForImageClassification(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING


AutoModelForImageClassification = auto_class_update(AutoModelForImageClassification, head_doc="image classification")


998
999
1000
1001
1002
1003
1004
class AutoModelForImageSegmentation(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_IMAGE_SEGMENTATION_MAPPING


AutoModelForImageSegmentation = auto_class_update(AutoModelForImageSegmentation, head_doc="image segmentation")


1005
1006
1007
1008
1009
1010
1011
1012
1013
class AutoModelForSemanticSegmentation(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING


AutoModelForSemanticSegmentation = auto_class_update(
    AutoModelForSemanticSegmentation, head_doc="semantic segmentation"
)


1014
1015
1016
1017
1018
1019
1020
1021
1022
class AutoModelForInstanceSegmentation(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING


AutoModelForInstanceSegmentation = auto_class_update(
    AutoModelForInstanceSegmentation, head_doc="instance segmentation"
)


1023
1024
1025
1026
1027
1028
1029
class AutoModelForObjectDetection(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING


AutoModelForObjectDetection = auto_class_update(AutoModelForObjectDetection, head_doc="object detection")


1030
1031
1032
1033
1034
1035
1036
1037
1038
class AutoModelForZeroShotObjectDetection(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING


AutoModelForZeroShotObjectDetection = auto_class_update(
    AutoModelForZeroShotObjectDetection, head_doc="zero-shot object detection"
)


NielsRogge's avatar
NielsRogge committed
1039
1040
1041
1042
1043
1044
1045
class AutoModelForVideoClassification(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING


AutoModelForVideoClassification = auto_class_update(AutoModelForVideoClassification, head_doc="video classification")


1046
1047
1048
1049
1050
1051
1052
class AutoModelForVision2Seq(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_VISION_2_SEQ_MAPPING


AutoModelForVision2Seq = auto_class_update(AutoModelForVision2Seq, head_doc="vision-to-text modeling")


1053
1054
1055
1056
1057
1058
1059
class AutoModelForAudioClassification(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING


AutoModelForAudioClassification = auto_class_update(AutoModelForAudioClassification, head_doc="audio classification")


1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
class AutoModelForCTC(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_CTC_MAPPING


AutoModelForCTC = auto_class_update(AutoModelForCTC, head_doc="connectionist temporal classification")


class AutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING


AutoModelForSpeechSeq2Seq = auto_class_update(
Joao Gante's avatar
Joao Gante committed
1072
    AutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling"
1073
1074
1075
)


1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
class AutoModelForAudioFrameClassification(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING


AutoModelForAudioFrameClassification = auto_class_update(
    AutoModelForAudioFrameClassification, head_doc="audio frame (token) classification"
)


class AutoModelForAudioXVector(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_AUDIO_XVECTOR_MAPPING


AutoModelForAudioXVector = auto_class_update(AutoModelForAudioXVector, head_doc="audio retrieval via x-vector")


NielsRogge's avatar
NielsRogge committed
1092
1093
1094
1095
1096
1097
1098
class AutoModelForMaskedImageModeling(_BaseAutoModelClass):
    _model_mapping = MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING


AutoModelForMaskedImageModeling = auto_class_update(AutoModelForMaskedImageModeling, head_doc="masked image modeling")


1099
class AutoModelWithLMHead(_AutoModelWithLMHead):
1100
1101
    @classmethod
    def from_config(cls, config):
1102
        warnings.warn(
1103
1104
1105
            "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use "
            "`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and "
            "`AutoModelForSeq2SeqLM` for encoder-decoder models.",
1106
1107
            FutureWarning,
        )
1108
        return super().from_config(config)
1109
1110
1111

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
1112
        warnings.warn(
1113
1114
1115
            "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use "
            "`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and "
            "`AutoModelForSeq2SeqLM` for encoder-decoder models.",
1116
1117
            FutureWarning,
        )
1118
        return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)