"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "5451f8896c23f006648aa8da852fec499dfe6000"
Unverified Commit 95b3ec3b authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Add FlaxVisionEncoderDecoderModel (#13359)



* Start the work on FlaxVisionEncoderDecoderModel

* Add FlaxVisionEncoderDecoderModel

* Add VisionEncoderDecoderConfig

* Make FlaxVisionEncoderDecoderModel visible to transformers

* Add test

* Fix wrong getattr usage

* Fix tests

* Add FlaxAutoModelForVision2Seq

* Expose FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING

* clean-up

* add integration test

* update expected logits

* update expected scores

* Add ViT2GPT2ModelIntegrationTest + some cleaning

* Add projection layer + PT/Flax equivalence tests

* Fix import

* minor changes

* make test slow again

* Apply suggestions

* Add modeling_flax_vision_encoder_decoder to _ignore_modules in get_model_modules()

* fix copies

* Apply suggestions from code review
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* split long strings in multiple lines

* decoder_input_ids can't be None

* Add back test_configuration_tie

* Remove attention_mask parameter

* fix test - encoder_last_hidden_state should be encoder_outputs.last_hidden_state instead of the projected vector

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Remove more encoder_attention_mask

* remove encoder_attention_mask when calling self.decode (in FlaxVisionEncoderDecoderModule)

* Fix style + pass 1s instead of None as encoder_attention_mask

* fix init_weights

* pass None for encoder_attention_mask

* pass 1s instead of None as encoder_attention_mask

* Fix doc style
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent a5030122
...@@ -499,7 +499,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -499,7 +499,7 @@ Flax), PyTorch, and/or TensorFlow.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| UniSpeechSat | ❌ | ❌ | ✅ | ❌ | ❌ | | UniSpeechSat | ❌ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Vision Encoder decoder | ❌ | ❌ | ✅ | ❌ | | | Vision Encoder decoder | ❌ | ❌ | ✅ | ❌ | |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| VisualBert | ❌ | ❌ | ✅ | ❌ | ❌ | | VisualBert | ❌ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
......
...@@ -160,6 +160,13 @@ AutoModelForImageClassification ...@@ -160,6 +160,13 @@ AutoModelForImageClassification
:members: :members:
AutoModelForVision2Seq
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.AutoModelForVision2Seq
:members:
AutoModelForAudioClassification AutoModelForAudioClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
...@@ -333,3 +340,10 @@ FlaxAutoModelForImageClassification ...@@ -333,3 +340,10 @@ FlaxAutoModelForImageClassification
.. autoclass:: transformers.FlaxAutoModelForImageClassification .. autoclass:: transformers.FlaxAutoModelForImageClassification
:members: :members:
FlaxAutoModelForVision2Seq
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxAutoModelForVision2Seq
:members:
...@@ -39,3 +39,10 @@ VisionEncoderDecoderModel ...@@ -39,3 +39,10 @@ VisionEncoderDecoderModel
.. autoclass:: transformers.VisionEncoderDecoderModel .. autoclass:: transformers.VisionEncoderDecoderModel
:members: forward, from_encoder_decoder_pretrained :members: forward, from_encoder_decoder_pretrained
FlaxVisionEncoderDecoderModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxVisionEncoderDecoderModel
:members: __call__, from_encoder_decoder_pretrained
...@@ -603,6 +603,7 @@ if is_torch_available(): ...@@ -603,6 +603,7 @@ if is_torch_available():
"MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING", "MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
"MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING", "MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
"MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", "MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"MODEL_FOR_VISION_2_SEQ_MAPPING",
"MODEL_MAPPING", "MODEL_MAPPING",
"MODEL_WITH_LM_HEAD_MAPPING", "MODEL_WITH_LM_HEAD_MAPPING",
"AutoModel", "AutoModel",
...@@ -622,6 +623,7 @@ if is_torch_available(): ...@@ -622,6 +623,7 @@ if is_torch_available():
"AutoModelForSpeechSeq2Seq", "AutoModelForSpeechSeq2Seq",
"AutoModelForTableQuestionAnswering", "AutoModelForTableQuestionAnswering",
"AutoModelForTokenClassification", "AutoModelForTokenClassification",
"AutoModelForVision2Seq",
"AutoModelWithLMHead", "AutoModelWithLMHead",
] ]
) )
...@@ -1825,6 +1827,7 @@ if is_flax_available(): ...@@ -1825,6 +1827,7 @@ if is_flax_available():
"FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING", "FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
"FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING", "FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
"FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", "FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING",
"FLAX_MODEL_MAPPING", "FLAX_MODEL_MAPPING",
"FlaxAutoModel", "FlaxAutoModel",
"FlaxAutoModelForCausalLM", "FlaxAutoModelForCausalLM",
...@@ -1837,6 +1840,7 @@ if is_flax_available(): ...@@ -1837,6 +1840,7 @@ if is_flax_available():
"FlaxAutoModelForSeq2SeqLM", "FlaxAutoModelForSeq2SeqLM",
"FlaxAutoModelForSequenceClassification", "FlaxAutoModelForSequenceClassification",
"FlaxAutoModelForTokenClassification", "FlaxAutoModelForTokenClassification",
"FlaxAutoModelForVision2Seq",
] ]
) )
...@@ -1957,6 +1961,7 @@ if is_flax_available(): ...@@ -1957,6 +1961,7 @@ if is_flax_available():
] ]
) )
_import_structure["models.t5"].extend(["FlaxT5ForConditionalGeneration", "FlaxT5Model", "FlaxT5PreTrainedModel"]) _import_structure["models.t5"].extend(["FlaxT5ForConditionalGeneration", "FlaxT5Model", "FlaxT5PreTrainedModel"])
_import_structure["models.vision_encoder_decoder"].append("FlaxVisionEncoderDecoderModel")
_import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel", "FlaxViTPreTrainedModel"]) _import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel", "FlaxViTPreTrainedModel"])
_import_structure["models.wav2vec2"].extend( _import_structure["models.wav2vec2"].extend(
["FlaxWav2Vec2ForCTC", "FlaxWav2Vec2ForPreTraining", "FlaxWav2Vec2Model", "FlaxWav2Vec2PreTrainedModel"] ["FlaxWav2Vec2ForCTC", "FlaxWav2Vec2ForPreTraining", "FlaxWav2Vec2Model", "FlaxWav2Vec2PreTrainedModel"]
...@@ -2457,6 +2462,7 @@ if TYPE_CHECKING: ...@@ -2457,6 +2462,7 @@ if TYPE_CHECKING:
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING, MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
MODEL_FOR_VISION_2_SEQ_MAPPING,
MODEL_MAPPING, MODEL_MAPPING,
MODEL_WITH_LM_HEAD_MAPPING, MODEL_WITH_LM_HEAD_MAPPING,
AutoModel, AutoModel,
...@@ -2476,6 +2482,7 @@ if TYPE_CHECKING: ...@@ -2476,6 +2482,7 @@ if TYPE_CHECKING:
AutoModelForSpeechSeq2Seq, AutoModelForSpeechSeq2Seq,
AutoModelForTableQuestionAnswering, AutoModelForTableQuestionAnswering,
AutoModelForTokenClassification, AutoModelForTokenClassification,
AutoModelForVision2Seq,
AutoModelWithLMHead, AutoModelWithLMHead,
) )
from .models.bart import ( from .models.bart import (
...@@ -3482,6 +3489,7 @@ if TYPE_CHECKING: ...@@ -3482,6 +3489,7 @@ if TYPE_CHECKING:
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING,
FLAX_MODEL_MAPPING, FLAX_MODEL_MAPPING,
FlaxAutoModel, FlaxAutoModel,
FlaxAutoModelForCausalLM, FlaxAutoModelForCausalLM,
...@@ -3494,6 +3502,7 @@ if TYPE_CHECKING: ...@@ -3494,6 +3502,7 @@ if TYPE_CHECKING:
FlaxAutoModelForSeq2SeqLM, FlaxAutoModelForSeq2SeqLM,
FlaxAutoModelForSequenceClassification, FlaxAutoModelForSequenceClassification,
FlaxAutoModelForTokenClassification, FlaxAutoModelForTokenClassification,
FlaxAutoModelForVision2Seq,
) )
from .models.bart import ( from .models.bart import (
FlaxBartForConditionalGeneration, FlaxBartForConditionalGeneration,
...@@ -3579,6 +3588,7 @@ if TYPE_CHECKING: ...@@ -3579,6 +3588,7 @@ if TYPE_CHECKING:
FlaxRobertaPreTrainedModel, FlaxRobertaPreTrainedModel,
) )
from .models.t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel from .models.t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel
from .models.vision_encoder_decoder import FlaxVisionEncoderDecoderModel
from .models.vit import FlaxViTForImageClassification, FlaxViTModel, FlaxViTPreTrainedModel from .models.vit import FlaxViTForImageClassification, FlaxViTModel, FlaxViTPreTrainedModel
from .models.wav2vec2 import ( from .models.wav2vec2 import (
FlaxWav2Vec2ForCTC, FlaxWav2Vec2ForCTC,
......
...@@ -46,6 +46,7 @@ if is_torch_available(): ...@@ -46,6 +46,7 @@ if is_torch_available():
"MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING", "MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
"MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING", "MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
"MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", "MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"MODEL_FOR_VISION_2_SEQ_MAPPING",
"MODEL_MAPPING", "MODEL_MAPPING",
"MODEL_WITH_LM_HEAD_MAPPING", "MODEL_WITH_LM_HEAD_MAPPING",
"AutoModel", "AutoModel",
...@@ -65,6 +66,7 @@ if is_torch_available(): ...@@ -65,6 +66,7 @@ if is_torch_available():
"AutoModelForSpeechSeq2Seq", "AutoModelForSpeechSeq2Seq",
"AutoModelForTableQuestionAnswering", "AutoModelForTableQuestionAnswering",
"AutoModelForTokenClassification", "AutoModelForTokenClassification",
"AutoModelForVision2Seq",
"AutoModelWithLMHead", "AutoModelWithLMHead",
] ]
...@@ -105,6 +107,7 @@ if is_flax_available(): ...@@ -105,6 +107,7 @@ if is_flax_available():
"FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING", "FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
"FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING", "FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
"FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", "FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING",
"FLAX_MODEL_MAPPING", "FLAX_MODEL_MAPPING",
"FlaxAutoModel", "FlaxAutoModel",
"FlaxAutoModelForCausalLM", "FlaxAutoModelForCausalLM",
...@@ -117,6 +120,7 @@ if is_flax_available(): ...@@ -117,6 +120,7 @@ if is_flax_available():
"FlaxAutoModelForSeq2SeqLM", "FlaxAutoModelForSeq2SeqLM",
"FlaxAutoModelForSequenceClassification", "FlaxAutoModelForSequenceClassification",
"FlaxAutoModelForTokenClassification", "FlaxAutoModelForTokenClassification",
"FlaxAutoModelForVision2Seq",
] ]
...@@ -144,6 +148,7 @@ if TYPE_CHECKING: ...@@ -144,6 +148,7 @@ if TYPE_CHECKING:
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING, MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
MODEL_FOR_VISION_2_SEQ_MAPPING,
MODEL_MAPPING, MODEL_MAPPING,
MODEL_WITH_LM_HEAD_MAPPING, MODEL_WITH_LM_HEAD_MAPPING,
AutoModel, AutoModel,
...@@ -163,6 +168,7 @@ if TYPE_CHECKING: ...@@ -163,6 +168,7 @@ if TYPE_CHECKING:
AutoModelForSpeechSeq2Seq, AutoModelForSpeechSeq2Seq,
AutoModelForTableQuestionAnswering, AutoModelForTableQuestionAnswering,
AutoModelForTokenClassification, AutoModelForTokenClassification,
AutoModelForVision2Seq,
AutoModelWithLMHead, AutoModelWithLMHead,
) )
...@@ -203,6 +209,7 @@ if TYPE_CHECKING: ...@@ -203,6 +209,7 @@ if TYPE_CHECKING:
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING,
FLAX_MODEL_MAPPING, FLAX_MODEL_MAPPING,
FlaxAutoModel, FlaxAutoModel,
FlaxAutoModelForCausalLM, FlaxAutoModelForCausalLM,
...@@ -215,6 +222,7 @@ if TYPE_CHECKING: ...@@ -215,6 +222,7 @@ if TYPE_CHECKING:
FlaxAutoModelForSeq2SeqLM, FlaxAutoModelForSeq2SeqLM,
FlaxAutoModelForSequenceClassification, FlaxAutoModelForSequenceClassification,
FlaxAutoModelForTokenClassification, FlaxAutoModelForTokenClassification,
FlaxAutoModelForVision2Seq,
) )
else: else:
......
...@@ -239,6 +239,7 @@ MODEL_NAMES_MAPPING = OrderedDict( ...@@ -239,6 +239,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("electra", "ELECTRA"), ("electra", "ELECTRA"),
("encoder-decoder", "Encoder decoder"), ("encoder-decoder", "Encoder decoder"),
("speech-encoder-decoder", "Speech Encoder decoder"), ("speech-encoder-decoder", "Speech Encoder decoder"),
("vision-encoder-decoder", "Vision Encoder decoder"),
("funnel", "Funnel Transformer"), ("funnel", "Funnel Transformer"),
("lxmert", "LXMERT"), ("lxmert", "LXMERT"),
("deberta-v2", "DeBERTa-v2"), ("deberta-v2", "DeBERTa-v2"),
......
...@@ -244,6 +244,12 @@ MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = OrderedDict( ...@@ -244,6 +244,12 @@ MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = OrderedDict(
] ]
) )
MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
[
("vision-encoder-decoder", "VisionEncoderDecoderModel"),
]
)
MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict( MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
[ [
# Model for Masked LM mapping # Model for Masked LM mapping
...@@ -511,6 +517,7 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping( ...@@ -511,6 +517,7 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING = _LazyAutoMapping( MODEL_FOR_IMAGE_SEGMENTATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES
) )
MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES) MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES)
MODEL_FOR_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES) MODEL_FOR_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES)
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping( MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
...@@ -655,6 +662,13 @@ class AutoModelForObjectDetection(_BaseAutoModelClass): ...@@ -655,6 +662,13 @@ class AutoModelForObjectDetection(_BaseAutoModelClass):
AutoModelForObjectDetection = auto_class_update(AutoModelForObjectDetection, head_doc="object detection") AutoModelForObjectDetection = auto_class_update(AutoModelForObjectDetection, head_doc="object detection")
class AutoModelForVision2Seq(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_VISION_2_SEQ_MAPPING
AutoModelForVision2Seq = auto_class_update(AutoModelForVision2Seq, head_doc="vision-to-text modeling")
class AutoModelForAudioClassification(_BaseAutoModelClass): class AutoModelForAudioClassification(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING _model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
......
...@@ -100,6 +100,12 @@ FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ...@@ -100,6 +100,12 @@ FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
] ]
) )
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
[
("vision-encoder-decoder", "FlaxVisionEncoderDecoderModel"),
]
)
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[ [
# Model for Causal LM mapping # Model for Causal LM mapping
...@@ -176,6 +182,7 @@ FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping( ...@@ -176,6 +182,7 @@ FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping( FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
) )
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES) FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping( FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
...@@ -279,3 +286,10 @@ class FlaxAutoModelForImageClassification(_BaseAutoModelClass): ...@@ -279,3 +286,10 @@ class FlaxAutoModelForImageClassification(_BaseAutoModelClass):
FlaxAutoModelForImageClassification = auto_class_update( FlaxAutoModelForImageClassification = auto_class_update(
FlaxAutoModelForImageClassification, head_doc="image classification" FlaxAutoModelForImageClassification, head_doc="image classification"
) )
class FlaxAutoModelForVision2Seq(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING
FlaxAutoModelForVision2Seq = auto_class_update(FlaxAutoModelForVision2Seq, head_doc="vision-to-text modeling")
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...file_utils import _LazyModule, is_torch_available from ...file_utils import _LazyModule, is_flax_available, is_torch_available
_import_structure = { _import_structure = {
...@@ -28,12 +28,18 @@ _import_structure = { ...@@ -28,12 +28,18 @@ _import_structure = {
if is_torch_available(): if is_torch_available():
_import_structure["modeling_vision_encoder_decoder"] = ["VisionEncoderDecoderModel"] _import_structure["modeling_vision_encoder_decoder"] = ["VisionEncoderDecoderModel"]
if is_flax_available():
_import_structure["modeling_flax_vision_encoder_decoder"] = ["FlaxVisionEncoderDecoderModel"]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_vision_encoder_decoder import VisionEncoderDecoderConfig from .configuration_vision_encoder_decoder import VisionEncoderDecoderConfig
if is_torch_available(): if is_torch_available():
from .modeling_vision_encoder_decoder import VisionEncoderDecoderModel from .modeling_vision_encoder_decoder import VisionEncoderDecoderModel
if is_flax_available():
from .modeling_flax_vision_encoder_decoder import FlaxVisionEncoderDecoderModel
else: else:
import sys import sys
......
...@@ -27,8 +27,8 @@ logger = logging.get_logger(__name__) ...@@ -27,8 +27,8 @@ logger = logging.get_logger(__name__)
class VisionEncoderDecoderConfig(PretrainedConfig): class VisionEncoderDecoderConfig(PretrainedConfig):
r""" r"""
:class:`~transformers.VisionEncoderDecoderConfig` is the configuration class to store the configuration of a :class:`~transformers.VisionEncoderDecoderConfig` is the configuration class to store the configuration of a
:class:`~transformers.VisionEncoderDecoderModel`. It is used to instantiate an Encoder Decoder model according to :class:`~transformers.VisionEncoderDecoderModel`. It is used to instantiate a Vision-Encoder-Text-Decoder model
the specified arguments, defining the encoder and decoder configs. according to the specified arguments, defining the encoder and decoder configs.
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
......
...@@ -70,8 +70,8 @@ VISION_ENCODER_DECODER_START_DOCSTRING = r""" ...@@ -70,8 +70,8 @@ VISION_ENCODER_DECODER_START_DOCSTRING = r"""
<https://arxiv.org/abs/2109.10282>`__ it is shown how leveraging large pretrained vision models for optical <https://arxiv.org/abs/2109.10282>`__ it is shown how leveraging large pretrained vision models for optical
character recognition (OCR) yields a significant performance improvement. character recognition (OCR) yields a significant performance improvement.
After such an Vision-Encoder Decoder model has been trained/fine-tuned, it can be saved/loaded just like any other After such a Vision-Encoder-Text-Decoder model has been trained/fine-tuned, it can be saved/loaded just like any
models (see the examples for more information). other models (see the examples for more information).
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
...@@ -94,13 +94,6 @@ VISION_ENCODER_DECODER_INPUTS_DOCSTRING = r""" ...@@ -94,13 +94,6 @@ VISION_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
Pixel values. Pixel values can be obtained using a feature extractor (e.g. if you use ViT as the encoder, Pixel values. Pixel values can be obtained using a feature extractor (e.g. if you use ViT as the encoder,
you should use :class:`~transformers.ViTFeatureExtractor`). See you should use :class:`~transformers.ViTFeatureExtractor`). See
:meth:`transformers.ViTFeatureExtractor.__call__` for details. :meth:`transformers.ViTFeatureExtractor.__call__` for details.
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Indices of decoder input sequence tokens in the vocabulary. Indices of decoder input sequence tokens in the vocabulary.
...@@ -130,10 +123,6 @@ VISION_ENCODER_DECODER_INPUTS_DOCSTRING = r""" ...@@ -130,10 +123,6 @@ VISION_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
vectors than the model's internal embedding lookup matrix.
decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`): decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded
representation. This is useful if you want more control over how to convert :obj:`decoder_input_ids` representation. This is useful if you want more control over how to convert :obj:`decoder_input_ids`
...@@ -165,8 +154,8 @@ VISION_ENCODER_DECODER_INPUTS_DOCSTRING = r""" ...@@ -165,8 +154,8 @@ VISION_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
class VisionEncoderDecoderModel(PreTrainedModel): class VisionEncoderDecoderModel(PreTrainedModel):
r""" r"""
:class:`~transformers.VisionEncoderDecoderModel` is a generic model class that will be instantiated as a :class:`~transformers.VisionEncoderDecoderModel` is a generic model class that will be instantiated as a
transformer architecture with one of the base model classes of the library as encoder and another one as decoder transformer architecture with one of the base vision model classes of the library as encoder and another one as
when created with the :meth`~transformers.AutoModel.from_pretrained` class method for the encoder and decoder when created with the :meth`~transformers.AutoModel.from_pretrained` class method for the encoder and
:meth`~transformers.AutoModelForCausalLM.from_pretrained` class method for the decoder. :meth`~transformers.AutoModelForCausalLM.from_pretrained` class method for the decoder.
""" """
config_class = VisionEncoderDecoderConfig config_class = VisionEncoderDecoderConfig
...@@ -186,6 +175,15 @@ class VisionEncoderDecoderModel(PreTrainedModel): ...@@ -186,6 +175,15 @@ class VisionEncoderDecoderModel(PreTrainedModel):
if not isinstance(config, self.config_class): if not isinstance(config, self.config_class):
raise ValueError(f"Config: {config} has to be of type {self.config_class}") raise ValueError(f"Config: {config} has to be of type {self.config_class}")
if config.decoder.cross_attention_hidden_size is not None:
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
raise ValueError(
f"If `cross_attention_hidden_size` is specified in the decoder's configuration, "
f"it has to be equal to the encoder's `hidden_size`."
f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
)
# initialize with config # initialize with config
# make sure input & output embeddings is not tied # make sure input & output embeddings is not tied
config.tie_word_embeddings = False config.tie_word_embeddings = False
......
...@@ -174,6 +174,9 @@ FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = None ...@@ -174,6 +174,9 @@ FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = None
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = None FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = None
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING = None
FLAX_MODEL_MAPPING = None FLAX_MODEL_MAPPING = None
...@@ -276,6 +279,15 @@ class FlaxAutoModelForTokenClassification: ...@@ -276,6 +279,15 @@ class FlaxAutoModelForTokenClassification:
requires_backends(cls, ["flax"]) requires_backends(cls, ["flax"])
class FlaxAutoModelForVision2Seq:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxBartForConditionalGeneration: class FlaxBartForConditionalGeneration:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
...@@ -949,6 +961,15 @@ class FlaxT5PreTrainedModel: ...@@ -949,6 +961,15 @@ class FlaxT5PreTrainedModel:
requires_backends(cls, ["flax"]) requires_backends(cls, ["flax"])
class FlaxVisionEncoderDecoderModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxViTForImageClassification: class FlaxViTForImageClassification:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
......
...@@ -355,6 +355,9 @@ MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = None ...@@ -355,6 +355,9 @@ MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = None
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = None MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = None
MODEL_FOR_VISION_2_SEQ_MAPPING = None
MODEL_MAPPING = None MODEL_MAPPING = None
...@@ -514,6 +517,15 @@ class AutoModelForTokenClassification: ...@@ -514,6 +517,15 @@ class AutoModelForTokenClassification:
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class AutoModelForVision2Seq:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AutoModelWithLMHead: class AutoModelWithLMHead:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
......
...@@ -29,7 +29,6 @@ from .test_modeling_flax_gpt2 import FlaxGPT2ModelTester ...@@ -29,7 +29,6 @@ from .test_modeling_flax_gpt2 import FlaxGPT2ModelTester
if is_flax_available(): if is_flax_available():
from transformers import ( from transformers import (
AutoConfig,
AutoTokenizer, AutoTokenizer,
EncoderDecoderConfig, EncoderDecoderConfig,
FlaxBertModel, FlaxBertModel,
...@@ -350,12 +349,6 @@ class FlaxEncoderDecoderModelTest(unittest.TestCase): ...@@ -350,12 +349,6 @@ class FlaxEncoderDecoderModelTest(unittest.TestCase):
def get_from_encoderdecoder_pretrained_model(self): def get_from_encoderdecoder_pretrained_model(self):
return FlaxEncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "gpt2") return FlaxEncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "gpt2")
def get_decoder_config(self):
config = AutoConfig.from_pretrained("gpt2")
config.is_decoder = True
config.add_cross_attention = True
return config
def _check_configuration_tie(self, model): def _check_configuration_tie(self, model):
module = model.module.bind(model.params) module = model.module.bind(model.params)
......
This diff is collapsed.
...@@ -34,6 +34,7 @@ if is_torch_available(): ...@@ -34,6 +34,7 @@ if is_torch_available():
import torch import torch
from transformers import ( from transformers import (
AutoTokenizer,
BertLMHeadModel, BertLMHeadModel,
DeiTModel, DeiTModel,
TrOCRForCausalLM, TrOCRForCausalLM,
...@@ -48,7 +49,7 @@ if is_torch_available(): ...@@ -48,7 +49,7 @@ if is_torch_available():
if is_vision_available(): if is_vision_available():
from PIL import Image from PIL import Image
from transformers import TrOCRProcessor from transformers import TrOCRProcessor, ViTFeatureExtractor
@require_torch @require_torch
...@@ -656,3 +657,69 @@ class TrOCRModelIntegrationTest(unittest.TestCase): ...@@ -656,3 +657,69 @@ class TrOCRModelIntegrationTest(unittest.TestCase):
).to(torch_device) ).to(torch_device)
self.assertTrue(torch.allclose(logits[0, 0, :10], expected_slice, atol=1e-4)) self.assertTrue(torch.allclose(logits[0, 0, :10], expected_slice, atol=1e-4))
@require_vision
@require_torch
class ViT2GPT2ModelIntegrationTest(unittest.TestCase):
@slow
def test_inference_coco_en(self):
loc = "ydshieh/vit-gpt2-coco-en"
feature_extractor = ViTFeatureExtractor.from_pretrained(loc)
tokenizer = AutoTokenizer.from_pretrained(loc)
model = VisionEncoderDecoderModel.from_pretrained(loc)
model.to(torch_device)
model.eval()
# We will verify our results on an image of cute cats
img = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
pixel_values = feature_extractor(images=img, return_tensors="pt").pixel_values.to(torch_device)
decoder_input_ids = torch.tensor([[model.config.decoder_start_token_id]]).to(torch_device)
with torch.no_grad():
logits = model(pixel_values, decoder_input_ids)[0].detach().cpu().numpy()
# verify the logits
expected_shape = (1, 1, model.config.decoder.vocab_size)
self.assertEqual(logits.shape, expected_shape)
EXPECTED_LOGIT_SLICE = np.array(
[
-38.705807,
-30.639929,
-31.41903,
-39.012012,
-38.38696,
-34.887207,
-33.290855,
-35.68447,
-38.508484,
-36.124645,
]
)
max_diff = np.amax(np.abs(logits[0, 0, :10] - EXPECTED_LOGIT_SLICE))
self.assertLessEqual(max_diff, 1e-4)
def generate_step(pixel_values):
outputs = model.generate(
pixel_values, max_length=16, num_beams=4, return_dict_in_generate=True, output_scores=True
)
output_ids = outputs.sequences
preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
preds = [pred.strip() for pred in preds]
return preds, outputs.sequences_scores.detach().cpu().numpy()
preds, scores = generate_step(pixel_values)
EXPECTED_SCORES = np.array([-0.59562886])
max_diff = np.amax(np.abs(scores - EXPECTED_SCORES))
self.assertLessEqual(max_diff, 1e-4)
# should produce
# ["a cat laying on top of a couch next to another cat"]
self.assertEqual(preds, ["a cat laying on top of a couch next to another cat"])
...@@ -187,6 +187,7 @@ def get_model_modules(): ...@@ -187,6 +187,7 @@ def get_model_modules():
"modeling_flax_encoder_decoder", "modeling_flax_encoder_decoder",
"modeling_flax_utils", "modeling_flax_utils",
"modeling_speech_encoder_decoder", "modeling_speech_encoder_decoder",
"modeling_flax_vision_encoder_decoder",
"modeling_transfo_xl_utilities", "modeling_transfo_xl_utilities",
"modeling_tf_auto", "modeling_tf_auto",
"modeling_tf_encoder_decoder", "modeling_tf_encoder_decoder",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment