"maint/vscode:/vscode.git/clone" did not exist on "a9d823b812c47a09129ec11fa7293eac3431aecd"
Unverified Commit 5f3ea66b authored by Matt's avatar Matt Committed by GitHub
Browse files

Add TF port of BLIP (#22090)



* Initial commit

* more stash commit

* Yet another stash commit

* yet more stash commit

* Mostly working except for docs / repo consistency

* Stop importing model list from torch file

* Add TF BLIP models to docs

* Add auto classes

* Move get_text_features and get_image_features

* Update src/transformers/models/blip/modeling_tf_blip.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip_text.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/blip/test_modeling_tf_blip.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/blip/test_modeling_tf_blip.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* Update tests/models/blip/test_modeling_tf_blip_text.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip_text.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/models/blip/modeling_tf_blip.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Use channels_last convolutions in TF (better performance + compatibility)

* Remove _shape function

* Move multi-line statement to one line in PT + TF

* Specify tf.keras.layers instead of importing from it

* Remove test_gradient_checkpointing and empty test_training methods

* move some multi-line statements to one line

* Update docstring for generate

* Remove pruned heads set

* Remove self.seq_len_dim

* Fixed issues with loss computation, should resolve some tests. Also ensured that the PT version follows the config for output_attentions and output_hidden_states

* ensure original model follows config in more cases

* Skip the same cross-attention tests in the PT tests - didn't realize we did it twice!

* Add training args throughout the models and layers

* make fixup

* Fix docstring for inputs_embeds

* Add docstring for is_decoder

* Add docstrings to text models

* Remove redundant computation

* Add unpack_inputs / keras_serializable

* Add modeling_tf_blip to doctests

* Add config classes for keras serialization

* Changes to allow model porting with pt-to-tf

* Quick fix to decoder head and test tweaks

* Revert an issue with masking the embeddings outputs

* Allow missing keys in some equivalence tests (for unused layers)

* Add tf-pt equivalence tests back in

* Update src/transformers/models/blip/modeling_tf_blip.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip_text.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip_text.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* make fixup

* Refactor invert_attention_mask out into tf_utils

* Re-enable cross-tests on the PT side too

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent a515d0a7
......@@ -269,7 +269,7 @@ Flax), PyTorch, and/or TensorFlow.
| BiT | ❌ | ❌ | ✅ | ❌ | ❌ |
| Blenderbot | ✅ | ✅ | ✅ | ✅ | ✅ |
| BlenderbotSmall | ✅ | ✅ | ✅ | ✅ | ✅ |
| BLIP | ❌ | ❌ | ✅ | | ❌ |
| BLIP | ❌ | ❌ | ✅ | | ❌ |
| BLIP-2 | ❌ | ❌ | ✅ | ❌ | ❌ |
| BLOOM | ❌ | ✅ | ✅ | ❌ | ❌ |
| BridgeTower | ❌ | ❌ | ✅ | ❌ | ❌ |
......
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
......@@ -93,4 +93,40 @@ The original code can be found [here](https://github.com/salesforce/BLIP).
## BlipForQuestionAnswering
[[autodoc]] BlipForQuestionAnswering
- forward
\ No newline at end of file
- forward
## TFBlipModel
[[autodoc]] TFBlipModel
- call
- get_text_features
- get_image_features
## TFBlipTextModel
[[autodoc]] TFBlipTextModel
- call
## TFBlipVisionModel
[[autodoc]] TFBlipVisionModel
- call
## TFBlipForConditionalGeneration
[[autodoc]] TFBlipForConditionalGeneration
- call
## TFBlipForImageTextRetrieval
[[autodoc]] TFBlipForImageTextRetrieval
- call
## TFBlipForQuestionAnswering
[[autodoc]] TFBlipForQuestionAnswering
- call
\ No newline at end of file
......@@ -2903,6 +2903,18 @@ else:
_import_structure["models.blenderbot_small"].extend(
["TFBlenderbotSmallForConditionalGeneration", "TFBlenderbotSmallModel", "TFBlenderbotSmallPreTrainedModel"]
)
_import_structure["models.blip"].extend(
[
"TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFBlipForConditionalGeneration",
"TFBlipForImageTextRetrieval",
"TFBlipForQuestionAnswering",
"TFBlipModel",
"TFBlipPreTrainedModel",
"TFBlipTextModel",
"TFBlipVisionModel",
]
)
_import_structure["models.camembert"].extend(
[
"TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
......@@ -6145,6 +6157,16 @@ if TYPE_CHECKING:
TFBlenderbotSmallModel,
TFBlenderbotSmallPreTrainedModel,
)
from .models.blip import (
TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST,
TFBlipForConditionalGeneration,
TFBlipForImageTextRetrieval,
TFBlipForQuestionAnswering,
TFBlipModel,
TFBlipPreTrainedModel,
TFBlipTextModel,
TFBlipVisionModel,
)
from .models.camembert import (
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFCamembertForCausalLM,
......
......@@ -196,7 +196,7 @@ class PTtoTFCommand(BaseTransformersCLICommand):
self._extra_commit_description = extra_commit_description
self._override_model_class = override_model_class
def get_inputs(self, pt_model, config):
def get_inputs(self, pt_model, tf_dummy_inputs, config):
"""
Returns the right inputs for the model, based on its signature.
"""
......@@ -255,7 +255,11 @@ class PTtoTFCommand(BaseTransformersCLICommand):
tf_input = processor(**processor_inputs, return_tensors="tf")
# Extra input requirements, in addition to the input modality
if config.is_encoder_decoder or (hasattr(pt_model, "encoder") and hasattr(pt_model, "decoder")):
if (
config.is_encoder_decoder
or (hasattr(pt_model, "encoder") and hasattr(pt_model, "decoder"))
or "decoder_input_ids" in tf_dummy_inputs
):
decoder_input_ids = np.asarray([[1], [1]], dtype=int) * (pt_model.config.decoder_start_token_id or 0)
pt_input.update({"decoder_input_ids": torch.tensor(decoder_input_ids)})
tf_input.update({"decoder_input_ids": tf.convert_to_tensor(decoder_input_ids)})
......@@ -306,18 +310,24 @@ class PTtoTFCommand(BaseTransformersCLICommand):
except AttributeError:
raise AttributeError(f"The TensorFlow equivalent of {architectures[0]} doesn't exist in transformers.")
# Load models and acquire a basic input compatible with the model.
# Check the TF dummy inputs to see what keys we need in the forward pass
tf_from_pt_model = tf_class.from_config(config)
tf_dummy_inputs = tf_from_pt_model.dummy_inputs
del tf_from_pt_model # Try to keep only one model in memory at a time
# Load the model and get some basic inputs
pt_model = pt_class.from_pretrained(self._local_dir)
pt_model.eval()
pt_input, tf_input = self.get_inputs(pt_model, config)
pt_input, tf_input = self.get_inputs(pt_model, tf_dummy_inputs, config)
with torch.no_grad():
pt_outputs = pt_model(**pt_input, output_hidden_states=True)
del pt_model # will no longer be used, and may have a large memory footprint
tf_from_pt_model = tf_class.from_pretrained(self._local_dir, from_pt=True)
tf_from_pt_outputs = tf_from_pt_model(**tf_input, output_hidden_states=True)
tf_from_pt_outputs = tf_from_pt_model(**tf_input, output_hidden_states=True, training=False)
# Confirms that cross loading PT weights into TF worked.
crossload_differences = self.find_pt_tf_differences(pt_outputs, tf_from_pt_outputs)
......
......@@ -406,6 +406,7 @@ def unpack_inputs(func):
func (`callable`):
The callable function of the TensorFlow model.
Returns:
A callable that wraps the original `func` with the behavior described above.
"""
......@@ -1157,6 +1158,38 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
"""
return cls(config, **kwargs)
def get_head_mask(self, head_mask: Optional[tf.Tensor], num_hidden_layers: int) -> tf.Tensor:
"""
Prepare the head mask if needed.
Args:
head_mask (`tf.Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`, *optional*):
The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard).
num_hidden_layers (`int`):
The number of hidden layers in the model.
Returns:
`tf.Tensor` with shape `[num_hidden_layers x batch x num_heads x seq_length x seq_length]` or list with
`[None]` for each layer.
"""
if head_mask is not None:
head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
else:
head_mask = [None] * num_hidden_layers
return head_mask
def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers):
"""-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]"""
if head_mask.shape.rank == 1:
head_mask = head_mask[None, None, :, None, None]
head_mask = tf.repeat(head_mask, repeats=num_hidden_layers, axis=0)
elif head_mask.shape.rank == 2:
head_mask = head_mask[:, None, :, None, None]
assert head_mask.shape.rank == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
head_mask = tf.cast(head_mask, tf.float32) # switch to float if need + fp16 compatibility
return head_mask
def eager_serving(self, inputs):
"""
Method used for serving the model. Intended not to be compiled with a tf.function decorator so that we can use
......
......@@ -34,6 +34,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict(
("bert", "TFBertModel"),
("blenderbot", "TFBlenderbotModel"),
("blenderbot-small", "TFBlenderbotSmallModel"),
("blip", "TFBlipModel"),
("camembert", "TFCamembertModel"),
("clip", "TFCLIPModel"),
("convbert", "TFConvBertModel"),
......@@ -213,6 +214,7 @@ TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Zero Shot Image Classification mapping
("blip", "TFBlipModel"),
("clip", "TFCLIPModel"),
]
)
......
......@@ -13,7 +13,13 @@
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_tf_available,
is_torch_available,
is_vision_available,
)
_import_structure = {
......@@ -52,6 +58,23 @@ else:
"BlipForImageTextRetrieval",
]
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_blip"] = [
"TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFBlipModel",
"TFBlipPreTrainedModel",
"TFBlipForConditionalGeneration",
"TFBlipForQuestionAnswering",
"TFBlipVisionModel",
"TFBlipTextModel",
"TFBlipForImageTextRetrieval",
]
if TYPE_CHECKING:
from .configuration_blip import BLIP_PRETRAINED_CONFIG_ARCHIVE_MAP, BlipConfig, BlipTextConfig, BlipVisionConfig
from .processing_blip import BlipProcessor
......@@ -81,6 +104,23 @@ if TYPE_CHECKING:
BlipVisionModel,
)
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_blip import (
TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST,
TFBlipForConditionalGeneration,
TFBlipForImageTextRetrieval,
TFBlipForQuestionAnswering,
TFBlipModel,
TFBlipPreTrainedModel,
TFBlipTextModel,
TFBlipVisionModel,
)
else:
import sys
......
......@@ -313,17 +313,12 @@ class BlipAttention(nn.Module):
bsz, tgt_len, embed_dim = hidden_states.size()
mixed_qkv = self.qkv(hidden_states)
mixed_qkv = (
self.qkv(hidden_states)
.reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
query_states, key_states, value_states = (
mixed_qkv[0],
mixed_qkv[1],
mixed_qkv[2],
)
query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2]
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
......@@ -587,9 +582,7 @@ class BlipEncoder(nn.Module):
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
Embedded representation of the inputs. Should be float, not int tokens.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
......@@ -824,10 +817,7 @@ class BlipModel(BlipPreTrainedModel):
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_outputs = self.vision_model(
pixel_values=pixel_values,
return_dict=return_dict,
)
vision_outputs = self.vision_model(pixel_values=pixel_values, return_dict=return_dict)
pooled_output = vision_outputs[1] # pooled_output
image_features = self.visual_projection(pooled_output)
......@@ -993,6 +983,10 @@ class BlipForConditionalGeneration(BlipPreTrainedModel):
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
vision_outputs = self.vision_model(
pixel_values=pixel_values,
......@@ -1037,7 +1031,7 @@ class BlipForConditionalGeneration(BlipPreTrainedModel):
Overrides *generate* function to be able to use the model as a conditional generator
Parameters:
pixel_values (*torch.FloatTensor* of shape *(batch_size, image_width, image_height)*:
pixel_values (*torch.FloatTensor* of shape *(batch_size, num_channels, image_height, image_width)*:
Input image to be processed
input_ids (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*):
The sequence used as a prompt for the generation.
......@@ -1066,9 +1060,7 @@ class BlipForConditionalGeneration(BlipPreTrainedModel):
"""
batch_size = pixel_values.shape[0]
vision_outputs = self.vision_model(
pixel_values=pixel_values,
)
vision_outputs = self.vision_model(pixel_values=pixel_values)
image_embeds = vision_outputs[0]
......@@ -1198,6 +1190,10 @@ class BlipForQuestionAnswering(BlipPreTrainedModel):
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
vision_outputs = self.vision_model(
pixel_values=pixel_values,
......@@ -1266,7 +1262,7 @@ class BlipForQuestionAnswering(BlipPreTrainedModel):
Parameters:
input_ids (*torch.LongTensor* of shape *(batch_size, sequence_length)*):
The sequence used as a prompt for the generation.
pixel_values (*torch.FloatTensor* of shape *(batch_size, image_width, image_height)*:
pixel_values (*torch.FloatTensor* of shape *(batch_size, num_channels, image_height, image_width)*:
Input image to be processed
attention_mask (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`. `1` for
......@@ -1295,9 +1291,7 @@ class BlipForQuestionAnswering(BlipPreTrainedModel):
2
```
"""
vision_outputs = self.vision_model(
pixel_values=pixel_values,
)
vision_outputs = self.vision_model(pixel_values=pixel_values)
image_embeds = vision_outputs[0]
......@@ -1412,6 +1406,10 @@ class BlipForImageTextRetrieval(BlipPreTrainedModel):
```
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
vision_outputs = self.vision_model(
pixel_values=pixel_values,
......
This diff is collapsed.
This diff is collapsed.
......@@ -453,9 +453,7 @@ class Blip2Encoder(nn.Module):
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
Embedded representation of the inputs. Should be float, not int tokens.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
......
......@@ -68,3 +68,31 @@ def stable_softmax(logits: tf.Tensor, axis: Optional[int] = None, name: Optional
# TODO: When the issue linked above gets sorted, add a check on TF version here and use the original function if
# it has the fix. After we drop the support for unfixed versions, remove this function.
return tf.nn.softmax(logits=logits + 1e-9, axis=axis, name=name)
def invert_attention_mask(encoder_attention_mask: tf.Tensor) -> tf.Tensor:
"""
Invert an attention mask (e.g., switches 0. and 1.).
Args:
encoder_attention_mask (`torch.Tensor`): An attention mask.
Returns:
`tf.Tensor`: The inverted attention mask.
"""
if not isinstance(encoder_attention_mask, tf.Tensor):
encoder_attention_mask = tf.convert_to_tensor(encoder_attention_mask) # Catches stray NumPy inputs
if encoder_attention_mask.shape.rank == 3:
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
if encoder_attention_mask.shape.rank == 2:
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow
# /transformer/transformer_layers.py#L270
# encoder_extended_attention_mask = (encoder_extended_attention_mask ==
# encoder_extended_attention_mask.transpose(-1, -2))
encoder_extended_attention_mask = (
tf.cast(1, encoder_attention_mask.dtype) - encoder_extended_attention_mask
) * encoder_extended_attention_mask.dtype.min
return encoder_extended_attention_mask
......@@ -556,6 +556,58 @@ class TFBlenderbotSmallPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["tf"])
TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST = None
class TFBlipForConditionalGeneration(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFBlipForImageTextRetrieval(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFBlipForQuestionAnswering(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFBlipModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFBlipPreTrainedModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFBlipTextModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFBlipVisionModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
......@@ -342,6 +342,9 @@ class BlipTextModelTest(ModelTesterMixin, unittest.TestCase):
model = BlipTextModel.from_pretrained(model_name)
self.assertIsNotNone(model)
def test_pt_tf_model_equivalence(self):
super().test_pt_tf_model_equivalence(allow_missing_keys=True)
class BlipModelTester:
def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True):
......@@ -524,6 +527,9 @@ class BlipModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
model = BlipModel.from_pretrained(model_name)
self.assertIsNotNone(model)
def test_pt_tf_model_equivalence(self):
super().test_pt_tf_model_equivalence(allow_missing_keys=True)
class BlipTextRetrievalModelTester:
def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True):
......
......@@ -164,3 +164,6 @@ class BlipTextModelTest(ModelTesterMixin, unittest.TestCase):
for model_name in BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = BlipTextModel.from_pretrained(model_name)
self.assertIsNotNone(model)
def test_pt_tf_model_equivalence(self):
super().test_pt_tf_model_equivalence(allow_missing_keys=True)
This diff is collapsed.
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the TensorFlow Blip model. """
import unittest
import numpy as np
from transformers import BlipTextConfig
from transformers.testing_utils import require_tf, slow
from transformers.utils import is_tf_available
from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
if is_tf_available():
import tensorflow as tf
from transformers import TFBlipTextModel
from transformers.models.blip.modeling_tf_blip import TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST
class BlipTextModelTester:
def __init__(
self,
parent,
batch_size=12,
seq_length=7,
is_training=True,
use_input_mask=True,
use_labels=True,
vocab_size=99,
hidden_size=32,
projection_dim=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
dropout=0.1,
attention_dropout=0.1,
max_position_embeddings=512,
initializer_range=0.02,
bos_token_id=0,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.projection_dim = projection_dim
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.dropout = dropout
self.attention_dropout = attention_dropout
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.scope = scope
self.bos_token_id = bos_token_id
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = random_attention_mask([self.batch_size, self.seq_length])
if input_mask is not None:
input_mask = input_mask.numpy()
batch_size, seq_length = input_mask.shape
rnd_start_indices = np.random.randint(1, seq_length - 1, size=(batch_size,))
for batch_idx, start_index in enumerate(rnd_start_indices):
input_mask[batch_idx, :start_index] = 1
input_mask[batch_idx, start_index:] = 0
config = self.get_config()
return config, input_ids, tf.convert_to_tensor(input_mask)
def get_config(self):
return BlipTextConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
projection_dim=self.projection_dim,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
dropout=self.dropout,
attention_dropout=self.attention_dropout,
max_position_embeddings=self.max_position_embeddings,
initializer_range=self.initializer_range,
bos_token_id=self.bos_token_id,
)
def create_and_check_model(self, config, input_ids, input_mask):
model = TFBlipTextModel(config=config)
result = model(input_ids, attention_mask=input_mask, training=False)
result = model(input_ids, training=False)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, input_ids, input_mask = config_and_inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
return config, inputs_dict
@require_tf
class BlipTextModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (TFBlipTextModel,) if is_tf_available() else ()
test_onnx = False
test_pruning = False
test_head_masking = False
def setUp(self):
self.model_tester = BlipTextModelTester(self)
self.config_tester = ConfigTester(self, config_class=BlipTextConfig, hidden_size=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_training(self):
pass
def test_training_gradient_checkpointing(self):
pass
@unittest.skip(reason="Blip does not use inputs_embeds")
def test_inputs_embeds(self):
pass
@unittest.skip(reason="BlipTextModel has no base class and is not available in MODEL_MAPPING")
def test_save_load_fast_init_from_base(self):
pass
@unittest.skip(reason="BlipTextModel has no base class and is not available in MODEL_MAPPING")
def test_save_load_fast_init_to_base(self):
pass
@slow
def test_model_from_pretrained(self):
for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
try:
model = TFBlipTextModel.from_pretrained(model_name)
except OSError:
model = TFBlipTextModel.from_pretrained(model_name, from_pt=True)
self.assertIsNotNone(model)
def test_pt_tf_model_equivalence(self, allow_missing_keys=True):
super().test_pt_tf_model_equivalence(allow_missing_keys=allow_missing_keys)
......@@ -1984,7 +1984,7 @@ class ModelTesterMixin:
self.check_pt_tf_outputs(tf_outputs, pt_outputs, type(pt_model))
@is_pt_tf_cross_test
def test_pt_tf_model_equivalence(self):
def test_pt_tf_model_equivalence(self, allow_missing_keys=False):
import transformers
for model_class in self.all_model_classes:
......@@ -2036,8 +2036,12 @@ class ModelTesterMixin:
# Check we can load pt model in tf and vice-versa with model => model functions
# Here requires `tf_inputs_dict` to build `tf_model`
tf_inputs_dict = self.prepare_tf_inputs_from_pt_inputs(pt_inputs_dict)
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict)
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
tf_model = transformers.load_pytorch_model_in_tf2_model(
tf_model, pt_model, tf_inputs=tf_inputs_dict, allow_missing_keys=allow_missing_keys
)
pt_model = transformers.load_tf2_model_in_pytorch_model(
pt_model, tf_model, allow_missing_keys=allow_missing_keys
)
# Original test: check without `labels`
self.check_pt_tf_models(tf_model, pt_model, pt_inputs_dict)
......@@ -2049,11 +2053,15 @@ class ModelTesterMixin:
with tempfile.TemporaryDirectory() as tmpdirname:
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
torch.save(pt_model.state_dict(), pt_checkpoint_path)
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path)
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(
tf_model, pt_checkpoint_path, allow_missing_keys=allow_missing_keys
)
tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
tf_model.save_weights(tf_checkpoint_path)
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path)
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(
pt_model, tf_checkpoint_path, allow_missing_keys=allow_missing_keys
)
# Original test: check without `labels`
self.check_pt_tf_models(tf_model, pt_model, pt_inputs_dict)
......
......@@ -668,7 +668,7 @@ class TFModelTesterMixin:
self.check_pt_tf_outputs(tf_outputs, pt_outputs, type(tf_model))
@is_pt_tf_cross_test
def test_pt_tf_model_equivalence(self):
def test_pt_tf_model_equivalence(self, allow_missing_keys=False):
import transformers
for model_class in self.all_model_classes:
......@@ -703,8 +703,12 @@ class TFModelTesterMixin:
tf_inputs_dict_with_labels = None
# Check we can load pt model in tf and vice-versa with model => model functions
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict)
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
tf_model = transformers.load_pytorch_model_in_tf2_model(
tf_model, pt_model, tf_inputs=tf_inputs_dict, allow_missing_keys=allow_missing_keys
)
pt_model = transformers.load_tf2_model_in_pytorch_model(
pt_model, tf_model, allow_missing_keys=allow_missing_keys
)
# Original test: check without `labels`
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
......@@ -716,11 +720,15 @@ class TFModelTesterMixin:
with tempfile.TemporaryDirectory() as tmpdirname:
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
torch.save(pt_model.state_dict(), pt_checkpoint_path)
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path)
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(
tf_model, pt_checkpoint_path, allow_missing_keys=allow_missing_keys
)
tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
tf_model.save_weights(tf_checkpoint_path)
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path)
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(
pt_model, tf_checkpoint_path, allow_missing_keys=allow_missing_keys
)
# Original test: check without `labels`
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
......@@ -791,7 +799,7 @@ class TFModelTesterMixin:
name="pixel_values",
dtype="float32",
)
elif model_class.__name__ in ["TFCLIPModel", "TFGroupViTModel"]:
elif model_class.__name__ in ["TFCLIPModel", "TFGroupViTModel", "TFBlipModel"]:
inputs = {
"input_ids": tf.keras.Input(batch_shape=(3, max_input), name="input_ids", dtype="int32"),
"pixel_values": tf.keras.Input(
......@@ -1792,6 +1800,8 @@ class TFModelTesterMixin:
for model_class in self.all_model_classes:
model = model_class(config)
tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class, return_labels=False)
if "labels" in tf_inputs_dict:
return # This is some kinda funky decoder model that needs labels in its forward pass
tf_inputs_dict = {
key: val
for key, val in tf_inputs_dict.items()
......@@ -1805,7 +1815,7 @@ class TFModelTesterMixin:
test_batch = next(iter(tf_dataset))
if isinstance(test_batch, tf.Tensor):
self.assertEqual(len(test_batch), len(input_dataset)) # Assert we didn't lose any data
else:
elif isinstance(test_batch, dict):
# Assert we discarded the unwanted extra column but kept everything else
self.assertEqual(len(test_batch), len(input_dataset.features) - 1)
self.assertNotIn("extra_unwanted_column", test_batch)
......
......@@ -145,6 +145,7 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
"TFSegformerDecodeHead", # Not a regular model.
"AltRobertaModel", # Building part of bigger (tested) model.
"BlipTextLMHeadModel", # No need to test it as it is tested by BlipTextVision models
"TFBlipTextLMHeadModel", # No need to test it as it is tested by BlipTextVision models
"BridgeTowerTextModel", # No need to test it as it is tested by BridgeTowerModel model.
"BridgeTowerVisionModel", # No need to test it as it is tested by BridgeTowerModel model.
"SpeechT5Decoder", # Building part of bigger (tested) model.
......@@ -205,6 +206,12 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
"BlipVisionModel",
"BlipTextLMHeadModel",
"BlipTextModel",
"TFBlipForConditionalGeneration",
"TFBlipForImageTextRetrieval",
"TFBlipForQuestionAnswering",
"TFBlipVisionModel",
"TFBlipTextLMHeadModel",
"TFBlipTextModel",
"Swin2SRForImageSuperResolution",
"BridgeTowerForImageAndTextRetrieval",
"BridgeTowerForMaskedLM",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment