Unverified Commit f74655cd authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[Flax] FlaxAutoModelForSeq2SeqLM (#12228)

* add FlaxAutoModelForSeq2SeqLM
parent e43e1126
...@@ -226,6 +226,13 @@ FlaxAutoModelForMaskedLM ...@@ -226,6 +226,13 @@ FlaxAutoModelForMaskedLM
:members: :members:
FlaxAutoModelForSeq2SeqLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxAutoModelForSeq2SeqLM
:members:
FlaxAutoModelForSequenceClassification FlaxAutoModelForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -1514,6 +1514,7 @@ if is_flax_available(): ...@@ -1514,6 +1514,7 @@ if is_flax_available():
"FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", "FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
"FLAX_MODEL_FOR_PRETRAINING_MAPPING", "FLAX_MODEL_FOR_PRETRAINING_MAPPING",
"FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING", "FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING",
"FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
"FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING", "FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
"FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", "FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"FLAX_MODEL_MAPPING", "FLAX_MODEL_MAPPING",
...@@ -1524,6 +1525,7 @@ if is_flax_available(): ...@@ -1524,6 +1525,7 @@ if is_flax_available():
"FlaxAutoModelForNextSentencePrediction", "FlaxAutoModelForNextSentencePrediction",
"FlaxAutoModelForPreTraining", "FlaxAutoModelForPreTraining",
"FlaxAutoModelForQuestionAnswering", "FlaxAutoModelForQuestionAnswering",
"FlaxAutoModelForSeq2SeqLM",
"FlaxAutoModelForSequenceClassification", "FlaxAutoModelForSequenceClassification",
"FlaxAutoModelForTokenClassification", "FlaxAutoModelForTokenClassification",
] ]
...@@ -2851,6 +2853,7 @@ if TYPE_CHECKING: ...@@ -2851,6 +2853,7 @@ if TYPE_CHECKING:
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
FLAX_MODEL_FOR_PRETRAINING_MAPPING, FLAX_MODEL_FOR_PRETRAINING_MAPPING,
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING, FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
FLAX_MODEL_MAPPING, FLAX_MODEL_MAPPING,
...@@ -2861,6 +2864,7 @@ if TYPE_CHECKING: ...@@ -2861,6 +2864,7 @@ if TYPE_CHECKING:
FlaxAutoModelForNextSentencePrediction, FlaxAutoModelForNextSentencePrediction,
FlaxAutoModelForPreTraining, FlaxAutoModelForPreTraining,
FlaxAutoModelForQuestionAnswering, FlaxAutoModelForQuestionAnswering,
FlaxAutoModelForSeq2SeqLM,
FlaxAutoModelForSequenceClassification, FlaxAutoModelForSequenceClassification,
FlaxAutoModelForTokenClassification, FlaxAutoModelForTokenClassification,
) )
......
...@@ -92,6 +92,7 @@ if is_flax_available(): ...@@ -92,6 +92,7 @@ if is_flax_available():
"FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", "FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
"FLAX_MODEL_FOR_PRETRAINING_MAPPING", "FLAX_MODEL_FOR_PRETRAINING_MAPPING",
"FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING", "FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING",
"FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
"FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING", "FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
"FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", "FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"FLAX_MODEL_MAPPING", "FLAX_MODEL_MAPPING",
...@@ -103,6 +104,7 @@ if is_flax_available(): ...@@ -103,6 +104,7 @@ if is_flax_available():
"FlaxAutoModelForNextSentencePrediction", "FlaxAutoModelForNextSentencePrediction",
"FlaxAutoModelForPreTraining", "FlaxAutoModelForPreTraining",
"FlaxAutoModelForQuestionAnswering", "FlaxAutoModelForQuestionAnswering",
"FlaxAutoModelForSeq2SeqLM",
"FlaxAutoModelForSequenceClassification", "FlaxAutoModelForSequenceClassification",
"FlaxAutoModelForTokenClassification", "FlaxAutoModelForTokenClassification",
] ]
...@@ -178,6 +180,7 @@ if TYPE_CHECKING: ...@@ -178,6 +180,7 @@ if TYPE_CHECKING:
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
FLAX_MODEL_FOR_PRETRAINING_MAPPING, FLAX_MODEL_FOR_PRETRAINING_MAPPING,
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING, FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
FLAX_MODEL_MAPPING, FLAX_MODEL_MAPPING,
...@@ -189,6 +192,7 @@ if TYPE_CHECKING: ...@@ -189,6 +192,7 @@ if TYPE_CHECKING:
FlaxAutoModelForNextSentencePrediction, FlaxAutoModelForNextSentencePrediction,
FlaxAutoModelForPreTraining, FlaxAutoModelForPreTraining,
FlaxAutoModelForQuestionAnswering, FlaxAutoModelForQuestionAnswering,
FlaxAutoModelForSeq2SeqLM,
FlaxAutoModelForSequenceClassification, FlaxAutoModelForSequenceClassification,
FlaxAutoModelForTokenClassification, FlaxAutoModelForTokenClassification,
) )
......
...@@ -129,6 +129,13 @@ FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict( ...@@ -129,6 +129,13 @@ FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict(
] ]
) )
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
[
# Model for Seq2Seq Causal LM mapping
(BartConfig, FlaxBartForConditionalGeneration)
]
)
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict( FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
[ [
# Model for Sequence Classification mapping # Model for Sequence Classification mapping
...@@ -197,6 +204,13 @@ FlaxAutoModelForMaskedLM = auto_class_factory( ...@@ -197,6 +204,13 @@ FlaxAutoModelForMaskedLM = auto_class_factory(
"FlaxAutoModelForMaskedLM", FLAX_MODEL_FOR_MASKED_LM_MAPPING, head_doc="masked language modeling" "FlaxAutoModelForMaskedLM", FLAX_MODEL_FOR_MASKED_LM_MAPPING, head_doc="masked language modeling"
) )
FlaxAutoModelForSeq2SeqLM = auto_class_factory(
"FlaxAutoModelForSeq2SeqLM",
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
head_doc="sequence-to-sequence language modeling",
)
FlaxAutoModelForSequenceClassification = auto_class_factory( FlaxAutoModelForSequenceClassification = auto_class_factory(
"FlaxAutoModelForSequenceClassification", "FlaxAutoModelForSequenceClassification",
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
......
...@@ -94,6 +94,9 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING = None ...@@ -94,6 +94,9 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING = None
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = None FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = None
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = None
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = None FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = None
...@@ -166,6 +169,15 @@ class FlaxAutoModelForQuestionAnswering: ...@@ -166,6 +169,15 @@ class FlaxAutoModelForQuestionAnswering:
requires_backends(cls, ["flax"]) requires_backends(cls, ["flax"])
class FlaxAutoModelForSeq2SeqLM:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxAutoModelForSequenceClassification: class FlaxAutoModelForSequenceClassification:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment