Unverified Commit 5f3ea66b authored by Matt's avatar Matt Committed by GitHub
Browse files

Add TF port of BLIP (#22090)



* Initial commit

* more stash commit

* Yet another stash commit

* yet more stash commit

* Mostly working except for docs / repo consistency

* Stop importing model list from torch file

* Add TF BLIP models to docs

* Add auto classes

* Move get_text_features and get_image_features

* Update src/transformers/models/blip/modeling_tf_blip.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip_text.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/blip/test_modeling_tf_blip.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/blip/test_modeling_tf_blip.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* Update tests/models/blip/test_modeling_tf_blip_text.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip_text.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/models/blip/modeling_tf_blip.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Use channels_last convolutions in TF (better performance + compatibility)

* Remove _shape function

* Move multi-line statement to one line in PT + TF

* Specify tf.keras.layers instead of importing from it

* Remove test_gradient_checkpointing and empty test_training methods

* move some multi-line statements to one line

* Update docstring for generate

* Remove pruned heads set

* Remove self.seq_len_dim

* Fixed issues with loss computation, should resolve some tests. Also ensured that the PT version follows the config for output_attentions and output_hidden_states

* ensure original model follows config in more cases

* Skip the same cross-attention tests in the PT tests - didn't realize we did it twice!

* Add training args throughout the models and layers

* make fixup

* Fix docstring for inputs_embeds

* Add docstring for is_decoder

* Add docstrings to text models

* Remove redundant computation

* Add unpack_inputs / keras_serializable

* Add modeling_tf_blip to doctests

* Add config classes for keras serialization

* Changes to allow model porting with pt-to-tf

* Quick fix to decoder head and test tweaks

* Revert an issue with masking the embeddings outputs

* Allow missing keys in some equivalence tests (for unused layers)

* Add tf-pt equivalence tests back in

* Update src/transformers/models/blip/modeling_tf_blip.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip_text.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip_text.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* make fixup

* Refactor invert_attention_mask out into tf_utils

* Re-enable cross-tests on the PT side too

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent a515d0a7
......@@ -269,7 +269,7 @@ Flax), PyTorch, and/or TensorFlow.
| BiT | ❌ | ❌ | ✅ | ❌ | ❌ |
| Blenderbot | ✅ | ✅ | ✅ | ✅ | ✅ |
| BlenderbotSmall | ✅ | ✅ | ✅ | ✅ | ✅ |
| BLIP | ❌ | ❌ | ✅ | | ❌ |
| BLIP | ❌ | ❌ | ✅ | | ❌ |
| BLIP-2 | ❌ | ❌ | ✅ | ❌ | ❌ |
| BLOOM | ❌ | ✅ | ✅ | ❌ | ❌ |
| BridgeTower | ❌ | ❌ | ✅ | ❌ | ❌ |
......
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
......@@ -93,4 +93,40 @@ The original code can be found [here](https://github.com/salesforce/BLIP).
## BlipForQuestionAnswering
[[autodoc]] BlipForQuestionAnswering
- forward
\ No newline at end of file
- forward
## TFBlipModel
[[autodoc]] TFBlipModel
- call
- get_text_features
- get_image_features
## TFBlipTextModel
[[autodoc]] TFBlipTextModel
- call
## TFBlipVisionModel
[[autodoc]] TFBlipVisionModel
- call
## TFBlipForConditionalGeneration
[[autodoc]] TFBlipForConditionalGeneration
- call
## TFBlipForImageTextRetrieval
[[autodoc]] TFBlipForImageTextRetrieval
- call
## TFBlipForQuestionAnswering
[[autodoc]] TFBlipForQuestionAnswering
- call
\ No newline at end of file
......@@ -2903,6 +2903,18 @@ else:
_import_structure["models.blenderbot_small"].extend(
["TFBlenderbotSmallForConditionalGeneration", "TFBlenderbotSmallModel", "TFBlenderbotSmallPreTrainedModel"]
)
_import_structure["models.blip"].extend(
[
"TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFBlipForConditionalGeneration",
"TFBlipForImageTextRetrieval",
"TFBlipForQuestionAnswering",
"TFBlipModel",
"TFBlipPreTrainedModel",
"TFBlipTextModel",
"TFBlipVisionModel",
]
)
_import_structure["models.camembert"].extend(
[
"TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
......@@ -6145,6 +6157,16 @@ if TYPE_CHECKING:
TFBlenderbotSmallModel,
TFBlenderbotSmallPreTrainedModel,
)
from .models.blip import (
TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST,
TFBlipForConditionalGeneration,
TFBlipForImageTextRetrieval,
TFBlipForQuestionAnswering,
TFBlipModel,
TFBlipPreTrainedModel,
TFBlipTextModel,
TFBlipVisionModel,
)
from .models.camembert import (
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFCamembertForCausalLM,
......
......@@ -196,7 +196,7 @@ class PTtoTFCommand(BaseTransformersCLICommand):
self._extra_commit_description = extra_commit_description
self._override_model_class = override_model_class
def get_inputs(self, pt_model, config):
def get_inputs(self, pt_model, tf_dummy_inputs, config):
"""
Returns the right inputs for the model, based on its signature.
"""
......@@ -255,7 +255,11 @@ class PTtoTFCommand(BaseTransformersCLICommand):
tf_input = processor(**processor_inputs, return_tensors="tf")
# Extra input requirements, in addition to the input modality
if config.is_encoder_decoder or (hasattr(pt_model, "encoder") and hasattr(pt_model, "decoder")):
if (
config.is_encoder_decoder
or (hasattr(pt_model, "encoder") and hasattr(pt_model, "decoder"))
or "decoder_input_ids" in tf_dummy_inputs
):
decoder_input_ids = np.asarray([[1], [1]], dtype=int) * (pt_model.config.decoder_start_token_id or 0)
pt_input.update({"decoder_input_ids": torch.tensor(decoder_input_ids)})
tf_input.update({"decoder_input_ids": tf.convert_to_tensor(decoder_input_ids)})
......@@ -306,18 +310,24 @@ class PTtoTFCommand(BaseTransformersCLICommand):
except AttributeError:
raise AttributeError(f"The TensorFlow equivalent of {architectures[0]} doesn't exist in transformers.")
# Load models and acquire a basic input compatible with the model.
# Check the TF dummy inputs to see what keys we need in the forward pass
tf_from_pt_model = tf_class.from_config(config)
tf_dummy_inputs = tf_from_pt_model.dummy_inputs
del tf_from_pt_model # Try to keep only one model in memory at a time
# Load the model and get some basic inputs
pt_model = pt_class.from_pretrained(self._local_dir)
pt_model.eval()
pt_input, tf_input = self.get_inputs(pt_model, config)
pt_input, tf_input = self.get_inputs(pt_model, tf_dummy_inputs, config)
with torch.no_grad():
pt_outputs = pt_model(**pt_input, output_hidden_states=True)
del pt_model # will no longer be used, and may have a large memory footprint
tf_from_pt_model = tf_class.from_pretrained(self._local_dir, from_pt=True)
tf_from_pt_outputs = tf_from_pt_model(**tf_input, output_hidden_states=True)
tf_from_pt_outputs = tf_from_pt_model(**tf_input, output_hidden_states=True, training=False)
# Confirms that cross loading PT weights into TF worked.
crossload_differences = self.find_pt_tf_differences(pt_outputs, tf_from_pt_outputs)
......
......@@ -406,6 +406,7 @@ def unpack_inputs(func):
func (`callable`):
The callable function of the TensorFlow model.
Returns:
A callable that wraps the original `func` with the behavior described above.
"""
......@@ -1157,6 +1158,38 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
"""
return cls(config, **kwargs)
def get_head_mask(self, head_mask: Optional[tf.Tensor], num_hidden_layers: int) -> tf.Tensor:
"""
Prepare the head mask if needed.
Args:
head_mask (`tf.Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`, *optional*):
The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard).
num_hidden_layers (`int`):
The number of hidden layers in the model.
Returns:
`tf.Tensor` with shape `[num_hidden_layers x batch x num_heads x seq_length x seq_length]` or list with
`[None]` for each layer.
"""
if head_mask is not None:
head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
else:
head_mask = [None] * num_hidden_layers
return head_mask
def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers):
"""-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]"""
if head_mask.shape.rank == 1:
head_mask = head_mask[None, None, :, None, None]
head_mask = tf.repeat(head_mask, repeats=num_hidden_layers, axis=0)
elif head_mask.shape.rank == 2:
head_mask = head_mask[:, None, :, None, None]
assert head_mask.shape.rank == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
head_mask = tf.cast(head_mask, tf.float32) # switch to float if need + fp16 compatibility
return head_mask
def eager_serving(self, inputs):
"""
Method used for serving the model. Intended not to be compiled with a tf.function decorator so that we can use
......
......@@ -34,6 +34,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict(
("bert", "TFBertModel"),
("blenderbot", "TFBlenderbotModel"),
("blenderbot-small", "TFBlenderbotSmallModel"),
("blip", "TFBlipModel"),
("camembert", "TFCamembertModel"),
("clip", "TFCLIPModel"),
("convbert", "TFConvBertModel"),
......@@ -213,6 +214,7 @@ TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Zero Shot Image Classification mapping
("blip", "TFBlipModel"),
("clip", "TFCLIPModel"),
]
)
......
......@@ -13,7 +13,13 @@
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_tf_available,
is_torch_available,
is_vision_available,
)
_import_structure = {
......@@ -52,6 +58,23 @@ else:
"BlipForImageTextRetrieval",
]
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_blip"] = [
"TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFBlipModel",
"TFBlipPreTrainedModel",
"TFBlipForConditionalGeneration",
"TFBlipForQuestionAnswering",
"TFBlipVisionModel",
"TFBlipTextModel",
"TFBlipForImageTextRetrieval",
]
if TYPE_CHECKING:
from .configuration_blip import BLIP_PRETRAINED_CONFIG_ARCHIVE_MAP, BlipConfig, BlipTextConfig, BlipVisionConfig
from .processing_blip import BlipProcessor
......@@ -81,6 +104,23 @@ if TYPE_CHECKING:
BlipVisionModel,
)
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_blip import (
TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST,
TFBlipForConditionalGeneration,
TFBlipForImageTextRetrieval,
TFBlipForQuestionAnswering,
TFBlipModel,
TFBlipPreTrainedModel,
TFBlipTextModel,
TFBlipVisionModel,
)
else:
import sys
......
......@@ -313,17 +313,12 @@ class BlipAttention(nn.Module):
bsz, tgt_len, embed_dim = hidden_states.size()
mixed_qkv = self.qkv(hidden_states)
mixed_qkv = (
self.qkv(hidden_states)
.reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
query_states, key_states, value_states = (
mixed_qkv[0],
mixed_qkv[1],
mixed_qkv[2],
)
query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2]
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
......@@ -587,9 +582,7 @@ class BlipEncoder(nn.Module):
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
Embedded representation of the inputs. Should be float, not int tokens.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
......@@ -824,10 +817,7 @@ class BlipModel(BlipPreTrainedModel):
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_outputs = self.vision_model(
pixel_values=pixel_values,
return_dict=return_dict,
)
vision_outputs = self.vision_model(pixel_values=pixel_values, return_dict=return_dict)
pooled_output = vision_outputs[1] # pooled_output
image_features = self.visual_projection(pooled_output)
......@@ -993,6 +983,10 @@ class BlipForConditionalGeneration(BlipPreTrainedModel):
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
vision_outputs = self.vision_model(
pixel_values=pixel_values,
......@@ -1037,7 +1031,7 @@ class BlipForConditionalGeneration(BlipPreTrainedModel):
Overrides *generate* function to be able to use the model as a conditional generator
Parameters:
pixel_values (*torch.FloatTensor* of shape *(batch_size, image_width, image_height)*:
pixel_values (*torch.FloatTensor* of shape *(batch_size, num_channels, image_height, image_width)*:
Input image to be processed
input_ids (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*):
The sequence used as a prompt for the generation.
......@@ -1066,9 +1060,7 @@ class BlipForConditionalGeneration(BlipPreTrainedModel):
"""
batch_size = pixel_values.shape[0]
vision_outputs = self.vision_model(
pixel_values=pixel_values,
)
vision_outputs = self.vision_model(pixel_values=pixel_values)
image_embeds = vision_outputs[0]
......@@ -1198,6 +1190,10 @@ class BlipForQuestionAnswering(BlipPreTrainedModel):
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
vision_outputs = self.vision_model(
pixel_values=pixel_values,
......@@ -1266,7 +1262,7 @@ class BlipForQuestionAnswering(BlipPreTrainedModel):
Parameters:
input_ids (*torch.LongTensor* of shape *(batch_size, sequence_length)*):
The sequence used as a prompt for the generation.
pixel_values (*torch.FloatTensor* of shape *(batch_size, image_width, image_height)*:
pixel_values (*torch.FloatTensor* of shape *(batch_size, num_channels, image_height, image_width)*:
Input image to be processed
attention_mask (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`. `1` for
......@@ -1295,9 +1291,7 @@ class BlipForQuestionAnswering(BlipPreTrainedModel):
2
```
"""
vision_outputs = self.vision_model(
pixel_values=pixel_values,
)
vision_outputs = self.vision_model(pixel_values=pixel_values)
image_embeds = vision_outputs[0]
......@@ -1412,6 +1406,10 @@ class BlipForImageTextRetrieval(BlipPreTrainedModel):
```
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
vision_outputs = self.vision_model(
pixel_values=pixel_values,
......
# coding=utf-8
# Copyright 2023 The Salesforce Team Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" TensorFlow BLIP model."""
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
import tensorflow as tf
from ...modeling_tf_outputs import TFBaseModelOutput, TFBaseModelOutputWithPooling
from ...modeling_tf_utils import (
DUMMY_INPUTS,
TFPreTrainedModel,
get_initializer,
get_tf_activation,
keras_serializable,
shape_list,
unpack_inputs,
)
from ...tf_utils import stable_softmax
from ...utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_blip import BlipConfig, BlipTextConfig, BlipVisionConfig
from .modeling_tf_blip_text import BLIP_TEXT_INPUTS_DOCSTRING, TFBlipTextLMHeadModel, TFBlipTextModel
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "Salesforce/blip-vqa-base"
TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
"Salesforce/blip-vqa-base",
"Salesforce/blip-vqa-capfit-large",
"Salesforce/blip-image-captioning-base",
"Salesforce/blip-image-captioning-large",
"Salesforce/blip-itm-base-coco",
"Salesforce/blip-itm-large-coco",
"Salesforce/blip-itm-base-flikr",
"Salesforce/blip-itm-large-flikr",
# See all BLIP models at https://huggingface.co/models?filter=blip
]
# Copied from transformers.models.clip.modeling_tf_clip.contrastive_loss
def contrastive_loss(logits: tf.Tensor) -> tf.Tensor:
return tf.math.reduce_mean(
tf.keras.metrics.sparse_categorical_crossentropy(
y_true=tf.range(shape_list(logits)[0]), y_pred=logits, from_logits=True
)
)
# Copied from transformers.models.clip.modeling_tf_clip.clip_loss with clip->blip
def blip_loss(similarity: tf.Tensor) -> tf.Tensor:
caption_loss = contrastive_loss(similarity)
image_loss = contrastive_loss(tf.transpose(similarity))
return (caption_loss + image_loss) / 2.0
@dataclass
class TFBlipForConditionalGenerationModelOutput(ModelOutput):
"""
Adapted from the base class for vision model's outputs that also contains image embeddings of the pooling of the
last hidden states. This class also adds the loss term from the text decoder.
Args:
loss (`tf.Tensor`, *optional*, returned when `labels` is provided, `tf.Tensor` of shape `(1,)`):
Languge modeling loss from the text decoder.
decoder_logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`, *optional*):
Prediction scores of the language modeling head of the text decoder model.
image_embeds (`tf.Tensor` of shape `(batch_size, output_dim)`, *optional*):
The image embeddings obtained after applying the Vision Transformer model to the input image.
last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed):
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.`
"""
loss: Optional[Tuple[tf.Tensor]] = None
decoder_logits: Optional[Tuple[tf.Tensor]] = None
image_embeds: Optional[tf.Tensor] = None
last_hidden_state: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
@dataclass
class TFBlipTextVisionModelOutput(ModelOutput):
"""
Adapted from the base class for vision model's outputs that also contains image embeddings of the pooling of the
last hidden states. This class also adds the loss term from the text decoder.
Args:
loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Languge modeling loss from the text decoder.
image_embeds (`tf.Tensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
The image embeddings obtained by applying the projection layer to the pooler_output.
last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[tf.Tensor] = None
image_embeds: Optional[tf.Tensor] = None
last_hidden_state: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
@dataclass
class TFBlipImageTextMatchingModelOutput(ModelOutput):
"""
Adapted from the base class for vision model's outputs that also contains image embeddings of the pooling of the
last hidden states. This class also adds the loss term from the text decoder as well as the image-text similarity
scores.
Args:
itm_score (`tf.Tensor`):
The image-text similarity scores.
loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Languge modeling loss from the text decoder.
image_embeds (`tf.Tensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
The image embeddings obtained by applying the projection layer to the pooler_output.
last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
vision_pooler_output (`tf.Tensor` of shape `(batch_size, hidden_size)`, *optional*):
Last layer hidden-state of the vision of the vision-only branch of the model.
attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
question_embeds (`tf.Tensor`):
The question embeddings obtained by the text projection layer.
"""
itm_score: Optional[tf.Tensor] = None
loss: Optional[tf.Tensor] = None
image_embeds: Optional[tf.Tensor] = None
last_hidden_state: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
vision_pooler_output: Optional[tf.Tensor] = None
attentions: Optional[Tuple[tf.Tensor]] = None
question_embeds: Optional[Tuple[tf.Tensor]] = None
@dataclass
class TFBlipOutput(ModelOutput):
"""
Args:
loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
Contrastive loss for image-text similarity.
logits_per_image:(`tf.Tensor` of shape `(image_batch_size, text_batch_size)`):
The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
similarity scores.
logits_per_text:(`tf.Tensor` of shape `(text_batch_size, image_batch_size)`):
The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
similarity scores.
text_embeds(`tf.Tensor` of shape `(batch_size, output_dim`):
The text embeddings obtained by applying the projection layer to the pooled output of [`BlipTextModel`].
image_embeds(`tf.Tensor` of shape `(batch_size, output_dim`):
The image embeddings obtained by applying the projection layer to the pooled output of [`BlipVisionModel`].
text_model_output(`BaseModelOutputWithPooling`):
The output of the [`BlipTextModel`].
vision_model_output(`BaseModelOutputWithPooling`):
The output of the [`BlipVisionModel`].
"""
loss: Optional[tf.Tensor] = None
logits_per_image: tf.Tensor = None
logits_per_text: tf.Tensor = None
text_embeds: tf.Tensor = None
image_embeds: tf.Tensor = None
text_model_output: TFBaseModelOutputWithPooling = None
vision_model_output: TFBaseModelOutputWithPooling = None
def to_tuple(self) -> Tuple[Any]:
return tuple(
self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
for k in self.keys()
)
class TFBlipVisionEmbeddings(tf.keras.layers.Layer):
def __init__(self, config: BlipVisionConfig, **kwargs):
super().__init__(**kwargs)
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.patch_embedding = tf.keras.layers.Conv2D(
filters=self.embed_dim,
kernel_size=self.patch_size,
strides=self.patch_size,
kernel_initializer=get_initializer(self.config.initializer_range),
data_format="channels_last",
name="patch_embedding",
)
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1
def build(self, input_shape):
self.class_embedding = self.add_weight(
shape=(1, 1, self.embed_dim),
initializer=get_initializer(self.config.initializer_range),
trainable=True,
name="class_embedding",
)
self.position_embedding = self.add_weight(
shape=(1, self.num_positions, self.embed_dim),
initializer=get_initializer(self.config.initializer_range),
trainable=True,
name="position_embedding",
)
def call(self, pixel_values: tf.Tensor) -> tf.Tensor:
# Input is channels-first, we transpose. PyTorch transposes after the conv because PyTorch
# likes channels-first convs.
batch_size = tf.shape(pixel_values)[0]
pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
patch_embeds = self.patch_embedding(pixel_values)
patch_embeds = tf.reshape(patch_embeds, (batch_size, self.num_patches, -1))
class_embeds = tf.broadcast_to(self.class_embedding, (batch_size, 1, self.embed_dim))
embeddings = tf.concat([class_embeds, patch_embeds], axis=1)
embeddings = embeddings + self.position_embedding[:, : tf.shape(embeddings)[1], :]
return embeddings
# Copied from transformers.models.clip.modeling_tf_clip.TFCLIPTextEmbeddings with CLIP->Blip
class TFBlipTextEmbeddings(tf.keras.layers.Layer):
def __init__(self, config: BlipTextConfig, **kwargs):
super().__init__(**kwargs)
self.embed_dim = config.hidden_size
self.config = config
def build(self, input_shape: tf.TensorShape):
with tf.name_scope("token_embedding"):
self.weight = self.add_weight(
shape=(self.config.vocab_size, self.embed_dim),
initializer=get_initializer(self.config.initializer_factor * self.config.initializer_range),
trainable=True,
name="weight",
)
with tf.name_scope("position_embedding"):
self.position_embedding = self.add_weight(
shape=(self.config.max_position_embeddings, self.embed_dim),
initializer=get_initializer(self.config.initializer_factor * self.config.initializer_range),
trainable=True,
name="embeddings",
)
super().build(input_shape)
def call(
self,
input_ids: tf.Tensor = None,
position_ids: tf.Tensor = None,
inputs_embeds: tf.Tensor = None,
) -> tf.Tensor:
"""
Applies embedding based on inputs tensor.
Returns:
final_embeddings (`tf.Tensor`): output embedding tensor.
"""
if input_ids is None and inputs_embeds is None:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
tf.debugging.assert_less(
input_ids,
tf.cast(self.config.vocab_size, dtype=input_ids.dtype),
message=(
"input_ids must be smaller than the embedding layer's input dimension (got"
f" {tf.math.reduce_max(input_ids)} >= {self.config.vocab_size})"
),
)
inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
input_shape = shape_list(inputs_embeds)[:-1]
if position_ids is None:
position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0)
position_embeds = tf.gather(params=self.position_embedding, indices=position_ids)
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
final_embeddings = inputs_embeds + position_embeds
return final_embeddings
class TFBlipAttention(tf.keras.layers.Layer):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.dropout = tf.keras.layers.Dropout(config.attention_dropout, name="dropout")
self.qkv = tf.keras.layers.Dense(
3 * self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="qkv"
)
self.projection = tf.keras.layers.Dense(
self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="projection"
)
def call(
self,
hidden_states: tf.Tensor,
head_mask: Optional[tf.Tensor] = None,
output_attentions: Optional[bool] = False,
training: Optional[bool] = None,
) -> Tuple[tf.Tensor, Optional[tf.Tensor], Optional[Tuple[tf.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
bsz, tgt_len, embed_dim = shape_list(hidden_states)
mixed_qkv = self.qkv(hidden_states)
mixed_qkv = tf.reshape(mixed_qkv, (bsz, tgt_len, 3, self.num_heads, self.head_dim))
mixed_qkv = tf.transpose(mixed_qkv, perm=(2, 0, 3, 1, 4))
query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2]
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = query_states @ tf.transpose(key_states, (0, 1, 3, 2))
attention_scores = attention_scores * self.scale
# Normalize the attention scores to probabilities.
attention_probs = stable_softmax(attention_scores, axis=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs, training=training)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = tf.transpose(attention_probs @ value_states, perm=(0, 2, 1, 3))
new_context_layer_shape = shape_list(context_layer)[:-2] + [self.embed_dim]
context_layer = tf.reshape(context_layer, new_context_layer_shape)
output = self.projection(context_layer)
outputs = (output, attention_probs) if output_attentions else (output, None)
return outputs
class TFBlipMLP(tf.keras.layers.Layer):
def __init__(self, config: BlipConfig, **kwargs):
super().__init__(**kwargs)
self.activation_fn = get_tf_activation(config.hidden_act)
in_proj_std = (config.hidden_size**-0.5) * ((2 * config.num_hidden_layers) ** -0.5)
fc_std = (2 * config.hidden_size) ** -0.5
self.fc1 = tf.keras.layers.Dense(
units=config.intermediate_size, kernel_initializer=get_initializer(fc_std), name="fc1"
)
self.fc2 = tf.keras.layers.Dense(
units=config.hidden_size, kernel_initializer=get_initializer(in_proj_std), name="fc2"
)
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.fc1(inputs=hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(inputs=hidden_states)
return hidden_states
class TFBlipEncoderLayer(tf.keras.layers.Layer):
def __init__(self, config: BlipConfig, **kwargs):
super().__init__(**kwargs)
self.embed_dim = config.hidden_size
self.self_attn = TFBlipAttention(config, name="self_attn")
self.layer_norm1 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm1")
self.mlp = TFBlipMLP(config, name="mlp")
self.layer_norm2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm2")
def call(
self,
hidden_states: tf.Tensor,
attention_mask: tf.Tensor,
output_attentions: Optional[bool] = False,
training: Optional[bool] = None,
) -> Tuple[tf.Tensor]:
"""
Args:
hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`tf.Tensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
`(config.encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states, attn_weights = self.self_attn(
hidden_states=hidden_states,
head_mask=attention_mask,
output_attentions=output_attentions,
training=training,
)
hidden_states = hidden_states + residual
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = hidden_states + residual
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class TFBlipPreTrainedModel(TFPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = BlipConfig
base_model_prefix = "blip"
_keys_to_ignore_on_load_missing = [r"position_ids"]
BLIP_START_DOCSTRING = r"""
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a [tf.keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
behavior.
Parameters:
config ([`BlipConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
"""
BLIP_VISION_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
[`BlipImageProcessor`]. See [`BlipImageProcessor.__call__`] for details.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
BLIP_INPUTS_DOCSTRING = r"""
Args:
input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoProcessor`]. See [`BlipProcessor.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.max_position_embeddings - 1]`.
[What are position IDs?](../glossary#position-ids)
pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
[`BlipImageProcessor`]. See [`BlipImageProcessor.__call__`] for details.
return_loss (`bool`, *optional*):
Whether or not to return the contrastive loss.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@keras_serializable
class TFBlipEncoder(tf.keras.layers.Layer):
config_class = BlipConfig
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`BlipEncoderLayer`].
Args:
config (`BlipConfig`):
The corresponding vision configuration for the `BlipEncoder`.
"""
def __init__(self, config: BlipConfig, **kwargs):
super().__init__(**kwargs)
self.config = config
self.layers = [TFBlipEncoderLayer(config, name=f"layers_._{i}") for i in range(config.num_hidden_layers)]
@unpack_inputs
def call(
self,
inputs_embeds,
attention_mask: Optional[tf.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: Optional[bool] = None,
) -> Union[Tuple, TFBaseModelOutput]:
r"""
Args:
inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
Embedded representation of the inputs. Should be float, not int tokens.
attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
hidden_states = inputs_embeds
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions=output_attentions,
training=training,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return TFBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
)
class TFBlipVisionModel(TFBlipPreTrainedModel):
main_input_name = "pixel_values"
config_class = BlipVisionConfig
def __init__(self, config: BlipVisionConfig, *args, **kwargs):
super().__init__(config, *args, **kwargs)
self.config = config
self.embeddings = TFBlipVisionEmbeddings(config, name="embeddings")
self.encoder = TFBlipEncoder(config, name="encoder")
self.post_layernorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="post_layernorm")
@property
def dummy_inputs(self) -> Dict[str, tf.Tensor]:
"""
Dummy inputs to build the network.
Returns:
`Dict[str, tf.Tensor]`: The dummy inputs.
"""
VISION_DUMMY_INPUTS = tf.random.uniform(
shape=(len(DUMMY_INPUTS), 3, self.config.image_size, self.config.image_size), dtype=tf.float32
)
return {"pixel_values": VISION_DUMMY_INPUTS}
@tf.function(
input_signature=[
{
"pixel_values": tf.TensorSpec((None, None, None, None), tf.float32, name="pixel_values"),
}
]
)
def serving(self, inputs: Dict[str, tf.Tensor]) -> TFBaseModelOutputWithPooling:
"""
Method used for serving the model.
Args:
inputs (`Dict[str, tf.Tensor]`):
The input of the saved model as a dictionary of tensors.
"""
output = self.call(inputs)
return self.serving_output(output)
def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFBaseModelOutputWithPooling(
last_hidden_state=output.last_hidden_state,
pooler_output=output.pooler_output,
hidden_states=hs,
attentions=attns,
)
@unpack_inputs
@add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFBaseModelOutputWithPooling, config_class=BlipVisionConfig)
def call(
self,
pixel_values: Optional[tf.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: Optional[bool] = None,
) -> Union[Tuple, TFBaseModelOutputWithPooling]:
r"""
Returns:
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
hidden_states = self.embeddings(pixel_values)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
last_hidden_state = encoder_outputs[0]
last_hidden_state = self.post_layernorm(last_hidden_state)
pooled_output = last_hidden_state[:, 0, :]
# TF gets confused if we call the layer with inputs of different ranks, so insert a singleton dimension
pooled_output = self.post_layernorm(tf.expand_dims(pooled_output, 1))
pooled_output = tf.squeeze(pooled_output, 1)
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return TFBaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
def get_input_embeddings(self):
return self.embeddings
class TFBlipMainLayer(tf.keras.layers.Layer):
config_class = BlipConfig
def __init__(self, config: BlipConfig, *args, **kwargs):
super().__init__(*args, **kwargs)
if not isinstance(config.text_config, BlipTextConfig):
raise ValueError(
"config.text_config is expected to be of type BlipTextConfig but is of type"
f" {type(config.text_config)}."
)
if not isinstance(config.vision_config, BlipVisionConfig):
raise ValueError(
"config.vision_config is expected to be of type BlipVisionConfig but is of type"
f" {type(config.vision_config)}."
)
text_config = config.text_config
vision_config = config.vision_config
self.projection_dim = config.projection_dim
self.text_embed_dim = text_config.hidden_size
self.vision_embed_dim = vision_config.hidden_size
self.text_model = TFBlipTextModel(text_config, name="text_model")
self.vision_model = TFBlipVisionModel(vision_config, name="vision_model")
self.visual_projection = tf.keras.layers.Dense(
self.projection_dim,
use_bias=False,
kernel_initializer=get_initializer(config.initializer_range),
name="visual_projection",
)
self.text_projection = tf.keras.layers.Dense(
self.projection_dim,
use_bias=False,
kernel_initializer=get_initializer(config.initializer_range),
name="text_projection",
)
self.config = config
def build(self, input_shape):
self.logit_scale = self.add_weight(
name="logit_scale",
shape=[],
initializer=tf.keras.initializers.Constant(self.config.logit_scale_init_value),
trainable=True,
)
@unpack_inputs
def call(
self,
input_ids: Optional[tf.Tensor] = None,
pixel_values: Optional[tf.Tensor] = None,
attention_mask: Optional[tf.Tensor] = None,
position_ids: Optional[tf.Tensor] = None,
return_loss: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: Optional[bool] = None,
) -> Union[Tuple, TFBlipOutput]:
# Use BLIP model's config for some fields (if specified) instead of those of vision & text components.
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
image_embeds = vision_outputs[1]
image_embeds = self.visual_projection(image_embeds)
text_embeds = text_outputs[1]
text_embeds = self.text_projection(text_embeds)
# normalized features
image_embeds = image_embeds / tf.norm(image_embeds, ord=2, axis=-1, keepdims=True)
text_embeds = text_embeds / tf.norm(text_embeds, ord=2, axis=-1, keepdims=True)
# cosine similarity as logits
logit_scale = tf.exp(self.logit_scale)
logits_per_text = tf.matmul(text_embeds, image_embeds, transpose_b=True) * logit_scale
logits_per_image = tf.transpose(logits_per_text)
loss = None
if return_loss:
loss = blip_loss(logits_per_text)
loss = tf.reshape(loss, (1,))
if not return_dict:
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
return ((loss,) + output) if loss is not None else output
return TFBlipOutput(
loss=loss,
logits_per_image=logits_per_image,
logits_per_text=logits_per_text,
text_embeds=text_embeds,
image_embeds=image_embeds,
text_model_output=text_outputs,
vision_model_output=vision_outputs,
)
class TFBlipModel(TFBlipPreTrainedModel):
config_class = BlipConfig
_keys_to_ignore_on_load_missing = [r"text_decoder.cls.predictions.decoder.bias"]
main_input_name = "input_ids"
def __init__(self, config: BlipConfig, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.blip = TFBlipMainLayer(config, name="blip")
@property
def dummy_inputs(self) -> Dict[str, tf.Tensor]:
"""
Dummy inputs to build the network.
Returns:
`Dict[str, tf.Tensor]`: The dummy inputs.
"""
VISION_DUMMY_INPUTS = tf.random.uniform(
shape=(len(DUMMY_INPUTS), 3, self.config.vision_config.image_size, self.config.vision_config.image_size),
dtype=tf.float32,
)
return {
"input_ids": tf.constant(DUMMY_INPUTS, dtype=tf.int32),
"pixel_values": VISION_DUMMY_INPUTS,
}
@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
"pixel_values": tf.TensorSpec((None, None, None, None), tf.float32, name="pixel_values"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
}
]
)
def serving(self, inputs: Dict[str, tf.Tensor]) -> TFBlipOutput:
"""
Method used for serving the model.
Args:
inputs (`Dict[str, tf.Tensor]`):
The input of the saved model as a dictionary of tensors.
"""
output = self.call(inputs)
return self.serving_output(output)
def serving_output(self, output: TFBlipOutput) -> TFBlipOutput:
return TFBlipOutput(
logits_per_image=output.logits_per_image,
logits_per_text=output.logits_per_text,
text_embeds=output.text_embeds,
image_embeds=output.image_embeds,
)
@unpack_inputs
@add_start_docstrings_to_model_forward(BLIP_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFBlipOutput, config_class=BlipConfig)
def call(
self,
input_ids: Optional[tf.Tensor] = None,
pixel_values: Optional[tf.Tensor] = None,
attention_mask: Optional[tf.Tensor] = None,
position_ids: Optional[tf.Tensor] = None,
return_loss: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: Optional[bool] = None,
) -> Union[Tuple, TFBlipOutput]:
r"""
Returns:
Examples:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, TFBlipModel
>>> model = TFBlipModel.from_pretrained("Salesforce/blip-image-captioning-base")
>>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(
... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="tf", padding=True
... )
>>> outputs = model(**inputs)
>>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
>>> probs = tf.nn.softmax(logits_per_image, axis=1) # we can take the softmax to get the label probabilities
```"""
outputs = self.blip(
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
position_ids=position_ids,
return_loss=return_loss,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
return outputs
@add_start_docstrings_to_model_forward(BLIP_TEXT_INPUTS_DOCSTRING)
def get_text_features(
self,
input_ids: Optional[tf.Tensor] = None,
attention_mask: Optional[tf.Tensor] = None,
position_ids: Optional[tf.Tensor] = None,
return_dict: Optional[bool] = None,
) -> tf.Tensor:
r"""
Returns:
text_features (`tf.Tensor` of shape `(batch_size, output_dim`): The text embeddings obtained by applying
the projection layer to the pooled output of [`TFBlipTextModel`].
Examples:
```python
>>> from transformers import AutoProcessor, TFBlipModel
>>> model = TFBlipModel.from_pretrained("Salesforce/blip-image-captioning-base")
>>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
>>> inputs = processor(text=["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="tf")
>>> text_features = model.get_text_features(**inputs)
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
text_outputs = self.blip.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
return_dict=return_dict,
)
pooled_output = text_outputs[1]
text_features = self.text_projection(pooled_output)
return text_features
@add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING)
def get_image_features(
self,
pixel_values: Optional[tf.Tensor] = None,
return_dict: Optional[bool] = None,
) -> tf.Tensor:
r"""
Returns:
image_features (`tf.Tensor` of shape `(batch_size, output_dim`): The image embeddings obtained by applying
the projection layer to the pooled output of [`TFBlipVisionModel`].
Examples:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, TFBlipModel
>>> model = TFBlipModel.from_pretrained("Salesforce/blip-image-captioning-base")
>>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, return_tensors="tf")
>>> image_features = model.get_image_features(**inputs)
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_outputs = self.blip.vision_model(pixel_values=pixel_values, return_dict=return_dict)
pooled_output = vision_outputs[1] # pooled_output
image_features = self.visual_projection(pooled_output)
return image_features
@add_start_docstrings(
"""
BLIP Model for image captioning. The model consists of a vision encoder and a text decoder. One can optionally pass
`input_ids` to the model, which serve as a text prompt, to make the text decoder continue the prompt. Otherwise,
the decoder starts generating text from the [BOS] (beginning-of-sequence) token. will start generating the caption
from the text input. If no text input is provided, the decoder will start with the [BOS] token only.
""",
BLIP_START_DOCSTRING,
)
class TFBlipForConditionalGeneration(TFBlipPreTrainedModel):
config_class = BlipConfig
_keys_to_ignore_on_load_missing = [r"text_decoder.cls.predictions.decoder.bias"]
main_input_name = "pixel_values"
def __init__(self, config: BlipConfig, *args, **kwargs):
super().__init__(config, *args, **kwargs)
self.vision_model = TFBlipVisionModel(config.vision_config, name="vision_model")
self.text_decoder = TFBlipTextLMHeadModel(config.text_config, name="text_decoder")
self.decoder_input_ids = config.text_config.bos_token_id
self.decoder_pad_token_id = config.text_config.pad_token_id
def get_input_embeddings(self) -> tf.keras.layers.Layer:
return self.vision_model.embeddings.patch_embedding
@property
def dummy_inputs(self):
input_ids = tf.constant(DUMMY_INPUTS, dtype=tf.int32)
VISION_DUMMY_INPUTS = tf.random.uniform(
shape=(len(DUMMY_INPUTS), 3, self.config.vision_config.image_size, self.config.vision_config.image_size),
dtype=tf.float32,
)
return {"input_ids": input_ids, "pixel_values": VISION_DUMMY_INPUTS}
@tf.function(
input_signature=[
{
"pixel_values": tf.TensorSpec((None, None, None, None), tf.float32, name="pixel_values"),
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
}
]
)
def serving(self, inputs: Dict[str, tf.Tensor]) -> TFBaseModelOutputWithPooling:
"""
Method used for serving the model.
Args:
inputs (`Dict[str, tf.Tensor]`):
The input of the saved model as a dictionary of tensors.
"""
output = self.call(inputs)
return self.serving_output(output)
def serving_output(
self, output: TFBlipForConditionalGenerationModelOutput
) -> TFBlipForConditionalGenerationModelOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFBlipForConditionalGenerationModelOutput(
last_hidden_state=output.last_hidden_state,
image_embeds=output.image_embeds,
hidden_states=hs,
attentions=attns,
)
@unpack_inputs
@add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFBlipForConditionalGenerationModelOutput, config_class=BlipConfig)
def call(
self,
pixel_values: tf.Tensor,
input_ids: Optional[tf.Tensor] = None,
attention_mask: Optional[tf.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
labels: Optional[tf.Tensor] = None,
return_dict: Optional[bool] = None,
training: Optional[bool] = None,
) -> Union[Tuple, TFBlipForConditionalGenerationModelOutput]:
r"""
Returns:
Examples:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, TFBlipForConditionalGeneration
>>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
>>> model = TFBlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> text = "A picture of"
>>> inputs = processor(images=image, text=text, return_tensors="tf")
>>> outputs = model(**inputs)
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
image_embeds = vision_outputs[0]
outputs = self.text_decoder(
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=image_embeds,
labels=labels,
return_dict=return_dict,
training=training,
)
if not return_dict:
outputs = (outputs[0], outputs[1], image_embeds, vision_outputs[0]) + vision_outputs[2:]
return tuple(output for output in outputs if output is not None)
if outputs.loss is not None and outputs.loss.shape.rank == 0:
outputs.loss = tf.reshape(outputs.loss, (1,))
return TFBlipForConditionalGenerationModelOutput(
loss=outputs.loss,
decoder_logits=outputs.logits,
image_embeds=image_embeds,
last_hidden_state=vision_outputs.last_hidden_state,
hidden_states=vision_outputs.hidden_states,
attentions=vision_outputs.attentions,
)
def generate(
self,
pixel_values: tf.Tensor,
input_ids: Optional[tf.Tensor] = None,
attention_mask: Optional[tf.Tensor] = None,
**generate_kwargs,
) -> tf.Tensor:
r"""
Overrides *generate* function to be able to use the model as a conditional generator
Parameters:
pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, image_height, image_width)`:
Input image to be processed
input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
The sequence used as a prompt for the generation.
attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
Examples:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, TFBlipForConditionalGeneration
>>> model = TFBlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
>>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, return_tensors="tf")
>>> outputs = model.generate(**inputs)
>>> print(processor.decode(outputs[0], skip_special_tokens=True))
two cats are laying on a couch
```
"""
batch_size = pixel_values.shape[0]
vision_outputs = self.vision_model(pixel_values=pixel_values)
image_embeds = vision_outputs[0]
image_attention_mask = tf.ones(shape_list(image_embeds)[:-1], dtype=tf.int32)
if isinstance(input_ids, list):
input_ids = tf.convert_to_tensor(input_ids, dtype=tf.int32)
elif input_ids is None:
input_ids = tf.convert_to_tensor(
[[self.decoder_input_ids, self.config.text_config.eos_token_id]], dtype=tf.int32
)
input_ids = tf.tile(input_ids, (batch_size, 1))
# PyTorch: input_ids[:, 0] = self.config.text_config.bos_token_id
input_ids = tf.concat(
[tf.ones((batch_size, 1), dtype=tf.int32) * self.config.text_config.bos_token_id, input_ids[:, 1:]], axis=1
)
attention_mask = attention_mask[:, :-1] if attention_mask is not None else None
outputs = self.text_decoder.generate(
input_ids=input_ids[:, :-1],
eos_token_id=self.config.text_config.sep_token_id,
pad_token_id=self.config.text_config.pad_token_id,
attention_mask=attention_mask,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_attention_mask,
**generate_kwargs,
)
return outputs
@add_start_docstrings(
"""
BLIP Model for visual question answering. The model consists of a vision encoder, a text encoder as well as a text
decoder. The vision encoder will encode the input image, the text encoder will encode the input question together
with the encoding of the image, and the text decoder will output the answer to the question.
""",
BLIP_START_DOCSTRING,
)
class TFBlipForQuestionAnswering(TFBlipPreTrainedModel):
config_class = BlipConfig
_keys_to_ignore_on_load_missing = [r"text_decoder.cls.predictions.decoder.bias"]
def __init__(self, config: BlipConfig, *args, **kwargs):
super().__init__(config, *args, **kwargs)
self.vision_model = TFBlipVisionModel(config.vision_config, name="vision_model")
self.text_encoder = TFBlipTextModel(config.text_config, name="text_encoder", add_pooling_layer=False)
self.text_decoder = TFBlipTextLMHeadModel(config.text_config, name="text_decoder")
self.decoder_pad_token_id = config.text_config.pad_token_id
self.decoder_start_token_id = config.text_config.bos_token_id
def get_input_embeddings(self) -> tf.keras.layers.Layer:
return self.vision_model.embeddings.patch_embedding
@property
def dummy_inputs(self):
input_ids = tf.constant(DUMMY_INPUTS, dtype=tf.int32)
VISION_DUMMY_INPUTS = tf.random.uniform(
shape=(len(DUMMY_INPUTS), 3, self.config.vision_config.image_size, self.config.vision_config.image_size),
dtype=tf.float32,
)
return {"input_ids": input_ids, "pixel_values": VISION_DUMMY_INPUTS, "decoder_input_ids": input_ids}
@tf.function(
input_signature=[
{
"pixel_values": tf.TensorSpec((None, None, None, None), tf.float32, name="pixel_values"),
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
}
]
)
def serving(self, inputs: Dict[str, tf.Tensor]) -> TFBaseModelOutputWithPooling:
"""
Method used for serving the model.
Args:
inputs (`Dict[str, tf.Tensor]`):
The input of the saved model as a dictionary of tensors.
"""
output = self.call(inputs)
return self.serving_output(output)
def serving_output(self, output: TFBlipTextVisionModelOutput) -> TFBlipTextVisionModelOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFBlipTextVisionModelOutput(
image_embeds=output.image_embeds,
last_hidden_state=output.last_hidden_state,
hidden_states=hs,
attentions=attns,
)
# Adapted from transformers.models.t5.modeling_tf_t5.TFT5PreTrainedModel._shift_right
def _shift_right(self, input_ids):
decoder_start_token_id = self.decoder_start_token_id
pad_token_id = self.decoder_pad_token_id
if decoder_start_token_id is None or pad_token_id is None:
raise ValueError("decoder_start_token_id and pad_token_id must be defined!")
start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id)
start_tokens = tf.cast(start_tokens, input_ids.dtype) # Ensure compatible dtypes for concatenation
shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = tf.where(
shifted_input_ids == -100,
tf.cast(tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids.dtype),
shifted_input_ids,
)
# "Verify that `labels` has only positive values and -100"
tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=shifted_input_ids.dtype))
return shifted_input_ids
@unpack_inputs
@add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFBlipTextVisionModelOutput, config_class=BlipVisionConfig)
def call(
self,
input_ids: tf.Tensor,
pixel_values: tf.Tensor,
decoder_input_ids: Optional[tf.Tensor] = None,
decoder_attention_mask: Optional[tf.Tensor] = None,
attention_mask: Optional[tf.Tensor] = None,
output_attentions: Optional[bool] = None,
foutput_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
labels: Optional[tf.Tensor] = None,
return_dict: Optional[bool] = None,
training: Optional[bool] = None,
) -> Union[Tuple, TFBlipTextVisionModelOutput]:
r"""
Returns:
Examples:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, TFBlipForQuestionAnswering
>>> model = TFBlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
>>> processor = AutoProcessor.from_pretrained("Salesforce/blip-vqa-base")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> # training
>>> text = "How many cats are in the picture?"
>>> label = "2"
>>> inputs = processor(images=image, text=text, return_tensors="tf")
>>> labels = processor(text=label, return_tensors="tf").input_ids
>>> inputs["labels"] = labels
>>> outputs = model(**inputs)
>>> loss = outputs.loss
>>> loss.backward()
>>> # inference
>>> text = "How many cats are in the picture?"
>>> inputs = processor(images=image, text=text, return_tensors="tf")
>>> outputs = model.generate(**inputs)
>>> print(processor.decode(outputs[0], skip_special_tokens=True))
2
```"""
if labels is None and decoder_input_ids is None:
raise ValueError(
"Either `decoder_input_ids` or `labels` should be passed when calling `forward` with"
" `TFBlipForQuestionAnswering`. if you are training the model make sure that `labels` is passed, if you"
" are using the model for inference make sure that `decoder_input_ids` is passed or call `generate`"
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
image_embeds = vision_outputs[0]
image_attention_mask = tf.ones(shape_list(image_embeds)[:-1], dtype=tf.int64)
question_embeds = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_attention_mask,
return_dict=return_dict,
training=training,
)
question_embeds = question_embeds[0] if not return_dict else question_embeds.last_hidden_state
if labels is not None and decoder_input_ids is None:
# get decoder inputs from shifting lm labels to the right - this is used in training mode
decoder_input_ids = self._shift_right(labels)
# replace possible -100 values in labels by `pad_token_id`
labels = tf.where(labels == self.decoder_pad_token_id, -100, labels)
answer_output = self.text_decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=question_embeds,
encoder_attention_mask=attention_mask,
labels=labels,
return_dict=return_dict,
training=training,
)
if labels is not None:
decoder_loss = tf.reduce_mean(answer_output.loss) if return_dict else tf.reduce_mean(answer_output[0])
else:
decoder_loss = None
if not return_dict:
outputs = (decoder_loss, image_embeds, vision_outputs[0]) + vision_outputs[2:]
return tuple(output for output in outputs if output is not None)
return TFBlipTextVisionModelOutput(
loss=decoder_loss,
image_embeds=image_embeds,
last_hidden_state=vision_outputs.last_hidden_state,
hidden_states=vision_outputs.hidden_states,
attentions=vision_outputs.attentions,
)
def generate(
self,
input_ids: tf.Tensor,
pixel_values: tf.Tensor,
attention_mask: Optional[tf.Tensor] = None,
**generate_kwargs,
) -> tf.Tensor:
r"""
Overrides *generate* function to be able to use the model as a conditional generator
Parameters:
input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
The sequence used as a prompt for the generation.
pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, image_height, image_width)`:
Input image to be processed
attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`. `1` for
tokens that are NOT MASKED, `0` for MASKED tokens.
generate_kwargs (dict, *optional*):
Additional arguments passed to the `generate` function of the decoder
Examples:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, TFBlipForQuestionAnswering
>>> model = TFBlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
>>> processor = AutoProcessor.from_pretrained("Salesforce/blip-vqa-base")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> text = "How many cats are in the picture?"
>>> inputs = processor(images=image, text=text, return_tensors="tf")
>>> outputs = model.generate(**inputs)
>>> print(processor.decode(outputs[0], skip_special_tokens=True))
2
```
"""
vision_outputs = self.vision_model(pixel_values=pixel_values)
image_embeds = vision_outputs[0]
image_attention_mask = tf.ones(shape_list(image_embeds)[:-1], dtype=tf.int32)
if isinstance(input_ids, list):
input_ids = tf.Tensor(input_ids)
question_outputs = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_attention_mask,
return_dict=False,
)
question_embeds = question_outputs[0]
question_attention_mask = tf.ones(shape_list(question_embeds)[:-1], dtype=tf.int32)
bos_ids = tf.fill(
(tf.shape(question_embeds)[0], 1), value=tf.cast(self.decoder_start_token_id, input_ids.dtype)
)
outputs = self.text_decoder.generate(
input_ids=bos_ids,
eos_token_id=self.config.text_config.sep_token_id,
pad_token_id=self.config.text_config.pad_token_id,
encoder_hidden_states=question_embeds,
encoder_attention_mask=question_attention_mask,
**generate_kwargs,
)
return outputs
@add_start_docstrings(
"""
BLIP Model with a vision and text projector, and a classification head on top. The model is used in the context of
image-text retrieval. Given an image and a text, the model returns the probability of the text being relevant to
the image.
""",
BLIP_START_DOCSTRING,
)
class TFBlipForImageTextRetrieval(TFBlipPreTrainedModel):
config_class = BlipConfig
def __init__(self, config: BlipConfig, *args, **kwargs):
super().__init__(config, *args, **kwargs)
self.vision_model = TFBlipVisionModel(config.vision_config, name="vision_model")
self.text_encoder = TFBlipTextModel(config.text_config, name="text_encoder", add_pooling_layer=False)
# vision projection layer
self.vision_proj = tf.keras.layers.Dense(
config.image_text_hidden_size,
kernel_initializer=get_initializer(config.initializer_range),
name="vision_proj",
)
# text projection layer
self.text_proj = tf.keras.layers.Dense(
config.image_text_hidden_size,
kernel_initializer=get_initializer(config.initializer_range),
name="text_proj",
)
# image text matching head
self.itm_head = tf.keras.layers.Dense(
2, kernel_initializer=get_initializer(config.initializer_range), name="itm_head"
)
self.decoder_pad_token_id = (
config.text_config.pad_token_id
if not hasattr(config, "decoder_pad_token_id")
else config.decoder_pad_token_id
)
self.decoder_start_token_id = (
config.text_config.bos_token_id
if not hasattr(config, "decoder_start_token_id")
else config.decoder_start_token_id
)
def get_input_embeddings(self) -> tf.keras.layers.Layer:
return self.vision_model.embeddings.patch_embedding
@property
def dummy_inputs(self):
input_ids = tf.constant(DUMMY_INPUTS, dtype=tf.int32)
VISION_DUMMY_INPUTS = tf.random.uniform(
shape=(len(DUMMY_INPUTS), 3, self.config.vision_config.image_size, self.config.vision_config.image_size),
dtype=tf.float32,
)
return {"input_ids": input_ids, "pixel_values": VISION_DUMMY_INPUTS}
@tf.function(
input_signature=[
{
"pixel_values": tf.TensorSpec((None, None, None, None), tf.float32, name="pixel_values"),
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
}
]
)
def serving(self, inputs: Dict[str, tf.Tensor]) -> TFBaseModelOutputWithPooling:
"""
Method used for serving the model.
Args:
inputs (`Dict[str, tf.Tensor]`):
The input of the saved model as a dictionary of tensors.
"""
output = self.call(inputs)
return self.serving_output(output)
def serving_output(self, output: TFBlipImageTextMatchingModelOutput) -> TFBlipImageTextMatchingModelOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFBlipImageTextMatchingModelOutput(
itm_score=output.itm_score,
last_hidden_state=hs,
hidden_states=output.hidden_states,
attentions=attns,
question_embeds=output.question_embeds,
)
@unpack_inputs
@add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFBlipImageTextMatchingModelOutput, config_class=BlipVisionConfig)
def call(
self,
input_ids: tf.Tensor,
pixel_values: Optional[tf.Tensor] = None,
use_itm_head: Optional[bool] = True,
attention_mask: Optional[tf.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: Optional[bool] = None,
) -> Union[Tuple, TFBlipImageTextMatchingModelOutput]:
r"""
Returns:
Examples:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, TFBlipForImageTextRetrieval
>>> model = TFBlipForImageTextRetrieval.from_pretrained("Salesforce/blip-itm-base-coco")
>>> processor = AutoProcessor.from_pretrained("Salesforce/blip-itm-base-coco")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> text = "an image of a cat"
>>> inputs = processor(images=image, text=text, return_tensors="tf")
>>> outputs = model(**inputs)
```
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
image_embeds = vision_outputs[0]
image_atts = tf.ones(shape_list(image_embeds)[:-1], dtype=tf.int64)
# Matt: In PyTorch, only one path (itm/non-itm) is taken. However, in TensorFlow this can result in
# some layers not being built! To avoid this, we always call both paths, then use an if statement to select
# which output to pass to the final output. The unnecessary nodes will be pruned from the final graph, but
# not before the layers have all been built correctly.
itm_question_embeds = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
return_dict=return_dict,
training=training,
)
itm_question_embeds = itm_question_embeds[0] if not return_dict else itm_question_embeds.last_hidden_state
itm_output = self.itm_head(itm_question_embeds[:, 0, :])
no_itm_question_embeds = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=return_dict,
training=training,
)
no_itm_question_embeds = (
no_itm_question_embeds[0] if not return_dict else no_itm_question_embeds.last_hidden_state
)
image_feat, _ = tf.linalg.normalize(self.vision_proj(image_embeds[:, 0, :]), ord=2, axis=-1)
text_feat, _ = tf.linalg.normalize(self.text_proj(no_itm_question_embeds[:, 0, :]), ord=2, axis=-1)
no_itm_output = tf.matmul(image_feat, text_feat, transpose_b=True)
if use_itm_head:
output = itm_output
question_embeds = itm_question_embeds
else:
output = no_itm_output
question_embeds = no_itm_question_embeds
if not return_dict:
outputs = (output, vision_outputs[0]) + vision_outputs[2:] + (question_embeds,)
return tuple(output for output in outputs if output is not None)
return TFBlipImageTextMatchingModelOutput(
itm_score=output,
last_hidden_state=vision_outputs.last_hidden_state,
hidden_states=vision_outputs.hidden_states,
attentions=vision_outputs.attentions,
question_embeds=question_embeds,
)
# coding=utf-8
# Copyright 2023 The Salesforce Team Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the BSD-3-clause license (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Dict, Optional, Tuple
import tensorflow as tf
from ...modeling_tf_outputs import (
TFBaseModelOutputWithPastAndCrossAttentions,
TFBaseModelOutputWithPoolingAndCrossAttentions,
TFCausalLMOutputWithCrossAttentions,
)
from ...modeling_tf_utils import (
DUMMY_INPUTS,
TFPreTrainedModel,
get_initializer,
get_tf_activation,
keras_serializable,
shape_list,
unpack_inputs,
)
from ...tf_utils import invert_attention_mask, stable_softmax
from ...utils import add_start_docstrings_to_model_forward, logging
from .configuration_blip import BlipTextConfig
logger = logging.get_logger(__name__)
BLIP_TEXT_INPUTS_DOCSTRING = r"""
Args:
input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoProcessor`]. See [`BlipProcessor.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.max_position_embeddings - 1]`.
[What are position IDs?](../glossary#position-ids)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
# Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#L52
class TFBlipTextEmbeddings(tf.keras.layers.Layer):
"""Construct the embeddings from word and position embeddings."""
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.word_embeddings = tf.keras.layers.Embedding(
config.vocab_size,
config.hidden_size,
embeddings_initializer=get_initializer(config.initializer_range),
name="word_embeddings",
)
self.position_embeddings = tf.keras.layers.Embedding(
config.max_position_embeddings,
config.hidden_size,
embeddings_initializer=get_initializer(config.initializer_range),
name="position_embeddings",
)
# self.LayerNorm is not snake-cased to stick with PyTorch model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob, name="dropout")
self.position_ids = tf.expand_dims(tf.range(config.max_position_embeddings), 0)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.config = config
def call(self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0, training=None):
if input_ids is not None:
input_shape = tf.shape(input_ids)
else:
input_shape = tf.shape(inputs_embeds)[:-1]
seq_length = input_shape[1]
if position_ids is None:
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
if inputs_embeds is None:
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
tf.debugging.assert_less(
input_ids,
tf.cast(self.config.vocab_size, dtype=input_ids.dtype),
message=(
"input_ids must be smaller than the embedding layer's input dimension (got"
f" {tf.math.reduce_max(input_ids)} >= {self.config.vocab_size})"
),
)
inputs_embeds = self.word_embeddings(input_ids)
embeddings = inputs_embeds
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings, training=training)
return embeddings
# Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#L97
class TFBlipTextSelfAttention(tf.keras.layers.Layer):
def __init__(self, config, is_cross_attention, **kwargs):
super().__init__(**kwargs)
self.config = config
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention heads (%d)"
% (config.hidden_size, config.num_attention_heads)
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = tf.keras.layers.Dense(
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
)
self.key = tf.keras.layers.Dense(
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
)
self.value = tf.keras.layers.Dense(
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
)
self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = tf.keras.layers.Embedding(
2 * config.max_position_embeddings - 1, self.attention_head_size
)
def transpose_for_scores(self, x):
new_x_shape = tf.concat(
[tf.shape(x)[:-1], tf.constant([self.num_attention_heads, self.attention_head_size], dtype=tf.int32)],
axis=0,
)
x = tf.reshape(x, new_x_shape)
return tf.transpose(x, perm=(0, 2, 1, 3))
def call(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
training=None,
):
mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = tf.concat([past_key_value[0], key_layer], axis=2)
value_layer = tf.concat([past_key_value[1], value_layer], axis=2)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
past_key_value = (key_layer, value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
seq_length = shape_list(hidden_states)[1]
position_ids_l = tf.expand_dims(tf.range(seq_length, dtype=tf.int64, device=hidden_states.device), 1)
position_ids_r = tf.expand_dims(tf.range(seq_length, dtype=tf.int64, device=hidden_states.device), 0)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
positional_embedding = tf.cast(positional_embedding, query_layer.dtype) # fp16 compatibility
if self.position_embedding_type == "relative_key":
relative_position_scores = tf.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = tf.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = tf.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in BlipTextModel forward() function)
attention_scores = attention_scores + tf.cast(attention_mask, attention_scores.dtype)
# Normalize the attention scores to probabilities.
attention_probs = stable_softmax(attention_scores, axis=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs_dropped = self.dropout(attention_probs, training=training)
# Mask heads if we want to
if head_mask is not None:
attention_probs_dropped = attention_probs_dropped * head_mask
context_layer = attention_probs_dropped @ value_layer
context_layer = tf.transpose(context_layer, perm=(0, 2, 1, 3))
new_context_layer_shape = shape_list(context_layer)[:-2] + [self.all_head_size]
context_layer = tf.reshape(context_layer, new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
outputs = outputs + (past_key_value,)
return outputs
class TFBlipTextSelfOutput(tf.keras.layers.Layer):
def __init__(self, config: BlipTextConfig, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: Optional[bool] = None) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.dropout(inputs=hidden_states, training=training)
hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor)
return hidden_states
# Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#242
class TFBlipTextAttention(tf.keras.layers.Layer):
def __init__(self, config, is_cross_attention=False, **kwargs):
super().__init__(**kwargs)
self.self = TFBlipTextSelfAttention(config, is_cross_attention, name="self")
# "output" is a protected attribute on TF models
self.self_output = TFBlipTextSelfOutput(config, name="output")
def call(
self,
hidden_states: tf.Tensor,
attention_mask: Optional[tf.Tensor] = None,
head_mask: Optional[tf.Tensor] = None,
encoder_hidden_states: Optional[tf.Tensor] = None,
encoder_attention_mask: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
output_attentions: Optional[bool] = False,
training: Optional[bool] = None,
):
self_outputs = self.self(
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
training=training,
)
attention_output = self.self_output(self_outputs[0], hidden_states, training=training)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->BlipText
class TFBlipTextIntermediate(tf.keras.layers.Layer):
def __init__(self, config: BlipTextConfig, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class TFBlipTextOutput(tf.keras.layers.Layer):
def __init__(self, config: BlipTextConfig, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.dropout(inputs=hidden_states, training=training)
hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor)
return hidden_states
class TFBlipTextLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.config = config
self.attention = TFBlipTextAttention(config, name="attention")
if self.config.is_decoder:
self.crossattention = TFBlipTextAttention(
config, is_cross_attention=self.config.is_decoder, name="crossattention"
)
self.intermediate = TFBlipTextIntermediate(config, name="intermediate")
self.self_output = TFBlipTextOutput(config, name="output")
def call(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
training=None,
):
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
training=training,
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
if encoder_hidden_states is not None:
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
output_attentions=output_attentions,
training=training,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
intermediate_output = self.intermediate(attention_output)
layer_output = self.self_output(intermediate_output, attention_output, training=training)
outputs = (layer_output,) + outputs
outputs = outputs + (present_key_value,)
return outputs
# Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#L386
@keras_serializable
class TFBlipTextEncoder(tf.keras.layers.Layer):
config_class = BlipTextConfig
def __init__(self, config, name=None, **kwargs):
super().__init__(name=name, **kwargs)
self.config = config
self.layer = [TFBlipTextLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
@unpack_inputs
def call(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
training=None,
):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.is_decoder else None
next_decoder_cache = () if use_cache else None
for i in range(self.config.num_hidden_layers):
layer_module = self.layer[i]
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
training=training,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return TFBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler with Bert->BlipText
class TFBlipTextPooler(tf.keras.layers.Layer):
def __init__(self, config: BlipTextConfig, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
units=config.hidden_size,
kernel_initializer=get_initializer(config.initializer_range),
activation="tanh",
name="dense",
)
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(inputs=first_token_tensor)
return pooled_output
# Copied from transformers.models.bert.modeling_tf_bert.TFBertPredictionHeadTransform with Bert->BlipText
class TFBlipTextPredictionHeadTransform(tf.keras.layers.Layer):
def __init__(self, config: BlipTextConfig, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
units=config.hidden_size,
kernel_initializer=get_initializer(config.initializer_range),
name="dense",
)
if isinstance(config.hidden_act, str):
self.transform_act_fn = get_tf_activation(config.hidden_act)
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(inputs=hidden_states)
return hidden_states
class TFBlipTextLMPredictionHead(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.transform = TFBlipTextPredictionHeadTransform(config, name="transform")
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = tf.keras.layers.Dense(
config.vocab_size,
kernel_initializer=get_initializer(config.initializer_range),
name="decoder",
use_bias=False,
)
self.config = config
def build(self, input_shape):
self.bias = self.add_weight(name="bias", shape=(self.config.vocab_size,), initializer="zeros", trainable=True)
def call(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states) + self.bias
return hidden_states
class TFBlipTextOnlyMLMHead(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.predictions = TFBlipTextLMPredictionHead(config, name="predictions")
def call(self, sequence_output: tf.Tensor) -> tf.Tensor:
prediction_scores = self.predictions(sequence_output)
return prediction_scores
# Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#L548
class TFBlipTextPreTrainedModel(TFPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = BlipTextConfig
base_model_prefix = "bert"
_keys_to_ignore_on_load_missing = [r"position_ids"]
# Adapted from https://github.com/salesforce/BLIP/blob/3a29b7410476bf5f2ba0955827390eb6ea1f4f9d/models/med.py#L571
class TFBlipTextModel(TFBlipTextPreTrainedModel):
"""
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
cross-attention is added between the self-attention layers, following the architecture described in [Attention is
all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. argument and `is_decoder` set to `True`; an
`encoder_hidden_states` is then expected as an input to the forward pass.
"""
def __init__(self, config, add_pooling_layer=True, name=None, **kwargs):
super().__init__(config, name=name, **kwargs)
self.config = config
self.embeddings = TFBlipTextEmbeddings(config, name="embeddings")
self.encoder = TFBlipTextEncoder(config, name="encoder")
self.pooler = TFBlipTextPooler(config, name="pooler") if add_pooling_layer else None
@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
}
]
)
def serving(self, inputs: Dict[str, tf.Tensor]) -> TFBaseModelOutputWithPoolingAndCrossAttentions:
output = self.call(inputs)
return self.serving_output(output)
def serving_output(
self, output: TFBaseModelOutputWithPoolingAndCrossAttentions
) -> TFBaseModelOutputWithPoolingAndCrossAttentions:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFBaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=output.last_hidden_state,
pooler_output=output.pooler_output,
hidden_states=hs,
attentions=attns,
)
def get_input_embeddings(self):
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
@tf.function
def get_extended_attention_mask(
self, attention_mask: tf.Tensor, input_shape: Tuple[int], is_decoder: bool
) -> tf.Tensor:
"""
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
Arguments:
attention_mask (`tf.Tensor`):
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
input_shape (`Tuple[int]`):
The shape of the input to the model.
is_decoder: (`bool`):
Whether the model is used as a decoder.
Returns:
`tf.Tensor` The extended attention mask, with the same dtype as `attention_mask.dtype`.
"""
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if not isinstance(attention_mask, tf.Tensor):
attention_mask = tf.convert_to_tensor(attention_mask) # Catches NumPy inputs that haven't been cast yet
if attention_mask.shape.rank == 3:
extended_attention_mask = attention_mask[:, None, :, :]
elif attention_mask.shape.rank == 2:
# Provided a padding mask of dimensions [batch_size, seq_length]
# - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
if is_decoder:
batch_size, seq_length = input_shape
seq_ids = tf.range(seq_length, dtype=attention_mask.dtype)
causal_mask = tf.broadcast_to(seq_ids, (batch_size, seq_length, seq_length)) <= seq_ids[None, :, None]
# in case past_key_values are used we need to add a prefix ones mask to the causal mask
if shape_list(causal_mask)[1] < shape_list(attention_mask)[1]:
prefix_seq_len = tf.shape(attention_mask)[1] - tf.shape(causal_mask)[1]
causal_mask = tf.concat(
[
tf.ones((batch_size, seq_length, prefix_seq_len), dtype=causal_mask.dtype),
causal_mask,
],
axis=-1,
)
extended_attention_mask = (
tf.cast(causal_mask[:, None, :, :], attention_mask.dtype) * attention_mask[:, None, None, :]
)
else:
extended_attention_mask = attention_mask[:, None, None, :]
else:
raise ValueError(
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
input_shape, attention_mask.shape
)
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = tf.cast(extended_attention_mask, self.dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask
@add_start_docstrings_to_model_forward(BLIP_TEXT_INPUTS_DOCSTRING)
@unpack_inputs
def call(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
is_decoder=False,
training=None,
):
r"""
encoder_hidden_states (`tf.Tensor`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (`tf.Tensor`, *optional*):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
past_key_values (`tuple(tuple(tf.Tensor))`, *optional*):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if is_decoder:
use_cache = use_cache if use_cache is not None else self.config.use_cache
else:
use_cache = False
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = shape_list(input_ids)
batch_size, seq_length = input_shape
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
batch_size, seq_length = input_shape
elif encoder_embeds is not None:
input_shape = shape_list(encoder_embeds)[:-1]
batch_size, seq_length = input_shape
else:
raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if attention_mask is None:
attention_mask = tf.ones(((batch_size, seq_length + past_key_values_length)))
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: tf.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, is_decoder)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if encoder_hidden_states is not None:
if type(encoder_hidden_states) == list:
encoder_batch_size, encoder_sequence_length, _ = shape_list(encoder_hidden_states[0])
else:
encoder_batch_size, encoder_sequence_length, _ = shape_list(encoder_hidden_states)
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if type(encoder_attention_mask) == list:
encoder_extended_attention_mask = [invert_attention_mask(mask) for mask in encoder_attention_mask]
elif encoder_attention_mask is None:
encoder_attention_mask = tf.ones(encoder_hidden_shape)
encoder_extended_attention_mask = invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
if encoder_embeds is None:
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)
else:
embedding_output = encoder_embeds
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return TFBaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)
# Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#L811
class TFBlipTextLMHeadModel(TFBlipTextPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
self.bert = TFBlipTextModel(config, add_pooling_layer=False, name="bert")
self.cls = TFBlipTextOnlyMLMHead(config, name="cls")
def get_output_embeddings(self):
return self.cls.predictions.decoder
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
@property
def dummy_inputs(self) -> Dict[str, tf.Tensor]:
"""
Dummy inputs to build the network.
Returns:
`Dict[str, tf.Tensor]`: The dummy inputs.
"""
return {"input_ids": tf.constant(DUMMY_INPUTS, dtype=tf.int32)}
@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
}
]
)
def serving(self, inputs: Dict[str, tf.Tensor]) -> TFCausalLMOutputWithCrossAttentions:
"""
Method used for serving the model.
Args:
inputs (`Dict[str, tf.Tensor]`):
The input of the saved model as a dictionary of tensors.
"""
output = self.call(inputs)
return self.serving_output(output)
def serving_output(self, output: TFCausalLMOutputWithCrossAttentions) -> TFCausalLMOutputWithCrossAttentions:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFCausalLMOutputWithCrossAttentions(
logits=output.logits,
cross_attentions=output.cross_attentions,
hidden_states=hs,
attentions=attns,
)
@add_start_docstrings_to_model_forward(BLIP_TEXT_INPUTS_DOCSTRING)
@unpack_inputs
def call(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
return_logits=False,
is_decoder=True,
training=None,
):
r"""
encoder_hidden_states (`tf.Tensor`, *optional*): Sequence of
hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model is
configured as a decoder.
encoder_attention_mask (`tf.Tensor`, *optional*):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
labels (`tf.Tensor`, *optional*):
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
`[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`
past_key_values (`tuple(tuple(tf.Tensor))`, *optional*):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None:
use_cache = False
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
is_decoder=is_decoder,
training=training,
)
sequence_output = outputs[0]
prediction_scores = self.cls(sequence_output)
if return_logits:
return prediction_scores[:, :-1, :]
lm_loss = None
if labels is not None:
# we are doing next-token prediction; shift prediction scores and input ids by one
shifted_prediction_scores = prediction_scores[:, :-1, :]
shifted_prediction_scores = tf.reshape(shifted_prediction_scores, (-1, self.config.vocab_size))
labels = labels[:, 1:]
labels = tf.reshape(labels, (-1,))
# Keras won't give us label smoothing for sparse CE, so we de-sparsify things here
one_hot_labels = tf.one_hot(labels, depth=self.config.vocab_size, dtype=tf.float32)
loss_fct = tf.keras.losses.CategoricalCrossentropy(from_logits=True, label_smoothing=0.1, reduction="none")
masked_positions = tf.cast(tf.not_equal(labels, -100), dtype=tf.float32)
lm_loss = loss_fct(one_hot_labels, shifted_prediction_scores)
lm_loss *= masked_positions
lm_loss = tf.reduce_sum(lm_loss, axis=0) / tf.math.count_nonzero(masked_positions, dtype=tf.float32)
if not return_dict:
output = (prediction_scores,) + outputs[2:]
return ((lm_loss,) + output) if lm_loss is not None else output
return TFCausalLMOutputWithCrossAttentions(
loss=lm_loss,
logits=prediction_scores,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
# cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
input_ids = input_ids[:, -1:]
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
"encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
"is_decoder": True,
}
def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
......@@ -453,9 +453,7 @@ class Blip2Encoder(nn.Module):
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
Embedded representation of the inputs. Should be float, not int tokens.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
......
......@@ -68,3 +68,31 @@ def stable_softmax(logits: tf.Tensor, axis: Optional[int] = None, name: Optional
# TODO: When the issue linked above gets sorted, add a check on TF version here and use the original function if
# it has the fix. After we drop the support for unfixed versions, remove this function.
return tf.nn.softmax(logits=logits + 1e-9, axis=axis, name=name)
def invert_attention_mask(encoder_attention_mask: tf.Tensor) -> tf.Tensor:
"""
Invert an attention mask (e.g., switches 0. and 1.).
Args:
encoder_attention_mask (`torch.Tensor`): An attention mask.
Returns:
`tf.Tensor`: The inverted attention mask.
"""
if not isinstance(encoder_attention_mask, tf.Tensor):
encoder_attention_mask = tf.convert_to_tensor(encoder_attention_mask) # Catches stray NumPy inputs
if encoder_attention_mask.shape.rank == 3:
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
if encoder_attention_mask.shape.rank == 2:
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow
# /transformer/transformer_layers.py#L270
# encoder_extended_attention_mask = (encoder_extended_attention_mask ==
# encoder_extended_attention_mask.transpose(-1, -2))
encoder_extended_attention_mask = (
tf.cast(1, encoder_attention_mask.dtype) - encoder_extended_attention_mask
) * encoder_extended_attention_mask.dtype.min
return encoder_extended_attention_mask
......@@ -556,6 +556,58 @@ class TFBlenderbotSmallPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["tf"])
TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST = None
class TFBlipForConditionalGeneration(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFBlipForImageTextRetrieval(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFBlipForQuestionAnswering(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFBlipModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFBlipPreTrainedModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFBlipTextModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFBlipVisionModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
......@@ -342,6 +342,9 @@ class BlipTextModelTest(ModelTesterMixin, unittest.TestCase):
model = BlipTextModel.from_pretrained(model_name)
self.assertIsNotNone(model)
def test_pt_tf_model_equivalence(self):
super().test_pt_tf_model_equivalence(allow_missing_keys=True)
class BlipModelTester:
def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True):
......@@ -524,6 +527,9 @@ class BlipModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
model = BlipModel.from_pretrained(model_name)
self.assertIsNotNone(model)
def test_pt_tf_model_equivalence(self):
super().test_pt_tf_model_equivalence(allow_missing_keys=True)
class BlipTextRetrievalModelTester:
def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True):
......
......@@ -164,3 +164,6 @@ class BlipTextModelTest(ModelTesterMixin, unittest.TestCase):
for model_name in BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = BlipTextModel.from_pretrained(model_name)
self.assertIsNotNone(model)
def test_pt_tf_model_equivalence(self):
super().test_pt_tf_model_equivalence(allow_missing_keys=True)
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the TensorFlow Blip model. """
import inspect
import tempfile
import unittest
import numpy as np
import requests
from transformers import BlipConfig, BlipTextConfig, BlipVisionConfig
from transformers.testing_utils import require_tf, require_vision, slow
from transformers.utils import is_tf_available, is_vision_available
from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_tf_available():
import tensorflow as tf
from transformers import (
TFBlipForConditionalGeneration,
TFBlipForImageTextRetrieval,
TFBlipForQuestionAnswering,
TFBlipModel,
TFBlipTextModel,
TFBlipVisionModel,
)
from transformers.models.blip.modeling_tf_blip import TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST
if is_vision_available():
from PIL import Image
from transformers import BlipProcessor
class TFBlipVisionModelTester:
def __init__(
self,
parent,
batch_size=12,
image_size=30,
patch_size=2,
num_channels=3,
is_training=True,
hidden_size=32,
projection_dim=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
dropout=0.1,
attention_dropout=0.1,
initializer_range=1e-10,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.is_training = is_training
self.hidden_size = hidden_size
self.projection_dim = projection_dim
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.dropout = dropout
self.attention_dropout = attention_dropout
self.initializer_range = initializer_range
self.scope = scope
# in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
self.seq_length = num_patches + 1
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
config = self.get_config()
return config, pixel_values
def get_config(self):
return BlipVisionConfig(
image_size=self.image_size,
patch_size=self.patch_size,
num_channels=self.num_channels,
hidden_size=self.hidden_size,
projection_dim=self.projection_dim,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
dropout=self.dropout,
attention_dropout=self.attention_dropout,
initializer_range=self.initializer_range,
)
def create_and_check_model(self, config, pixel_values):
model = TFBlipVisionModel(config=config)
result = model(pixel_values)
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
image_size = (self.image_size, self.image_size)
patch_size = (self.patch_size, self.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_tf
class TFBlipVisionModelTest(TFModelTesterMixin, unittest.TestCase):
"""
Here we also overwrite some of the tests of test_modeling_common.py, as Blip does not use input_ids, inputs_embeds,
attention_mask and seq_length.
"""
all_model_classes = (TFBlipVisionModel,) if is_tf_available() else ()
fx_compatible = False
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
test_onnx = False
def setUp(self):
self.model_tester = TFBlipVisionModelTester(self)
self.config_tester = ConfigTester(self, config_class=BlipVisionConfig, has_text_modality=False, hidden_size=37)
def test_config(self):
self.config_tester.run_common_tests()
@unittest.skip(reason="Blip does not use inputs_embeds")
def test_inputs_embeds(self):
pass
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.call)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_model_common_attributes(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
self.assertIsInstance(model.get_input_embeddings(), (tf.keras.layers.Layer))
x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, tf.keras.layers.Layer))
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip(reason="BlipVisionModel has no base class and is not available in MODEL_MAPPING")
def test_save_load_fast_init_from_base(self):
pass
@unittest.skip(reason="BlipVisionModel has no base class and is not available in MODEL_MAPPING")
def test_save_load_fast_init_to_base(self):
pass
@slow
def test_model_from_pretrained(self):
for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
try:
model = TFBlipVisionModel.from_pretrained(model_name)
except OSError:
model = TFBlipVisionModel.from_pretrained(model_name, from_pt=True)
self.assertIsNotNone(model)
class TFBlipTextModelTester:
def __init__(
self,
parent,
batch_size=12,
seq_length=7,
is_training=True,
use_input_mask=True,
use_labels=True,
vocab_size=99,
hidden_size=32,
projection_dim=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
dropout=0.1,
attention_dropout=0.1,
max_position_embeddings=512,
initializer_range=0.02,
bos_token_id=0,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.projection_dim = projection_dim
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.dropout = dropout
self.attention_dropout = attention_dropout
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.scope = scope
self.bos_token_id = bos_token_id
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = random_attention_mask([self.batch_size, self.seq_length])
if input_mask is not None:
input_mask = input_mask.numpy()
batch_size, seq_length = input_mask.shape
rnd_start_indices = np.random.randint(1, seq_length - 1, size=(batch_size,))
for batch_idx, start_index in enumerate(rnd_start_indices):
input_mask[batch_idx, :start_index] = 1
input_mask[batch_idx, start_index:] = 0
input_mask = tf.convert_to_tensor(input_mask)
config = self.get_config()
return config, input_ids, input_mask
def get_config(self):
return BlipTextConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
projection_dim=self.projection_dim,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
dropout=self.dropout,
attention_dropout=self.attention_dropout,
max_position_embeddings=self.max_position_embeddings,
initializer_range=self.initializer_range,
bos_token_id=self.bos_token_id,
)
def create_and_check_model(self, config, input_ids, input_mask):
model = TFBlipTextModel(config=config)
result = model(input_ids, attention_mask=input_mask, training=False)
result = model(input_ids, training=False)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, input_ids, input_mask = config_and_inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
return config, inputs_dict
@require_tf
class TFBlipTextModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (TFBlipTextModel,) if is_tf_available() else ()
fx_compatible = False
test_pruning = False
test_head_masking = False
test_onnx = False
def setUp(self):
self.model_tester = TFBlipTextModelTester(self)
self.config_tester = ConfigTester(self, config_class=BlipTextConfig, hidden_size=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip(reason="Blip does not use inputs_embeds")
def test_inputs_embeds(self):
pass
@unittest.skip(reason="BlipTextModel has no base class and is not available in MODEL_MAPPING")
def test_save_load_fast_init_from_base(self):
pass
@unittest.skip(reason="BlipTextModel has no base class and is not available in MODEL_MAPPING")
def test_save_load_fast_init_to_base(self):
pass
@slow
def test_model_from_pretrained(self):
for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
try:
model = TFBlipTextModel.from_pretrained(model_name)
except OSError:
model = TFBlipTextModel.from_pretrained(model_name, from_pt=True)
self.assertIsNotNone(model)
def test_pt_tf_model_equivalence(self, allow_missing_keys=True):
super().test_pt_tf_model_equivalence(allow_missing_keys=allow_missing_keys)
class TFBlipModelTester:
def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True):
if text_kwargs is None:
text_kwargs = {}
if vision_kwargs is None:
vision_kwargs = {}
self.parent = parent
self.text_model_tester = TFBlipTextModelTester(parent, **text_kwargs)
self.vision_model_tester = TFBlipVisionModelTester(parent, **vision_kwargs)
self.is_training = is_training
def prepare_config_and_inputs(self):
text_config, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs()
vision_config, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
config = self.get_config()
return config, input_ids, attention_mask, pixel_values
def get_config(self):
return BlipConfig.from_text_vision_configs(
self.text_model_tester.get_config(), self.vision_model_tester.get_config(), projection_dim=64
)
def create_and_check_model(self, config, input_ids, attention_mask, pixel_values):
model = TFBlipModel(config)
result = model(input_ids, pixel_values, attention_mask, training=False)
self.parent.assertEqual(
result.logits_per_image.shape, (self.vision_model_tester.batch_size, self.text_model_tester.batch_size)
)
self.parent.assertEqual(
result.logits_per_text.shape, (self.text_model_tester.batch_size, self.vision_model_tester.batch_size)
)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, input_ids, attention_mask, pixel_values = config_and_inputs
inputs_dict = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"pixel_values": pixel_values,
"return_loss": True,
}
return config, inputs_dict
@require_tf
class TFBlipModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (TFBlipModel,) if is_tf_available() else ()
pipeline_model_mapping = (
{"feature-extraction": TFBlipModel, "image-to-text": TFBlipForConditionalGeneration}
if is_tf_available()
else {}
)
test_head_masking = False
test_pruning = False
test_resize_embeddings = False
test_attention_outputs = False
test_onnx = False
def setUp(self):
self.model_tester = TFBlipModelTester(self)
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip(reason="Hidden_states is tested in individual model tests")
def test_hidden_states_output(self):
pass
@unittest.skip(reason="Inputs_embeds is tested in individual model tests")
def test_inputs_embeds(self):
pass
@unittest.skip(reason="Retain_grad is tested in individual model tests")
def test_retain_grad_hidden_states_attentions(self):
pass
@unittest.skip(reason="BlipModel does not have input/output embeddings")
def test_model_common_attributes(self):
pass
def test_load_vision_text_config(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# Save BlipConfig and check if we can load BlipVisionConfig from it
with tempfile.TemporaryDirectory() as tmp_dir_name:
config.save_pretrained(tmp_dir_name)
vision_config = BlipVisionConfig.from_pretrained(tmp_dir_name)
self.assertDictEqual(config.vision_config.to_dict(), vision_config.to_dict())
# Save BlipConfig and check if we can load BlipTextConfig from it
with tempfile.TemporaryDirectory() as tmp_dir_name:
config.save_pretrained(tmp_dir_name)
text_config = BlipTextConfig.from_pretrained(tmp_dir_name)
self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict())
@slow
def test_model_from_pretrained(self):
for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = TFBlipModel.from_pretrained(model_name, from_pt=True)
self.assertIsNotNone(model)
def test_pt_tf_model_equivalence(self, allow_missing_keys=True):
super().test_pt_tf_model_equivalence(allow_missing_keys=allow_missing_keys)
class BlipTextRetrievalModelTester:
def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True):
if text_kwargs is None:
text_kwargs = {}
if vision_kwargs is None:
vision_kwargs = {}
self.parent = parent
self.text_model_tester = TFBlipTextModelTester(parent, **text_kwargs)
self.vision_model_tester = TFBlipVisionModelTester(parent, **vision_kwargs)
self.is_training = is_training
def prepare_config_and_inputs(self):
text_config, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs()
vision_config, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
config = self.get_config()
return config, input_ids, attention_mask, pixel_values
def get_config(self):
return BlipConfig.from_text_vision_configs(
self.text_model_tester.get_config(), self.vision_model_tester.get_config(), projection_dim=64
)
def create_and_check_model(self, config, input_ids, attention_mask, pixel_values):
model = TFBlipModel(config)
result = model(input_ids, pixel_values, attention_mask, training=False)
self.parent.assertEqual(
result.logits_per_image.shape, (self.vision_model_tester.batch_size, self.text_model_tester.batch_size)
)
self.parent.assertEqual(
result.logits_per_text.shape, (self.text_model_tester.batch_size, self.vision_model_tester.batch_size)
)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, input_ids, attention_mask, pixel_values = config_and_inputs
inputs_dict = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"pixel_values": pixel_values,
}
return config, inputs_dict
class BlipTextImageModelsModelTester:
def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True):
if text_kwargs is None:
text_kwargs = {}
if vision_kwargs is None:
vision_kwargs = {}
self.parent = parent
self.text_model_tester = TFBlipTextModelTester(parent, **text_kwargs)
self.vision_model_tester = TFBlipVisionModelTester(parent, **vision_kwargs)
self.is_training = is_training
def prepare_config_and_inputs(self):
text_config, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs()
vision_config, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
config = self.get_config()
return config, input_ids, attention_mask, pixel_values
def get_config(self):
return BlipConfig.from_text_vision_configs(
self.text_model_tester.get_config(), self.vision_model_tester.get_config(), projection_dim=64
)
def create_and_check_model(self, config, input_ids, attention_mask, pixel_values):
model = TFBlipModel(config)
result = model(input_ids, pixel_values, attention_mask, training=False)
self.parent.assertEqual(
result.logits_per_image.shape, (self.vision_model_tester.batch_size, self.text_model_tester.batch_size)
)
self.parent.assertEqual(
result.logits_per_text.shape, (self.text_model_tester.batch_size, self.vision_model_tester.batch_size)
)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, input_ids, attention_mask, pixel_values = config_and_inputs
inputs_dict = {
"input_ids": input_ids,
"labels": input_ids,
"attention_mask": attention_mask,
"pixel_values": pixel_values,
}
return config, inputs_dict
@require_tf
@require_vision
class BlipVQAModelTest(unittest.TestCase):
all_model_classes = (TFBlipForQuestionAnswering,) if is_tf_available() else ()
def setUp(self):
self.model_tester = TFBlipModelTester(self)
def _prepare_inputs_for_vqa(self):
_, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
inputs_dict["labels"] = inputs_dict["input_ids"]
inputs_dict.pop("return_loss")
return inputs_dict
def test_class_name_consistency(self):
"""
Tests that all VQA models have a class name that ends with "ForQuestionAnswering"
"""
for model_class in self.all_model_classes:
model = model_class(self.model_tester.get_config())
self.assertTrue(
model.__class__.__name__.endswith("ForQuestionAnswering"),
f"Class name should end with 'ForVisualQuestionAnswering' got {model.__class__.__name__}",
)
def test_training(self):
"""
Tests that all VQA models can be trained on a single batch
"""
for model_class in self.all_model_classes:
model = model_class(self.model_tester.get_config())
loss = model(**self._prepare_inputs_for_vqa(), training=True).loss
self.assertIsNotNone(loss, "Loss should not be None")
@require_tf
class TFBlipTextRetrievalModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (TFBlipForImageTextRetrieval,) if is_tf_available() else ()
test_head_masking = False
test_pruning = False
test_resize_embeddings = False
test_attention_outputs = False
test_onnx = False
def setUp(self):
self.model_tester = BlipTextRetrievalModelTester(self)
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip(reason="Hidden_states is tested in individual model tests")
def test_hidden_states_output(self):
pass
@unittest.skip(reason="Inputs_embeds is tested in individual model tests")
def test_inputs_embeds(self):
pass
@unittest.skip(reason="Retain_grad is tested in individual model tests")
def test_retain_grad_hidden_states_attentions(self):
pass
@unittest.skip(reason="BlipModel does not have input/output embeddings")
def test_model_common_attributes(self):
pass
def test_training(self):
if not self.model_tester.is_training:
return
for model_class in self.all_model_classes[:-1]:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
model = model_class(config)
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
# hardcode labels to be the same as input_ids
inputs["labels"] = inputs["input_ids"]
loss = model(**inputs, training=True).loss
self.assertTrue(loss is not None)
def test_load_vision_text_config(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# Save BlipConfig and check if we can load BlipVisionConfig from it
with tempfile.TemporaryDirectory() as tmp_dir_name:
config.save_pretrained(tmp_dir_name)
vision_config = BlipVisionConfig.from_pretrained(tmp_dir_name)
self.assertDictEqual(config.vision_config.to_dict(), vision_config.to_dict())
# Save BlipConfig and check if we can load BlipTextConfig from it
with tempfile.TemporaryDirectory() as tmp_dir_name:
config.save_pretrained(tmp_dir_name)
text_config = BlipTextConfig.from_pretrained(tmp_dir_name)
self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict())
@slow
def test_model_from_pretrained(self):
for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = TFBlipModel.from_pretrained(model_name, from_pt=True)
self.assertIsNotNone(model)
@unittest.skip(reason="Tested in individual model tests")
def test_compile_tf_model(self):
pass
@unittest.skip("Model doesn't have a clean loss output.")
def test_keras_fit(self):
pass
@require_tf
class TFBlipTextImageModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (TFBlipForConditionalGeneration, TFBlipForQuestionAnswering) if is_tf_available() else ()
test_head_masking = False
test_pruning = False
test_resize_embeddings = False
test_attention_outputs = False
test_onnx = False
def setUp(self):
self.model_tester = BlipTextImageModelsModelTester(self)
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip(reason="Hidden_states is tested in individual model tests")
def test_hidden_states_output(self):
pass
@unittest.skip(reason="Inputs_embeds is tested in individual model tests")
def test_inputs_embeds(self):
pass
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.call)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
if model.config.is_encoder_decoder:
expected_arg_names = [
"input_ids",
"attention_mask",
"decoder_input_ids",
"decoder_attention_mask",
]
expected_arg_names.extend(
["head_mask", "decoder_head_mask", "cross_attn_head_mask", "encoder_outputs"]
if "head_mask" and "decoder_head_mask" and "cross_attn_head_mask" in arg_names
else ["encoder_outputs"]
)
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
else:
expected_arg_names = (
["input_ids"] if model_class != TFBlipForConditionalGeneration else ["pixel_values"]
)
self.assertListEqual(arg_names[:1], expected_arg_names)
@unittest.skip(reason="Tested in individual model tests")
def test_compile_tf_model(self):
pass
@unittest.skip("Has some odd input names!")
def test_keras_fit(self):
pass
@unittest.skip(reason="Retain_grad is tested in individual model tests")
def test_retain_grad_hidden_states_attentions(self):
pass
@unittest.skip(reason="BlipModel does not have input/output embeddings")
def test_model_common_attributes(self):
pass
def test_training(self):
if not self.model_tester.is_training:
return
for model_class in self.all_model_classes[:-1]:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
model = model_class(config)
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
# hardcode labels to be the same as input_ids
inputs["labels"] = inputs["input_ids"]
loss = model(**inputs, training=True).loss
self.assertIsNotNone(loss)
def test_load_vision_text_config(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# Save BlipConfig and check if we can load BlipVisionConfig from it
with tempfile.TemporaryDirectory() as tmp_dir_name:
config.save_pretrained(tmp_dir_name)
vision_config = BlipVisionConfig.from_pretrained(tmp_dir_name)
self.assertDictEqual(config.vision_config.to_dict(), vision_config.to_dict())
# Save BlipConfig and check if we can load BlipTextConfig from it
with tempfile.TemporaryDirectory() as tmp_dir_name:
config.save_pretrained(tmp_dir_name)
text_config = BlipTextConfig.from_pretrained(tmp_dir_name)
self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict())
@slow
def test_model_from_pretrained(self):
for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
try:
model = TFBlipModel.from_pretrained(model_name)
except OSError:
model = TFBlipModel.from_pretrained(model_name, from_pt=True)
self.assertIsNotNone(model)
# We will verify our results on an image of cute cats
def prepare_img():
url = "https://huggingface.co/hf-internal-testing/blip-test-image/resolve/main/demo.jpg"
im = Image.open(requests.get(url, stream=True).raw)
return im
@require_vision
@require_tf
@slow
class TFBlipModelIntegrationTest(unittest.TestCase):
def test_inference_image_captioning(self):
model = TFBlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", from_pt=True)
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
image = prepare_img()
# image only
inputs = processor(images=image, return_tensors="tf")
predictions = model.generate(**inputs)
# Test output
self.assertEqual(
predictions[0].numpy().tolist(), [30522, 1037, 2450, 3564, 2006, 1996, 3509, 2007, 2014, 3899, 102]
)
# image and context
context = ["a picture of"]
inputs = processor(images=image, text=context, return_tensors="tf")
predictions = model.generate(**inputs)
# Test output
self.assertEqual(
predictions[0].numpy().tolist(),
[30522, 1037, 3861, 1997, 1037, 2450, 3564, 2006, 1996, 3509, 2007, 2014, 3899, 102],
)
def test_inference_vqa(self):
model = TFBlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base", from_pt=True)
processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
image = prepare_img()
text = "how many dogs are in the picture?"
inputs = processor(image, text=text, return_tensors="tf")
out = model.generate(**inputs)
# Test output
self.assertEqual(out[0].numpy().tolist(), [30522, 1015, 102])
def test_inference_itm(self):
model = TFBlipForImageTextRetrieval.from_pretrained("Salesforce/blip-itm-base-coco", from_pt=True)
processor = BlipProcessor.from_pretrained("Salesforce/blip-itm-base-coco")
image = prepare_img()
text = "A woman and her dog sitting in a beach"
inputs = processor(image, text, return_tensors="tf")
out_itm = model(**inputs)
out = model(**inputs, use_itm_head=False, training=False)
expected_scores = tf.convert_to_tensor([[0.9798, 0.0202]])
self.assertTrue(np.allclose(tf.nn.softmax(out_itm[0]).numpy(), expected_scores, rtol=1e-3, atol=1e-3))
self.assertTrue(np.allclose(out[0], tf.convert_to_tensor([[0.5053]]), rtol=1e-3, atol=1e-3))
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the TensorFlow Blip model. """
import unittest
import numpy as np
from transformers import BlipTextConfig
from transformers.testing_utils import require_tf, slow
from transformers.utils import is_tf_available
from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
if is_tf_available():
import tensorflow as tf
from transformers import TFBlipTextModel
from transformers.models.blip.modeling_tf_blip import TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST
class BlipTextModelTester:
def __init__(
self,
parent,
batch_size=12,
seq_length=7,
is_training=True,
use_input_mask=True,
use_labels=True,
vocab_size=99,
hidden_size=32,
projection_dim=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
dropout=0.1,
attention_dropout=0.1,
max_position_embeddings=512,
initializer_range=0.02,
bos_token_id=0,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.projection_dim = projection_dim
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.dropout = dropout
self.attention_dropout = attention_dropout
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.scope = scope
self.bos_token_id = bos_token_id
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = random_attention_mask([self.batch_size, self.seq_length])
if input_mask is not None:
input_mask = input_mask.numpy()
batch_size, seq_length = input_mask.shape
rnd_start_indices = np.random.randint(1, seq_length - 1, size=(batch_size,))
for batch_idx, start_index in enumerate(rnd_start_indices):
input_mask[batch_idx, :start_index] = 1
input_mask[batch_idx, start_index:] = 0
config = self.get_config()
return config, input_ids, tf.convert_to_tensor(input_mask)
def get_config(self):
return BlipTextConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
projection_dim=self.projection_dim,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
dropout=self.dropout,
attention_dropout=self.attention_dropout,
max_position_embeddings=self.max_position_embeddings,
initializer_range=self.initializer_range,
bos_token_id=self.bos_token_id,
)
def create_and_check_model(self, config, input_ids, input_mask):
model = TFBlipTextModel(config=config)
result = model(input_ids, attention_mask=input_mask, training=False)
result = model(input_ids, training=False)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, input_ids, input_mask = config_and_inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
return config, inputs_dict
@require_tf
class BlipTextModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (TFBlipTextModel,) if is_tf_available() else ()
test_onnx = False
test_pruning = False
test_head_masking = False
def setUp(self):
self.model_tester = BlipTextModelTester(self)
self.config_tester = ConfigTester(self, config_class=BlipTextConfig, hidden_size=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_training(self):
pass
def test_training_gradient_checkpointing(self):
pass
@unittest.skip(reason="Blip does not use inputs_embeds")
def test_inputs_embeds(self):
pass
@unittest.skip(reason="BlipTextModel has no base class and is not available in MODEL_MAPPING")
def test_save_load_fast_init_from_base(self):
pass
@unittest.skip(reason="BlipTextModel has no base class and is not available in MODEL_MAPPING")
def test_save_load_fast_init_to_base(self):
pass
@slow
def test_model_from_pretrained(self):
for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
try:
model = TFBlipTextModel.from_pretrained(model_name)
except OSError:
model = TFBlipTextModel.from_pretrained(model_name, from_pt=True)
self.assertIsNotNone(model)
def test_pt_tf_model_equivalence(self, allow_missing_keys=True):
super().test_pt_tf_model_equivalence(allow_missing_keys=allow_missing_keys)
......@@ -1984,7 +1984,7 @@ class ModelTesterMixin:
self.check_pt_tf_outputs(tf_outputs, pt_outputs, type(pt_model))
@is_pt_tf_cross_test
def test_pt_tf_model_equivalence(self):
def test_pt_tf_model_equivalence(self, allow_missing_keys=False):
import transformers
for model_class in self.all_model_classes:
......@@ -2036,8 +2036,12 @@ class ModelTesterMixin:
# Check we can load pt model in tf and vice-versa with model => model functions
# Here requires `tf_inputs_dict` to build `tf_model`
tf_inputs_dict = self.prepare_tf_inputs_from_pt_inputs(pt_inputs_dict)
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict)
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
tf_model = transformers.load_pytorch_model_in_tf2_model(
tf_model, pt_model, tf_inputs=tf_inputs_dict, allow_missing_keys=allow_missing_keys
)
pt_model = transformers.load_tf2_model_in_pytorch_model(
pt_model, tf_model, allow_missing_keys=allow_missing_keys
)
# Original test: check without `labels`
self.check_pt_tf_models(tf_model, pt_model, pt_inputs_dict)
......@@ -2049,11 +2053,15 @@ class ModelTesterMixin:
with tempfile.TemporaryDirectory() as tmpdirname:
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
torch.save(pt_model.state_dict(), pt_checkpoint_path)
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path)
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(
tf_model, pt_checkpoint_path, allow_missing_keys=allow_missing_keys
)
tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
tf_model.save_weights(tf_checkpoint_path)
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path)
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(
pt_model, tf_checkpoint_path, allow_missing_keys=allow_missing_keys
)
# Original test: check without `labels`
self.check_pt_tf_models(tf_model, pt_model, pt_inputs_dict)
......
......@@ -668,7 +668,7 @@ class TFModelTesterMixin:
self.check_pt_tf_outputs(tf_outputs, pt_outputs, type(tf_model))
@is_pt_tf_cross_test
def test_pt_tf_model_equivalence(self):
def test_pt_tf_model_equivalence(self, allow_missing_keys=False):
import transformers
for model_class in self.all_model_classes:
......@@ -703,8 +703,12 @@ class TFModelTesterMixin:
tf_inputs_dict_with_labels = None
# Check we can load pt model in tf and vice-versa with model => model functions
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict)
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
tf_model = transformers.load_pytorch_model_in_tf2_model(
tf_model, pt_model, tf_inputs=tf_inputs_dict, allow_missing_keys=allow_missing_keys
)
pt_model = transformers.load_tf2_model_in_pytorch_model(
pt_model, tf_model, allow_missing_keys=allow_missing_keys
)
# Original test: check without `labels`
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
......@@ -716,11 +720,15 @@ class TFModelTesterMixin:
with tempfile.TemporaryDirectory() as tmpdirname:
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
torch.save(pt_model.state_dict(), pt_checkpoint_path)
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path)
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(
tf_model, pt_checkpoint_path, allow_missing_keys=allow_missing_keys
)
tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
tf_model.save_weights(tf_checkpoint_path)
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path)
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(
pt_model, tf_checkpoint_path, allow_missing_keys=allow_missing_keys
)
# Original test: check without `labels`
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
......@@ -791,7 +799,7 @@ class TFModelTesterMixin:
name="pixel_values",
dtype="float32",
)
elif model_class.__name__ in ["TFCLIPModel", "TFGroupViTModel"]:
elif model_class.__name__ in ["TFCLIPModel", "TFGroupViTModel", "TFBlipModel"]:
inputs = {
"input_ids": tf.keras.Input(batch_shape=(3, max_input), name="input_ids", dtype="int32"),
"pixel_values": tf.keras.Input(
......@@ -1792,6 +1800,8 @@ class TFModelTesterMixin:
for model_class in self.all_model_classes:
model = model_class(config)
tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class, return_labels=False)
if "labels" in tf_inputs_dict:
return # This is some kinda funky decoder model that needs labels in its forward pass
tf_inputs_dict = {
key: val
for key, val in tf_inputs_dict.items()
......@@ -1805,7 +1815,7 @@ class TFModelTesterMixin:
test_batch = next(iter(tf_dataset))
if isinstance(test_batch, tf.Tensor):
self.assertEqual(len(test_batch), len(input_dataset)) # Assert we didn't lose any data
else:
elif isinstance(test_batch, dict):
# Assert we discarded the unwanted extra column but kept everything else
self.assertEqual(len(test_batch), len(input_dataset.features) - 1)
self.assertNotIn("extra_unwanted_column", test_batch)
......
......@@ -145,6 +145,7 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
"TFSegformerDecodeHead", # Not a regular model.
"AltRobertaModel", # Building part of bigger (tested) model.
"BlipTextLMHeadModel", # No need to test it as it is tested by BlipTextVision models
"TFBlipTextLMHeadModel", # No need to test it as it is tested by BlipTextVision models
"BridgeTowerTextModel", # No need to test it as it is tested by BridgeTowerModel model.
"BridgeTowerVisionModel", # No need to test it as it is tested by BridgeTowerModel model.
"SpeechT5Decoder", # Building part of bigger (tested) model.
......@@ -205,6 +206,12 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
"BlipVisionModel",
"BlipTextLMHeadModel",
"BlipTextModel",
"TFBlipForConditionalGeneration",
"TFBlipForImageTextRetrieval",
"TFBlipForQuestionAnswering",
"TFBlipVisionModel",
"TFBlipTextLMHeadModel",
"TFBlipTextModel",
"Swin2SRForImageSuperResolution",
"BridgeTowerForImageAndTextRetrieval",
"BridgeTowerForMaskedLM",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment