Unverified Commit ea91c2ad authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[AutoModel] Add AutoModelForTextEncoding (#24305)

* [AutoModel] Add AutoModelForTextEncoding

* add mt5

* add other models

* add to docs

* fix tf imports

* add tf to docs / init

* up

* fix inits

* add to dummy objects
parent feb83521
......@@ -214,6 +214,14 @@ The following auto classes are available for the following natural language proc
[[autodoc]] FlaxAutoModelForQuestionAnswering
### AutoModelForTextEncoding
[[autodoc]] AutoModelForTextEncoding
### TFAutoModelForTextEncoding
[[autodoc]] TFAutoModelForTextEncoding
## Computer vision
The following auto classes are available for the following computer vision tasks.
......
......@@ -1053,6 +1053,7 @@ else:
"MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
"MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
"MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
"MODEL_FOR_TEXT_ENCODING_MAPPING",
"MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING",
"MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING",
......@@ -1087,6 +1088,7 @@ else:
"AutoModelForSequenceClassification",
"AutoModelForSpeechSeq2Seq",
"AutoModelForTableQuestionAnswering",
"AutoModelForTextEncoding",
"AutoModelForTokenClassification",
"AutoModelForUniversalSegmentation",
"AutoModelForVideoClassification",
......@@ -2984,6 +2986,7 @@ else:
"TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
"TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
"TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
"TF_MODEL_FOR_TEXT_ENCODING_MAPPING",
"TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"TF_MODEL_FOR_VISION_2_SEQ_MAPPING",
"TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING",
......@@ -3003,6 +3006,7 @@ else:
"TFAutoModelForSequenceClassification",
"TFAutoModelForSpeechSeq2Seq",
"TFAutoModelForTableQuestionAnswering",
"TFAutoModelForTextEncoding",
"TFAutoModelForTokenClassification",
"TFAutoModelForVision2Seq",
"TFAutoModelForZeroShotImageClassification",
......@@ -4807,6 +4811,7 @@ if TYPE_CHECKING:
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_TEXT_ENCODING_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING,
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING,
......@@ -4841,6 +4846,7 @@ if TYPE_CHECKING:
AutoModelForSequenceClassification,
AutoModelForSpeechSeq2Seq,
AutoModelForTableQuestionAnswering,
AutoModelForTextEncoding,
AutoModelForTokenClassification,
AutoModelForUniversalSegmentation,
AutoModelForVideoClassification,
......@@ -6374,6 +6380,7 @@ if TYPE_CHECKING:
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
TF_MODEL_FOR_TEXT_ENCODING_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_VISION_2_SEQ_MAPPING,
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING,
......@@ -6393,6 +6400,7 @@ if TYPE_CHECKING:
TFAutoModelForSequenceClassification,
TFAutoModelForSpeechSeq2Seq,
TFAutoModelForTableQuestionAnswering,
TFAutoModelForTextEncoding,
TFAutoModelForTokenClassification,
TFAutoModelForVision2Seq,
TFAutoModelForZeroShotImageClassification,
......
......@@ -64,6 +64,7 @@ else:
"MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
"MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
"MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
"MODEL_FOR_TEXT_ENCODING_MAPPING",
"MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING",
"MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING",
......@@ -85,6 +86,7 @@ else:
"AutoModelForImageSegmentation",
"AutoModelForInstanceSegmentation",
"AutoModelForMaskGeneration",
"AutoModelForTextEncoding",
"AutoModelForMaskedImageModeling",
"AutoModelForMaskedLM",
"AutoModelForMultipleChoice",
......@@ -131,6 +133,7 @@ else:
"TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
"TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
"TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
"TF_MODEL_FOR_TEXT_ENCODING_MAPPING",
"TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"TF_MODEL_FOR_VISION_2_SEQ_MAPPING",
"TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING",
......@@ -150,6 +153,7 @@ else:
"TFAutoModelForSequenceClassification",
"TFAutoModelForSpeechSeq2Seq",
"TFAutoModelForTableQuestionAnswering",
"TFAutoModelForTextEncoding",
"TFAutoModelForTokenClassification",
"TFAutoModelForVision2Seq",
"TFAutoModelForZeroShotImageClassification",
......@@ -233,6 +237,7 @@ if TYPE_CHECKING:
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_TEXT_ENCODING_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING,
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING,
......@@ -267,6 +272,7 @@ if TYPE_CHECKING:
AutoModelForSequenceClassification,
AutoModelForSpeechSeq2Seq,
AutoModelForTableQuestionAnswering,
AutoModelForTextEncoding,
AutoModelForTokenClassification,
AutoModelForUniversalSegmentation,
AutoModelForVideoClassification,
......@@ -300,6 +306,7 @@ if TYPE_CHECKING:
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
TF_MODEL_FOR_TEXT_ENCODING_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_VISION_2_SEQ_MAPPING,
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING,
......@@ -319,6 +326,7 @@ if TYPE_CHECKING:
TFAutoModelForSequenceClassification,
TFAutoModelForSpeechSeq2Seq,
TFAutoModelForTableQuestionAnswering,
TFAutoModelForTextEncoding,
TFAutoModelForTokenClassification,
TFAutoModelForVision2Seq,
TFAutoModelForZeroShotImageClassification,
......
......@@ -1011,6 +1011,36 @@ MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict(
]
)
MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict(
[
("albert", "AlbertModel"),
("bert", "BertModel"),
("big_bird", "BigBirdModel"),
("data2vec-text", "Data2VecTextModel"),
("deberta", "DebertaModel"),
("deberta-v2", "DebertaV2Model"),
("distilbert", "DistilBertModel"),
("electra", "ElectraModel"),
("flaubert", "FlaubertModel"),
("ibert", "IBertModel"),
("longformer", "LongformerModel"),
("mobilebert", "MobileBertModel"),
("mt5", "MT5EncoderModel"),
("nystromformer", "NystromformerModel"),
("reformer", "ReformerModel"),
("rembert", "RemBertModel"),
("roberta", "RobertaModel"),
("roberta-prelayernorm", "RobertaPreLayerNormModel"),
("roc_bert", "RoCBertModel"),
("roformer", "RoFormerModel"),
("squeezebert", "SqueezeBertModel"),
("t5", "T5EncoderModel"),
("xlm", "XLMModel"),
("xlm-roberta", "XLMRobertaModel"),
("xlm-roberta-xl", "XLMRobertaXLModel"),
]
)
MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES)
MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES)
......@@ -1088,11 +1118,17 @@ MODEL_FOR_BACKBONE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_BA
MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASK_GENERATION_MAPPING_NAMES)
MODEL_FOR_TEXT_ENCODING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES)
class AutoModelForMaskGeneration(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_MASK_GENERATION_MAPPING
class AutoModelForTextEncoding(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_TEXT_ENCODING_MAPPING
class AutoModel(_BaseAutoModelClass):
_model_mapping = MODEL_MAPPING
......
......@@ -437,6 +437,28 @@ TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict(
("sam", "TFSamModel"),
]
)
TF_MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict(
[
("albert", "TFAlbertModel"),
("bert", "TFBertModel"),
("convbert", "TFConvBertModel"),
("deberta", "TFDebertaModel"),
("deberta-v2", "TFDebertaV2Model"),
("distilbert", "TFDistilBertModel"),
("electra", "TFElectraModel"),
("flaubert", "TFFlaubertModel"),
("longformer", "TFLongformerModel"),
("mobilebert", "TFMobileBertModel"),
("mt5", "TFMT5EncoderModel"),
("rembert", "TFRemBertModel"),
("roberta", "TFRobertaModel"),
("roberta-prelayernorm", "TFRobertaPreLayerNormModel"),
("roformer", "TFRoFormerModel"),
("t5", "TFT5EncoderModel"),
("xlm", "TFXLMModel"),
("xlm-roberta", "TFXLMRobertaModel"),
]
)
TF_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_MAPPING_NAMES)
TF_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
......@@ -491,11 +513,17 @@ TF_MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES
)
TF_MODEL_FOR_TEXT_ENCODING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES)
class TFAutoModelForMaskGeneration(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_MASK_GENERATION_MAPPING
class TFAutoModelForTextEncoding(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_TEXT_ENCODING_MAPPING
class TFAutoModel(_BaseAutoModelClass):
_model_mapping = TF_MODEL_MAPPING
......
......@@ -524,6 +524,9 @@ MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = None
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = None
MODEL_FOR_TEXT_ENCODING_MAPPING = None
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = None
......@@ -726,6 +729,13 @@ class AutoModelForTableQuestionAnswering(metaclass=DummyObject):
requires_backends(self, ["torch"])
class AutoModelForTextEncoding(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class AutoModelForTokenClassification(metaclass=DummyObject):
_backends = ["torch"]
......
......@@ -264,6 +264,9 @@ TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = None
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = None
TF_MODEL_FOR_TEXT_ENCODING_MAPPING = None
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = None
......@@ -377,6 +380,13 @@ class TFAutoModelForTableQuestionAnswering(metaclass=DummyObject):
requires_backends(self, ["tf"])
class TFAutoModelForTextEncoding(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFAutoModelForTokenClassification(metaclass=DummyObject):
_backends = ["tf"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment