Unverified Commit b67fd797 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Add TFVisionEncoderDecoderModel (#14148)



* Start the work on TFVisionEncoderDecoderModel

* Expose TFVisionEncoderDecoderModel

* fix import

* Add modeling_tf_vision_encoder_decoder to _ignore_modules in get_model_modules()

* reorder

* Apply the fix for checkpoint loading as in #14016

* remove attention_mask + fix VISION_DUMMY_INPUTS

* A minimal change to make TF generate() work for vision models as encoder in encoder-decoder setting

* fix wrong condition: shape_list(input_ids) == 2

* add tests

* use personal TFViTModel checkpoint (for now)

* Add equivalence tests + projection layer

* style

* make sure projection layer can run

* Add examples

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Clean comments (need to work on TODOs for PyTorch models)

* Remove TF -> PT in check_pt_tf_equivalence for TFVisionEncoderDecoderModel

* fixes

* Revert changes in PT code.

* Update tests/test_modeling_tf_vision_encoder_decoder.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Add test_inference_coco_en for TF test

* fix quality

* fix name

* build doc

* add main_input_name

* Fix ckpt name in test

* fix diff between master and this PR

* fix doc

* fix style and quality

* fix missing doc

* fix labels handling

* Delete auto.rst

* Add the changes done in #14016

* fix prefix

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* make style
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 37bc0b4e
......@@ -261,7 +261,7 @@ Flax), PyTorch, and/or TensorFlow.
| TrOCR | ❌ | ❌ | ✅ | ❌ | ❌ |
| UniSpeech | ❌ | ❌ | ✅ | ❌ | ❌ |
| UniSpeechSat | ❌ | ❌ | ✅ | ❌ | ❌ |
| Vision Encoder decoder | ❌ | ❌ | ✅ | | ✅ |
| Vision Encoder decoder | ❌ | ❌ | ✅ | | ✅ |
| VisionTextDualEncoder | ❌ | ❌ | ✅ | ❌ | ✅ |
| VisualBert | ❌ | ❌ | ✅ | ❌ | ❌ |
| ViT | ❌ | ❌ | ✅ | ✅ | ✅ |
......
......@@ -194,6 +194,10 @@ Likewise, if your `NewModel` is a subclass of [`PreTrainedModel`], make sure its
[[autodoc]] TFAutoModelForQuestionAnswering
## TFAutoModelForVision2Seq
[[autodoc]] TFAutoModelForVision2Seq
## FlaxAutoModel
[[autodoc]] FlaxAutoModel
......
......@@ -33,6 +33,12 @@ An example of how to use a [`VisionEncoderDecoderModel`] for inference can be se
- forward
- from_encoder_decoder_pretrained
## TFVisionEncoderDecoderModel
[[autodoc]] TFVisionEncoderDecoderModel
- call
- from_encoder_decoder_pretrained
## FlaxVisionEncoderDecoderModel
[[autodoc]] FlaxVisionEncoderDecoderModel
......
......@@ -1487,6 +1487,7 @@ if is_tf_available():
"TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
"TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
"TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"TF_MODEL_FOR_VISION_2_SEQ_MAPPING",
"TF_MODEL_MAPPING",
"TF_MODEL_WITH_LM_HEAD_MAPPING",
"TFAutoModel",
......@@ -1500,6 +1501,7 @@ if is_tf_available():
"TFAutoModelForSequenceClassification",
"TFAutoModelForTableQuestionAnswering",
"TFAutoModelForTokenClassification",
"TFAutoModelForVision2Seq",
"TFAutoModelWithLMHead",
]
)
......@@ -1838,6 +1840,7 @@ if is_tf_available():
"TFTransfoXLPreTrainedModel",
]
)
_import_structure["models.vision_encoder_decoder"].extend(["TFVisionEncoderDecoderModel"])
_import_structure["models.vit"].extend(
[
"TFViTForImageClassification",
......@@ -3354,6 +3357,7 @@ if TYPE_CHECKING:
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_VISION_2_SEQ_MAPPING,
TF_MODEL_MAPPING,
TF_MODEL_WITH_LM_HEAD_MAPPING,
TFAutoModel,
......@@ -3367,6 +3371,7 @@ if TYPE_CHECKING:
TFAutoModelForSequenceClassification,
TFAutoModelForTableQuestionAnswering,
TFAutoModelForTokenClassification,
TFAutoModelForVision2Seq,
TFAutoModelWithLMHead,
)
from .models.bart import TFBartForConditionalGeneration, TFBartModel, TFBartPretrainedModel
......@@ -3636,6 +3641,7 @@ if TYPE_CHECKING:
TFTransfoXLModel,
TFTransfoXLPreTrainedModel,
)
from .models.vision_encoder_decoder import TFVisionEncoderDecoderModel
from .models.vit import TFViTForImageClassification, TFViTModel, TFViTPreTrainedModel
from .models.wav2vec2 import (
TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
......
......@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from dataclasses import dataclass
from typing import Optional, Tuple, Union
......@@ -628,14 +629,18 @@ class TFGenerationMixin:
bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list)
), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated"
# This block corresponds to the following line in `generation_utils`:
# "input_ids = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs"))"
# with the following differences:
# 1. In PT, `generate()`'s `model_kwargs` can accept `encoder_outputs`, but not the case in TF.
# 2. There is no shape checking in PT.
# In both PT/TF, if `input_ids` is `None`, we try to create it as it is for a text model.
if input_ids is None:
assert isinstance(bos_token_id, int) and bos_token_id >= 0, (
"you should either supply a context to complete as `input_ids` input "
"or a `bos_token_id` (integer >= 0) as a first token to start the generation."
)
input_ids = tf.fill((batch_size, 1), bos_token_id)
else:
assert len(shape_list(input_ids)) == 2, "Input prompt should be of shape (batch_size, sequence length)."
# not allow to duplicate outputs when greedy decoding
if do_sample is False:
......@@ -691,21 +696,29 @@ class TFGenerationMixin:
# get encoder and store encoder outputs
encoder = self.get_encoder()
encoder_outputs = encoder(
input_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict_in_generate,
)
encoder_kwargs = {
"attention_mask": attention_mask,
"output_attentions": output_attentions,
"output_hidden_states": output_hidden_states,
"return_dict": return_dict_in_generate,
}
# vision models don't use `attention_mask`.
signature = dict(inspect.signature(encoder.call).parameters)
if "attention_mask" not in signature:
encoder_kwargs.pop("attention_mask")
encoder_outputs = encoder(input_ids, **encoder_kwargs)
if return_dict_in_generate:
if output_attentions:
model_kwargs["encoder_attentions"] = encoder_outputs.attentions
if output_hidden_states:
model_kwargs["encoder_hidden_states"] = encoder_outputs.hidden_states
# The condition `len(shape_list(input_ids)) == 2` is to make this block treats only text inputs.
# (vision inputs might occur when the model is an encoder-decoder model)
# Expand input ids if num_beams > 1 or num_return_sequences > 1
if num_return_sequences > 1 or num_beams > 1:
if len(shape_list(input_ids)) == 2 and (num_return_sequences > 1 or num_beams > 1):
input_ids_len = shape_list(input_ids)[-1]
input_ids = tf.broadcast_to(
tf.expand_dims(input_ids, 1), (batch_size, effective_batch_mult * num_beams, input_ids_len)
......
......@@ -87,6 +87,7 @@ if is_tf_available():
"TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
"TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
"TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"TF_MODEL_FOR_VISION_2_SEQ_MAPPING",
"TF_MODEL_MAPPING",
"TF_MODEL_WITH_LM_HEAD_MAPPING",
"TFAutoModel",
......@@ -100,6 +101,7 @@ if is_tf_available():
"TFAutoModelForSequenceClassification",
"TFAutoModelForTableQuestionAnswering",
"TFAutoModelForTokenClassification",
"TFAutoModelForVision2Seq",
"TFAutoModelWithLMHead",
]
......@@ -197,6 +199,7 @@ if TYPE_CHECKING:
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_VISION_2_SEQ_MAPPING,
TF_MODEL_MAPPING,
TF_MODEL_WITH_LM_HEAD_MAPPING,
TFAutoModel,
......@@ -210,6 +213,7 @@ if TYPE_CHECKING:
TFAutoModelForSequenceClassification,
TFAutoModelForTableQuestionAnswering,
TFAutoModelForTokenClassification,
TFAutoModelForVision2Seq,
TFAutoModelWithLMHead,
)
......
......@@ -156,6 +156,12 @@ TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
]
)
TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
[
("vision-encoder-decoder", "TFVisionEncoderDecoderModel"),
]
)
TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Masked LM mapping
......@@ -182,7 +188,6 @@ TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
]
)
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Seq2Seq Causal LM mapping
......@@ -327,6 +332,7 @@ TF_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
)
TF_MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
TF_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES)
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
......@@ -387,6 +393,13 @@ class TFAutoModelForImageClassification(_BaseAutoModelClass):
AutoModelForImageClassification = auto_class_update(TFAutoModelForImageClassification, head_doc="image classification")
class TFAutoModelForVision2Seq(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_VISION_2_SEQ_MAPPING
TFAutoModelForVision2Seq = auto_class_update(TFAutoModelForVision2Seq, head_doc="vision-to-text modeling")
class TFAutoModelForMaskedLM(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_MASKED_LM_MAPPING
......
......@@ -148,10 +148,10 @@ ENCODER_DECODER_INPUTS_DOCSTRING = r"""
@add_start_docstrings(ENCODER_DECODER_START_DOCSTRING)
class TFEncoderDecoderModel(TFPreTrainedModel):
r"""
[`TFEncoderDecoder`] is a generic model class that will be instantiated as a transformer architecture with one of
the base model classes of the library as encoder and another one as decoder when created with the
:meth*~transformers.TFAutoModel.from_pretrained* class method for the encoder and
:meth*~transformers.TFAutoModelForCausalLM.from_pretrained* class method for the decoder.
[`TFEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture with one
of the base model classes of the library as encoder and another one as decoder when created with the
[`~TFAutoModel.from_pretrained`] class method for the encoder and [`~TFAutoModelForCausalLM.from_pretrained`] class
method for the decoder.
"""
config_class = EncoderDecoderConfig
base_model_prefix = "encoder_decoder"
......@@ -233,13 +233,6 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
# Add `decoder_input_ids` because `self.decoder` requires it.
input_ids = tf.constant(DUMMY_INPUTS)
dummy = {"input_ids": input_ids, "decoder_input_ids": input_ids}
# Add `encoder_hidden_states` to make the cross-attention layers' weights initialized
if self.config.add_cross_attention:
batch_size, seq_len = input_ids.shape
shape = (batch_size, seq_len) + (self.config.hidden_size,)
h = tf.random.uniform(shape=shape)
dummy["encoder_hidden_states"] = h
return dummy
def get_encoder(self):
......
......@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING
from ...file_utils import _LazyModule, is_flax_available, is_torch_available
from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available
_import_structure = {
......@@ -28,6 +28,9 @@ _import_structure = {
if is_torch_available():
_import_structure["modeling_vision_encoder_decoder"] = ["VisionEncoderDecoderModel"]
if is_tf_available():
_import_structure["modeling_tf_vision_encoder_decoder"] = ["TFVisionEncoderDecoderModel"]
if is_flax_available():
_import_structure["modeling_flax_vision_encoder_decoder"] = ["FlaxVisionEncoderDecoderModel"]
......@@ -37,6 +40,9 @@ if TYPE_CHECKING:
if is_torch_available():
from .modeling_vision_encoder_decoder import VisionEncoderDecoderModel
if is_tf_available():
from .modeling_tf_vision_encoder_decoder import TFVisionEncoderDecoderModel
if is_flax_available():
from .modeling_flax_vision_encoder_decoder import FlaxVisionEncoderDecoderModel
......
......@@ -245,6 +245,9 @@ TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = None
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = None
TF_MODEL_FOR_VISION_2_SEQ_MAPPING = None
TF_MODEL_MAPPING = None
......@@ -383,6 +386,18 @@ class TFAutoModelForTokenClassification:
requires_backends(self, ["tf"])
class TFAutoModelForVision2Seq:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
def call(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFAutoModelWithLMHead:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
......@@ -2678,6 +2693,18 @@ class TFTransfoXLPreTrainedModel:
requires_backends(self, ["tf"])
class TFVisionEncoderDecoderModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
def call(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFViTForImageClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
......
......@@ -490,7 +490,7 @@ class TFEncoderDecoderMixin:
def test_real_model_save_load_from_pretrained(self):
model_2 = self.get_pretrained_model()
input_ids = ids_tensor([13, 5], model_2.config.encoder.vocab_size)
decoder_input_ids = ids_tensor([13, 1], model_2.config.encoder.vocab_size)
decoder_input_ids = ids_tensor([13, 1], model_2.config.decoder.vocab_size)
attention_mask = ids_tensor([13, 5], vocab_size=2)
outputs = model_2(
......@@ -650,7 +650,7 @@ class TFGPT2EncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
# make sure that cross attention layers are added
decoder_config.add_cross_attention = True
# disable cache for now
# disable cache for now
decoder_config.use_cache = False
return {
"config": config,
......
This diff is collapsed.
......@@ -203,6 +203,7 @@ def get_model_modules():
"modeling_tf_pytorch_utils",
"modeling_tf_utils",
"modeling_tf_transfo_xl_utilities",
"modeling_tf_vision_encoder_decoder",
"modeling_vision_encoder_decoder",
]
modules = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment