Unverified Commit b67fd797 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Add TFVisionEncoderDecoderModel (#14148)



* Start the work on TFVisionEncoderDecoderModel

* Expose TFVisionEncoderDecoderModel

* fix import

* Add modeling_tf_vision_encoder_decoder to _ignore_modules in get_model_modules()

* reorder

* Apply the fix for checkpoint loading as in #14016

* remove attention_mask + fix VISION_DUMMY_INPUTS

* A minimal change to make TF generate() work for vision models as encoder in encoder-decoder setting

* fix wrong condition: shape_list(input_ids) == 2

* add tests

* use personal TFViTModel checkpoint (for now)

* Add equivalence tests + projection layer

* style

* make sure projection layer can run

* Add examples

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Clean comments (need to work on TODOs for PyTorch models)

* Remove TF -> PT in check_pt_tf_equivalence for TFVisionEncoderDecoderModel

* fixes

* Revert changes in PT code.

* Update tests/test_modeling_tf_vision_encoder_decoder.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Add test_inference_coco_en for TF test

* fix quality

* fix name

* build doc

* add main_input_name

* Fix ckpt name in test

* fix diff between master and this PR

* fix doc

* fix style and quality

* fix missing doc

* fix labels handling

* Delete auto.rst

* Add the changes done in #14016

* fix prefix

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* make style
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 37bc0b4e
...@@ -261,7 +261,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -261,7 +261,7 @@ Flax), PyTorch, and/or TensorFlow.
| TrOCR | ❌ | ❌ | ✅ | ❌ | ❌ | | TrOCR | ❌ | ❌ | ✅ | ❌ | ❌ |
| UniSpeech | ❌ | ❌ | ✅ | ❌ | ❌ | | UniSpeech | ❌ | ❌ | ✅ | ❌ | ❌ |
| UniSpeechSat | ❌ | ❌ | ✅ | ❌ | ❌ | | UniSpeechSat | ❌ | ❌ | ✅ | ❌ | ❌ |
| Vision Encoder decoder | ❌ | ❌ | ✅ | | ✅ | | Vision Encoder decoder | ❌ | ❌ | ✅ | | ✅ |
| VisionTextDualEncoder | ❌ | ❌ | ✅ | ❌ | ✅ | | VisionTextDualEncoder | ❌ | ❌ | ✅ | ❌ | ✅ |
| VisualBert | ❌ | ❌ | ✅ | ❌ | ❌ | | VisualBert | ❌ | ❌ | ✅ | ❌ | ❌ |
| ViT | ❌ | ❌ | ✅ | ✅ | ✅ | | ViT | ❌ | ❌ | ✅ | ✅ | ✅ |
......
...@@ -194,6 +194,10 @@ Likewise, if your `NewModel` is a subclass of [`PreTrainedModel`], make sure its ...@@ -194,6 +194,10 @@ Likewise, if your `NewModel` is a subclass of [`PreTrainedModel`], make sure its
[[autodoc]] TFAutoModelForQuestionAnswering [[autodoc]] TFAutoModelForQuestionAnswering
## TFAutoModelForVision2Seq
[[autodoc]] TFAutoModelForVision2Seq
## FlaxAutoModel ## FlaxAutoModel
[[autodoc]] FlaxAutoModel [[autodoc]] FlaxAutoModel
......
...@@ -33,6 +33,12 @@ An example of how to use a [`VisionEncoderDecoderModel`] for inference can be se ...@@ -33,6 +33,12 @@ An example of how to use a [`VisionEncoderDecoderModel`] for inference can be se
- forward - forward
- from_encoder_decoder_pretrained - from_encoder_decoder_pretrained
## TFVisionEncoderDecoderModel
[[autodoc]] TFVisionEncoderDecoderModel
- call
- from_encoder_decoder_pretrained
## FlaxVisionEncoderDecoderModel ## FlaxVisionEncoderDecoderModel
[[autodoc]] FlaxVisionEncoderDecoderModel [[autodoc]] FlaxVisionEncoderDecoderModel
......
...@@ -1487,6 +1487,7 @@ if is_tf_available(): ...@@ -1487,6 +1487,7 @@ if is_tf_available():
"TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING", "TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
"TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING", "TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
"TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", "TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"TF_MODEL_FOR_VISION_2_SEQ_MAPPING",
"TF_MODEL_MAPPING", "TF_MODEL_MAPPING",
"TF_MODEL_WITH_LM_HEAD_MAPPING", "TF_MODEL_WITH_LM_HEAD_MAPPING",
"TFAutoModel", "TFAutoModel",
...@@ -1500,6 +1501,7 @@ if is_tf_available(): ...@@ -1500,6 +1501,7 @@ if is_tf_available():
"TFAutoModelForSequenceClassification", "TFAutoModelForSequenceClassification",
"TFAutoModelForTableQuestionAnswering", "TFAutoModelForTableQuestionAnswering",
"TFAutoModelForTokenClassification", "TFAutoModelForTokenClassification",
"TFAutoModelForVision2Seq",
"TFAutoModelWithLMHead", "TFAutoModelWithLMHead",
] ]
) )
...@@ -1838,6 +1840,7 @@ if is_tf_available(): ...@@ -1838,6 +1840,7 @@ if is_tf_available():
"TFTransfoXLPreTrainedModel", "TFTransfoXLPreTrainedModel",
] ]
) )
_import_structure["models.vision_encoder_decoder"].extend(["TFVisionEncoderDecoderModel"])
_import_structure["models.vit"].extend( _import_structure["models.vit"].extend(
[ [
"TFViTForImageClassification", "TFViTForImageClassification",
...@@ -3354,6 +3357,7 @@ if TYPE_CHECKING: ...@@ -3354,6 +3357,7 @@ if TYPE_CHECKING:
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING, TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_VISION_2_SEQ_MAPPING,
TF_MODEL_MAPPING, TF_MODEL_MAPPING,
TF_MODEL_WITH_LM_HEAD_MAPPING, TF_MODEL_WITH_LM_HEAD_MAPPING,
TFAutoModel, TFAutoModel,
...@@ -3367,6 +3371,7 @@ if TYPE_CHECKING: ...@@ -3367,6 +3371,7 @@ if TYPE_CHECKING:
TFAutoModelForSequenceClassification, TFAutoModelForSequenceClassification,
TFAutoModelForTableQuestionAnswering, TFAutoModelForTableQuestionAnswering,
TFAutoModelForTokenClassification, TFAutoModelForTokenClassification,
TFAutoModelForVision2Seq,
TFAutoModelWithLMHead, TFAutoModelWithLMHead,
) )
from .models.bart import TFBartForConditionalGeneration, TFBartModel, TFBartPretrainedModel from .models.bart import TFBartForConditionalGeneration, TFBartModel, TFBartPretrainedModel
...@@ -3636,6 +3641,7 @@ if TYPE_CHECKING: ...@@ -3636,6 +3641,7 @@ if TYPE_CHECKING:
TFTransfoXLModel, TFTransfoXLModel,
TFTransfoXLPreTrainedModel, TFTransfoXLPreTrainedModel,
) )
from .models.vision_encoder_decoder import TFVisionEncoderDecoderModel
from .models.vit import TFViTForImageClassification, TFViTModel, TFViTPreTrainedModel from .models.vit import TFViTForImageClassification, TFViTModel, TFViTPreTrainedModel
from .models.wav2vec2 import ( from .models.wav2vec2 import (
TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST, TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
...@@ -628,14 +629,18 @@ class TFGenerationMixin: ...@@ -628,14 +629,18 @@ class TFGenerationMixin:
bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list) bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list)
), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated" ), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated"
# This block corresponds to the following line in `generation_utils`:
# "input_ids = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs"))"
# with the following differences:
# 1. In PT, `generate()`'s `model_kwargs` can accept `encoder_outputs`, but not the case in TF.
# 2. There is no shape checking in PT.
# In both PT/TF, if `input_ids` is `None`, we try to create it as it is for a text model.
if input_ids is None: if input_ids is None:
assert isinstance(bos_token_id, int) and bos_token_id >= 0, ( assert isinstance(bos_token_id, int) and bos_token_id >= 0, (
"you should either supply a context to complete as `input_ids` input " "you should either supply a context to complete as `input_ids` input "
"or a `bos_token_id` (integer >= 0) as a first token to start the generation." "or a `bos_token_id` (integer >= 0) as a first token to start the generation."
) )
input_ids = tf.fill((batch_size, 1), bos_token_id) input_ids = tf.fill((batch_size, 1), bos_token_id)
else:
assert len(shape_list(input_ids)) == 2, "Input prompt should be of shape (batch_size, sequence length)."
# not allow to duplicate outputs when greedy decoding # not allow to duplicate outputs when greedy decoding
if do_sample is False: if do_sample is False:
...@@ -691,21 +696,29 @@ class TFGenerationMixin: ...@@ -691,21 +696,29 @@ class TFGenerationMixin:
# get encoder and store encoder outputs # get encoder and store encoder outputs
encoder = self.get_encoder() encoder = self.get_encoder()
encoder_outputs = encoder( encoder_kwargs = {
input_ids, "attention_mask": attention_mask,
attention_mask=attention_mask, "output_attentions": output_attentions,
output_attentions=output_attentions, "output_hidden_states": output_hidden_states,
output_hidden_states=output_hidden_states, "return_dict": return_dict_in_generate,
return_dict=return_dict_in_generate, }
)
# vision models don't use `attention_mask`.
signature = dict(inspect.signature(encoder.call).parameters)
if "attention_mask" not in signature:
encoder_kwargs.pop("attention_mask")
encoder_outputs = encoder(input_ids, **encoder_kwargs)
if return_dict_in_generate: if return_dict_in_generate:
if output_attentions: if output_attentions:
model_kwargs["encoder_attentions"] = encoder_outputs.attentions model_kwargs["encoder_attentions"] = encoder_outputs.attentions
if output_hidden_states: if output_hidden_states:
model_kwargs["encoder_hidden_states"] = encoder_outputs.hidden_states model_kwargs["encoder_hidden_states"] = encoder_outputs.hidden_states
# The condition `len(shape_list(input_ids)) == 2` is to make this block treats only text inputs.
# (vision inputs might occur when the model is an encoder-decoder model)
# Expand input ids if num_beams > 1 or num_return_sequences > 1 # Expand input ids if num_beams > 1 or num_return_sequences > 1
if num_return_sequences > 1 or num_beams > 1: if len(shape_list(input_ids)) == 2 and (num_return_sequences > 1 or num_beams > 1):
input_ids_len = shape_list(input_ids)[-1] input_ids_len = shape_list(input_ids)[-1]
input_ids = tf.broadcast_to( input_ids = tf.broadcast_to(
tf.expand_dims(input_ids, 1), (batch_size, effective_batch_mult * num_beams, input_ids_len) tf.expand_dims(input_ids, 1), (batch_size, effective_batch_mult * num_beams, input_ids_len)
......
...@@ -87,6 +87,7 @@ if is_tf_available(): ...@@ -87,6 +87,7 @@ if is_tf_available():
"TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING", "TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
"TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING", "TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
"TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", "TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"TF_MODEL_FOR_VISION_2_SEQ_MAPPING",
"TF_MODEL_MAPPING", "TF_MODEL_MAPPING",
"TF_MODEL_WITH_LM_HEAD_MAPPING", "TF_MODEL_WITH_LM_HEAD_MAPPING",
"TFAutoModel", "TFAutoModel",
...@@ -100,6 +101,7 @@ if is_tf_available(): ...@@ -100,6 +101,7 @@ if is_tf_available():
"TFAutoModelForSequenceClassification", "TFAutoModelForSequenceClassification",
"TFAutoModelForTableQuestionAnswering", "TFAutoModelForTableQuestionAnswering",
"TFAutoModelForTokenClassification", "TFAutoModelForTokenClassification",
"TFAutoModelForVision2Seq",
"TFAutoModelWithLMHead", "TFAutoModelWithLMHead",
] ]
...@@ -197,6 +199,7 @@ if TYPE_CHECKING: ...@@ -197,6 +199,7 @@ if TYPE_CHECKING:
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING, TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_VISION_2_SEQ_MAPPING,
TF_MODEL_MAPPING, TF_MODEL_MAPPING,
TF_MODEL_WITH_LM_HEAD_MAPPING, TF_MODEL_WITH_LM_HEAD_MAPPING,
TFAutoModel, TFAutoModel,
...@@ -210,6 +213,7 @@ if TYPE_CHECKING: ...@@ -210,6 +213,7 @@ if TYPE_CHECKING:
TFAutoModelForSequenceClassification, TFAutoModelForSequenceClassification,
TFAutoModelForTableQuestionAnswering, TFAutoModelForTableQuestionAnswering,
TFAutoModelForTokenClassification, TFAutoModelForTokenClassification,
TFAutoModelForVision2Seq,
TFAutoModelWithLMHead, TFAutoModelWithLMHead,
) )
......
...@@ -156,6 +156,12 @@ TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ...@@ -156,6 +156,12 @@ TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
] ]
) )
TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
[
("vision-encoder-decoder", "TFVisionEncoderDecoderModel"),
]
)
TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict( TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
[ [
# Model for Masked LM mapping # Model for Masked LM mapping
...@@ -182,7 +188,6 @@ TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict( ...@@ -182,7 +188,6 @@ TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
] ]
) )
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict( TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[ [
# Model for Seq2Seq Causal LM mapping # Model for Seq2Seq Causal LM mapping
...@@ -327,6 +332,7 @@ TF_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL ...@@ -327,6 +332,7 @@ TF_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping( TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES CONFIG_MAPPING_NAMES, TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
) )
TF_MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
TF_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES) TF_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES)
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping( TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
...@@ -387,6 +393,13 @@ class TFAutoModelForImageClassification(_BaseAutoModelClass): ...@@ -387,6 +393,13 @@ class TFAutoModelForImageClassification(_BaseAutoModelClass):
AutoModelForImageClassification = auto_class_update(TFAutoModelForImageClassification, head_doc="image classification") AutoModelForImageClassification = auto_class_update(TFAutoModelForImageClassification, head_doc="image classification")
class TFAutoModelForVision2Seq(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_VISION_2_SEQ_MAPPING
TFAutoModelForVision2Seq = auto_class_update(TFAutoModelForVision2Seq, head_doc="vision-to-text modeling")
class TFAutoModelForMaskedLM(_BaseAutoModelClass): class TFAutoModelForMaskedLM(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_MASKED_LM_MAPPING _model_mapping = TF_MODEL_FOR_MASKED_LM_MAPPING
......
...@@ -148,10 +148,10 @@ ENCODER_DECODER_INPUTS_DOCSTRING = r""" ...@@ -148,10 +148,10 @@ ENCODER_DECODER_INPUTS_DOCSTRING = r"""
@add_start_docstrings(ENCODER_DECODER_START_DOCSTRING) @add_start_docstrings(ENCODER_DECODER_START_DOCSTRING)
class TFEncoderDecoderModel(TFPreTrainedModel): class TFEncoderDecoderModel(TFPreTrainedModel):
r""" r"""
[`TFEncoderDecoder`] is a generic model class that will be instantiated as a transformer architecture with one of [`TFEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture with one
the base model classes of the library as encoder and another one as decoder when created with the of the base model classes of the library as encoder and another one as decoder when created with the
:meth*~transformers.TFAutoModel.from_pretrained* class method for the encoder and [`~TFAutoModel.from_pretrained`] class method for the encoder and [`~TFAutoModelForCausalLM.from_pretrained`] class
:meth*~transformers.TFAutoModelForCausalLM.from_pretrained* class method for the decoder. method for the decoder.
""" """
config_class = EncoderDecoderConfig config_class = EncoderDecoderConfig
base_model_prefix = "encoder_decoder" base_model_prefix = "encoder_decoder"
...@@ -233,13 +233,6 @@ class TFEncoderDecoderModel(TFPreTrainedModel): ...@@ -233,13 +233,6 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
# Add `decoder_input_ids` because `self.decoder` requires it. # Add `decoder_input_ids` because `self.decoder` requires it.
input_ids = tf.constant(DUMMY_INPUTS) input_ids = tf.constant(DUMMY_INPUTS)
dummy = {"input_ids": input_ids, "decoder_input_ids": input_ids} dummy = {"input_ids": input_ids, "decoder_input_ids": input_ids}
# Add `encoder_hidden_states` to make the cross-attention layers' weights initialized
if self.config.add_cross_attention:
batch_size, seq_len = input_ids.shape
shape = (batch_size, seq_len) + (self.config.hidden_size,)
h = tf.random.uniform(shape=shape)
dummy["encoder_hidden_states"] = h
return dummy return dummy
def get_encoder(self): def get_encoder(self):
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...file_utils import _LazyModule, is_flax_available, is_torch_available from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available
_import_structure = { _import_structure = {
...@@ -28,6 +28,9 @@ _import_structure = { ...@@ -28,6 +28,9 @@ _import_structure = {
if is_torch_available(): if is_torch_available():
_import_structure["modeling_vision_encoder_decoder"] = ["VisionEncoderDecoderModel"] _import_structure["modeling_vision_encoder_decoder"] = ["VisionEncoderDecoderModel"]
if is_tf_available():
_import_structure["modeling_tf_vision_encoder_decoder"] = ["TFVisionEncoderDecoderModel"]
if is_flax_available(): if is_flax_available():
_import_structure["modeling_flax_vision_encoder_decoder"] = ["FlaxVisionEncoderDecoderModel"] _import_structure["modeling_flax_vision_encoder_decoder"] = ["FlaxVisionEncoderDecoderModel"]
...@@ -37,6 +40,9 @@ if TYPE_CHECKING: ...@@ -37,6 +40,9 @@ if TYPE_CHECKING:
if is_torch_available(): if is_torch_available():
from .modeling_vision_encoder_decoder import VisionEncoderDecoderModel from .modeling_vision_encoder_decoder import VisionEncoderDecoderModel
if is_tf_available():
from .modeling_tf_vision_encoder_decoder import TFVisionEncoderDecoderModel
if is_flax_available(): if is_flax_available():
from .modeling_flax_vision_encoder_decoder import FlaxVisionEncoderDecoderModel from .modeling_flax_vision_encoder_decoder import FlaxVisionEncoderDecoderModel
......
...@@ -245,6 +245,9 @@ TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = None ...@@ -245,6 +245,9 @@ TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = None
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = None TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = None
TF_MODEL_FOR_VISION_2_SEQ_MAPPING = None
TF_MODEL_MAPPING = None TF_MODEL_MAPPING = None
...@@ -383,6 +386,18 @@ class TFAutoModelForTokenClassification: ...@@ -383,6 +386,18 @@ class TFAutoModelForTokenClassification:
requires_backends(self, ["tf"]) requires_backends(self, ["tf"])
class TFAutoModelForVision2Seq:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
def call(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFAutoModelWithLMHead: class TFAutoModelWithLMHead:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"]) requires_backends(self, ["tf"])
...@@ -2678,6 +2693,18 @@ class TFTransfoXLPreTrainedModel: ...@@ -2678,6 +2693,18 @@ class TFTransfoXLPreTrainedModel:
requires_backends(self, ["tf"]) requires_backends(self, ["tf"])
class TFVisionEncoderDecoderModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
def call(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFViTForImageClassification: class TFViTForImageClassification:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"]) requires_backends(self, ["tf"])
......
...@@ -490,7 +490,7 @@ class TFEncoderDecoderMixin: ...@@ -490,7 +490,7 @@ class TFEncoderDecoderMixin:
def test_real_model_save_load_from_pretrained(self): def test_real_model_save_load_from_pretrained(self):
model_2 = self.get_pretrained_model() model_2 = self.get_pretrained_model()
input_ids = ids_tensor([13, 5], model_2.config.encoder.vocab_size) input_ids = ids_tensor([13, 5], model_2.config.encoder.vocab_size)
decoder_input_ids = ids_tensor([13, 1], model_2.config.encoder.vocab_size) decoder_input_ids = ids_tensor([13, 1], model_2.config.decoder.vocab_size)
attention_mask = ids_tensor([13, 5], vocab_size=2) attention_mask = ids_tensor([13, 5], vocab_size=2)
outputs = model_2( outputs = model_2(
......
This diff is collapsed.
...@@ -203,6 +203,7 @@ def get_model_modules(): ...@@ -203,6 +203,7 @@ def get_model_modules():
"modeling_tf_pytorch_utils", "modeling_tf_pytorch_utils",
"modeling_tf_utils", "modeling_tf_utils",
"modeling_tf_transfo_xl_utilities", "modeling_tf_transfo_xl_utilities",
"modeling_tf_vision_encoder_decoder",
"modeling_vision_encoder_decoder", "modeling_vision_encoder_decoder",
] ]
modules = [] modules = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment