Unverified Commit b67fd797 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Add TFVisionEncoderDecoderModel (#14148)



* Start the work on TFVisionEncoderDecoderModel

* Expose TFVisionEncoderDecoderModel

* fix import

* Add modeling_tf_vision_encoder_decoder to _ignore_modules in get_model_modules()

* reorder

* Apply the fix for checkpoint loading as in #14016

* remove attention_mask + fix VISION_DUMMY_INPUTS

* A minimal change to make TF generate() work for vision models as encoder in encoder-decoder setting

* fix wrong condition: shape_list(input_ids) == 2

* add tests

* use personal TFViTModel checkpoint (for now)

* Add equivalence tests + projection layer

* style

* make sure projection layer can run

* Add examples

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Clean comments (need to work on TODOs for PyTorch models)

* Remove TF -> PT in check_pt_tf_equivalence for TFVisionEncoderDecoderModel

* fixes

* Revert changes in PT code.

* Update tests/test_modeling_tf_vision_encoder_decoder.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Add test_inference_coco_en for TF test

* fix quality

* fix name

* build doc

* add main_input_name

* Fix ckpt name in test

* fix diff between master and this PR

* fix doc

* fix style and quality

* fix missing doc

* fix labels handling

* Delete auto.rst

* Add the changes done in #14016

* fix prefix

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* make style
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 37bc0b4e
......@@ -261,7 +261,7 @@ Flax), PyTorch, and/or TensorFlow.
| TrOCR | ❌ | ❌ | ✅ | ❌ | ❌ |
| UniSpeech | ❌ | ❌ | ✅ | ❌ | ❌ |
| UniSpeechSat | ❌ | ❌ | ✅ | ❌ | ❌ |
| Vision Encoder decoder | ❌ | ❌ | ✅ | | ✅ |
| Vision Encoder decoder | ❌ | ❌ | ✅ | | ✅ |
| VisionTextDualEncoder | ❌ | ❌ | ✅ | ❌ | ✅ |
| VisualBert | ❌ | ❌ | ✅ | ❌ | ❌ |
| ViT | ❌ | ❌ | ✅ | ✅ | ✅ |
......
......@@ -194,6 +194,10 @@ Likewise, if your `NewModel` is a subclass of [`PreTrainedModel`], make sure its
[[autodoc]] TFAutoModelForQuestionAnswering
## TFAutoModelForVision2Seq
[[autodoc]] TFAutoModelForVision2Seq
## FlaxAutoModel
[[autodoc]] FlaxAutoModel
......
......@@ -33,6 +33,12 @@ An example of how to use a [`VisionEncoderDecoderModel`] for inference can be se
- forward
- from_encoder_decoder_pretrained
## TFVisionEncoderDecoderModel
[[autodoc]] TFVisionEncoderDecoderModel
- call
- from_encoder_decoder_pretrained
## FlaxVisionEncoderDecoderModel
[[autodoc]] FlaxVisionEncoderDecoderModel
......
......@@ -1487,6 +1487,7 @@ if is_tf_available():
"TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
"TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
"TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"TF_MODEL_FOR_VISION_2_SEQ_MAPPING",
"TF_MODEL_MAPPING",
"TF_MODEL_WITH_LM_HEAD_MAPPING",
"TFAutoModel",
......@@ -1500,6 +1501,7 @@ if is_tf_available():
"TFAutoModelForSequenceClassification",
"TFAutoModelForTableQuestionAnswering",
"TFAutoModelForTokenClassification",
"TFAutoModelForVision2Seq",
"TFAutoModelWithLMHead",
]
)
......@@ -1838,6 +1840,7 @@ if is_tf_available():
"TFTransfoXLPreTrainedModel",
]
)
_import_structure["models.vision_encoder_decoder"].extend(["TFVisionEncoderDecoderModel"])
_import_structure["models.vit"].extend(
[
"TFViTForImageClassification",
......@@ -3354,6 +3357,7 @@ if TYPE_CHECKING:
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_VISION_2_SEQ_MAPPING,
TF_MODEL_MAPPING,
TF_MODEL_WITH_LM_HEAD_MAPPING,
TFAutoModel,
......@@ -3367,6 +3371,7 @@ if TYPE_CHECKING:
TFAutoModelForSequenceClassification,
TFAutoModelForTableQuestionAnswering,
TFAutoModelForTokenClassification,
TFAutoModelForVision2Seq,
TFAutoModelWithLMHead,
)
from .models.bart import TFBartForConditionalGeneration, TFBartModel, TFBartPretrainedModel
......@@ -3636,6 +3641,7 @@ if TYPE_CHECKING:
TFTransfoXLModel,
TFTransfoXLPreTrainedModel,
)
from .models.vision_encoder_decoder import TFVisionEncoderDecoderModel
from .models.vit import TFViTForImageClassification, TFViTModel, TFViTPreTrainedModel
from .models.wav2vec2 import (
TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
......
......@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from dataclasses import dataclass
from typing import Optional, Tuple, Union
......@@ -628,14 +629,18 @@ class TFGenerationMixin:
bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list)
), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated"
# This block corresponds to the following line in `generation_utils`:
# "input_ids = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs"))"
# with the following differences:
# 1. In PT, `generate()`'s `model_kwargs` can accept `encoder_outputs`, but not the case in TF.
# 2. There is no shape checking in PT.
# In both PT/TF, if `input_ids` is `None`, we try to create it as it is for a text model.
if input_ids is None:
assert isinstance(bos_token_id, int) and bos_token_id >= 0, (
"you should either supply a context to complete as `input_ids` input "
"or a `bos_token_id` (integer >= 0) as a first token to start the generation."
)
input_ids = tf.fill((batch_size, 1), bos_token_id)
else:
assert len(shape_list(input_ids)) == 2, "Input prompt should be of shape (batch_size, sequence length)."
# not allow to duplicate outputs when greedy decoding
if do_sample is False:
......@@ -691,21 +696,29 @@ class TFGenerationMixin:
# get encoder and store encoder outputs
encoder = self.get_encoder()
encoder_outputs = encoder(
input_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict_in_generate,
)
encoder_kwargs = {
"attention_mask": attention_mask,
"output_attentions": output_attentions,
"output_hidden_states": output_hidden_states,
"return_dict": return_dict_in_generate,
}
# vision models don't use `attention_mask`.
signature = dict(inspect.signature(encoder.call).parameters)
if "attention_mask" not in signature:
encoder_kwargs.pop("attention_mask")
encoder_outputs = encoder(input_ids, **encoder_kwargs)
if return_dict_in_generate:
if output_attentions:
model_kwargs["encoder_attentions"] = encoder_outputs.attentions
if output_hidden_states:
model_kwargs["encoder_hidden_states"] = encoder_outputs.hidden_states
# The condition `len(shape_list(input_ids)) == 2` is to make this block treats only text inputs.
# (vision inputs might occur when the model is an encoder-decoder model)
# Expand input ids if num_beams > 1 or num_return_sequences > 1
if num_return_sequences > 1 or num_beams > 1:
if len(shape_list(input_ids)) == 2 and (num_return_sequences > 1 or num_beams > 1):
input_ids_len = shape_list(input_ids)[-1]
input_ids = tf.broadcast_to(
tf.expand_dims(input_ids, 1), (batch_size, effective_batch_mult * num_beams, input_ids_len)
......
......@@ -87,6 +87,7 @@ if is_tf_available():
"TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
"TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
"TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"TF_MODEL_FOR_VISION_2_SEQ_MAPPING",
"TF_MODEL_MAPPING",
"TF_MODEL_WITH_LM_HEAD_MAPPING",
"TFAutoModel",
......@@ -100,6 +101,7 @@ if is_tf_available():
"TFAutoModelForSequenceClassification",
"TFAutoModelForTableQuestionAnswering",
"TFAutoModelForTokenClassification",
"TFAutoModelForVision2Seq",
"TFAutoModelWithLMHead",
]
......@@ -197,6 +199,7 @@ if TYPE_CHECKING:
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_VISION_2_SEQ_MAPPING,
TF_MODEL_MAPPING,
TF_MODEL_WITH_LM_HEAD_MAPPING,
TFAutoModel,
......@@ -210,6 +213,7 @@ if TYPE_CHECKING:
TFAutoModelForSequenceClassification,
TFAutoModelForTableQuestionAnswering,
TFAutoModelForTokenClassification,
TFAutoModelForVision2Seq,
TFAutoModelWithLMHead,
)
......
......@@ -156,6 +156,12 @@ TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
]
)
TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
[
("vision-encoder-decoder", "TFVisionEncoderDecoderModel"),
]
)
TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Masked LM mapping
......@@ -182,7 +188,6 @@ TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
]
)
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Seq2Seq Causal LM mapping
......@@ -327,6 +332,7 @@ TF_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
)
TF_MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
TF_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES)
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
......@@ -387,6 +393,13 @@ class TFAutoModelForImageClassification(_BaseAutoModelClass):
AutoModelForImageClassification = auto_class_update(TFAutoModelForImageClassification, head_doc="image classification")
class TFAutoModelForVision2Seq(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_VISION_2_SEQ_MAPPING
TFAutoModelForVision2Seq = auto_class_update(TFAutoModelForVision2Seq, head_doc="vision-to-text modeling")
class TFAutoModelForMaskedLM(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_MASKED_LM_MAPPING
......
......@@ -148,10 +148,10 @@ ENCODER_DECODER_INPUTS_DOCSTRING = r"""
@add_start_docstrings(ENCODER_DECODER_START_DOCSTRING)
class TFEncoderDecoderModel(TFPreTrainedModel):
r"""
[`TFEncoderDecoder`] is a generic model class that will be instantiated as a transformer architecture with one of
the base model classes of the library as encoder and another one as decoder when created with the
:meth*~transformers.TFAutoModel.from_pretrained* class method for the encoder and
:meth*~transformers.TFAutoModelForCausalLM.from_pretrained* class method for the decoder.
[`TFEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture with one
of the base model classes of the library as encoder and another one as decoder when created with the
[`~TFAutoModel.from_pretrained`] class method for the encoder and [`~TFAutoModelForCausalLM.from_pretrained`] class
method for the decoder.
"""
config_class = EncoderDecoderConfig
base_model_prefix = "encoder_decoder"
......@@ -233,13 +233,6 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
# Add `decoder_input_ids` because `self.decoder` requires it.
input_ids = tf.constant(DUMMY_INPUTS)
dummy = {"input_ids": input_ids, "decoder_input_ids": input_ids}
# Add `encoder_hidden_states` to make the cross-attention layers' weights initialized
if self.config.add_cross_attention:
batch_size, seq_len = input_ids.shape
shape = (batch_size, seq_len) + (self.config.hidden_size,)
h = tf.random.uniform(shape=shape)
dummy["encoder_hidden_states"] = h
return dummy
def get_encoder(self):
......
......@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING
from ...file_utils import _LazyModule, is_flax_available, is_torch_available
from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available
_import_structure = {
......@@ -28,6 +28,9 @@ _import_structure = {
if is_torch_available():
_import_structure["modeling_vision_encoder_decoder"] = ["VisionEncoderDecoderModel"]
if is_tf_available():
_import_structure["modeling_tf_vision_encoder_decoder"] = ["TFVisionEncoderDecoderModel"]
if is_flax_available():
_import_structure["modeling_flax_vision_encoder_decoder"] = ["FlaxVisionEncoderDecoderModel"]
......@@ -37,6 +40,9 @@ if TYPE_CHECKING:
if is_torch_available():
from .modeling_vision_encoder_decoder import VisionEncoderDecoderModel
if is_tf_available():
from .modeling_tf_vision_encoder_decoder import TFVisionEncoderDecoderModel
if is_flax_available():
from .modeling_flax_vision_encoder_decoder import FlaxVisionEncoderDecoderModel
......
# coding=utf-8
# Copyright 2022 HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Classes to support TF Vision-Encoder-Text-Decoder architectures"""
import tempfile
from typing import Optional
import tensorflow as tf
from ...configuration_utils import PretrainedConfig
from ...file_utils import (
DUMMY_INPUTS,
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from ...modeling_tf_outputs import TFBaseModelOutput, TFSeq2SeqLMOutput
from ...modeling_tf_utils import TFPreTrainedModel, get_initializer, input_processing, shape_list
from ...utils import logging
from ..auto.configuration_auto import AutoConfig
from ..auto.modeling_tf_auto import TFAutoModel, TFAutoModelForCausalLM
from .configuration_vision_encoder_decoder import VisionEncoderDecoderConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "VisionEncoderDecoderConfig"
VISION_ENCODER_DECODER_START_DOCSTRING = r"""
This class can be used to initialize an image-to-text-sequence model with any pretrained vision autoencoding model
as the encoder and any pretrained text autoregressive model as the decoder. The encoder is loaded via
[`~TFAutoModel.from_pretrained`] function and the decoder is loaded via [`~TFAutoModelForCausalLM.from_pretrained`]
function. Cross-attention layers are automatically added to the decoder and should be fine-tuned on a downstream
generative task, like image captioning.
The effectiveness of initializing sequence-to-sequence models with pretrained checkpoints for sequence generation
tasks was shown in [Leveraging Pre-trained Checkpoints for Sequence Generation
Tasks](https://arxiv.org/abs/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. Michael Matena, Yanqi
Zhou, Wei Li, Peter J. Liu.
Additionally, in [TrOCR: Transformer-based Optical Character Recognition with Pre-trained
Models](https://arxiv.org/abs/2109.10282) it is shown how leveraging large pretrained vision models for optical
character recognition (OCR) yields a significant performance improvement.
After such a Vision-Encoder-Text-Decoder model has been trained/fine-tuned, it can be saved/loaded just like any
other models (see the examples for more information).
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a [tf.keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
behavior.
Parameters:
config ([`VisionEncoderDecoderConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
"""
VISION_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` ``Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using the vision's model's feature extractor. For example, using
[`ViTFeatureExtractor`]. See [`ViTFeatureExtractor.__call__`] for details.
decoder_input_ids (`np.ndarray` or `tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Indices of decoder input sequence tokens in the vocabulary.
Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
Provide for sequence to sequence training to the decoder. Indices can be obtained using
[`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for
details.
decoder_attention_mask (`np.ndarray` or `tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
be used by default.
encoder_outputs (`tuple(tuple(tf.Tensor)`, *optional*):
This tuple must consist of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
`last_hidden_state` (`tf.Tensor` of shape `({0}, hidden_size)`) is a tensor of hidden-states at the output
of the last layer of the encoder. Used in the cross-attention of the decoder.
past_key_values (`tuple(tuple(tf.Tensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `({0})`.
decoder_inputs_embeds (`np.ndarray` or `tf.Tensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
representation. This is useful if you want more control over how to convert `decoder_input_ids` indices
into associated vectors than the model's internal embedding lookup matrix.
labels (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
Labels for computing the masked language modeling loss for the decoder. Indices should be in `[-100, 0,
..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
If set to `True`, the model will return a [`~file_utils.Seq2SeqLMOutput`] instead of a plain tuple.
training (`bool`, *optional*, defaults to `False`):
Whether or not to use the model in training mode (some modules like dropout modules have different
behaviors between training and evaluation).
kwargs: (*optional*) Remaining dictionary of keyword arguments. Keyword arguments come in two flavors:
- Without a prefix which will be input as `**encoder_kwargs` for the encoder forward function.
- With a *decoder_* prefix which will be input as `**decoder_kwargs` for the decoder forward function.
"""
@add_start_docstrings(VISION_ENCODER_DECODER_START_DOCSTRING)
class TFVisionEncoderDecoderModel(TFPreTrainedModel):
r"""
[`TFVisionEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture
with one of the base vision model classes of the library as encoder and another one of the base model classes as
decoder when created with the [`~TFAutoModel.from_pretrained`] class method for the encoder and
[`~TFAutoModelForCausalLM.from_pretrained`] class method for the decoder.
"""
config_class = VisionEncoderDecoderConfig
base_model_prefix = "vision_encoder_decoder"
load_weight_prefix = "tf_vision_encoder_decoder_model"
main_input_name = "pixel_values"
def __init__(
self,
config: Optional[PretrainedConfig] = None,
encoder: Optional[TFPreTrainedModel] = None,
decoder: Optional[TFPreTrainedModel] = None,
):
if config is None and (encoder is None or decoder is None):
raise ValueError("Either a configuration or an encoder and a decoder has to be provided.")
if config is None:
config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
else:
if not isinstance(config, self.config_class):
raise ValueError(f"config: {config} has to be of type {self.config_class}")
if config.decoder.cross_attention_hidden_size is not None:
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
raise ValueError(
"If `cross_attention_hidden_size` is specified in the decoder's configuration, "
"it has to be equal to the encoder's `hidden_size`. "
f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
)
# initialize with config
super().__init__(config)
if encoder is None:
encoder = TFAutoModel.from_config(config.encoder, name="encoder")
if decoder is None:
decoder = TFAutoModelForCausalLM.from_config(config.decoder, name="decoder")
self.encoder = encoder
self.decoder = decoder
if self.encoder.config.to_dict() != self.config.encoder.to_dict():
logger.warning(
f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config: {self.config.encoder}"
)
if self.decoder.config.to_dict() != self.config.decoder.to_dict():
logger.warning(
f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config: {self.config.decoder}"
)
# make sure that the individual model's config refers to the shared config
# so that the updates to the config will be synced
self.encoder.config = self.config.encoder
self.decoder.config = self.config.decoder
# encoder outputs might need to be projected to different dimension for decoder
if (
self.encoder.config.hidden_size != self.decoder.config.hidden_size
and self.decoder.config.cross_attention_hidden_size is None
):
self.enc_to_dec_proj = tf.keras.layers.Dense(
units=self.decoder.config.hidden_size,
kernel_initializer=get_initializer(config.encoder.initializer_range),
name="enc_to_dec_proj",
)
if self.encoder.get_output_embeddings() is not None:
raise ValueError(
f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head"
)
@property
def dummy_inputs(self):
"""
Dummy inputs to build the network.
Returns:
`Dict[str, tf.Tensor]`: The dummy inputs.
"""
decoder_input_ids = tf.constant(DUMMY_INPUTS)
batch_size, seq_len = decoder_input_ids.shape
VISION_DUMMY_INPUTS = tf.random.uniform(
shape=(
batch_size,
self.config.encoder.num_channels,
self.config.encoder.image_size,
self.config.encoder.image_size,
),
dtype=tf.float32,
)
pixel_values = tf.constant(VISION_DUMMY_INPUTS)
# Add `decoder_input_ids` because `self.decoder` requires it.
dummy = {"pixel_values": pixel_values, "decoder_input_ids": decoder_input_ids}
return dummy
def get_encoder(self):
return self.encoder
def get_decoder(self):
return self.decoder
def get_input_embeddings(self):
return self.encoder.get_input_embeddings()
def get_output_embeddings(self):
return self.decoder.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
return self.decoder.set_output_embeddings(new_embeddings)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r"""
Initializing `TFVisionEncoderDecoderModel` from a pytorch checkpoint is not supported currently.
If there are only pytorch checkpoints for a particular encoder-decoder model, a workaround is:
```python
>>> # a workaround to load from pytorch checkpoint
>>> _model = VisionEncoderDecoderModel.from_pretrained("ydshieh/vit-gpt2-coco-en")
>>> _model.encoder.save_pretrained("./encoder")
>>> _model.decoder.save_pretrained("./decoder")
>>> model = TFVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
... "./encoder", "./decoder", encoder_from_pt=True, decoder_from_pt=True
... )
>>> # This is only for copying some specific attributes of this particular model.
>>> model.config = _model.config
```
Example:
```python
>>> from transformers import TFVisionEncoderDecoderModel, ViTFeatureExtractor, GPT2Tokenizer
>>> from PIL import Image
>>> import requests
>>> feature_extractor = ViTFeatureExtractor.from_pretrained("ydshieh/vit-gpt2-coco-en")
>>> decoder_tokenizer = GPT2Tokenizer.from_pretrained("ydshieh/vit-gpt2-coco-en")
>>> model = TFVisionEncoderDecoderModel.from_pretrained("ydshieh/vit-gpt2-coco-en")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> img = Image.open(requests.get(url, stream=True).raw)
>>> pixel_values = feature_extractor(images=img, return_tensors="tf").pixel_values # Batch size 1
>>> output_ids = model.generate(
... pixel_values, max_length=16, num_beams=4, return_dict_in_generate=True
>>> ).sequences
>>> preds = decoder_tokenizer.batch_decode(output_ids, skip_special_tokens=True)
>>> preds = [pred.strip() for pred in preds]
>>> assert preds == ["a cat laying on top of a couch next to another cat"]
```"""
from_pt = kwargs.pop("from_pt", False)
if from_pt:
raise ValueError(
"Initializing `TFVisionEncoderDecoderModel` from a pytorch checkpoint is not supported currently. "
"Use a tensorflow checkpoint instead. If only the pytorch checkpoints are available, "
"create the encoder and decoder models separately, and use them to initialize `TFVisionEncoderDecoderModel`. "
"Check `TFVisionEncoderDecoderModel.from_encoder_decoder_pretrained()` for more details."
)
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
@classmethod
def from_encoder_decoder_pretrained(
cls,
encoder_pretrained_model_name_or_path: str = None,
decoder_pretrained_model_name_or_path: str = None,
*model_args,
**kwargs
) -> TFPreTrainedModel:
r"""
Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model
checkpoints.
Params:
encoder_pretrained_model_name_or_path (`str`, *optional*):
Information necessary to initiate the encoder. Can be either:
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. An
example is `google/vit-base-patch16-224-in21k`.
- A path to a *directory* containing model weights saved using
[`~TFPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
- A path or url to a *pytorch index checkpoint file* (e.g, `./pt_model/`). In this case,
`encoder_from_pt` should be set to `True`.
decoder_pretrained_model_name_or_path (`str`, *optional*, defaults to *None*):
Information necessary to initiate the decoder. Can be either:
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
user or organization name, like `dbmdz/bert-base-german-cased`.
- A path to a *directory* containing model weights saved using
[`~TFPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
- A path or url to a *pytorch checkpoint file* (e.g, `./pt_model/`). In this case,
`decoder_from_pt` should be set to `True`.
model_args (remaining positional arguments, *optional*):
All remaning positional arguments will be passed to the underlying model's `__init__` method.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
`output_attentions=True`).
- To update the encoder configuration, use the prefix *encoder_* for each configuration parameter.
- To update the decoder configuration, use the prefix *decoder_* for each configuration parameter.
- To update the parent model configuration, do not use a prefix for each configuration parameter.
Behaves differently depending on whether a `config` is provided or automatically loaded.
Example:
```python
>>> from transformers import TFVisionEncoderDecoderModel
>>> # initialize a vit-bert from a pretrained ViT and a pretrained BERT model. Note that the cross-attention layers will be randomly initialized
>>> model = TFVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
... "google/vit-base-patch16-224-in21k", "bert-base-uncased"
... )
>>> # saving model after fine-tuning
>>> model.save_pretrained("./vit-bert")
>>> # load fine-tuned model
>>> model = TFVisionEncoderDecoderModel.from_pretrained("./vit-bert")
```"""
kwargs_encoder = {
argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_")
}
kwargs_decoder = {
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
}
# remove encoder, decoder kwargs from kwargs
for key in kwargs_encoder.keys():
del kwargs["encoder_" + key]
for key in kwargs_decoder.keys():
del kwargs["decoder_" + key]
# Load and initialize the encoder and decoder
# The distinction between encoder and decoder at the model level is made
# by the value of the flag `is_decoder` that we need to set correctly.
encoder = kwargs_encoder.pop("model", None)
if encoder is None:
if encoder_pretrained_model_name_or_path is None:
raise ValueError(
"If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
"to be defined."
)
if "config" not in kwargs_encoder:
encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path)
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
logger.info(
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
"from a decoder model. Cross-attention and casual mask are disabled."
)
encoder_config.is_decoder = False
encoder_config.add_cross_attention = False
kwargs_encoder["config"] = encoder_config
kwargs_encoder["name"] = "encoder"
kwargs_encoder["load_weight_prefix"] = cls.load_weight_prefix
encoder = TFAutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder)
# This is necessary to make `from_pretrained` following `save_pretrained` work correctly
if kwargs_encoder.get("from_pt", None):
del kwargs_encoder["from_pt"]
with tempfile.TemporaryDirectory() as tmp_dirname:
encoder.save_pretrained(tmp_dirname)
del encoder
encoder = TFAutoModel.from_pretrained(tmp_dirname, *model_args, **kwargs_encoder)
decoder = kwargs_decoder.pop("model", None)
if decoder is None:
if decoder_pretrained_model_name_or_path is None:
raise ValueError(
"If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
"to be defined."
)
if "config" not in kwargs_decoder:
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
logger.info(
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for "
"cross attention layers."
)
decoder_config.is_decoder = True
decoder_config.add_cross_attention = True
kwargs_decoder["config"] = decoder_config
if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
logger.warning(
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
"make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
"passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
"`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
)
kwargs_decoder["name"] = "decoder"
kwargs_decoder["load_weight_prefix"] = cls.load_weight_prefix
decoder = TFAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
# This is necessary to make `from_pretrained` following `save_pretrained` work correctly
if kwargs_decoder.get("from_pt", None):
del kwargs_decoder["from_pt"]
with tempfile.TemporaryDirectory() as tmp_dirname:
decoder.save_pretrained(tmp_dirname)
del decoder
decoder = TFAutoModelForCausalLM.from_pretrained(tmp_dirname, **kwargs_decoder)
# Make sure these 2 `tf.keras.Model` have fixed names so `from_pretrained` could load model weights correctly.
if encoder.name != "encoder":
raise ValueError("encoder model must be created with the name `encoder`.")
if decoder.name != "decoder":
raise ValueError("decoder model must be created with the name `decoder`.")
# instantiate config with corresponding kwargs
config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
return cls(encoder=encoder, decoder=decoder, config=config)
@add_start_docstrings_to_model_forward(
VISION_ENCODER_DECODER_INPUTS_DOCSTRING.format("batch_size, sequence_length")
)
@replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
pixel_values=None,
decoder_input_ids=None,
decoder_attention_mask=None,
encoder_outputs=None,
past_key_values=None,
decoder_inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
r"""
Returns:
Examples:
```python
>>> from transformers import AutoFeatureExtractor, AutoTokenizer, TFVisionEncoderDecoderModel
>>> from PIL import Image
>>> import requests
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
>>> decoder_tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> # initialize a bert2gpt2 from a pretrained BERT and GPT2 models. Note that the cross-attention layers will be randomly initialized
>>> model = TFVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
... "google/vit-base-patch16-224-in21k", "gpt2"
... )
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> img = Image.open(requests.get(url, stream=True).raw)
>>> # forward
>>> pixel_values = feature_extractor(images=img, return_tensors="tf").pixel_values # Batch size 1
>>> decoder_input_ids = decoder_tokenizer("Linda Davis", return_tensors="tf").input_ids # Batch size 1
>>> outputs = model(pixel_values=pixel_values, decoder_input_ids=decoder_input_ids)
>>> # training
>>> outputs = model(pixel_values=pixel_values, decoder_input_ids=decoder_input_ids, labels=decoder_input_ids)
>>> loss, logits = outputs.loss, outputs.logits
>>> # save and load from pretrained
>>> model.save_pretrained("vit-gpt2")
>>> model = TFVisionEncoderDecoderModel.from_pretrained("vit-gpt2")
>>> # generation
>>> generated = model.generate(pixel_values, decoder_start_token_id=model.config.decoder.bos_token_id)
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
kwargs_decoder = {
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
}
# Let the user be responsible for the expected format.
if encoder_outputs is not None:
if return_dict and not isinstance(encoder_outputs, ModelOutput):
raise ValueError(
"If `return_dict=True` and `encoder_outputs` is provided, it should be an instance of "
f"`ModelOutput`. Got an instance {type(encoder_outputs)} for `encoder_outputs`."
)
if encoder_outputs is None:
encoder_processing_inputs = {
"func": self.encoder.call,
"config": self.encoder.config,
"input_ids": pixel_values,
"output_attentions": output_attentions,
"output_hidden_states": output_hidden_states,
"return_dict": return_dict,
"training": training,
"kwargs_call": kwargs_encoder,
}
# Add arguments to encoder from `kwargs_encoder`
encoder_processing_inputs.update(kwargs_encoder)
kwargs_encoder = {}
encoder_inputs = input_processing(**encoder_processing_inputs)
if "input_ids" in encoder_inputs:
encoder_inputs["pixel_values"] = encoder_inputs.pop("input_ids")
if encoder_inputs["pixel_values"] is None:
raise ValueError("You have to specify pixel_values")
# Handle the case where the inputs are passed as a single dict which contains `labels`.
# The `labels` shouldn't be passed to `self.encoder` below, because it is a based model without this
# parameter (otherwise, an error occurs when `input_processing` is called inside `self.encoder.call()`).
if "labels" in encoder_inputs:
labels = encoder_inputs.pop("labels")
# handle the init case where `dummy_inputs` returns a dict containing `decoder_input_ids`.
if "decoder_input_ids" in encoder_inputs:
decoder_input_ids = encoder_inputs.pop("decoder_input_ids")
# handle the init case where `dummy_inputs` returns a dict containing `decoder_input_ids`.
if "decoder_attention_mask" in encoder_inputs:
decoder_attention_mask = encoder_inputs.pop("decoder_attention_mask")
encoder_outputs = self.encoder(**encoder_inputs)
encoder_hidden_states = encoder_outputs[0]
# optionally project encoder_hidden_states
if (
self.encoder.config.hidden_size != self.decoder.config.hidden_size
and self.decoder.config.cross_attention_hidden_size is None
):
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
batch_size, sequence_length = shape_list(encoder_hidden_states)[:2]
encoder_attention_mask = tf.ones(shape=(batch_size, sequence_length), dtype=tf.int32)
decoder_processing_inputs = {
"func": self.decoder.call,
"config": self.decoder.config,
"input_ids": decoder_input_ids,
"attention_mask": decoder_attention_mask,
"encoder_hidden_states": encoder_hidden_states,
"encoder_attention_mask": encoder_attention_mask,
"inputs_embeds": decoder_inputs_embeds,
"labels": labels,
"output_attentions": output_attentions,
"output_hidden_states": output_hidden_states,
"use_cache": use_cache,
"past_key_values": past_key_values,
"return_dict": return_dict,
"training": training,
"kwargs_call": kwargs_decoder,
}
# Add arguments to decoder from `kwargs_decoder`
decoder_processing_inputs.update(kwargs_decoder)
kwargs_decoder = {}
decoder_inputs = input_processing(**decoder_processing_inputs)
decoder_outputs = self.decoder(**decoder_inputs)
loss = None if decoder_inputs["labels"] is None else decoder_outputs[0]
logits = decoder_outputs[0] if decoder_inputs["labels"] is None else decoder_outputs[1]
past_key_values = None
if decoder_inputs["use_cache"]:
past_key_values = decoder_outputs[1] if decoder_inputs["labels"] is None else decoder_outputs[2]
# The starting index of the remaining elements in `decoder_outputs`
start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)])
past = (encoder_outputs[0], past_key_values) if past_key_values else None
if not decoder_inputs["return_dict"]:
if not isinstance(encoder_outputs, tuple):
encoder_outputs = encoder_outputs.to_tuple()
output = (loss, logits, past) + decoder_outputs[start_index:] + encoder_outputs
output = tuple([x for x in output if x is not None])
return output
return TFSeq2SeqLMOutput(
loss=decoder_outputs.loss,
logits=decoder_outputs.logits,
past_key_values=past,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
def serving_output(self, output):
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
cross_attns = (
tf.convert_to_tensor(output.cross_attentions)
if self.config.output_attentions and output.cross_attentions is not None
else None
)
return TFSeq2SeqLMOutput(
logits=output.logits,
past_key_values=pkv,
decoder_hidden_states=dec_hs,
decoder_attentions=dec_attns,
encoder_last_hidden_state=output.encoder_last_hidden_state,
encoder_hidden_states=enc_hs,
encoder_attentions=enc_attns,
cross_attentions=cross_attns,
)
def prepare_inputs_for_generation(self, decoder_input_ids, past, use_cache=None, **kwargs):
if past is None or len(past) not in {1, 2}:
raise ValueError(f"past has to be an iterable of length 1,2 got {past}")
if len(past) == 1:
if not isinstance(past[0], tf.Tensor):
raise ValueError(f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}")
encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0])
past_key_values = None
else:
if len(past) != 2:
raise ValueError(
"`past` has to be of length 2 with the encoder_outputs at the first position and past_key_values at the second position."
)
encoder_outputs, past_key_values = past
if isinstance(encoder_outputs, tuple):
if not isinstance(encoder_outputs[0], tf.Tensor):
raise ValueError(
f"`encoder_outputs[0]` has to be of type `tf.Tensor`, but is {type(encoder_outputs[0])}"
)
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs[0])
elif isinstance(encoder_outputs, tf.Tensor):
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs)
if not past_key_values:
raise ValueError(
f"decoder cached states must be truthy. got {past_key_values} from the 2nd element of past"
)
decoder_input_ids = decoder_input_ids[:, -1:]
if not isinstance(encoder_outputs, TFBaseModelOutput):
raise ValueError(f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}.")
return {
"pixel_values": None, # encoder_outputs is defined. pixel_values not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
}
def resize_token_embeddings(self, *args, **kwargs):
raise NotImplementedError(
"Resizing the embedding layers via the TFVisionEncoderDecoderModel directly is not supported."
"Please use the respective methods of the wrapped objects (model.decoder.resize_token_embeddings(...))"
)
def _reorder_cache(self, past, beam_idx):
# apply decoder cache reordering here
if len(past) == 1:
return past
encoder_outputs, past_key_values = past
return (encoder_outputs, self.decoder._reorder_cache(past_key_values, beam_idx))
......@@ -245,6 +245,9 @@ TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = None
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = None
TF_MODEL_FOR_VISION_2_SEQ_MAPPING = None
TF_MODEL_MAPPING = None
......@@ -383,6 +386,18 @@ class TFAutoModelForTokenClassification:
requires_backends(self, ["tf"])
class TFAutoModelForVision2Seq:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
def call(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFAutoModelWithLMHead:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
......@@ -2678,6 +2693,18 @@ class TFTransfoXLPreTrainedModel:
requires_backends(self, ["tf"])
class TFVisionEncoderDecoderModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
def call(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFViTForImageClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
......
......@@ -490,7 +490,7 @@ class TFEncoderDecoderMixin:
def test_real_model_save_load_from_pretrained(self):
model_2 = self.get_pretrained_model()
input_ids = ids_tensor([13, 5], model_2.config.encoder.vocab_size)
decoder_input_ids = ids_tensor([13, 1], model_2.config.encoder.vocab_size)
decoder_input_ids = ids_tensor([13, 1], model_2.config.decoder.vocab_size)
attention_mask = ids_tensor([13, 5], vocab_size=2)
outputs = model_2(
......
# coding=utf-8
# Copyright 2022 HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the TensorFlow VisionEncoderDecoder model. """
import os
import tempfile
import unittest
import numpy as np
from transformers import is_tf_available, is_torch_available, is_vision_available
from transformers.testing_utils import (
is_pt_tf_cross_test,
require_tf,
require_torch,
require_vision,
slow,
torch_device,
)
from .test_modeling_tf_common import floats_tensor, ids_tensor
from .test_modeling_tf_gpt2 import TFGPT2ModelTester
from .test_modeling_tf_vit import TFViTModelTester
if is_tf_available():
import tensorflow as tf
from transformers import (
AutoConfig,
AutoFeatureExtractor,
AutoTokenizer,
TFAutoModel,
TFAutoModelForCausalLM,
TFGPT2LMHeadModel,
TFVisionEncoderDecoderModel,
TFViTModel,
VisionEncoderDecoderConfig,
)
from transformers.modeling_tf_outputs import TFBaseModelOutput
if is_torch_available():
import torch
from transformers import GPT2LMHeadModel, VisionEncoderDecoderModel, ViTModel
if is_vision_available():
from PIL import Image
from transformers import ViTFeatureExtractor
@require_tf
class TFVisionEncoderDecoderMixin:
def get_encoder_decoder_model(self, config, decoder_config):
raise NotImplementedError
def prepare_config_and_inputs(self):
raise NotImplementedError
def get_pretrained_model(self):
raise NotImplementedError
def check_encoder_decoder_model_from_pretrained_configs(
self,
config,
pixel_values,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
**kwargs
):
encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
self.assertTrue(encoder_decoder_config.decoder.is_decoder)
enc_dec_model = TFVisionEncoderDecoderModel(encoder_decoder_config)
self.assertTrue(enc_dec_model.config.is_encoder_decoder)
outputs_encoder_decoder = enc_dec_model(
pixel_values=pixel_values,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
)
self.assertEqual(
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
)
self.assertEqual(outputs_encoder_decoder["encoder_last_hidden_state"].shape[0], pixel_values.shape[0])
self.assertEqual(outputs_encoder_decoder["encoder_last_hidden_state"].shape[-1], config.hidden_size)
def check_encoder_decoder_model(
self,
config,
pixel_values,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
**kwargs
):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = TFVisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
self.assertTrue(enc_dec_model.config.decoder.is_decoder)
self.assertTrue(enc_dec_model.config.decoder.add_cross_attention)
self.assertTrue(enc_dec_model.config.is_encoder_decoder)
outputs_encoder_decoder = enc_dec_model(
pixel_values=pixel_values,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
)
self.assertEqual(
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
)
self.assertEqual(outputs_encoder_decoder["encoder_last_hidden_state"].shape[0], pixel_values.shape[0])
self.assertEqual(outputs_encoder_decoder["encoder_last_hidden_state"].shape[-1], config.hidden_size)
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_hidden_states)
outputs_encoder_decoder = enc_dec_model(
pixel_values=None,
encoder_outputs=encoder_outputs,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
)
self.assertEqual(
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
)
self.assertEqual(outputs_encoder_decoder["encoder_last_hidden_state"].shape[0], pixel_values.shape[0])
self.assertEqual(outputs_encoder_decoder["encoder_last_hidden_state"].shape[-1], config.hidden_size)
def check_encoder_decoder_model_from_pretrained(
self,
config,
pixel_values,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
return_dict,
**kwargs
):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
kwargs = {"encoder_model": encoder_model, "decoder_model": decoder_model, "return_dict": return_dict}
enc_dec_model = TFVisionEncoderDecoderModel.from_encoder_decoder_pretrained(**kwargs)
outputs_encoder_decoder = enc_dec_model(
pixel_values=pixel_values,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
return_dict=True,
)
self.assertEqual(
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
)
self.assertEqual(outputs_encoder_decoder["encoder_last_hidden_state"].shape[0], pixel_values.shape[0])
self.assertEqual(outputs_encoder_decoder["encoder_last_hidden_state"].shape[-1], config.hidden_size)
def check_save_and_load(
self,
config,
pixel_values,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
**kwargs
):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = TFVisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
outputs = enc_dec_model(
pixel_values=pixel_values,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
)
out_2 = np.array(outputs[0])
out_2[np.isnan(out_2)] = 0
with tempfile.TemporaryDirectory() as tmpdirname:
enc_dec_model.save_pretrained(tmpdirname)
enc_dec_model = TFVisionEncoderDecoderModel.from_pretrained(tmpdirname)
after_outputs = enc_dec_model(
pixel_values=pixel_values,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
)
out_1 = np.array(after_outputs[0])
out_1[np.isnan(out_1)] = 0
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5)
def check_encoder_decoder_model_labels(
self,
config,
pixel_values,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
labels,
**kwargs
):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = TFVisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
outputs_encoder_decoder = enc_dec_model(
pixel_values=pixel_values,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
labels=labels,
)
# Make sure `loss` exist
self.assertIn("loss", outputs_encoder_decoder)
batch_size, seq_len = decoder_input_ids.shape
expected_shape = (batch_size, seq_len - 1, decoder_config.vocab_size)
self.assertEqual(outputs_encoder_decoder["logits"].shape, expected_shape)
self.assertEqual(outputs_encoder_decoder["encoder_last_hidden_state"].shape[0], pixel_values.shape[0])
self.assertEqual(outputs_encoder_decoder["encoder_last_hidden_state"].shape[-1], config.hidden_size)
def check_encoder_decoder_model_output_attentions(
self,
config,
pixel_values,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
**kwargs
):
# make the decoder inputs a different shape from the encoder inputs to harden the test
decoder_input_ids = decoder_input_ids[:, :-1]
decoder_attention_mask = decoder_attention_mask[:, :-1]
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = TFVisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
outputs_encoder_decoder = enc_dec_model(
pixel_values=pixel_values,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
output_attentions=True,
)
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
self.assertEqual(len(encoder_attentions), config.num_hidden_layers)
self.assertEqual(encoder_attentions[0].shape[-3:-2], (config.num_attention_heads,))
decoder_attentions = outputs_encoder_decoder["decoder_attentions"]
num_decoder_layers = (
decoder_config.num_decoder_layers
if hasattr(decoder_config, "num_decoder_layers")
else decoder_config.num_hidden_layers
)
self.assertEqual(len(decoder_attentions), num_decoder_layers)
self.assertEqual(
decoder_attentions[0].shape[-3:],
(decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]),
)
cross_attentions = outputs_encoder_decoder["cross_attentions"]
self.assertEqual(len(cross_attentions), num_decoder_layers)
cross_attention_input_seq_len = decoder_input_ids.shape[-1] * (
1 + (decoder_config.ngram if hasattr(decoder_config, "ngram") else 0)
)
self.assertEqual(
cross_attentions[0].shape[-3:-1],
(decoder_config.num_attention_heads, cross_attention_input_seq_len),
)
def check_encoder_decoder_model_generate(self, pixel_values, config, decoder_config, **kwargs):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = TFVisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
# Bert does not have a bos token id, so use pad_token_id instead
generated_output = enc_dec_model.generate(
pixel_values, decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id
)
self.assertEqual(
tuple(generated_output.shape.as_list()), (pixel_values.shape[0],) + (decoder_config.max_length,)
)
def check_pt_tf_equivalence(self, pt_model, tf_model, inputs_dict):
pt_model.to(torch_device)
pt_model.eval()
# prepare inputs
tf_inputs = inputs_dict
pt_inputs = {k: torch.tensor(v.numpy()) for k, v in tf_inputs.items()}
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
tf_outputs = tf_model(**inputs_dict).to_tuple()
self.assertEqual(len(tf_outputs), len(pt_outputs), "Output lengths differ between TF and PyTorch")
for tf_output, pt_output in zip(tf_outputs, pt_outputs):
self.assert_almost_equals(tf_output.numpy(), pt_output.numpy(), 1e-3)
# PT -> TF
with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname:
pt_model.encoder.save_pretrained(encoder_tmp_dirname)
pt_model.decoder.save_pretrained(decoder_tmp_dirname)
tf_model_loaded = TFVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
encoder_tmp_dirname, decoder_tmp_dirname, encoder_from_pt=True, decoder_from_pt=True
)
# This is only for copying some specific attributes of this particular model.
tf_model_loaded.config = pt_model.config
tf_outputs_loaded = tf_model_loaded(**inputs_dict).to_tuple()
self.assertEqual(len(tf_outputs_loaded), len(pt_outputs), "Output lengths differ between TF and PyTorch")
for tf_output_loaded, pt_output in zip(tf_outputs_loaded, pt_outputs):
self.assert_almost_equals(tf_output_loaded.numpy(), pt_output.numpy(), 1e-3)
def check_equivalence_pt_to_tf(self, config, decoder_config, inputs_dict):
encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
pt_model = VisionEncoderDecoderModel(encoder_decoder_config)
with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname:
pt_model.encoder.save_pretrained(encoder_tmp_dirname)
pt_model.decoder.save_pretrained(decoder_tmp_dirname)
tf_model = TFVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
encoder_tmp_dirname, decoder_tmp_dirname, encoder_from_pt=True, decoder_from_pt=True
)
# This is only for copying some specific attributes of this particular model.
tf_model.config = pt_model.config
self.check_pt_tf_equivalence(pt_model, tf_model, inputs_dict)
def check_equivalence_tf_to_pt(self, config, decoder_config, inputs_dict):
encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
# Using `_tf_model`, the test will fail, because the weights of `_tf_model` get extended before saving
# the encoder/decoder models.
# There was a (very) ugly potential fix, which wasn't integrated to `transformers`: see
# https://github.com/huggingface/transformers/pull/13222/commits/dbb3c9de76eee235791d2064094654637c99f36d#r697304245
# (the change in `src/transformers/modeling_tf_utils.py`)
_tf_model = TFVisionEncoderDecoderModel(encoder_decoder_config)
# Make sure model is built
_tf_model(**inputs_dict)
# Using `tf_model` to pass the test.
encoder = _tf_model.encoder.__class__(encoder_decoder_config.encoder)
decoder = _tf_model.decoder.__class__(encoder_decoder_config.decoder)
# Make sure models are built
encoder(encoder.dummy_inputs)
decoder(decoder.dummy_inputs)
tf_model = TFVisionEncoderDecoderModel(encoder=encoder, decoder=decoder)
with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname:
tf_model.encoder.save_pretrained(encoder_tmp_dirname)
tf_model.decoder.save_pretrained(decoder_tmp_dirname)
pt_model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
encoder_tmp_dirname, decoder_tmp_dirname, encoder_from_tf=True, decoder_from_tf=True
)
# This is only for copying some specific attributes of this particular model.
pt_model.config = tf_model.config
self.check_pt_tf_equivalence(pt_model, tf_model, inputs_dict)
def test_encoder_decoder_model(self):
config_inputs_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model(**config_inputs_dict)
def test_encoder_decoder_model_from_pretrained_configs(self):
config_inputs_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_from_pretrained_configs(**config_inputs_dict)
def test_encoder_decoder_model_from_pretrained(self):
config_inputs_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_from_pretrained(**config_inputs_dict, return_dict=False)
def test_encoder_decoder_model_from_pretrained_return_dict(self):
config_inputs_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_from_pretrained(**config_inputs_dict, return_dict=True)
def test_save_and_load_from_pretrained(self):
config_inputs_dict = self.prepare_config_and_inputs()
self.check_save_and_load(**config_inputs_dict)
def test_encoder_decoder_model_labels(self):
config_inputs_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_labels(**config_inputs_dict)
def test_encoder_decoder_model_output_attentions(self):
config_inputs_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_output_attentions(**config_inputs_dict)
def test_encoder_decoder_model_generate(self):
config_inputs_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_generate(**config_inputs_dict)
def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
diff = np.abs((a - b)).max()
self.assertLessEqual(diff, tol, f"Difference between torch and tf is {diff} (>= {tol}).")
@is_pt_tf_cross_test
def test_pt_tf_equivalence(self):
config_inputs_dict = self.prepare_config_and_inputs()
# Keep only common arguments
arg_names = [
"config",
"pixel_values",
"decoder_config",
"decoder_input_ids",
"decoder_attention_mask",
"encoder_hidden_states",
]
config_inputs_dict = {k: v for k, v in config_inputs_dict.items() if k in arg_names}
config = config_inputs_dict.pop("config")
decoder_config = config_inputs_dict.pop("decoder_config")
inputs_dict = config_inputs_dict
# `encoder_hidden_states` is not used in model call/forward
del inputs_dict["encoder_hidden_states"]
# Avoid the case where a sequence has no place to attend (after combined with the causal attention mask)
batch_size = inputs_dict["decoder_attention_mask"].shape[0]
inputs_dict["decoder_attention_mask"] = tf.constant(
np.concatenate([np.ones(shape=(batch_size, 1)), inputs_dict["decoder_attention_mask"][:, 1:]], axis=1)
)
# TF models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
decoder_config.use_cache = False
self.assertTrue(decoder_config.cross_attention_hidden_size is None)
# check without `enc_to_dec_proj` projection
self.assertTrue(config.hidden_size == decoder_config.hidden_size)
self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict)
self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict)
# This is not working, because pt/tf equivalence test for encoder-decoder use `from_encoder_decoder_pretrained`,
# which randomly initialize `enc_to_dec_proj`.
# # check `enc_to_dec_proj` work as expected
# decoder_config.hidden_size = decoder_config.hidden_size * 2
# self.assertTrue(config.hidden_size != decoder_config.hidden_size)
# self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict)
# self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict)
# Let's just check `enc_to_dec_proj` can run for now
decoder_config.hidden_size = decoder_config.hidden_size * 2
self.assertTrue(config.hidden_size != decoder_config.hidden_size)
encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
model = TFVisionEncoderDecoderModel(encoder_decoder_config)
model(**inputs_dict)
@slow
def test_real_model_save_load_from_pretrained(self):
model_2 = self.get_pretrained_model()
pixel_values = floats_tensor(
[
13,
model_2.config.encoder.num_channels,
model_2.config.encoder.image_size,
model_2.config.encoder.image_size,
]
)
decoder_input_ids = ids_tensor([13, 1], model_2.config.decoder.vocab_size)
outputs = model_2(
pixel_values=pixel_values,
decoder_input_ids=decoder_input_ids,
)
out_2 = np.array(outputs[0])
out_2[np.isnan(out_2)] = 0
with tempfile.TemporaryDirectory() as tmp_dirname:
model_2.save_pretrained(tmp_dirname)
model_1 = TFVisionEncoderDecoderModel.from_pretrained(tmp_dirname)
after_outputs = model_1(pixel_values=pixel_values, decoder_input_ids=decoder_input_ids)
out_1 = np.array(after_outputs[0])
out_1[np.isnan(out_1)] = 0
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5)
@require_tf
class TFViT2GPT2EncoderDecoderModelTest(TFVisionEncoderDecoderMixin, unittest.TestCase):
def get_pretrained_model(self):
return TFVisionEncoderDecoderModel.from_encoder_decoder_pretrained("google/vit-base-patch16-224-in21k", "gpt2")
def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = TFViTModel(config, name="encoder")
decoder_model = TFGPT2LMHeadModel(decoder_config, name="decoder")
return encoder_model, decoder_model
def prepare_config_and_inputs(self):
model_tester_encoder = TFViTModelTester(self, batch_size=13)
model_tester_decoder = TFGPT2ModelTester(self)
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs_for_decoder()
(config, pixel_values, labels) = encoder_config_and_inputs
(
decoder_config,
decoder_input_ids,
decoder_attention_mask,
decoder_head_mask,
decoder_token_type_ids,
decoder_sequence_labels,
decoder_token_labels,
decoder_choice_labels,
encoder_hidden_states,
encoder_attention_mask,
) = decoder_config_and_inputs
# make sure that cross attention layers are added
decoder_config.add_cross_attention = True
# disable cache for now
decoder_config.use_cache = False
return {
"config": config,
"pixel_values": pixel_values,
"decoder_config": decoder_config,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
"encoder_hidden_states": encoder_hidden_states, # This is not used in the tests.
"labels": decoder_token_labels,
}
@require_tf
class TFVisionEncoderDecoderModelTest(unittest.TestCase):
def get_from_encoderdecoder_pretrained_model(self):
return TFVisionEncoderDecoderModel.from_encoder_decoder_pretrained("google/vit-base-patch16-224-in21k", "gpt2")
def get_decoder_config(self):
config = AutoConfig.from_pretrained("gpt2")
config.is_decoder = True
config.add_cross_attention = True
return config
def get_encoderdecoder_model(self):
return TFVisionEncoderDecoderModel.from_pretrained("ydshieh/vit-gpt2-coco-en")
def get_encoder_decoder_models(self):
encoder_model = TFViTModel.from_pretrained("google/vit-base-patch16-224-in21k", name="encoder")
decoder_model = TFGPT2LMHeadModel.from_pretrained("gpt2", config=self.get_decoder_config(), name="decoder")
return {"encoder": encoder_model, "decoder": decoder_model}
def _check_configuration_tie(self, model):
assert id(model.decoder.config) == id(model.config.decoder)
assert id(model.encoder.config) == id(model.config.encoder)
@slow
def test_configuration_tie(self):
model = self.get_from_encoderdecoder_pretrained_model()
self._check_configuration_tie(model)
model = TFVisionEncoderDecoderModel(**self.get_encoder_decoder_models())
self._check_configuration_tie(model)
model = self.get_encoderdecoder_model()
self._check_configuration_tie(model)
# We will verify our results on an image of cute cats
def prepare_img():
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
return image
@require_tf
class TFVisionEncoderDecoderModelSaveLoadTests(unittest.TestCase):
def get_encoder_decoder_config(self):
encoder_config = AutoConfig.from_pretrained("google/vit-base-patch16-224-in21k")
decoder_config = AutoConfig.from_pretrained("gpt2", is_decoder=True, add_cross_attention=True)
return VisionEncoderDecoderConfig.from_encoder_decoder_configs(encoder_config, decoder_config)
def get_encoder_decoder_config_small(self):
encoder_config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-vit")
decoder_config = AutoConfig.from_pretrained(
"hf-internal-testing/tiny-random-gpt2", is_decoder=True, add_cross_attention=True
)
return VisionEncoderDecoderConfig.from_encoder_decoder_configs(encoder_config, decoder_config)
def test_encoder_decoder_save_load_from_encoder_decoder(self):
config = self.get_encoder_decoder_config_small()
# create two random ViT/GPT2 models for vit-gpt2 & initialize weights (+cross_attention weights)
encoder = TFViTModel(config.encoder)
encoder(encoder.dummy_inputs)
decoder = TFGPT2LMHeadModel(config.decoder)
decoder(decoder.dummy_inputs)
encoder_decoder_orig = TFVisionEncoderDecoderModel(encoder=encoder, decoder=decoder)
pixel_values = floats_tensor(
[
13,
encoder.config.num_channels,
encoder.config.image_size,
encoder.config.image_size,
]
)
decoder_input_ids = ids_tensor([13, 1], decoder.config.vocab_size)
logits_orig = encoder_decoder_orig(pixel_values=pixel_values, decoder_input_ids=decoder_input_ids).logits
with tempfile.TemporaryDirectory() as tmp_dirname:
encoder_path = os.path.join(tmp_dirname, "encoder")
decoder_path = os.path.join(tmp_dirname, "decoder")
encoder.save_pretrained(encoder_path)
decoder.save_pretrained(decoder_path)
encoder_decoder = TFVisionEncoderDecoderModel.from_encoder_decoder_pretrained(encoder_path, decoder_path)
logits_1 = encoder_decoder(pixel_values=pixel_values, decoder_input_ids=decoder_input_ids).logits
self.assertTrue(logits_orig.numpy().sum() - logits_1.numpy().sum() < 1e-3)
max_diff = np.max(np.abs(logits_1.numpy() - logits_orig.numpy()))
self.assertAlmostEqual(max_diff, 0.0, places=4)
with tempfile.TemporaryDirectory() as tmp_dirname:
encoder_decoder.save_pretrained(tmp_dirname)
encoder_decoder = TFVisionEncoderDecoderModel.from_pretrained(tmp_dirname)
logits_2 = encoder_decoder(pixel_values=pixel_values, decoder_input_ids=decoder_input_ids).logits
max_diff = np.max(np.abs(logits_2.numpy() - logits_orig.numpy()))
self.assertAlmostEqual(max_diff, 0.0, places=4)
@require_torch
@is_pt_tf_cross_test
def test_encoder_decoder_save_load_from_encoder_decoder_from_pt(self):
config = self.get_encoder_decoder_config_small()
# create two random ViT/GPT2 models for vit-gpt2 & initialize weights (+cross_attention weights)
encoder_pt = ViTModel(config.encoder).to(torch_device).eval()
decoder_pt = GPT2LMHeadModel(config.decoder).to(torch_device).eval()
encoder_decoder_pt = VisionEncoderDecoderModel(encoder=encoder_pt, decoder=decoder_pt).to(torch_device).eval()
pixel_values = floats_tensor(
[
13,
encoder_pt.config.num_channels,
encoder_pt.config.image_size,
encoder_pt.config.image_size,
]
)
decoder_input_ids = ids_tensor([13, 1], decoder_pt.config.vocab_size)
pt_pixel_values = torch.tensor(pixel_values.numpy(), device=torch_device, dtype=torch.float)
pt_decoder_input_ids = torch.tensor(decoder_input_ids.numpy(), device=torch_device, dtype=torch.long)
logits_pt = encoder_decoder_pt(pixel_values=pt_pixel_values, decoder_input_ids=pt_decoder_input_ids).logits
# PyTorch => TensorFlow
with tempfile.TemporaryDirectory() as tmp_dirname_1, tempfile.TemporaryDirectory() as tmp_dirname_2:
encoder_decoder_pt.encoder.save_pretrained(tmp_dirname_1)
encoder_decoder_pt.decoder.save_pretrained(tmp_dirname_2)
encoder_decoder_tf = TFVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
tmp_dirname_1, tmp_dirname_2, encoder_from_pt=True, decoder_from_pt=True
)
logits_tf = encoder_decoder_tf(pixel_values=pixel_values, decoder_input_ids=decoder_input_ids).logits
max_diff = np.max(np.abs(logits_pt.detach().cpu().numpy() - logits_tf.numpy()))
self.assertAlmostEqual(max_diff, 0.0, places=3)
# Make sure `from_pretrained` following `save_pretrained` work and give the same result
# (See https://github.com/huggingface/transformers/pull/14016)
with tempfile.TemporaryDirectory() as tmp_dirname:
encoder_decoder_tf.save_pretrained(tmp_dirname)
encoder_decoder_tf = TFVisionEncoderDecoderModel.from_pretrained(tmp_dirname)
logits_tf_2 = encoder_decoder_tf(pixel_values=pixel_values, decoder_input_ids=decoder_input_ids).logits
max_diff = np.max(np.abs(logits_tf_2.numpy() - logits_tf.numpy()))
self.assertAlmostEqual(max_diff, 0.0, places=3)
@require_vision
@slow
def test_encoder_decoder_from_pretrained(self):
load_weight_prefix = TFVisionEncoderDecoderModel.load_weight_prefix
config = self.get_encoder_decoder_config()
feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
decoder_tokenizer = AutoTokenizer.from_pretrained("gpt2")
img = prepare_img()
pixel_values = feature_extractor(images=img, return_tensors="tf").pixel_values
decoder_input_ids = decoder_tokenizer("Linda Davis", return_tensors="tf").input_ids
with tempfile.TemporaryDirectory() as tmp_dirname:
# Since most of HF's models don't have pretrained cross-attention layers, they are randomly
# initialized even if we create models using `from_pretrained` method.
# For the tests, the decoder need to be a model with pretrained cross-attention layers.
# So we create pretrained models (without `load_weight_prefix`), save them, and later,
# we load them using `from_pretrained`.
# (we don't need to do this for encoder, but let's make the code more similar between encoder/decoder)
encoder = TFAutoModel.from_pretrained("google/vit-base-patch16-224-in21k", name="encoder")
# It's necessary to specify `add_cross_attention=True` here.
decoder = TFAutoModelForCausalLM.from_pretrained(
"gpt2", is_decoder=True, add_cross_attention=True, name="decoder"
)
pretrained_encoder_dir = os.path.join(tmp_dirname, "pretrained_encoder")
pretrained_decoder_dir = os.path.join(tmp_dirname, "pretrained_decoder")
encoder.save_pretrained(pretrained_encoder_dir)
decoder.save_pretrained(pretrained_decoder_dir)
del encoder
del decoder
enc_dec_model = TFVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
pretrained_encoder_dir,
pretrained_decoder_dir,
)
# check that the from pretrained methods work
enc_dec_model.save_pretrained(tmp_dirname)
enc_dec_model = TFVisionEncoderDecoderModel.from_pretrained(tmp_dirname)
output = enc_dec_model(pixel_values, decoder_input_ids=decoder_input_ids, labels=decoder_input_ids)
loss_pretrained = output.loss
del enc_dec_model
# Create the model using `__init__` with loaded ``pretrained`` encoder / decoder
encoder = TFAutoModel.from_pretrained(
pretrained_encoder_dir, load_weight_prefix=load_weight_prefix, name="encoder"
)
decoder = TFAutoModelForCausalLM.from_pretrained(
pretrained_decoder_dir, load_weight_prefix=load_weight_prefix, name="decoder"
)
enc_dec_model = TFVisionEncoderDecoderModel(config=config, encoder=encoder, decoder=decoder)
output = enc_dec_model(pixel_values, decoder_input_ids=decoder_input_ids, labels=decoder_input_ids)
loss_init = output.loss
max_diff = np.max(np.abs(loss_pretrained - loss_init))
expected_diff = 0.0
self.assertAlmostEqual(max_diff, expected_diff, places=4)
@require_vision
@require_tf
class TFViT2GPT2ModelIntegrationTest(unittest.TestCase):
@slow
def test_inference_coco_en(self):
loc = "ydshieh/vit-gpt2-coco-en"
feature_extractor = ViTFeatureExtractor.from_pretrained(loc)
tokenizer = AutoTokenizer.from_pretrained(loc)
model = TFVisionEncoderDecoderModel.from_pretrained(loc)
# We will verify our results on an image of cute cats
img = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
pixel_values = feature_extractor(images=img, return_tensors="tf").pixel_values
decoder_input_ids = tf.constant([[model.config.decoder_start_token_id]])
logits = model(pixel_values, decoder_input_ids)[0].numpy()
# verify the logits
expected_shape = (1, 1, model.config.decoder.vocab_size)
self.assertEqual(logits.shape, expected_shape)
EXPECTED_LOGIT_SLICE = np.array(
[
-38.705807,
-30.639929,
-31.41903,
-39.012012,
-38.38696,
-34.887207,
-33.290855,
-35.68447,
-38.508484,
-36.124645,
]
)
max_diff = np.amax(np.abs(logits[0, 0, :10] - EXPECTED_LOGIT_SLICE))
self.assertLessEqual(max_diff, 1e-4)
def generate_step(pixel_values):
outputs = model.generate(
pixel_values, max_length=16, num_beams=4, return_dict_in_generate=True, output_scores=True
)
output_ids = outputs.sequences
preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
preds = [pred.strip() for pred in preds]
return preds, outputs.scores.numpy()
preds, scores = generate_step(pixel_values)
# should produce
# ["a cat laying on top of a couch next to another cat"]
self.assertEqual(preds, ["a cat laying on top of a couch next to another cat"])
......@@ -203,6 +203,7 @@ def get_model_modules():
"modeling_tf_pytorch_utils",
"modeling_tf_utils",
"modeling_tf_transfo_xl_utilities",
"modeling_tf_vision_encoder_decoder",
"modeling_vision_encoder_decoder",
]
modules = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment