Unverified Commit 95b3ec3b authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Add FlaxVisionEncoderDecoderModel (#13359)



* Start the work on FlaxVisionEncoderDecoderModel

* Add FlaxVisionEncoderDecoderModel

* Add VisionEncoderDecoderConfig

* Make FlaxVisionEncoderDecoderModel visible to transformers

* Add test

* Fix wrong getattr usage

* Fix tests

* Add FlaxAutoModelForVision2Seq

* Expose FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING

* clean-up

* add integration test

* update expected logits

* update expected scores

* Add ViT2GPT2ModelIntegrationTest + some cleaning

* Add projection layer + PT/Flax equivalence tests

* Fix import

* minor changes

* make test slow again

* Apply suggestions

* Add modeling_flax_vision_encoder_decoder to _ignore_modules in get_model_modules()

* fix copies

* Apply suggestions from code review
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* split long strings in multiple lines

* decoder_input_ids can't be None

* Add back test_configuration_tie

* Remove attention_mask parameter

* fix test - encoder_last_hidden_state should be encoder_outputs.last_hidden_state instead of the projected vector

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Remove more encoder_attention_mask

* remove encoder_attention_mask when calling self.decode (in FlaxVisionEncoderDecoderModule)

* Fix style + pass 1s instead of None as encoder_attention_mask

* fix init_weights

* pass None for encoder_attention_mask

* pass 1s instead of None as encoder_attention_mask

* Fix doc style
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent a5030122
......@@ -499,7 +499,7 @@ Flax), PyTorch, and/or TensorFlow.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| UniSpeechSat | ❌ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Vision Encoder decoder | ❌ | ❌ | ✅ | ❌ | |
| Vision Encoder decoder | ❌ | ❌ | ✅ | ❌ | |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| VisualBert | ❌ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
......
......@@ -160,6 +160,13 @@ AutoModelForImageClassification
:members:
AutoModelForVision2Seq
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.AutoModelForVision2Seq
:members:
AutoModelForAudioClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......@@ -333,3 +340,10 @@ FlaxAutoModelForImageClassification
.. autoclass:: transformers.FlaxAutoModelForImageClassification
:members:
FlaxAutoModelForVision2Seq
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxAutoModelForVision2Seq
:members:
......@@ -39,3 +39,10 @@ VisionEncoderDecoderModel
.. autoclass:: transformers.VisionEncoderDecoderModel
:members: forward, from_encoder_decoder_pretrained
FlaxVisionEncoderDecoderModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxVisionEncoderDecoderModel
:members: __call__, from_encoder_decoder_pretrained
......@@ -603,6 +603,7 @@ if is_torch_available():
"MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
"MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
"MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"MODEL_FOR_VISION_2_SEQ_MAPPING",
"MODEL_MAPPING",
"MODEL_WITH_LM_HEAD_MAPPING",
"AutoModel",
......@@ -622,6 +623,7 @@ if is_torch_available():
"AutoModelForSpeechSeq2Seq",
"AutoModelForTableQuestionAnswering",
"AutoModelForTokenClassification",
"AutoModelForVision2Seq",
"AutoModelWithLMHead",
]
)
......@@ -1825,6 +1827,7 @@ if is_flax_available():
"FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
"FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
"FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING",
"FLAX_MODEL_MAPPING",
"FlaxAutoModel",
"FlaxAutoModelForCausalLM",
......@@ -1837,6 +1840,7 @@ if is_flax_available():
"FlaxAutoModelForSeq2SeqLM",
"FlaxAutoModelForSequenceClassification",
"FlaxAutoModelForTokenClassification",
"FlaxAutoModelForVision2Seq",
]
)
......@@ -1957,6 +1961,7 @@ if is_flax_available():
]
)
_import_structure["models.t5"].extend(["FlaxT5ForConditionalGeneration", "FlaxT5Model", "FlaxT5PreTrainedModel"])
_import_structure["models.vision_encoder_decoder"].append("FlaxVisionEncoderDecoderModel")
_import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel", "FlaxViTPreTrainedModel"])
_import_structure["models.wav2vec2"].extend(
["FlaxWav2Vec2ForCTC", "FlaxWav2Vec2ForPreTraining", "FlaxWav2Vec2Model", "FlaxWav2Vec2PreTrainedModel"]
......@@ -2457,6 +2462,7 @@ if TYPE_CHECKING:
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
MODEL_FOR_VISION_2_SEQ_MAPPING,
MODEL_MAPPING,
MODEL_WITH_LM_HEAD_MAPPING,
AutoModel,
......@@ -2476,6 +2482,7 @@ if TYPE_CHECKING:
AutoModelForSpeechSeq2Seq,
AutoModelForTableQuestionAnswering,
AutoModelForTokenClassification,
AutoModelForVision2Seq,
AutoModelWithLMHead,
)
from .models.bart import (
......@@ -3482,6 +3489,7 @@ if TYPE_CHECKING:
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING,
FLAX_MODEL_MAPPING,
FlaxAutoModel,
FlaxAutoModelForCausalLM,
......@@ -3494,6 +3502,7 @@ if TYPE_CHECKING:
FlaxAutoModelForSeq2SeqLM,
FlaxAutoModelForSequenceClassification,
FlaxAutoModelForTokenClassification,
FlaxAutoModelForVision2Seq,
)
from .models.bart import (
FlaxBartForConditionalGeneration,
......@@ -3579,6 +3588,7 @@ if TYPE_CHECKING:
FlaxRobertaPreTrainedModel,
)
from .models.t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel
from .models.vision_encoder_decoder import FlaxVisionEncoderDecoderModel
from .models.vit import FlaxViTForImageClassification, FlaxViTModel, FlaxViTPreTrainedModel
from .models.wav2vec2 import (
FlaxWav2Vec2ForCTC,
......
......@@ -46,6 +46,7 @@ if is_torch_available():
"MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
"MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
"MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"MODEL_FOR_VISION_2_SEQ_MAPPING",
"MODEL_MAPPING",
"MODEL_WITH_LM_HEAD_MAPPING",
"AutoModel",
......@@ -65,6 +66,7 @@ if is_torch_available():
"AutoModelForSpeechSeq2Seq",
"AutoModelForTableQuestionAnswering",
"AutoModelForTokenClassification",
"AutoModelForVision2Seq",
"AutoModelWithLMHead",
]
......@@ -105,6 +107,7 @@ if is_flax_available():
"FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
"FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
"FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING",
"FLAX_MODEL_MAPPING",
"FlaxAutoModel",
"FlaxAutoModelForCausalLM",
......@@ -117,6 +120,7 @@ if is_flax_available():
"FlaxAutoModelForSeq2SeqLM",
"FlaxAutoModelForSequenceClassification",
"FlaxAutoModelForTokenClassification",
"FlaxAutoModelForVision2Seq",
]
......@@ -144,6 +148,7 @@ if TYPE_CHECKING:
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
MODEL_FOR_VISION_2_SEQ_MAPPING,
MODEL_MAPPING,
MODEL_WITH_LM_HEAD_MAPPING,
AutoModel,
......@@ -163,6 +168,7 @@ if TYPE_CHECKING:
AutoModelForSpeechSeq2Seq,
AutoModelForTableQuestionAnswering,
AutoModelForTokenClassification,
AutoModelForVision2Seq,
AutoModelWithLMHead,
)
......@@ -203,6 +209,7 @@ if TYPE_CHECKING:
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING,
FLAX_MODEL_MAPPING,
FlaxAutoModel,
FlaxAutoModelForCausalLM,
......@@ -215,6 +222,7 @@ if TYPE_CHECKING:
FlaxAutoModelForSeq2SeqLM,
FlaxAutoModelForSequenceClassification,
FlaxAutoModelForTokenClassification,
FlaxAutoModelForVision2Seq,
)
else:
......
......@@ -239,6 +239,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("electra", "ELECTRA"),
("encoder-decoder", "Encoder decoder"),
("speech-encoder-decoder", "Speech Encoder decoder"),
("vision-encoder-decoder", "Vision Encoder decoder"),
("funnel", "Funnel Transformer"),
("lxmert", "LXMERT"),
("deberta-v2", "DeBERTa-v2"),
......
......@@ -244,6 +244,12 @@ MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = OrderedDict(
]
)
MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
[
("vision-encoder-decoder", "VisionEncoderDecoderModel"),
]
)
MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Masked LM mapping
......@@ -511,6 +517,7 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES
)
MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES)
MODEL_FOR_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES)
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
......@@ -655,6 +662,13 @@ class AutoModelForObjectDetection(_BaseAutoModelClass):
AutoModelForObjectDetection = auto_class_update(AutoModelForObjectDetection, head_doc="object detection")
class AutoModelForVision2Seq(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_VISION_2_SEQ_MAPPING
AutoModelForVision2Seq = auto_class_update(AutoModelForVision2Seq, head_doc="vision-to-text modeling")
class AutoModelForAudioClassification(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
......
......@@ -100,6 +100,12 @@ FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
]
)
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
[
("vision-encoder-decoder", "FlaxVisionEncoderDecoderModel"),
]
)
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Causal LM mapping
......@@ -176,6 +182,7 @@ FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
)
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
......@@ -279,3 +286,10 @@ class FlaxAutoModelForImageClassification(_BaseAutoModelClass):
FlaxAutoModelForImageClassification = auto_class_update(
FlaxAutoModelForImageClassification, head_doc="image classification"
)
class FlaxAutoModelForVision2Seq(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING
FlaxAutoModelForVision2Seq = auto_class_update(FlaxAutoModelForVision2Seq, head_doc="vision-to-text modeling")
......@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING
from ...file_utils import _LazyModule, is_torch_available
from ...file_utils import _LazyModule, is_flax_available, is_torch_available
_import_structure = {
......@@ -28,12 +28,18 @@ _import_structure = {
if is_torch_available():
_import_structure["modeling_vision_encoder_decoder"] = ["VisionEncoderDecoderModel"]
if is_flax_available():
_import_structure["modeling_flax_vision_encoder_decoder"] = ["FlaxVisionEncoderDecoderModel"]
if TYPE_CHECKING:
from .configuration_vision_encoder_decoder import VisionEncoderDecoderConfig
if is_torch_available():
from .modeling_vision_encoder_decoder import VisionEncoderDecoderModel
if is_flax_available():
from .modeling_flax_vision_encoder_decoder import FlaxVisionEncoderDecoderModel
else:
import sys
......
......@@ -27,8 +27,8 @@ logger = logging.get_logger(__name__)
class VisionEncoderDecoderConfig(PretrainedConfig):
r"""
:class:`~transformers.VisionEncoderDecoderConfig` is the configuration class to store the configuration of a
:class:`~transformers.VisionEncoderDecoderModel`. It is used to instantiate an Encoder Decoder model according to
the specified arguments, defining the encoder and decoder configs.
:class:`~transformers.VisionEncoderDecoderModel`. It is used to instantiate a Vision-Encoder-Text-Decoder model
according to the specified arguments, defining the encoder and decoder configs.
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
......
......@@ -70,8 +70,8 @@ VISION_ENCODER_DECODER_START_DOCSTRING = r"""
<https://arxiv.org/abs/2109.10282>`__ it is shown how leveraging large pretrained vision models for optical
character recognition (OCR) yields a significant performance improvement.
After such an Vision-Encoder Decoder model has been trained/fine-tuned, it can be saved/loaded just like any other
models (see the examples for more information).
After such a Vision-Encoder-Text-Decoder model has been trained/fine-tuned, it can be saved/loaded just like any
other models (see the examples for more information).
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
......@@ -94,13 +94,6 @@ VISION_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
Pixel values. Pixel values can be obtained using a feature extractor (e.g. if you use ViT as the encoder,
you should use :class:`~transformers.ViTFeatureExtractor`). See
:meth:`transformers.ViTFeatureExtractor.__call__` for details.
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Indices of decoder input sequence tokens in the vocabulary.
......@@ -130,10 +123,6 @@ VISION_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
vectors than the model's internal embedding lookup matrix.
decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded
representation. This is useful if you want more control over how to convert :obj:`decoder_input_ids`
......@@ -165,8 +154,8 @@ VISION_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
class VisionEncoderDecoderModel(PreTrainedModel):
r"""
:class:`~transformers.VisionEncoderDecoderModel` is a generic model class that will be instantiated as a
transformer architecture with one of the base model classes of the library as encoder and another one as decoder
when created with the :meth`~transformers.AutoModel.from_pretrained` class method for the encoder and
transformer architecture with one of the base vision model classes of the library as encoder and another one as
decoder when created with the :meth`~transformers.AutoModel.from_pretrained` class method for the encoder and
:meth`~transformers.AutoModelForCausalLM.from_pretrained` class method for the decoder.
"""
config_class = VisionEncoderDecoderConfig
......@@ -186,6 +175,15 @@ class VisionEncoderDecoderModel(PreTrainedModel):
if not isinstance(config, self.config_class):
raise ValueError(f"Config: {config} has to be of type {self.config_class}")
if config.decoder.cross_attention_hidden_size is not None:
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
raise ValueError(
f"If `cross_attention_hidden_size` is specified in the decoder's configuration, "
f"it has to be equal to the encoder's `hidden_size`."
f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
)
# initialize with config
# make sure input & output embeddings is not tied
config.tie_word_embeddings = False
......
......@@ -174,6 +174,9 @@ FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = None
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = None
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING = None
FLAX_MODEL_MAPPING = None
......@@ -276,6 +279,15 @@ class FlaxAutoModelForTokenClassification:
requires_backends(cls, ["flax"])
class FlaxAutoModelForVision2Seq:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxBartForConditionalGeneration:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
......@@ -949,6 +961,15 @@ class FlaxT5PreTrainedModel:
requires_backends(cls, ["flax"])
class FlaxVisionEncoderDecoderModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxViTForImageClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
......
......@@ -355,6 +355,9 @@ MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = None
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = None
MODEL_FOR_VISION_2_SEQ_MAPPING = None
MODEL_MAPPING = None
......@@ -514,6 +517,15 @@ class AutoModelForTokenClassification:
requires_backends(cls, ["torch"])
class AutoModelForVision2Seq:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AutoModelWithLMHead:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
......
......@@ -29,7 +29,6 @@ from .test_modeling_flax_gpt2 import FlaxGPT2ModelTester
if is_flax_available():
from transformers import (
AutoConfig,
AutoTokenizer,
EncoderDecoderConfig,
FlaxBertModel,
......@@ -350,12 +349,6 @@ class FlaxEncoderDecoderModelTest(unittest.TestCase):
def get_from_encoderdecoder_pretrained_model(self):
return FlaxEncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "gpt2")
def get_decoder_config(self):
config = AutoConfig.from_pretrained("gpt2")
config.is_decoder = True
config.add_cross_attention = True
return config
def _check_configuration_tie(self, model):
module = model.module.bind(model.params)
......
This diff is collapsed.
......@@ -34,6 +34,7 @@ if is_torch_available():
import torch
from transformers import (
AutoTokenizer,
BertLMHeadModel,
DeiTModel,
TrOCRForCausalLM,
......@@ -48,7 +49,7 @@ if is_torch_available():
if is_vision_available():
from PIL import Image
from transformers import TrOCRProcessor
from transformers import TrOCRProcessor, ViTFeatureExtractor
@require_torch
......@@ -656,3 +657,69 @@ class TrOCRModelIntegrationTest(unittest.TestCase):
).to(torch_device)
self.assertTrue(torch.allclose(logits[0, 0, :10], expected_slice, atol=1e-4))
@require_vision
@require_torch
class ViT2GPT2ModelIntegrationTest(unittest.TestCase):
@slow
def test_inference_coco_en(self):
loc = "ydshieh/vit-gpt2-coco-en"
feature_extractor = ViTFeatureExtractor.from_pretrained(loc)
tokenizer = AutoTokenizer.from_pretrained(loc)
model = VisionEncoderDecoderModel.from_pretrained(loc)
model.to(torch_device)
model.eval()
# We will verify our results on an image of cute cats
img = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
pixel_values = feature_extractor(images=img, return_tensors="pt").pixel_values.to(torch_device)
decoder_input_ids = torch.tensor([[model.config.decoder_start_token_id]]).to(torch_device)
with torch.no_grad():
logits = model(pixel_values, decoder_input_ids)[0].detach().cpu().numpy()
# verify the logits
expected_shape = (1, 1, model.config.decoder.vocab_size)
self.assertEqual(logits.shape, expected_shape)
EXPECTED_LOGIT_SLICE = np.array(
[
-38.705807,
-30.639929,
-31.41903,
-39.012012,
-38.38696,
-34.887207,
-33.290855,
-35.68447,
-38.508484,
-36.124645,
]
)
max_diff = np.amax(np.abs(logits[0, 0, :10] - EXPECTED_LOGIT_SLICE))
self.assertLessEqual(max_diff, 1e-4)
def generate_step(pixel_values):
outputs = model.generate(
pixel_values, max_length=16, num_beams=4, return_dict_in_generate=True, output_scores=True
)
output_ids = outputs.sequences
preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
preds = [pred.strip() for pred in preds]
return preds, outputs.sequences_scores.detach().cpu().numpy()
preds, scores = generate_step(pixel_values)
EXPECTED_SCORES = np.array([-0.59562886])
max_diff = np.amax(np.abs(scores - EXPECTED_SCORES))
self.assertLessEqual(max_diff, 1e-4)
# should produce
# ["a cat laying on top of a couch next to another cat"]
self.assertEqual(preds, ["a cat laying on top of a couch next to another cat"])
......@@ -187,6 +187,7 @@ def get_model_modules():
"modeling_flax_encoder_decoder",
"modeling_flax_utils",
"modeling_speech_encoder_decoder",
"modeling_flax_vision_encoder_decoder",
"modeling_transfo_xl_utilities",
"modeling_tf_auto",
"modeling_tf_encoder_decoder",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment