Unverified Commit 5f3ea66b authored by Matt's avatar Matt Committed by GitHub
Browse files

Add TF port of BLIP (#22090)



* Initial commit

* more stash commit

* Yet another stash commit

* yet more stash commit

* Mostly working except for docs / repo consistency

* Stop importing model list from torch file

* Add TF BLIP models to docs

* Add auto classes

* Move get_text_features and get_image_features

* Update src/transformers/models/blip/modeling_tf_blip.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip_text.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/blip/test_modeling_tf_blip.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/blip/test_modeling_tf_blip.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* Update tests/models/blip/test_modeling_tf_blip_text.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip_text.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/models/blip/modeling_tf_blip.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Use channels_last convolutions in TF (better performance + compatibility)

* Remove _shape function

* Move multi-line statement to one line in PT + TF

* Specify tf.keras.layers instead of importing from it

* Remove test_gradient_checkpointing and empty test_training methods

* move some multi-line statements to one line

* Update docstring for generate

* Remove pruned heads set

* Remove self.seq_len_dim

* Fixed issues with loss computation, should resolve some tests. Also ensured that the PT version follows the config for output_attentions and output_hidden_states

* ensure original model follows config in more cases

* Skip the same cross-attention tests in the PT tests - didn't realize we did it twice!

* Add training args throughout the models and layers

* make fixup

* Fix docstring for inputs_embeds

* Add docstring for is_decoder

* Add docstrings to text models

* Remove redundant computation

* Add unpack_inputs / keras_serializable

* Add modeling_tf_blip to doctests

* Add config classes for keras serialization

* Changes to allow model porting with pt-to-tf

* Quick fix to decoder head and test tweaks

* Revert an issue with masking the embeddings outputs

* Allow missing keys in some equivalence tests (for unused layers)

* Add tf-pt equivalence tests back in

* Update src/transformers/models/blip/modeling_tf_blip.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip_text.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip_text.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* make fixup

* Refactor invert_attention_mask out into tf_utils

* Re-enable cross-tests on the PT side too

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent a515d0a7
...@@ -269,7 +269,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -269,7 +269,7 @@ Flax), PyTorch, and/or TensorFlow.
| BiT | ❌ | ❌ | ✅ | ❌ | ❌ | | BiT | ❌ | ❌ | ✅ | ❌ | ❌ |
| Blenderbot | ✅ | ✅ | ✅ | ✅ | ✅ | | Blenderbot | ✅ | ✅ | ✅ | ✅ | ✅ |
| BlenderbotSmall | ✅ | ✅ | ✅ | ✅ | ✅ | | BlenderbotSmall | ✅ | ✅ | ✅ | ✅ | ✅ |
| BLIP | ❌ | ❌ | ✅ | | ❌ | | BLIP | ❌ | ❌ | ✅ | | ❌ |
| BLIP-2 | ❌ | ❌ | ✅ | ❌ | ❌ | | BLIP-2 | ❌ | ❌ | ✅ | ❌ | ❌ |
| BLOOM | ❌ | ✅ | ✅ | ❌ | ❌ | | BLOOM | ❌ | ✅ | ✅ | ❌ | ❌ |
| BridgeTower | ❌ | ❌ | ✅ | ❌ | ❌ | | BridgeTower | ❌ | ❌ | ✅ | ❌ | ❌ |
......
<!--Copyright 2022 The HuggingFace Team. All rights reserved. <!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at the License. You may obtain a copy of the License at
...@@ -93,4 +93,40 @@ The original code can be found [here](https://github.com/salesforce/BLIP). ...@@ -93,4 +93,40 @@ The original code can be found [here](https://github.com/salesforce/BLIP).
## BlipForQuestionAnswering ## BlipForQuestionAnswering
[[autodoc]] BlipForQuestionAnswering [[autodoc]] BlipForQuestionAnswering
- forward - forward
\ No newline at end of file
## TFBlipModel
[[autodoc]] TFBlipModel
- call
- get_text_features
- get_image_features
## TFBlipTextModel
[[autodoc]] TFBlipTextModel
- call
## TFBlipVisionModel
[[autodoc]] TFBlipVisionModel
- call
## TFBlipForConditionalGeneration
[[autodoc]] TFBlipForConditionalGeneration
- call
## TFBlipForImageTextRetrieval
[[autodoc]] TFBlipForImageTextRetrieval
- call
## TFBlipForQuestionAnswering
[[autodoc]] TFBlipForQuestionAnswering
- call
\ No newline at end of file
...@@ -2903,6 +2903,18 @@ else: ...@@ -2903,6 +2903,18 @@ else:
_import_structure["models.blenderbot_small"].extend( _import_structure["models.blenderbot_small"].extend(
["TFBlenderbotSmallForConditionalGeneration", "TFBlenderbotSmallModel", "TFBlenderbotSmallPreTrainedModel"] ["TFBlenderbotSmallForConditionalGeneration", "TFBlenderbotSmallModel", "TFBlenderbotSmallPreTrainedModel"]
) )
_import_structure["models.blip"].extend(
[
"TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFBlipForConditionalGeneration",
"TFBlipForImageTextRetrieval",
"TFBlipForQuestionAnswering",
"TFBlipModel",
"TFBlipPreTrainedModel",
"TFBlipTextModel",
"TFBlipVisionModel",
]
)
_import_structure["models.camembert"].extend( _import_structure["models.camembert"].extend(
[ [
"TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST", "TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
...@@ -6145,6 +6157,16 @@ if TYPE_CHECKING: ...@@ -6145,6 +6157,16 @@ if TYPE_CHECKING:
TFBlenderbotSmallModel, TFBlenderbotSmallModel,
TFBlenderbotSmallPreTrainedModel, TFBlenderbotSmallPreTrainedModel,
) )
from .models.blip import (
TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST,
TFBlipForConditionalGeneration,
TFBlipForImageTextRetrieval,
TFBlipForQuestionAnswering,
TFBlipModel,
TFBlipPreTrainedModel,
TFBlipTextModel,
TFBlipVisionModel,
)
from .models.camembert import ( from .models.camembert import (
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST, TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFCamembertForCausalLM, TFCamembertForCausalLM,
......
...@@ -196,7 +196,7 @@ class PTtoTFCommand(BaseTransformersCLICommand): ...@@ -196,7 +196,7 @@ class PTtoTFCommand(BaseTransformersCLICommand):
self._extra_commit_description = extra_commit_description self._extra_commit_description = extra_commit_description
self._override_model_class = override_model_class self._override_model_class = override_model_class
def get_inputs(self, pt_model, config): def get_inputs(self, pt_model, tf_dummy_inputs, config):
""" """
Returns the right inputs for the model, based on its signature. Returns the right inputs for the model, based on its signature.
""" """
...@@ -255,7 +255,11 @@ class PTtoTFCommand(BaseTransformersCLICommand): ...@@ -255,7 +255,11 @@ class PTtoTFCommand(BaseTransformersCLICommand):
tf_input = processor(**processor_inputs, return_tensors="tf") tf_input = processor(**processor_inputs, return_tensors="tf")
# Extra input requirements, in addition to the input modality # Extra input requirements, in addition to the input modality
if config.is_encoder_decoder or (hasattr(pt_model, "encoder") and hasattr(pt_model, "decoder")): if (
config.is_encoder_decoder
or (hasattr(pt_model, "encoder") and hasattr(pt_model, "decoder"))
or "decoder_input_ids" in tf_dummy_inputs
):
decoder_input_ids = np.asarray([[1], [1]], dtype=int) * (pt_model.config.decoder_start_token_id or 0) decoder_input_ids = np.asarray([[1], [1]], dtype=int) * (pt_model.config.decoder_start_token_id or 0)
pt_input.update({"decoder_input_ids": torch.tensor(decoder_input_ids)}) pt_input.update({"decoder_input_ids": torch.tensor(decoder_input_ids)})
tf_input.update({"decoder_input_ids": tf.convert_to_tensor(decoder_input_ids)}) tf_input.update({"decoder_input_ids": tf.convert_to_tensor(decoder_input_ids)})
...@@ -306,18 +310,24 @@ class PTtoTFCommand(BaseTransformersCLICommand): ...@@ -306,18 +310,24 @@ class PTtoTFCommand(BaseTransformersCLICommand):
except AttributeError: except AttributeError:
raise AttributeError(f"The TensorFlow equivalent of {architectures[0]} doesn't exist in transformers.") raise AttributeError(f"The TensorFlow equivalent of {architectures[0]} doesn't exist in transformers.")
# Load models and acquire a basic input compatible with the model. # Check the TF dummy inputs to see what keys we need in the forward pass
tf_from_pt_model = tf_class.from_config(config)
tf_dummy_inputs = tf_from_pt_model.dummy_inputs
del tf_from_pt_model # Try to keep only one model in memory at a time
# Load the model and get some basic inputs
pt_model = pt_class.from_pretrained(self._local_dir) pt_model = pt_class.from_pretrained(self._local_dir)
pt_model.eval() pt_model.eval()
pt_input, tf_input = self.get_inputs(pt_model, config) pt_input, tf_input = self.get_inputs(pt_model, tf_dummy_inputs, config)
with torch.no_grad(): with torch.no_grad():
pt_outputs = pt_model(**pt_input, output_hidden_states=True) pt_outputs = pt_model(**pt_input, output_hidden_states=True)
del pt_model # will no longer be used, and may have a large memory footprint del pt_model # will no longer be used, and may have a large memory footprint
tf_from_pt_model = tf_class.from_pretrained(self._local_dir, from_pt=True) tf_from_pt_model = tf_class.from_pretrained(self._local_dir, from_pt=True)
tf_from_pt_outputs = tf_from_pt_model(**tf_input, output_hidden_states=True) tf_from_pt_outputs = tf_from_pt_model(**tf_input, output_hidden_states=True, training=False)
# Confirms that cross loading PT weights into TF worked. # Confirms that cross loading PT weights into TF worked.
crossload_differences = self.find_pt_tf_differences(pt_outputs, tf_from_pt_outputs) crossload_differences = self.find_pt_tf_differences(pt_outputs, tf_from_pt_outputs)
......
...@@ -406,6 +406,7 @@ def unpack_inputs(func): ...@@ -406,6 +406,7 @@ def unpack_inputs(func):
func (`callable`): func (`callable`):
The callable function of the TensorFlow model. The callable function of the TensorFlow model.
Returns: Returns:
A callable that wraps the original `func` with the behavior described above. A callable that wraps the original `func` with the behavior described above.
""" """
...@@ -1157,6 +1158,38 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -1157,6 +1158,38 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
""" """
return cls(config, **kwargs) return cls(config, **kwargs)
def get_head_mask(self, head_mask: Optional[tf.Tensor], num_hidden_layers: int) -> tf.Tensor:
"""
Prepare the head mask if needed.
Args:
head_mask (`tf.Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`, *optional*):
The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard).
num_hidden_layers (`int`):
The number of hidden layers in the model.
Returns:
`tf.Tensor` with shape `[num_hidden_layers x batch x num_heads x seq_length x seq_length]` or list with
`[None]` for each layer.
"""
if head_mask is not None:
head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
else:
head_mask = [None] * num_hidden_layers
return head_mask
def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers):
"""-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]"""
if head_mask.shape.rank == 1:
head_mask = head_mask[None, None, :, None, None]
head_mask = tf.repeat(head_mask, repeats=num_hidden_layers, axis=0)
elif head_mask.shape.rank == 2:
head_mask = head_mask[:, None, :, None, None]
assert head_mask.shape.rank == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
head_mask = tf.cast(head_mask, tf.float32) # switch to float if need + fp16 compatibility
return head_mask
def eager_serving(self, inputs): def eager_serving(self, inputs):
""" """
Method used for serving the model. Intended not to be compiled with a tf.function decorator so that we can use Method used for serving the model. Intended not to be compiled with a tf.function decorator so that we can use
......
...@@ -34,6 +34,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict( ...@@ -34,6 +34,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict(
("bert", "TFBertModel"), ("bert", "TFBertModel"),
("blenderbot", "TFBlenderbotModel"), ("blenderbot", "TFBlenderbotModel"),
("blenderbot-small", "TFBlenderbotSmallModel"), ("blenderbot-small", "TFBlenderbotSmallModel"),
("blip", "TFBlipModel"),
("camembert", "TFCamembertModel"), ("camembert", "TFCamembertModel"),
("clip", "TFCLIPModel"), ("clip", "TFCLIPModel"),
("convbert", "TFConvBertModel"), ("convbert", "TFConvBertModel"),
...@@ -213,6 +214,7 @@ TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ...@@ -213,6 +214,7 @@ TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[ [
# Model for Zero Shot Image Classification mapping # Model for Zero Shot Image Classification mapping
("blip", "TFBlipModel"),
("clip", "TFCLIPModel"), ("clip", "TFCLIPModel"),
] ]
) )
......
...@@ -13,7 +13,13 @@ ...@@ -13,7 +13,13 @@
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_tf_available,
is_torch_available,
is_vision_available,
)
_import_structure = { _import_structure = {
...@@ -52,6 +58,23 @@ else: ...@@ -52,6 +58,23 @@ else:
"BlipForImageTextRetrieval", "BlipForImageTextRetrieval",
] ]
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_blip"] = [
"TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFBlipModel",
"TFBlipPreTrainedModel",
"TFBlipForConditionalGeneration",
"TFBlipForQuestionAnswering",
"TFBlipVisionModel",
"TFBlipTextModel",
"TFBlipForImageTextRetrieval",
]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_blip import BLIP_PRETRAINED_CONFIG_ARCHIVE_MAP, BlipConfig, BlipTextConfig, BlipVisionConfig from .configuration_blip import BLIP_PRETRAINED_CONFIG_ARCHIVE_MAP, BlipConfig, BlipTextConfig, BlipVisionConfig
from .processing_blip import BlipProcessor from .processing_blip import BlipProcessor
...@@ -81,6 +104,23 @@ if TYPE_CHECKING: ...@@ -81,6 +104,23 @@ if TYPE_CHECKING:
BlipVisionModel, BlipVisionModel,
) )
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_blip import (
TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST,
TFBlipForConditionalGeneration,
TFBlipForImageTextRetrieval,
TFBlipForQuestionAnswering,
TFBlipModel,
TFBlipPreTrainedModel,
TFBlipTextModel,
TFBlipVisionModel,
)
else: else:
import sys import sys
......
...@@ -313,17 +313,12 @@ class BlipAttention(nn.Module): ...@@ -313,17 +313,12 @@ class BlipAttention(nn.Module):
bsz, tgt_len, embed_dim = hidden_states.size() bsz, tgt_len, embed_dim = hidden_states.size()
mixed_qkv = self.qkv(hidden_states)
mixed_qkv = ( mixed_qkv = (
self.qkv(hidden_states) self.qkv(hidden_states)
.reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads) .reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads)
.permute(2, 0, 3, 1, 4) .permute(2, 0, 3, 1, 4)
) )
query_states, key_states, value_states = ( query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2]
mixed_qkv[0],
mixed_qkv[1],
mixed_qkv[2],
)
# Take the dot product between "query" and "key" to get the raw attention scores. # Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
...@@ -587,9 +582,7 @@ class BlipEncoder(nn.Module): ...@@ -587,9 +582,7 @@ class BlipEncoder(nn.Module):
r""" r"""
Args: Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. Embedded representation of the inputs. Should be float, not int tokens.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
...@@ -824,10 +817,7 @@ class BlipModel(BlipPreTrainedModel): ...@@ -824,10 +817,7 @@ class BlipModel(BlipPreTrainedModel):
```""" ```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_outputs = self.vision_model( vision_outputs = self.vision_model(pixel_values=pixel_values, return_dict=return_dict)
pixel_values=pixel_values,
return_dict=return_dict,
)
pooled_output = vision_outputs[1] # pooled_output pooled_output = vision_outputs[1] # pooled_output
image_features = self.visual_projection(pooled_output) image_features = self.visual_projection(pooled_output)
...@@ -993,6 +983,10 @@ class BlipForConditionalGeneration(BlipPreTrainedModel): ...@@ -993,6 +983,10 @@ class BlipForConditionalGeneration(BlipPreTrainedModel):
```""" ```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
vision_outputs = self.vision_model( vision_outputs = self.vision_model(
pixel_values=pixel_values, pixel_values=pixel_values,
...@@ -1037,7 +1031,7 @@ class BlipForConditionalGeneration(BlipPreTrainedModel): ...@@ -1037,7 +1031,7 @@ class BlipForConditionalGeneration(BlipPreTrainedModel):
Overrides *generate* function to be able to use the model as a conditional generator Overrides *generate* function to be able to use the model as a conditional generator
Parameters: Parameters:
pixel_values (*torch.FloatTensor* of shape *(batch_size, image_width, image_height)*: pixel_values (*torch.FloatTensor* of shape *(batch_size, num_channels, image_height, image_width)*:
Input image to be processed Input image to be processed
input_ids (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*): input_ids (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*):
The sequence used as a prompt for the generation. The sequence used as a prompt for the generation.
...@@ -1066,9 +1060,7 @@ class BlipForConditionalGeneration(BlipPreTrainedModel): ...@@ -1066,9 +1060,7 @@ class BlipForConditionalGeneration(BlipPreTrainedModel):
""" """
batch_size = pixel_values.shape[0] batch_size = pixel_values.shape[0]
vision_outputs = self.vision_model( vision_outputs = self.vision_model(pixel_values=pixel_values)
pixel_values=pixel_values,
)
image_embeds = vision_outputs[0] image_embeds = vision_outputs[0]
...@@ -1198,6 +1190,10 @@ class BlipForQuestionAnswering(BlipPreTrainedModel): ...@@ -1198,6 +1190,10 @@ class BlipForQuestionAnswering(BlipPreTrainedModel):
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
vision_outputs = self.vision_model( vision_outputs = self.vision_model(
pixel_values=pixel_values, pixel_values=pixel_values,
...@@ -1266,7 +1262,7 @@ class BlipForQuestionAnswering(BlipPreTrainedModel): ...@@ -1266,7 +1262,7 @@ class BlipForQuestionAnswering(BlipPreTrainedModel):
Parameters: Parameters:
input_ids (*torch.LongTensor* of shape *(batch_size, sequence_length)*): input_ids (*torch.LongTensor* of shape *(batch_size, sequence_length)*):
The sequence used as a prompt for the generation. The sequence used as a prompt for the generation.
pixel_values (*torch.FloatTensor* of shape *(batch_size, image_width, image_height)*: pixel_values (*torch.FloatTensor* of shape *(batch_size, num_channels, image_height, image_width)*:
Input image to be processed Input image to be processed
attention_mask (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*): attention_mask (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`. `1` for Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`. `1` for
...@@ -1295,9 +1291,7 @@ class BlipForQuestionAnswering(BlipPreTrainedModel): ...@@ -1295,9 +1291,7 @@ class BlipForQuestionAnswering(BlipPreTrainedModel):
2 2
``` ```
""" """
vision_outputs = self.vision_model( vision_outputs = self.vision_model(pixel_values=pixel_values)
pixel_values=pixel_values,
)
image_embeds = vision_outputs[0] image_embeds = vision_outputs[0]
...@@ -1412,6 +1406,10 @@ class BlipForImageTextRetrieval(BlipPreTrainedModel): ...@@ -1412,6 +1406,10 @@ class BlipForImageTextRetrieval(BlipPreTrainedModel):
``` ```
""" """
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
vision_outputs = self.vision_model( vision_outputs = self.vision_model(
pixel_values=pixel_values, pixel_values=pixel_values,
......
This diff is collapsed.
This diff is collapsed.
...@@ -453,9 +453,7 @@ class Blip2Encoder(nn.Module): ...@@ -453,9 +453,7 @@ class Blip2Encoder(nn.Module):
r""" r"""
Args: Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. Embedded representation of the inputs. Should be float, not int tokens.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
......
...@@ -68,3 +68,31 @@ def stable_softmax(logits: tf.Tensor, axis: Optional[int] = None, name: Optional ...@@ -68,3 +68,31 @@ def stable_softmax(logits: tf.Tensor, axis: Optional[int] = None, name: Optional
# TODO: When the issue linked above gets sorted, add a check on TF version here and use the original function if # TODO: When the issue linked above gets sorted, add a check on TF version here and use the original function if
# it has the fix. After we drop the support for unfixed versions, remove this function. # it has the fix. After we drop the support for unfixed versions, remove this function.
return tf.nn.softmax(logits=logits + 1e-9, axis=axis, name=name) return tf.nn.softmax(logits=logits + 1e-9, axis=axis, name=name)
def invert_attention_mask(encoder_attention_mask: tf.Tensor) -> tf.Tensor:
"""
Invert an attention mask (e.g., switches 0. and 1.).
Args:
encoder_attention_mask (`torch.Tensor`): An attention mask.
Returns:
`tf.Tensor`: The inverted attention mask.
"""
if not isinstance(encoder_attention_mask, tf.Tensor):
encoder_attention_mask = tf.convert_to_tensor(encoder_attention_mask) # Catches stray NumPy inputs
if encoder_attention_mask.shape.rank == 3:
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
if encoder_attention_mask.shape.rank == 2:
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow
# /transformer/transformer_layers.py#L270
# encoder_extended_attention_mask = (encoder_extended_attention_mask ==
# encoder_extended_attention_mask.transpose(-1, -2))
encoder_extended_attention_mask = (
tf.cast(1, encoder_attention_mask.dtype) - encoder_extended_attention_mask
) * encoder_extended_attention_mask.dtype.min
return encoder_extended_attention_mask
...@@ -556,6 +556,58 @@ class TFBlenderbotSmallPreTrainedModel(metaclass=DummyObject): ...@@ -556,6 +556,58 @@ class TFBlenderbotSmallPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["tf"]) requires_backends(self, ["tf"])
TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST = None
class TFBlipForConditionalGeneration(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFBlipForImageTextRetrieval(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFBlipForQuestionAnswering(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFBlipModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFBlipPreTrainedModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFBlipTextModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFBlipVisionModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
...@@ -342,6 +342,9 @@ class BlipTextModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -342,6 +342,9 @@ class BlipTextModelTest(ModelTesterMixin, unittest.TestCase):
model = BlipTextModel.from_pretrained(model_name) model = BlipTextModel.from_pretrained(model_name)
self.assertIsNotNone(model) self.assertIsNotNone(model)
def test_pt_tf_model_equivalence(self):
super().test_pt_tf_model_equivalence(allow_missing_keys=True)
class BlipModelTester: class BlipModelTester:
def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True): def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True):
...@@ -524,6 +527,9 @@ class BlipModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): ...@@ -524,6 +527,9 @@ class BlipModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
model = BlipModel.from_pretrained(model_name) model = BlipModel.from_pretrained(model_name)
self.assertIsNotNone(model) self.assertIsNotNone(model)
def test_pt_tf_model_equivalence(self):
super().test_pt_tf_model_equivalence(allow_missing_keys=True)
class BlipTextRetrievalModelTester: class BlipTextRetrievalModelTester:
def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True): def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True):
......
...@@ -164,3 +164,6 @@ class BlipTextModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -164,3 +164,6 @@ class BlipTextModelTest(ModelTesterMixin, unittest.TestCase):
for model_name in BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = BlipTextModel.from_pretrained(model_name) model = BlipTextModel.from_pretrained(model_name)
self.assertIsNotNone(model) self.assertIsNotNone(model)
def test_pt_tf_model_equivalence(self):
super().test_pt_tf_model_equivalence(allow_missing_keys=True)
This diff is collapsed.
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the TensorFlow Blip model. """
import unittest
import numpy as np
from transformers import BlipTextConfig
from transformers.testing_utils import require_tf, slow
from transformers.utils import is_tf_available
from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
if is_tf_available():
import tensorflow as tf
from transformers import TFBlipTextModel
from transformers.models.blip.modeling_tf_blip import TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST
class BlipTextModelTester:
def __init__(
self,
parent,
batch_size=12,
seq_length=7,
is_training=True,
use_input_mask=True,
use_labels=True,
vocab_size=99,
hidden_size=32,
projection_dim=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
dropout=0.1,
attention_dropout=0.1,
max_position_embeddings=512,
initializer_range=0.02,
bos_token_id=0,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.projection_dim = projection_dim
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.dropout = dropout
self.attention_dropout = attention_dropout
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.scope = scope
self.bos_token_id = bos_token_id
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = random_attention_mask([self.batch_size, self.seq_length])
if input_mask is not None:
input_mask = input_mask.numpy()
batch_size, seq_length = input_mask.shape
rnd_start_indices = np.random.randint(1, seq_length - 1, size=(batch_size,))
for batch_idx, start_index in enumerate(rnd_start_indices):
input_mask[batch_idx, :start_index] = 1
input_mask[batch_idx, start_index:] = 0
config = self.get_config()
return config, input_ids, tf.convert_to_tensor(input_mask)
def get_config(self):
return BlipTextConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
projection_dim=self.projection_dim,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
dropout=self.dropout,
attention_dropout=self.attention_dropout,
max_position_embeddings=self.max_position_embeddings,
initializer_range=self.initializer_range,
bos_token_id=self.bos_token_id,
)
def create_and_check_model(self, config, input_ids, input_mask):
model = TFBlipTextModel(config=config)
result = model(input_ids, attention_mask=input_mask, training=False)
result = model(input_ids, training=False)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, input_ids, input_mask = config_and_inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
return config, inputs_dict
@require_tf
class BlipTextModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (TFBlipTextModel,) if is_tf_available() else ()
test_onnx = False
test_pruning = False
test_head_masking = False
def setUp(self):
self.model_tester = BlipTextModelTester(self)
self.config_tester = ConfigTester(self, config_class=BlipTextConfig, hidden_size=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_training(self):
pass
def test_training_gradient_checkpointing(self):
pass
@unittest.skip(reason="Blip does not use inputs_embeds")
def test_inputs_embeds(self):
pass
@unittest.skip(reason="BlipTextModel has no base class and is not available in MODEL_MAPPING")
def test_save_load_fast_init_from_base(self):
pass
@unittest.skip(reason="BlipTextModel has no base class and is not available in MODEL_MAPPING")
def test_save_load_fast_init_to_base(self):
pass
@slow
def test_model_from_pretrained(self):
for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
try:
model = TFBlipTextModel.from_pretrained(model_name)
except OSError:
model = TFBlipTextModel.from_pretrained(model_name, from_pt=True)
self.assertIsNotNone(model)
def test_pt_tf_model_equivalence(self, allow_missing_keys=True):
super().test_pt_tf_model_equivalence(allow_missing_keys=allow_missing_keys)
...@@ -1984,7 +1984,7 @@ class ModelTesterMixin: ...@@ -1984,7 +1984,7 @@ class ModelTesterMixin:
self.check_pt_tf_outputs(tf_outputs, pt_outputs, type(pt_model)) self.check_pt_tf_outputs(tf_outputs, pt_outputs, type(pt_model))
@is_pt_tf_cross_test @is_pt_tf_cross_test
def test_pt_tf_model_equivalence(self): def test_pt_tf_model_equivalence(self, allow_missing_keys=False):
import transformers import transformers
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
...@@ -2036,8 +2036,12 @@ class ModelTesterMixin: ...@@ -2036,8 +2036,12 @@ class ModelTesterMixin:
# Check we can load pt model in tf and vice-versa with model => model functions # Check we can load pt model in tf and vice-versa with model => model functions
# Here requires `tf_inputs_dict` to build `tf_model` # Here requires `tf_inputs_dict` to build `tf_model`
tf_inputs_dict = self.prepare_tf_inputs_from_pt_inputs(pt_inputs_dict) tf_inputs_dict = self.prepare_tf_inputs_from_pt_inputs(pt_inputs_dict)
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict) tf_model = transformers.load_pytorch_model_in_tf2_model(
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model) tf_model, pt_model, tf_inputs=tf_inputs_dict, allow_missing_keys=allow_missing_keys
)
pt_model = transformers.load_tf2_model_in_pytorch_model(
pt_model, tf_model, allow_missing_keys=allow_missing_keys
)
# Original test: check without `labels` # Original test: check without `labels`
self.check_pt_tf_models(tf_model, pt_model, pt_inputs_dict) self.check_pt_tf_models(tf_model, pt_model, pt_inputs_dict)
...@@ -2049,11 +2053,15 @@ class ModelTesterMixin: ...@@ -2049,11 +2053,15 @@ class ModelTesterMixin:
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin") pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
torch.save(pt_model.state_dict(), pt_checkpoint_path) torch.save(pt_model.state_dict(), pt_checkpoint_path)
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path) tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(
tf_model, pt_checkpoint_path, allow_missing_keys=allow_missing_keys
)
tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5") tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
tf_model.save_weights(tf_checkpoint_path) tf_model.save_weights(tf_checkpoint_path)
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path) pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(
pt_model, tf_checkpoint_path, allow_missing_keys=allow_missing_keys
)
# Original test: check without `labels` # Original test: check without `labels`
self.check_pt_tf_models(tf_model, pt_model, pt_inputs_dict) self.check_pt_tf_models(tf_model, pt_model, pt_inputs_dict)
......
...@@ -668,7 +668,7 @@ class TFModelTesterMixin: ...@@ -668,7 +668,7 @@ class TFModelTesterMixin:
self.check_pt_tf_outputs(tf_outputs, pt_outputs, type(tf_model)) self.check_pt_tf_outputs(tf_outputs, pt_outputs, type(tf_model))
@is_pt_tf_cross_test @is_pt_tf_cross_test
def test_pt_tf_model_equivalence(self): def test_pt_tf_model_equivalence(self, allow_missing_keys=False):
import transformers import transformers
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
...@@ -703,8 +703,12 @@ class TFModelTesterMixin: ...@@ -703,8 +703,12 @@ class TFModelTesterMixin:
tf_inputs_dict_with_labels = None tf_inputs_dict_with_labels = None
# Check we can load pt model in tf and vice-versa with model => model functions # Check we can load pt model in tf and vice-versa with model => model functions
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict) tf_model = transformers.load_pytorch_model_in_tf2_model(
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model) tf_model, pt_model, tf_inputs=tf_inputs_dict, allow_missing_keys=allow_missing_keys
)
pt_model = transformers.load_tf2_model_in_pytorch_model(
pt_model, tf_model, allow_missing_keys=allow_missing_keys
)
# Original test: check without `labels` # Original test: check without `labels`
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict) self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
...@@ -716,11 +720,15 @@ class TFModelTesterMixin: ...@@ -716,11 +720,15 @@ class TFModelTesterMixin:
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin") pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
torch.save(pt_model.state_dict(), pt_checkpoint_path) torch.save(pt_model.state_dict(), pt_checkpoint_path)
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path) tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(
tf_model, pt_checkpoint_path, allow_missing_keys=allow_missing_keys
)
tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5") tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
tf_model.save_weights(tf_checkpoint_path) tf_model.save_weights(tf_checkpoint_path)
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path) pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(
pt_model, tf_checkpoint_path, allow_missing_keys=allow_missing_keys
)
# Original test: check without `labels` # Original test: check without `labels`
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict) self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
...@@ -791,7 +799,7 @@ class TFModelTesterMixin: ...@@ -791,7 +799,7 @@ class TFModelTesterMixin:
name="pixel_values", name="pixel_values",
dtype="float32", dtype="float32",
) )
elif model_class.__name__ in ["TFCLIPModel", "TFGroupViTModel"]: elif model_class.__name__ in ["TFCLIPModel", "TFGroupViTModel", "TFBlipModel"]:
inputs = { inputs = {
"input_ids": tf.keras.Input(batch_shape=(3, max_input), name="input_ids", dtype="int32"), "input_ids": tf.keras.Input(batch_shape=(3, max_input), name="input_ids", dtype="int32"),
"pixel_values": tf.keras.Input( "pixel_values": tf.keras.Input(
...@@ -1792,6 +1800,8 @@ class TFModelTesterMixin: ...@@ -1792,6 +1800,8 @@ class TFModelTesterMixin:
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
model = model_class(config) model = model_class(config)
tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class, return_labels=False) tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class, return_labels=False)
if "labels" in tf_inputs_dict:
return # This is some kinda funky decoder model that needs labels in its forward pass
tf_inputs_dict = { tf_inputs_dict = {
key: val key: val
for key, val in tf_inputs_dict.items() for key, val in tf_inputs_dict.items()
...@@ -1805,7 +1815,7 @@ class TFModelTesterMixin: ...@@ -1805,7 +1815,7 @@ class TFModelTesterMixin:
test_batch = next(iter(tf_dataset)) test_batch = next(iter(tf_dataset))
if isinstance(test_batch, tf.Tensor): if isinstance(test_batch, tf.Tensor):
self.assertEqual(len(test_batch), len(input_dataset)) # Assert we didn't lose any data self.assertEqual(len(test_batch), len(input_dataset)) # Assert we didn't lose any data
else: elif isinstance(test_batch, dict):
# Assert we discarded the unwanted extra column but kept everything else # Assert we discarded the unwanted extra column but kept everything else
self.assertEqual(len(test_batch), len(input_dataset.features) - 1) self.assertEqual(len(test_batch), len(input_dataset.features) - 1)
self.assertNotIn("extra_unwanted_column", test_batch) self.assertNotIn("extra_unwanted_column", test_batch)
......
...@@ -145,6 +145,7 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [ ...@@ -145,6 +145,7 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
"TFSegformerDecodeHead", # Not a regular model. "TFSegformerDecodeHead", # Not a regular model.
"AltRobertaModel", # Building part of bigger (tested) model. "AltRobertaModel", # Building part of bigger (tested) model.
"BlipTextLMHeadModel", # No need to test it as it is tested by BlipTextVision models "BlipTextLMHeadModel", # No need to test it as it is tested by BlipTextVision models
"TFBlipTextLMHeadModel", # No need to test it as it is tested by BlipTextVision models
"BridgeTowerTextModel", # No need to test it as it is tested by BridgeTowerModel model. "BridgeTowerTextModel", # No need to test it as it is tested by BridgeTowerModel model.
"BridgeTowerVisionModel", # No need to test it as it is tested by BridgeTowerModel model. "BridgeTowerVisionModel", # No need to test it as it is tested by BridgeTowerModel model.
"SpeechT5Decoder", # Building part of bigger (tested) model. "SpeechT5Decoder", # Building part of bigger (tested) model.
...@@ -205,6 +206,12 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [ ...@@ -205,6 +206,12 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
"BlipVisionModel", "BlipVisionModel",
"BlipTextLMHeadModel", "BlipTextLMHeadModel",
"BlipTextModel", "BlipTextModel",
"TFBlipForConditionalGeneration",
"TFBlipForImageTextRetrieval",
"TFBlipForQuestionAnswering",
"TFBlipVisionModel",
"TFBlipTextLMHeadModel",
"TFBlipTextModel",
"Swin2SRForImageSuperResolution", "Swin2SRForImageSuperResolution",
"BridgeTowerForImageAndTextRetrieval", "BridgeTowerForImageAndTextRetrieval",
"BridgeTowerForMaskedLM", "BridgeTowerForMaskedLM",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment