"vscode:/vscode.git/clone" did not exist on "af37d183b30f0e4430e479aab509ab0e7cb553c9"
Unverified Commit 95b3ec3b authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Add FlaxVisionEncoderDecoderModel (#13359)



* Start the work on FlaxVisionEncoderDecoderModel

* Add FlaxVisionEncoderDecoderModel

* Add VisionEncoderDecoderConfig

* Make FlaxVisionEncoderDecoderModel visible to transformers

* Add test

* Fix wrong getattr usage

* Fix tests

* Add FlaxAutoModelForVision2Seq

* Expose FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING

* clean-up

* add integration test

* update expected logits

* update expected scores

* Add ViT2GPT2ModelIntegrationTest + some cleaning

* Add projection layer + PT/Flax equivalence tests

* Fix import

* minor changes

* make test slow again

* Apply suggestions

* Add modeling_flax_vision_encoder_decoder to _ignore_modules in get_model_modules()

* fix copies

* Apply suggestions from code review
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* split long strings in multiple lines

* decoder_input_ids can't be None

* Add back test_configuration_tie

* Remove attention_mask parameter

* fix test - encoder_last_hidden_state should be encoder_outputs.last_hidden_state instead of the projected vector

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Remove more encoder_attention_mask

* remove encoder_attention_mask when calling self.decode (in FlaxVisionEncoderDecoderModule)

* Fix style + pass 1s instead of None as encoder_attention_mask

* fix init_weights

* pass None for encoder_attention_mask

* pass 1s instead of None as encoder_attention_mask

* Fix doc style
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent a5030122
......@@ -499,7 +499,7 @@ Flax), PyTorch, and/or TensorFlow.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| UniSpeechSat | ❌ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Vision Encoder decoder | ❌ | ❌ | ✅ | ❌ | |
| Vision Encoder decoder | ❌ | ❌ | ✅ | ❌ | |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| VisualBert | ❌ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
......
......@@ -160,6 +160,13 @@ AutoModelForImageClassification
:members:
AutoModelForVision2Seq
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.AutoModelForVision2Seq
:members:
AutoModelForAudioClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......@@ -333,3 +340,10 @@ FlaxAutoModelForImageClassification
.. autoclass:: transformers.FlaxAutoModelForImageClassification
:members:
FlaxAutoModelForVision2Seq
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxAutoModelForVision2Seq
:members:
......@@ -39,3 +39,10 @@ VisionEncoderDecoderModel
.. autoclass:: transformers.VisionEncoderDecoderModel
:members: forward, from_encoder_decoder_pretrained
FlaxVisionEncoderDecoderModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxVisionEncoderDecoderModel
:members: __call__, from_encoder_decoder_pretrained
......@@ -603,6 +603,7 @@ if is_torch_available():
"MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
"MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
"MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"MODEL_FOR_VISION_2_SEQ_MAPPING",
"MODEL_MAPPING",
"MODEL_WITH_LM_HEAD_MAPPING",
"AutoModel",
......@@ -622,6 +623,7 @@ if is_torch_available():
"AutoModelForSpeechSeq2Seq",
"AutoModelForTableQuestionAnswering",
"AutoModelForTokenClassification",
"AutoModelForVision2Seq",
"AutoModelWithLMHead",
]
)
......@@ -1825,6 +1827,7 @@ if is_flax_available():
"FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
"FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
"FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING",
"FLAX_MODEL_MAPPING",
"FlaxAutoModel",
"FlaxAutoModelForCausalLM",
......@@ -1837,6 +1840,7 @@ if is_flax_available():
"FlaxAutoModelForSeq2SeqLM",
"FlaxAutoModelForSequenceClassification",
"FlaxAutoModelForTokenClassification",
"FlaxAutoModelForVision2Seq",
]
)
......@@ -1957,6 +1961,7 @@ if is_flax_available():
]
)
_import_structure["models.t5"].extend(["FlaxT5ForConditionalGeneration", "FlaxT5Model", "FlaxT5PreTrainedModel"])
_import_structure["models.vision_encoder_decoder"].append("FlaxVisionEncoderDecoderModel")
_import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel", "FlaxViTPreTrainedModel"])
_import_structure["models.wav2vec2"].extend(
["FlaxWav2Vec2ForCTC", "FlaxWav2Vec2ForPreTraining", "FlaxWav2Vec2Model", "FlaxWav2Vec2PreTrainedModel"]
......@@ -2457,6 +2462,7 @@ if TYPE_CHECKING:
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
MODEL_FOR_VISION_2_SEQ_MAPPING,
MODEL_MAPPING,
MODEL_WITH_LM_HEAD_MAPPING,
AutoModel,
......@@ -2476,6 +2482,7 @@ if TYPE_CHECKING:
AutoModelForSpeechSeq2Seq,
AutoModelForTableQuestionAnswering,
AutoModelForTokenClassification,
AutoModelForVision2Seq,
AutoModelWithLMHead,
)
from .models.bart import (
......@@ -3482,6 +3489,7 @@ if TYPE_CHECKING:
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING,
FLAX_MODEL_MAPPING,
FlaxAutoModel,
FlaxAutoModelForCausalLM,
......@@ -3494,6 +3502,7 @@ if TYPE_CHECKING:
FlaxAutoModelForSeq2SeqLM,
FlaxAutoModelForSequenceClassification,
FlaxAutoModelForTokenClassification,
FlaxAutoModelForVision2Seq,
)
from .models.bart import (
FlaxBartForConditionalGeneration,
......@@ -3579,6 +3588,7 @@ if TYPE_CHECKING:
FlaxRobertaPreTrainedModel,
)
from .models.t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel
from .models.vision_encoder_decoder import FlaxVisionEncoderDecoderModel
from .models.vit import FlaxViTForImageClassification, FlaxViTModel, FlaxViTPreTrainedModel
from .models.wav2vec2 import (
FlaxWav2Vec2ForCTC,
......
......@@ -46,6 +46,7 @@ if is_torch_available():
"MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
"MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
"MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"MODEL_FOR_VISION_2_SEQ_MAPPING",
"MODEL_MAPPING",
"MODEL_WITH_LM_HEAD_MAPPING",
"AutoModel",
......@@ -65,6 +66,7 @@ if is_torch_available():
"AutoModelForSpeechSeq2Seq",
"AutoModelForTableQuestionAnswering",
"AutoModelForTokenClassification",
"AutoModelForVision2Seq",
"AutoModelWithLMHead",
]
......@@ -105,6 +107,7 @@ if is_flax_available():
"FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
"FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
"FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING",
"FLAX_MODEL_MAPPING",
"FlaxAutoModel",
"FlaxAutoModelForCausalLM",
......@@ -117,6 +120,7 @@ if is_flax_available():
"FlaxAutoModelForSeq2SeqLM",
"FlaxAutoModelForSequenceClassification",
"FlaxAutoModelForTokenClassification",
"FlaxAutoModelForVision2Seq",
]
......@@ -144,6 +148,7 @@ if TYPE_CHECKING:
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
MODEL_FOR_VISION_2_SEQ_MAPPING,
MODEL_MAPPING,
MODEL_WITH_LM_HEAD_MAPPING,
AutoModel,
......@@ -163,6 +168,7 @@ if TYPE_CHECKING:
AutoModelForSpeechSeq2Seq,
AutoModelForTableQuestionAnswering,
AutoModelForTokenClassification,
AutoModelForVision2Seq,
AutoModelWithLMHead,
)
......@@ -203,6 +209,7 @@ if TYPE_CHECKING:
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING,
FLAX_MODEL_MAPPING,
FlaxAutoModel,
FlaxAutoModelForCausalLM,
......@@ -215,6 +222,7 @@ if TYPE_CHECKING:
FlaxAutoModelForSeq2SeqLM,
FlaxAutoModelForSequenceClassification,
FlaxAutoModelForTokenClassification,
FlaxAutoModelForVision2Seq,
)
else:
......
......@@ -239,6 +239,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("electra", "ELECTRA"),
("encoder-decoder", "Encoder decoder"),
("speech-encoder-decoder", "Speech Encoder decoder"),
("vision-encoder-decoder", "Vision Encoder decoder"),
("funnel", "Funnel Transformer"),
("lxmert", "LXMERT"),
("deberta-v2", "DeBERTa-v2"),
......
......@@ -244,6 +244,12 @@ MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = OrderedDict(
]
)
MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
[
("vision-encoder-decoder", "VisionEncoderDecoderModel"),
]
)
MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Masked LM mapping
......@@ -511,6 +517,7 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES
)
MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES)
MODEL_FOR_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES)
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
......@@ -655,6 +662,13 @@ class AutoModelForObjectDetection(_BaseAutoModelClass):
AutoModelForObjectDetection = auto_class_update(AutoModelForObjectDetection, head_doc="object detection")
class AutoModelForVision2Seq(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_VISION_2_SEQ_MAPPING
AutoModelForVision2Seq = auto_class_update(AutoModelForVision2Seq, head_doc="vision-to-text modeling")
class AutoModelForAudioClassification(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
......
......@@ -100,6 +100,12 @@ FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
]
)
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
[
("vision-encoder-decoder", "FlaxVisionEncoderDecoderModel"),
]
)
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Causal LM mapping
......@@ -176,6 +182,7 @@ FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
)
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
......@@ -279,3 +286,10 @@ class FlaxAutoModelForImageClassification(_BaseAutoModelClass):
FlaxAutoModelForImageClassification = auto_class_update(
FlaxAutoModelForImageClassification, head_doc="image classification"
)
class FlaxAutoModelForVision2Seq(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING
FlaxAutoModelForVision2Seq = auto_class_update(FlaxAutoModelForVision2Seq, head_doc="vision-to-text modeling")
......@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING
from ...file_utils import _LazyModule, is_torch_available
from ...file_utils import _LazyModule, is_flax_available, is_torch_available
_import_structure = {
......@@ -28,12 +28,18 @@ _import_structure = {
if is_torch_available():
_import_structure["modeling_vision_encoder_decoder"] = ["VisionEncoderDecoderModel"]
if is_flax_available():
_import_structure["modeling_flax_vision_encoder_decoder"] = ["FlaxVisionEncoderDecoderModel"]
if TYPE_CHECKING:
from .configuration_vision_encoder_decoder import VisionEncoderDecoderConfig
if is_torch_available():
from .modeling_vision_encoder_decoder import VisionEncoderDecoderModel
if is_flax_available():
from .modeling_flax_vision_encoder_decoder import FlaxVisionEncoderDecoderModel
else:
import sys
......
......@@ -27,8 +27,8 @@ logger = logging.get_logger(__name__)
class VisionEncoderDecoderConfig(PretrainedConfig):
r"""
:class:`~transformers.VisionEncoderDecoderConfig` is the configuration class to store the configuration of a
:class:`~transformers.VisionEncoderDecoderModel`. It is used to instantiate an Encoder Decoder model according to
the specified arguments, defining the encoder and decoder configs.
:class:`~transformers.VisionEncoderDecoderModel`. It is used to instantiate a Vision-Encoder-Text-Decoder model
according to the specified arguments, defining the encoder and decoder configs.
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
......
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Classes to support Vision-Encoder-Text-Decoder architectures """
import os
from typing import Optional, Tuple, Union
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, unfreeze
from jax import lax
from jax.random import PRNGKey
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutputWithCrossAttentions, FlaxSeq2SeqLMOutput
from ...modeling_flax_utils import FlaxPreTrainedModel
from ...utils import logging
from .configuration_vision_encoder_decoder import VisionEncoderDecoderConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "VisionEncoderDecoderConfig"
VISION_ENCODER_DECODER_START_DOCSTRING = r"""
This class can be used to initialize an image-to-text-sequence model with any pretrained vision autoencoding model
as the encoder and any pretrained text autoregressive model as the decoder. The encoder is loaded via
:meth:`~transformers.AutoModel.from_pretrained` function and the decoder is loaded via
:meth:`~transformers.AutoModelForCausalLM.from_pretrained` function. Cross-attention layers are automatically added
to the decoder and should be fine-tuned on a downstream generative task, like image captioning.
The effectiveness of initializing sequence-to-sequence models with pretrained checkpoints for sequence generation
tasks was shown in `Leveraging Pre-trained Checkpoints for Sequence Generation Tasks
<https://arxiv.org/abs/1907.12461>`__ by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. Michael Matena, Yanqi
Zhou, Wei Li, Peter J. Liu.
Additionally, in `TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models
<https://arxiv.org/abs/2109.10282>`__ it is shown how leveraging large pretrained vision models for optical
character recognition (OCR) yields a significant performance improvement.
After such a Vision-Encoder-Text-Decoder model has been trained/fine-tuned, it can be saved/loaded just like any
other models (see the examples for more information).
This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the
generic methods the library implements for all its model (such as downloading or saving, resizing the input
embeddings, pruning heads etc.)
This model is also a Flax Linen `flax.nn.Module
<https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html>`__ subclass. Use it as a regular Flax
Module and refer to the Flax documentation for all matter related to general usage and behavior.
Parameters:
config (:class:`~transformers.VisionEncoderDecoderConfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights.
"""
VISION_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
Args:
pixel_values (:obj:`jnp.ndarray` of shape :obj:`(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using the vision model's feature extractor. For example, using
:class:`~transformers.ViTFeatureExtractor`. See :meth:`transformers.ViTFeatureExtractor.__call__` for
details.
decoder_input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Indices of decoder input sequence tokens in the vocabulary.
Indices can be obtained using :class:`~transformers.PreTrainedTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
details.
`What are decoder input IDs? <../glossary.html#decoder-input-ids>`__
decoder_attention_mask (:obj:`jnp.ndarray` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
also be used by default.
decoder_position_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
range ``[0, config.decoder.max_position_embeddings - 1]``.
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`):
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
more detail.
return_dict (:obj:`bool`, `optional`):
If set to ``True``, the model will return a :class:`~transformers.file_utils.FlaxSeq2SeqLMOutput` instead
of a plain tuple.
"""
VISION_ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING = r"""
Args:
pixel_values (:obj:`jnp.ndarray` of shape :obj:`(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using the vision model's feature extractor. For example, using
:class:`~transformers.ViTFeatureExtractor`. See :meth:`transformers.ViTFeatureExtractor.__call__` for
details.
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`):
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
more detail.
return_dict (:obj:`bool`, `optional`):
If set to ``True``, the model will return a :class:`~transformers.file_utils.FlaxBaseModelOutput` instead
of a plain tuple.
"""
VISION_ENCODER_DECODER_DECODE_INPUTS_DOCSTRING = r"""
Args:
decoder_input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Indices of decoder input sequence tokens in the vocabulary.
Indices can be obtained using :class:`~transformers.PreTrainedTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
details.
`What are decoder input IDs? <../glossary.html#decoder-input-ids>`__
If :obj:`past_key_values` is used, optionally only the last :obj:`decoder_input_ids` have to be input (see
:obj:`past_key_values`).
For sequence to sequence training, :obj:`decoder_input_ids` should be provided. If no
:obj:`decoder_input_ids` is provided, the model will create this tensor by shifting the :obj:`input_ids` to
the right for denoising pre-training.
encoder_outputs (:obj:`tuple(tuple(jnp.ndarray)`):
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
`optional`) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
cross-attention of the decoder.
decoder_attention_mask (:obj:`jnp.ndarray` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
also be used by default.
decoder_position_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
range ``[0, config.decoder.max_position_embeddings - 1]``.
past_key_values (:obj:`Dict[str, jnp.ndarray]`, `optional`, returned by ``init_cache`` or when passing previous ``past_key_values``):
Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
auto-regressive decoding. Pre-computed key and value hidden-states are of shape `[batch_size, max_length]`.
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`):
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
more detail.
return_dict (:obj:`bool`, `optional`):
If set to ``True``, the model will return a
:class:`~transformers.file_utils.FlaxCausalLMOutputWithCrossAttentions` instead of a plain tuple.
"""
class FlaxVisionEncoderDecoderModule(nn.Module):
config: VisionEncoderDecoderConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
encoder_config = self.config.encoder
decoder_config = self.config.decoder
# Copied from `modeling_hybrid_clip.py` with modifications.
from ...models.auto.modeling_flax_auto import FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, FLAX_MODEL_MAPPING
encoder_module = FLAX_MODEL_MAPPING[encoder_config.__class__].module_class
decoder_module = FLAX_MODEL_FOR_CAUSAL_LM_MAPPING[decoder_config.__class__].module_class
self.encoder = encoder_module(encoder_config, dtype=self.dtype)
self.decoder = decoder_module(decoder_config, dtype=self.dtype)
# encoder outputs might need to be projected to different dimension for decoder
if (
self.encoder.config.hidden_size != self.decoder.config.hidden_size
and self.decoder.config.cross_attention_hidden_size is None
):
self.enc_to_dec_proj = nn.Dense(
self.decoder.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.decoder.config.initializer_range, self.dtype),
dtype=self.dtype,
)
else:
self.enc_to_dec_proj = None
def _get_encoder_module(self):
return self.encoder
def _get_projection_module(self):
return self.enc_to_dec_proj
def _get_decoder_module(self):
return self.decoder
def __call__(
self,
pixel_values,
decoder_input_ids,
decoder_attention_mask,
decoder_position_ids,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
):
encoder_outputs = self.encoder(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
encoder_hidden_states = encoder_outputs[0]
# optionally project encoder_hidden_states
if self.enc_to_dec_proj is not None:
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
# The advantage of explicitly setting this is TPU XLA compiler knows as soon as possible what shape this
# variable has and can better optimize. Also passing `None` can lead to some problems when jitting the model.
# In Flax/JAX, we only want to pass `None` for non-tensor function inputs. For all tensor function inputs, we
# should always pass a tensor and not `None`.
batch_size, sequence_length = encoder_hidden_states.shape[:2]
encoder_attention_mask = jnp.ones((batch_size, sequence_length))
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
position_ids=decoder_position_ids,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
if not return_dict:
return decoder_outputs + encoder_outputs
return FlaxSeq2SeqLMOutput(
logits=decoder_outputs.logits,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
@add_start_docstrings(VISION_ENCODER_DECODER_START_DOCSTRING)
class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel):
r"""
:class:`~transformers.FlaxVisionEncoderDecoderModel` is a generic model class that will be instantiated as a
transformer architecture with the module (flax.nn.Module) of one of the base vision model classes of the library as
encoder module and another one as decoder module when created with the
:meth`~transformers.FlaxAutoModel.from_pretrained` class method for the encoder and
:meth`~transformers.FlaxAutoModelForCausalLM.from_pretrained` class method for the decoder.
"""
config_class = VisionEncoderDecoderConfig
base_model_prefix = "vision_encoder_decoder"
module_class = FlaxVisionEncoderDecoderModule
def __init__(
self,
config: VisionEncoderDecoderConfig,
input_shape: Optional[Tuple] = None,
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
**kwargs
):
if input_shape is None:
num_channels = getattr(config.encoder, "num_channels", 3)
input_shape = (
(1, config.encoder.image_size, config.encoder.image_size, num_channels),
(1, 1),
)
if config.decoder.cross_attention_hidden_size is not None:
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
raise ValueError(
f"If `cross_attention_hidden_size` is specified in the decoder's configuration, "
f"it has to be equal to the encoder's `hidden_size`."
f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
)
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
encoder_input_shape, decoder_input_shape = input_shape
# init input tensors
pixel_values = jnp.zeros(encoder_input_shape, dtype=self.dtype)
decoder_input_ids = jnp.zeros(decoder_input_shape, dtype="i4")
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
batch_size, _, _, _ = pixel_values.shape
decoder_batch_size, decoder_sequence_length = decoder_input_ids.shape
if not decoder_batch_size == batch_size:
raise ValueError(
f"The inputs of encoder and decoder should have the same batch size, but got {batch_size} for encoder "
f"and {decoder_batch_size} for decoder."
)
decoder_position_ids = jnp.broadcast_to(
jnp.arange(decoder_sequence_length)[None, :], (decoder_batch_size, decoder_sequence_length)
)
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(
rngs,
pixel_values,
decoder_input_ids,
decoder_attention_mask,
decoder_position_ids,
)["params"]
def init_cache(self, batch_size, max_length, encoder_outputs):
r"""
Args:
batch_size (:obj:`int`):
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
max_length (:obj:`int`):
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
cache.
encoder_outputs (:obj:`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
``encoder_outputs`` consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`,
`optional`: :obj:`attentions`). :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length,
hidden_size)`, `optional`) is a sequence of hidden-states at the output of the last layer of the
encoder. Used in the cross-attention of the decoder.
"""
# init input variables to retrieve cache
decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
decoder_position_ids = jnp.broadcast_to(
jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape
)
def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
decoder_module = module._get_decoder_module()
return decoder_module(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
position_ids=decoder_position_ids,
**kwargs,
)
init_variables = self.module.init(
jax.random.PRNGKey(0),
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
decoder_position_ids=decoder_position_ids,
encoder_hidden_states=encoder_outputs[0],
init_cache=True,
method=_decoder_forward, # we only need to call the decoder to init the cache
)
return unfreeze(init_variables["cache"])
@add_start_docstrings(VISION_ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=_CONFIG_FOR_DOC)
def encode(
self,
pixel_values: jnp.ndarray,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
r"""
Returns:
Example::
>>> from transformers import FlaxVisionEncoderDecoderModel
>>> from PIL import Image
>>> import requests
>>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
>>> # initialize a vit-gpt2 from pretrained ViT and GPT2 models. Note that the cross-attention layers will be randomly initialized
>>> model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained('vit', 'gpt2')
>>> pixel_values = feature_extractor(images=image, return_tensors="np").pixel_values
>>> encoder_outputs = model.encode(pixel_values)
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
# `FlaxViTModel` expects channel first format, but `FlaxViTModule` expects channel last format.
# Currently, we assume this holds for all Flax vision models, and perform a transpose here.
pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
def _encoder_forward(module, pixel_values, **kwargs):
encode_module = module._get_encoder_module()
return encode_module(pixel_values, **kwargs)
outputs = self.module.apply(
{"params": params or self.params},
pixel_values=jnp.array(pixel_values, dtype=self.dtype),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=not train,
rngs=rngs,
method=_encoder_forward,
)
if return_dict:
outputs = FlaxBaseModelOutput(
last_hidden_state=outputs.last_hidden_state,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
return outputs
@add_start_docstrings(VISION_ENCODER_DECODER_DECODE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def decode(
self,
decoder_input_ids,
encoder_outputs,
decoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_position_ids: Optional[jnp.ndarray] = None,
past_key_values: dict = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
r"""
Returns:
Example::
>>> from transformers import FlaxVisionEncoderDecoderModel
>>> import jax.numpy as jnp
>>> from PIL import Image
>>> import requests
>>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
>>> # initialize a vit-gpt2 from pretrained ViT and GPT2 models. Note that the cross-attention layers will be randomly initialized
>>> model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained('vit', 'gpt2')
>>> pixel_values = feature_extractor(images=image, return_tensors="np").pixel_values
>>> encoder_outputs = model.encode(pixel_values)
>>> decoder_start_token_id = model.config.decoder.bos_token_id
>>> decoder_input_ids = jnp.ones((pixel_values.shape[0], 1), dtype="i4") * decoder_start_token_id
>>> outputs = model.decode(decoder_input_ids, encoder_outputs)
>>> logits = outputs.logits
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
encoder_hidden_states = encoder_outputs[0]
batch_size, sequence_length = encoder_hidden_states.shape[:2]
encoder_attention_mask = jnp.ones((batch_size, sequence_length))
batch_size, sequence_length = decoder_input_ids.shape
if decoder_attention_mask is None:
decoder_attention_mask = jnp.ones((batch_size, sequence_length))
if decoder_position_ids is None:
if past_key_values is not None:
raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")
decoder_position_ids = jnp.broadcast_to(
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
)
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
inputs = {"params": params or self.params}
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be
# passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
# it can be changed by FlaxBartAttention module
if past_key_values:
inputs["cache"] = past_key_values
mutable = ["cache"]
else:
mutable = False
def _decoder_forward(
module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, encoder_hidden_states, **kwargs
):
projection_module = module._get_projection_module()
decoder_module = module._get_decoder_module()
# optionally project encoder_hidden_states
if projection_module is not None:
encoder_hidden_states = projection_module(encoder_hidden_states)
return decoder_module(
decoder_input_ids,
decoder_attention_mask,
decoder_position_ids,
encoder_hidden_states,
**kwargs,
)
outputs = self.module.apply(
inputs,
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=not train,
rngs=rngs,
mutable=mutable,
method=_decoder_forward,
)
# add updated cache to model output
if past_key_values is not None and return_dict:
outputs, past = outputs
outputs["past_key_values"] = unfreeze(past["cache"])
return outputs
elif past_key_values is not None and not return_dict:
outputs, past = outputs
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
return outputs
@add_start_docstrings_to_model_forward(VISION_ENCODER_DECODER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def __call__(
self,
pixel_values: jnp.ndarray,
decoder_input_ids: Optional[jnp.ndarray] = None,
decoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_position_ids: Optional[jnp.ndarray] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
r"""
Returns:
Examples::
>>> from transformers import FlaxVisionEncoderDecoderModel, ViTFeatureExtractor, GPT2Tokenizer
>>> from PIL import Image
>>> import requests
>>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
>>> # load output tokenizer
>>> tokenizer_output = GPT2Tokenizer.from_pretrained('gpt2')
>>> # initialize a vit-gpt2 from pretrained ViT and GPT2 models. Note that the cross-attention layers will be randomly initialized
>>> model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained('vit', 'gpt2')
>>> pixel_values = feature_extractor(images=image, return_tensors="np").pixel_values
>>> # use GPT2's eos_token as the pad as well as eos token
>>> model.config.eos_token_id = model.config.decoder.eos_token_id
>>> model.config.pad_token_id = model.config.eos_token_id
>>> # generation
>>> sequences = model.generate(pixel_values, num_beams=4, max_length=12).sequences
>>> captions = tokenizer_output.batch_decode(sequences, skip_special_tokens=True)
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
# prepare encoder inputs
# `FlaxViTModel` expects channel first format, but `FlaxViTModule` expects channel last format.
# Currently, we assume this holds for all Flax vision models, and perform a transpose here.
pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
# prepare decoder inputs
if decoder_input_ids is None:
raise ValueError("`decoder_input_ids` can't be `None`.")
if decoder_attention_mask is None:
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
if decoder_position_ids is None:
batch_size, sequence_length = decoder_input_ids.shape
decoder_position_ids = jnp.broadcast_to(
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
)
# Handle any PRNG if needed
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
return self.module.apply(
{"params": params or self.params},
pixel_values=jnp.array(pixel_values, dtype=self.dtype),
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=not train,
rngs=rngs,
)
def prepare_inputs_for_generation(
self,
decoder_input_ids,
max_length,
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
encoder_outputs=None,
**kwargs
):
# initializing the cache
batch_size, seq_length = decoder_input_ids.shape
past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
# But since the decoder uses a causal mask, those positions are masked anyways.
# Thus we can create a single static attention_mask here, which is more efficient for compilation
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
if decoder_attention_mask is not None:
decoder_position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
else:
decoder_position_ids = jnp.broadcast_to(
jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)
)
return {
"past_key_values": past_key_values,
"encoder_outputs": encoder_outputs,
"decoder_attention_mask": extended_attention_mask,
"decoder_position_ids": decoder_position_ids,
}
def update_inputs_for_generation(self, model_outputs, model_kwargs):
model_kwargs["past_key_values"] = model_outputs.past_key_values
model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1
return model_kwargs
@classmethod
def from_encoder_decoder_pretrained(
cls,
encoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
decoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
*model_args,
**kwargs
) -> FlaxPreTrainedModel:
r"""
Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model
checkpoints.
Params:
encoder_pretrained_model_name_or_path (:obj: `Union[str, os.PathLike]`, `optional`):
Information necessary to initiate the encoder. Can be either:
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co. An
example is ``google/vit-base-patch16-224-in21k``.
- A path to a `directory` containing model weights saved using
:func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
decoder_pretrained_model_name_or_path (:obj: `Union[str, os.PathLike]`, `optional`, defaults to `None`):
Information necessary to initiate the decoder. Can be either:
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
a user or organization name, like ``dbmdz/bert-base-german-cased``.
- A path to a `directory` containing model weights saved using
:func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
model_args (remaining positional arguments, `optional`):
All remaning positional arguments will be passed to the underlying model's ``__init__`` method.
kwargs (remaining dictionary of keyword arguments, `optional`):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
:obj:`output_attentions=True`).
- To update the encoder configuration, use the prefix `encoder_` for each configuration parameter.
- To update the decoder configuration, use the prefix `decoder_` for each configuration parameter.
- To update the parent model configuration, do not use a prefix for each configuration parameter.
Behaves differently depending on whether a :obj:`config` is provided or automatically loaded.
Example::
>>> from transformers import FlaxVisionEncoderDecoderModel
>>> # initialize a vit-gpt2 from a pretrained ViT and a pretrained GPT2 model. Note that the cross-attention layers will be randomly initialized
>>> model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained('google/vit-base-patch16-224-in21k', 'gpt2')
>>> # saving model after fine-tuning
>>> model.save_pretrained("./vit-gpt2")
>>> # load fine-tuned model
>>> model = FlaxVisionEncoderDecoderModel.from_pretrained("./vit-gpt2")
"""
kwargs_encoder = {
argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_")
}
kwargs_decoder = {
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
}
# remove encoder, decoder kwargs from kwargs
for key in kwargs_encoder.keys():
del kwargs["encoder_" + key]
for key in kwargs_decoder.keys():
del kwargs["decoder_" + key]
# Load and initialize the encoder and decoder
# The distinction between encoder and decoder at the model level is made
# by the value of the flag `is_decoder` that we need to set correctly.
encoder = kwargs_encoder.pop("model", None)
if encoder is None:
if encoder_pretrained_model_name_or_path is None:
raise ValueError(
"If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
"to be defined"
)
from ..auto.modeling_flax_auto import FlaxAutoModel
if "config" not in kwargs_encoder:
from ..auto.configuration_auto import AutoConfig
encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path)
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
logger.info(
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model from a decoder "
"model. Cross-attention and casual mask are disabled."
)
encoder_config.is_decoder = False
encoder_config.add_cross_attention = False
kwargs_encoder["config"] = encoder_config
encoder = FlaxAutoModel.from_pretrained(
encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder
)
decoder = kwargs_decoder.pop("model", None)
if decoder is None:
if decoder_pretrained_model_name_or_path is None:
raise ValueError(
"If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
"to be defined."
)
from ..auto.modeling_flax_auto import FlaxAutoModelForCausalLM
if "config" not in kwargs_decoder:
from ..auto.configuration_auto import AutoConfig
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
logger.info(
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention "
f"layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if "
f"{decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
)
decoder_config.is_decoder = True
decoder_config.add_cross_attention = True
kwargs_decoder["config"] = decoder_config
if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
logger.warning(
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. In order "
f"to initialize {decoder_pretrained_model_name_or_path} as a decoder, make sure that the "
"attributes `is_decoder` and `add_cross_attention` of `decoder_config` passed to "
"`.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a `decoder_config` to "
"`.from_encoder_decoder_pretrained(...)`"
)
decoder = FlaxAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
# instantiate config with corresponding kwargs
dtype = kwargs.pop("dtype", jnp.float32)
config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
# init model
model = cls(config, dtype=dtype)
model.params["encoder"] = encoder.params
model.params["decoder"] = decoder.params
return model
......@@ -70,8 +70,8 @@ VISION_ENCODER_DECODER_START_DOCSTRING = r"""
<https://arxiv.org/abs/2109.10282>`__ it is shown how leveraging large pretrained vision models for optical
character recognition (OCR) yields a significant performance improvement.
After such an Vision-Encoder Decoder model has been trained/fine-tuned, it can be saved/loaded just like any other
models (see the examples for more information).
After such a Vision-Encoder-Text-Decoder model has been trained/fine-tuned, it can be saved/loaded just like any
other models (see the examples for more information).
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
......@@ -94,13 +94,6 @@ VISION_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
Pixel values. Pixel values can be obtained using a feature extractor (e.g. if you use ViT as the encoder,
you should use :class:`~transformers.ViTFeatureExtractor`). See
:meth:`transformers.ViTFeatureExtractor.__call__` for details.
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Indices of decoder input sequence tokens in the vocabulary.
......@@ -130,10 +123,6 @@ VISION_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
vectors than the model's internal embedding lookup matrix.
decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded
representation. This is useful if you want more control over how to convert :obj:`decoder_input_ids`
......@@ -165,8 +154,8 @@ VISION_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
class VisionEncoderDecoderModel(PreTrainedModel):
r"""
:class:`~transformers.VisionEncoderDecoderModel` is a generic model class that will be instantiated as a
transformer architecture with one of the base model classes of the library as encoder and another one as decoder
when created with the :meth`~transformers.AutoModel.from_pretrained` class method for the encoder and
transformer architecture with one of the base vision model classes of the library as encoder and another one as
decoder when created with the :meth`~transformers.AutoModel.from_pretrained` class method for the encoder and
:meth`~transformers.AutoModelForCausalLM.from_pretrained` class method for the decoder.
"""
config_class = VisionEncoderDecoderConfig
......@@ -186,6 +175,15 @@ class VisionEncoderDecoderModel(PreTrainedModel):
if not isinstance(config, self.config_class):
raise ValueError(f"Config: {config} has to be of type {self.config_class}")
if config.decoder.cross_attention_hidden_size is not None:
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
raise ValueError(
f"If `cross_attention_hidden_size` is specified in the decoder's configuration, "
f"it has to be equal to the encoder's `hidden_size`."
f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
)
# initialize with config
# make sure input & output embeddings is not tied
config.tie_word_embeddings = False
......
......@@ -174,6 +174,9 @@ FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = None
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = None
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING = None
FLAX_MODEL_MAPPING = None
......@@ -276,6 +279,15 @@ class FlaxAutoModelForTokenClassification:
requires_backends(cls, ["flax"])
class FlaxAutoModelForVision2Seq:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxBartForConditionalGeneration:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
......@@ -949,6 +961,15 @@ class FlaxT5PreTrainedModel:
requires_backends(cls, ["flax"])
class FlaxVisionEncoderDecoderModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxViTForImageClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
......
......@@ -355,6 +355,9 @@ MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = None
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = None
MODEL_FOR_VISION_2_SEQ_MAPPING = None
MODEL_MAPPING = None
......@@ -514,6 +517,15 @@ class AutoModelForTokenClassification:
requires_backends(cls, ["torch"])
class AutoModelForVision2Seq:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AutoModelWithLMHead:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
......
......@@ -29,7 +29,6 @@ from .test_modeling_flax_gpt2 import FlaxGPT2ModelTester
if is_flax_available():
from transformers import (
AutoConfig,
AutoTokenizer,
EncoderDecoderConfig,
FlaxBertModel,
......@@ -350,12 +349,6 @@ class FlaxEncoderDecoderModelTest(unittest.TestCase):
def get_from_encoderdecoder_pretrained_model(self):
return FlaxEncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "gpt2")
def get_decoder_config(self):
config = AutoConfig.from_pretrained("gpt2")
config.is_decoder = True
config.add_cross_attention = True
return config
def _check_configuration_tie(self, model):
module = model.module.bind(model.params)
......
# coding=utf-8
# Copyright 2021 HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
import numpy as np
from transformers import is_flax_available, is_torch_available, is_vision_available
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, require_vision, slow, torch_device
from .test_modeling_flax_common import floats_tensor, ids_tensor
from .test_modeling_flax_gpt2 import FlaxGPT2ModelTester
from .test_modeling_flax_vit import FlaxViTModelTester
if is_flax_available():
from transformers import (
AutoTokenizer,
FlaxGPT2LMHeadModel,
FlaxVisionEncoderDecoderModel,
FlaxViTModel,
VisionEncoderDecoderConfig,
)
from transformers.modeling_flax_pytorch_utils import (
convert_pytorch_state_dict_to_flax,
load_flax_weights_in_pytorch_model,
)
if is_torch_available():
import torch
from transformers import VisionEncoderDecoderModel
if is_vision_available():
from PIL import Image
from transformers import ViTFeatureExtractor
@require_flax
class FlaxEncoderDecoderMixin:
def get_encoder_decoder_model(self, config, decoder_config):
raise NotImplementedError
def prepare_config_and_inputs(self):
raise NotImplementedError
def get_pretrained_model(self):
raise NotImplementedError
def check_encoder_decoder_model_from_pretrained_configs(
self,
config,
pixel_values,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
**kwargs
):
encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
self.assertTrue(encoder_decoder_config.decoder.is_decoder)
enc_dec_model = FlaxVisionEncoderDecoderModel(encoder_decoder_config)
self.assertTrue(enc_dec_model.config.is_encoder_decoder)
outputs_encoder_decoder = enc_dec_model(
pixel_values=pixel_values,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
)
self.assertEqual(
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
)
self.assertEqual(outputs_encoder_decoder["encoder_last_hidden_state"].shape[0], pixel_values.shape[0])
self.assertEqual(outputs_encoder_decoder["encoder_last_hidden_state"].shape[-1], config.hidden_size)
def check_encoder_decoder_model_from_pretrained(
self,
config,
pixel_values,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
return_dict,
**kwargs
):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
kwargs = {"encoder_model": encoder_model, "decoder_model": decoder_model, "return_dict": return_dict}
enc_dec_model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained(**kwargs)
outputs_encoder_decoder = enc_dec_model(
pixel_values=pixel_values,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
return_dict=True,
)
self.assertEqual(
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
)
self.assertEqual(outputs_encoder_decoder["encoder_last_hidden_state"].shape[0], pixel_values.shape[0])
self.assertEqual(outputs_encoder_decoder["encoder_last_hidden_state"].shape[-1], config.hidden_size)
def check_save_and_load(
self,
config,
pixel_values,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
**kwargs
):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
kwargs = {"encoder_model": encoder_model, "decoder_model": decoder_model}
enc_dec_model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained(**kwargs)
outputs = enc_dec_model(
pixel_values=pixel_values,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
)
out_2 = np.array(outputs[0])
out_2[np.isnan(out_2)] = 0
with tempfile.TemporaryDirectory() as tmpdirname:
enc_dec_model.save_pretrained(tmpdirname)
FlaxVisionEncoderDecoderModel.from_pretrained(tmpdirname)
after_outputs = enc_dec_model(
pixel_values=pixel_values,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
)
out_1 = np.array(after_outputs[0])
out_1[np.isnan(out_1)] = 0
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5)
def check_encoder_decoder_model_output_attentions(
self,
config,
pixel_values,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
**kwargs
):
# make the decoder inputs a different shape from the encoder inputs to harden the test
decoder_input_ids = decoder_input_ids[:, :-1]
decoder_attention_mask = decoder_attention_mask[:, :-1]
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
kwargs = {"encoder_model": encoder_model, "decoder_model": decoder_model}
enc_dec_model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained(**kwargs)
outputs_encoder_decoder = enc_dec_model(
pixel_values=pixel_values,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
output_attentions=True,
)
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
self.assertEqual(len(encoder_attentions), config.num_hidden_layers)
self.assertEqual(encoder_attentions[0].shape[-3:-2], (config.num_attention_heads,))
decoder_attentions = outputs_encoder_decoder["decoder_attentions"]
num_decoder_layers = (
decoder_config.num_decoder_layers
if hasattr(decoder_config, "num_decoder_layers")
else decoder_config.num_hidden_layers
)
self.assertEqual(len(decoder_attentions), num_decoder_layers)
self.assertEqual(
decoder_attentions[0].shape[-3:],
(decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]),
)
cross_attentions = outputs_encoder_decoder["cross_attentions"]
self.assertEqual(len(cross_attentions), num_decoder_layers)
cross_attention_input_seq_len = decoder_input_ids.shape[-1] * (
1 + (decoder_config.ngram if hasattr(decoder_config, "ngram") else 0)
)
self.assertEqual(
cross_attentions[0].shape[-3:-1],
(decoder_config.num_attention_heads, cross_attention_input_seq_len),
)
def check_encoder_decoder_model_generate(self, pixel_values, config, decoder_config, **kwargs):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
kwargs = {"encoder_model": encoder_model, "decoder_model": decoder_model}
enc_dec_model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained(**kwargs)
pad_token_id = enc_dec_model.config.decoder.pad_token_id
eos_token_id = enc_dec_model.config.decoder.eos_token_id
decoder_start_token_id = enc_dec_model.config.decoder.decoder_start_token_id
# Copied from generation_utils (GPT2 doesn't have `pad_token_id`)
if pad_token_id is None and eos_token_id is not None:
pad_token_id = eos_token_id
if decoder_start_token_id is None:
decoder_start_token_id = enc_dec_model.config.decoder.bos_token_id
# Bert does not have a bos token id, so use pad_token_id instead
# Copied from `test_modeling_encoder_decoder.py`
if decoder_start_token_id is None:
decoder_start_token_id = pad_token_id
generated_output = enc_dec_model.generate(
pixel_values,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
decoder_start_token_id=decoder_start_token_id,
)
generated_sequences = generated_output.sequences
self.assertEqual(generated_sequences.shape, (pixel_values.shape[0],) + (decoder_config.max_length,))
def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):
pt_model.to(torch_device)
pt_model.eval()
# prepare inputs
flax_inputs = inputs_dict
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()}
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
fx_outputs = fx_model(**inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-5)
# PT -> Flax
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
fx_model_loaded = FlaxVisionEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 1e-5)
# Flax -> PT
with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
pt_model_loaded = VisionEncoderDecoderModel.from_pretrained(tmpdirname, from_flax=True)
pt_model_loaded.to(torch_device)
pt_model_loaded.eval()
with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output_loaded in zip(fx_outputs, pt_outputs_loaded):
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 1e-5)
def check_equivalence_pt_to_flax(self, config, decoder_config, inputs_dict):
encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
pt_model = VisionEncoderDecoderModel(encoder_decoder_config)
fx_model = FlaxVisionEncoderDecoderModel(encoder_decoder_config)
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
fx_model.params = fx_state
self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
def check_equivalence_flax_to_pt(self, config, decoder_config, inputs_dict):
encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
pt_model = VisionEncoderDecoderModel(encoder_decoder_config)
fx_model = FlaxVisionEncoderDecoderModel(encoder_decoder_config)
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
def test_encoder_decoder_model_from_pretrained_configs(self):
config_inputs_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_from_pretrained_configs(**config_inputs_dict)
def test_encoder_decoder_model_from_pretrained(self):
config_inputs_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_from_pretrained(**config_inputs_dict, return_dict=False)
def test_encoder_decoder_model_from_pretrained_return_dict(self):
config_inputs_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_from_pretrained(**config_inputs_dict, return_dict=True)
def test_save_and_load_from_pretrained(self):
config_inputs_dict = self.prepare_config_and_inputs()
self.check_save_and_load(**config_inputs_dict)
def test_encoder_decoder_model_output_attentions(self):
config_inputs_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_output_attentions(**config_inputs_dict)
def test_encoder_decoder_model_generate(self):
config_inputs_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_generate(**config_inputs_dict)
def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
diff = np.abs((a - b)).max()
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
@is_pt_flax_cross_test
def test_pt_flax_equivalence(self):
config_inputs_dict = self.prepare_config_and_inputs()
config = config_inputs_dict.pop("config")
decoder_config = config_inputs_dict.pop("decoder_config")
inputs_dict = config_inputs_dict
# `encoder_hidden_states` is not used in model call/forward
del inputs_dict["encoder_hidden_states"]
# Avoid the case where a sequence has no place to attend (after combined with the causal attention mask)
batch_size = inputs_dict["decoder_attention_mask"].shape[0]
inputs_dict["decoder_attention_mask"] = np.concatenate(
[np.ones(shape=(batch_size, 1)), inputs_dict["decoder_attention_mask"][:, 1:]], axis=1
)
# Flax models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
decoder_config.use_cache = False
self.assertTrue(decoder_config.cross_attention_hidden_size is None)
# check without `enc_to_dec_proj` projection
self.assertTrue(config.hidden_size == decoder_config.hidden_size)
self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict)
self.check_equivalence_flax_to_pt(config, decoder_config, inputs_dict)
# check `enc_to_dec_proj` work as expected
decoder_config.hidden_size = decoder_config.hidden_size * 2
self.assertTrue(config.hidden_size != decoder_config.hidden_size)
self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict)
self.check_equivalence_flax_to_pt(config, decoder_config, inputs_dict)
@slow
def test_real_model_save_load_from_pretrained(self):
model_2 = self.get_pretrained_model()
pixel_values = floats_tensor(
[
13,
model_2.config.encoder.num_channels,
model_2.config.encoder.image_size,
model_2.config.encoder.image_size,
]
)
decoder_input_ids = ids_tensor([13, 1], model_2.config.decoder.vocab_size)
outputs = model_2(
pixel_values=pixel_values,
decoder_input_ids=decoder_input_ids,
)
out_2 = np.array(outputs[0])
out_2[np.isnan(out_2)] = 0
with tempfile.TemporaryDirectory() as tmp_dirname:
model_2.save_pretrained(tmp_dirname)
model_1 = FlaxVisionEncoderDecoderModel.from_pretrained(tmp_dirname)
after_outputs = model_1(
pixel_values=pixel_values,
decoder_input_ids=decoder_input_ids,
)
out_1 = np.array(after_outputs[0])
out_1[np.isnan(out_1)] = 0
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5)
@require_flax
class FlaxViT2GPT2EncoderDecoderModelTest(FlaxEncoderDecoderMixin, unittest.TestCase):
def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = FlaxViTModel(config)
decoder_model = FlaxGPT2LMHeadModel(decoder_config)
return encoder_model, decoder_model
def prepare_config_and_inputs(self):
model_tester_encoder = FlaxViTModelTester(self, batch_size=13)
model_tester_decoder = FlaxGPT2ModelTester(self, batch_size=13)
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs_for_decoder()
(config, pixel_values) = encoder_config_and_inputs
(
decoder_config,
decoder_input_ids,
decoder_attention_mask,
encoder_hidden_states,
encoder_attention_mask,
) = decoder_config_and_inputs
# make sure that cross attention layers are added
decoder_config.add_cross_attention = True
return {
"config": config,
"pixel_values": pixel_values,
"decoder_config": decoder_config,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
"encoder_hidden_states": encoder_hidden_states, # This is not used in the tests.
}
def get_pretrained_model(self):
return FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
"google/vit-base-patch16-224-in21k", "gpt2"
)
@require_flax
class FlaxVisionEncoderDecoderModelTest(unittest.TestCase):
def get_from_encoderdecoder_pretrained_model(self):
return FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
"google/vit-base-patch16-224-in21k", "gpt2"
)
def _check_configuration_tie(self, model):
module = model.module.bind(model.params)
assert id(module.decoder.config) == id(model.config.decoder)
assert id(module.encoder.config) == id(model.config.encoder)
@slow
def test_configuration_tie(self):
model = self.get_from_encoderdecoder_pretrained_model()
self._check_configuration_tie(model)
# We will verify our results on an image of cute cats
def prepare_img():
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
return image
@require_vision
@require_flax
class FlaxViT2GPT2ModelIntegrationTest(unittest.TestCase):
@slow
def test_inference_coco_en(self):
loc = "ydshieh/vit-gpt2-coco-en"
feature_extractor = ViTFeatureExtractor.from_pretrained(loc)
tokenizer = AutoTokenizer.from_pretrained(loc)
model = FlaxVisionEncoderDecoderModel.from_pretrained(loc)
img = prepare_img()
pixel_values = feature_extractor(images=img, return_tensors="np").pixel_values
decoder_input_ids = np.array([[model.config.decoder_start_token_id]])
logits = model(pixel_values, decoder_input_ids)[0]
logits = np.array(logits)
# verify the logits
expected_shape = (1, 1, model.config.decoder.vocab_size)
self.assertEqual(logits.shape, expected_shape)
EXPECTED_LOGIT_SLICE = np.array(
[
-38.705837,
-30.639936,
-31.41905,
-39.01204,
-38.38698,
-34.887215,
-33.29087,
-35.684475,
-38.50852,
-36.124676,
]
)
max_diff = np.amax(np.abs(logits[0, 0, :10] - EXPECTED_LOGIT_SLICE))
self.assertLessEqual(max_diff, 1e-4)
def generate_step(pixel_values):
outputs = model.generate(pixel_values, max_length=16, num_beams=4)
output_ids = outputs.sequences
preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
preds = [pred.strip() for pred in preds]
return preds, outputs.scores
preds, scores = generate_step(pixel_values)
EXPECTED_SCORES = np.array([-0.59563464])
scores = np.array(scores)
max_diff = np.amax(np.abs(scores - EXPECTED_SCORES))
self.assertLessEqual(max_diff, 1e-4)
# should produce
# ["a cat laying on top of a couch next to another cat"]
self.assertEqual(preds, ["a cat laying on top of a couch next to another cat"])
......@@ -34,6 +34,7 @@ if is_torch_available():
import torch
from transformers import (
AutoTokenizer,
BertLMHeadModel,
DeiTModel,
TrOCRForCausalLM,
......@@ -48,7 +49,7 @@ if is_torch_available():
if is_vision_available():
from PIL import Image
from transformers import TrOCRProcessor
from transformers import TrOCRProcessor, ViTFeatureExtractor
@require_torch
......@@ -656,3 +657,69 @@ class TrOCRModelIntegrationTest(unittest.TestCase):
).to(torch_device)
self.assertTrue(torch.allclose(logits[0, 0, :10], expected_slice, atol=1e-4))
@require_vision
@require_torch
class ViT2GPT2ModelIntegrationTest(unittest.TestCase):
@slow
def test_inference_coco_en(self):
loc = "ydshieh/vit-gpt2-coco-en"
feature_extractor = ViTFeatureExtractor.from_pretrained(loc)
tokenizer = AutoTokenizer.from_pretrained(loc)
model = VisionEncoderDecoderModel.from_pretrained(loc)
model.to(torch_device)
model.eval()
# We will verify our results on an image of cute cats
img = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
pixel_values = feature_extractor(images=img, return_tensors="pt").pixel_values.to(torch_device)
decoder_input_ids = torch.tensor([[model.config.decoder_start_token_id]]).to(torch_device)
with torch.no_grad():
logits = model(pixel_values, decoder_input_ids)[0].detach().cpu().numpy()
# verify the logits
expected_shape = (1, 1, model.config.decoder.vocab_size)
self.assertEqual(logits.shape, expected_shape)
EXPECTED_LOGIT_SLICE = np.array(
[
-38.705807,
-30.639929,
-31.41903,
-39.012012,
-38.38696,
-34.887207,
-33.290855,
-35.68447,
-38.508484,
-36.124645,
]
)
max_diff = np.amax(np.abs(logits[0, 0, :10] - EXPECTED_LOGIT_SLICE))
self.assertLessEqual(max_diff, 1e-4)
def generate_step(pixel_values):
outputs = model.generate(
pixel_values, max_length=16, num_beams=4, return_dict_in_generate=True, output_scores=True
)
output_ids = outputs.sequences
preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
preds = [pred.strip() for pred in preds]
return preds, outputs.sequences_scores.detach().cpu().numpy()
preds, scores = generate_step(pixel_values)
EXPECTED_SCORES = np.array([-0.59562886])
max_diff = np.amax(np.abs(scores - EXPECTED_SCORES))
self.assertLessEqual(max_diff, 1e-4)
# should produce
# ["a cat laying on top of a couch next to another cat"]
self.assertEqual(preds, ["a cat laying on top of a couch next to another cat"])
......@@ -187,6 +187,7 @@ def get_model_modules():
"modeling_flax_encoder_decoder",
"modeling_flax_utils",
"modeling_speech_encoder_decoder",
"modeling_flax_vision_encoder_decoder",
"modeling_transfo_xl_utilities",
"modeling_tf_auto",
"modeling_tf_encoder_decoder",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment