Unverified Commit 51eb6d34 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Flax] Fix mt5 auto (#12612)

* fix_torch_device_generate_test

* remove @

* fix mt5 auto
parent e7f33e8c
......@@ -84,6 +84,7 @@ from .configuration_auto import (
GPTNeoConfig,
MarianConfig,
MBartConfig,
MT5Config,
RobertaConfig,
T5Config,
ViTConfig,
......@@ -108,6 +109,7 @@ FLAX_MODEL_MAPPING = OrderedDict(
(ViTConfig, FlaxViTModel),
(MBartConfig, FlaxMBartModel),
(T5Config, FlaxT5Model),
(MT5Config, FlaxT5Model),
(Wav2Vec2Config, FlaxWav2Vec2Model),
(MarianConfig, FlaxMarianModel),
]
......@@ -123,6 +125,7 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
(ElectraConfig, FlaxElectraForPreTraining),
(MBartConfig, FlaxMBartForConditionalGeneration),
(T5Config, FlaxT5ForConditionalGeneration),
(MT5Config, FlaxT5ForConditionalGeneration),
(Wav2Vec2Config, FlaxWav2Vec2ForPreTraining),
]
)
......@@ -144,6 +147,7 @@ FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
# Model for Seq2Seq Causal LM mapping
(BartConfig, FlaxBartForConditionalGeneration),
(T5Config, FlaxT5ForConditionalGeneration),
(MT5Config, FlaxT5ForConditionalGeneration),
(MarianConfig, FlaxMarianMTModel),
]
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment