"git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "64ee3e57389bce76e22f4da7ae5716553a604ce3"
Unverified Commit 94306352 authored by Alazar's avatar Alazar Committed by GitHub
Browse files

Port IDEFICS to tensorflow (#26870)



* Initial commit

* Just a copy of modeling_idefics.py that will be ported to TF

* - Prepend TF to the name of all classes
- Convert pytorch ops to TF (not all operations are converted yet)

* Add TF imports

* Add autotranslated files

* Add TF classes to model_tf_auto.py

* Add the TF classes in model_doc

* include auto-translated code

* Adopted from auto-translated version

* Add a forgotten super().build

* Add test code for TF version.

* Fix indentation and load pytorch weights for now

* Some fixes. Many tests are still failing but some are passing now.

- I have added TODO's for some of the hacks I made to unblock me
  and I will address them soon
- I have the processing_idefics.py hacked in my view to support TF temporarily

* Add ALL_LAYERNORM_LAYERS to match pytorch

* Revert "Add ALL_LAYERNORM_LAYERS to match pytorch"

This reverts commit 7e0a35119b4d7a6284d04d8c543fba1b29e573c9 as it
is not needed in the tf implementation.

* Fix freeze_relevant_params()

* Some more fixes

* Fix test_attention_outputs

* Add tf stuff to processing_idefics.py

processing_idefics.py supports both pytorch and tf now.

test_processor_idefics.py for pytorch is passing, so i didn't break anything
but still some issues with tf. I also need to add tf tests in
test_processor_idefics.py.

* Pass return_tensors to image processing code and fix test

* Pass return_tensors to the image processor __init__

* Fix several test cases

- Make input to some of the forward pass of type `TFModelInputType`
- Decorate main layer forward pass with `@unpack_inputs`
- Decorate main layer with `@keras_serializable`
- Pass `inputs` to TFIdeficsModel

* Some more fixes forgotten in last commit

* Fix processing code and vision_tf.py

* Fix perceiver bug

* Import from

* Auto-add build() methods + style pass

* Fix build() errors due to `None` being passed as shape to some layers

* Change name in TFIdeficsForVisionText2Text to attribute in IdeficsForVisionText2Text

* Fix pytorch weights load for tf2

There were a lot of `name=` missing in weight initialization code.

* Attempt to fix CI

* Add back accidently removed line

* Remove torch-specific stuff from the TF test file

* make fix-copies, make style, remove autotranslated files

* Fixes to imports/docstrings

* Let's try the from future import in desperation

* Fix the core random_attention_mask fn to match the torch/flax behaviour

* Clean random_attention_mask up correctly

* Remove torch-only test

* Fix loss shape, couple of nits

* make style

* Don't test for OOB embeddings because IDEFICS uses those deliberately

* Fix loss computation to handle masking

* Fix test failures when flattening

* Fix some test failures

- Add cross attention gate which was missing and wasn't being passed arround
- Fix overwriting of image_attention_mask due to hack I had for dummy inputs

* Add a proper stateless scaled_dot_product_attention

* make style

* Adding missing attribute from the PyTorch version

* Small cleanups to decoupledlinearlayer in case that helps

* Pass epsilon to LayerNormalization

* Attemp to fix pytorch weight cross-loading for TFIdeficsEmbedding

* Fix a bug in TFIdeficsGatedCrossAttentionLayer

* Patching up build() methods

* Constant self.inv_freq

* Constant self.inv_freq

* First working version

The TF implementation works now, there was a bug in the TFIdeficsDecoupledLinear
where the weights were mis-intialized (in_features,out_features)
when it should be: (out_features, in_features)

I have tested this so far with tiny-random and idefics-9b-instruct
and gives correct output.

I also dumped the final outputs for both pytorch and TF
and they are identical.

* Fix some test failures

* remove print statement

* Fix return_tensors

* Fix CI test failure check_code_quality

* Attempt to fix CI failures by running `make fixup`

The hardcoded IDs in test_modeling_tf_idefics.py are for the integration
test and makes that file unreadable and should probably be moved to a seperate file.

* Attempt to fix tests_pr_documentation_tests

* Fix a test failure in test_image_processing_idefics.py

* Fix test test_pt_tf_model_equivalence

* Fix a few failures

* Tiny fix

* Some minor fixes

* Remove a duplicate test

* Override a few test failures for IDEFICS

- `test_keras_save_load` is passing now
- `test_compile_tf_model` is still failing

* Fix processing_idefics.py after rebase

* Guard import keras with is_tf_available

* fix check code quality

* fix check code quality

* Minor fixes

* Skip test_save_load temporarily

This test passed on my local box but fails on the CI, skipping
for now to see if there are other remaining failures on the CI.

* Run `ruff format tests src utils`

* Fix last failing test, `test_compile_tf_model`

* Add fixes for vision_tf.py

I forgot to add this file in last commit.

* Minor fixes

* Replace "<<<" with "<<" for doc tests

IDEFICS-9B is too big for doctest runner, so don't run it there

* Make code more readable

* Fix bug after code review

I added a layer_norm_eps to IdeficsConfig but I don't even need it
since the vision config has a layer_norm_eps.

* Fix after code review

Use original code tokenizer.convert_tokens_to_ids

* Keep PyTorch as the default return_tensors

* Fixes to modeling_tf after code review

* Fixes from code review

- Remove all references of `TF_IDEFICS_PRETRAINED_MODEL_ARCHIVE_LIST`
- Pass 1e-5 to LayerNormalization in perceiver

* Run ruff

* Undo a change

* Refactor processing code after Matt's suggestion

* Remove TODO's that aren't needed anymore

* For pytorch, Use original pytorch processing code from main

Since this PR is a TF port it shouldn't make any modifications
to pytorch IDEFICS code. This changes undo's the pytorch processing
modifications I made and uses original code from main.

* Update tests/models/idefics/test_modeling_idefics.py

* Update tests/models/idefics/test_modeling_tf_idefics.py

* Add missing imports for is_pt_tf_cross_test

* [DO NOT MERGE]: This is a commit for debugging and will be reverted

The cross test `test_pt_tf_model_equivalence` passes locally but
fails when running on the CI. This commit is to help debug that
and will be reverted.

* Revert "[DO NOT MERGE]: This is a commit for debugging and will be reverted"

This reverts commit 8f0d709ec5bd46685fb0b4259d914ffee794875b.

* [DO NOT MERGE]: This commit is for debugging a CI failure and will be reverted

* [DO NOT MERGE]: This commit is for debugging a CI failure and will be reverted

* Revert "[DO NOT MERGE]: This commit is for debugging a CI failure and will be reverted"

This reverts commit 998cc38b8c3d313bf5e5eb55a7f5b7b881897b89.

* Revert "[DO NOT MERGE]: This commit is for debugging a CI failure and will be reverted"

This reverts commit 1c695ac4219c4ae4d39b330b01744dc27deb7dd4.

* Don't skip test_save_load

IIRC test_save_load was also failing on the CI but not on my local
box, it might be easier to debug that on the CI first than the cross tests

* Debugging commit, will be reverted

* Revert "Debugging commit, will be reverted"

This reverts commit 8eafc8e41e20c4e95a3a90834f06a6e9f445e2d5.

* Override `test_save_load` and push model to save

Maybe this will help me repro this weird bug

* pass my repo_id

* add endpoint

* Pass a temp (write) token just for this CI

* Undo last few commits, still pushing to hub for model debugging

The issue seems to be with save_pretrained(),  when I looked at the model saved
from the CI test failure it is basically empty and has no weights.
`self.save_weights(..)` seems to be failing in save_pretrained but needs
more debugging

* Add logging to modeling tf utils, will be reverted just for debugging

* Debugging, will revert

* Revert "Debugging, will revert"

This reverts commit 9d0d3075fb7c82d8cde3a5c76bc8f3876c5c55d3.

* Revert "Add logging to modeling tf utils, will be reverted just for debugging"

This reverts commit 774b6b7b1c17b3ce5d7634ade768f2f686cee617.

* Remove `test_save_load`

The CI failures are gone after my latest rebase, no idea why
but I was still saving the model to my hub on HF and the tf_model.h5
file now has everything.

* Run make fix-copies

* Run ruff format tests src utils

* Debugging commit, will be reverted

* Run ruff, also trigger CI run

* Run ruff again

* Undo debugging commit

---------
Co-authored-by: default avatarMatt <rocketknight1@gmail.com>
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>
parent de2f7221
...@@ -160,7 +160,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -160,7 +160,7 @@ Flax), PyTorch, and/or TensorFlow.
| [HerBERT](model_doc/herbert) | ✅ | ✅ | ✅ | | [HerBERT](model_doc/herbert) | ✅ | ✅ | ✅ |
| [Hubert](model_doc/hubert) | ✅ | ✅ | ❌ | | [Hubert](model_doc/hubert) | ✅ | ✅ | ❌ |
| [I-BERT](model_doc/ibert) | ✅ | ❌ | ❌ | | [I-BERT](model_doc/ibert) | ✅ | ❌ | ❌ |
| [IDEFICS](model_doc/idefics) | ✅ | | ❌ | | [IDEFICS](model_doc/idefics) | ✅ | | ❌ |
| [Idefics2](model_doc/idefics2) | ✅ | ❌ | ❌ | | [Idefics2](model_doc/idefics2) | ✅ | ❌ | ❌ |
| [ImageGPT](model_doc/imagegpt) | ✅ | ❌ | ❌ | | [ImageGPT](model_doc/imagegpt) | ✅ | ❌ | ❌ |
| [Informer](model_doc/informer) | ✅ | ❌ | ❌ | | [Informer](model_doc/informer) | ✅ | ❌ | ❌ |
......
...@@ -52,6 +52,16 @@ To train a new IDEFICS model from scratch use the m4 codebase (a link will be pr ...@@ -52,6 +52,16 @@ To train a new IDEFICS model from scratch use the m4 codebase (a link will be pr
[[autodoc]] IdeficsForVisionText2Text [[autodoc]] IdeficsForVisionText2Text
- forward - forward
## TFIdeficsModel
[[autodoc]] TFIdeficsModel
- call
## TFIdeficsForVisionText2Text
[[autodoc]] TFIdeficsForVisionText2Text
- call
## IdeficsImageProcessor ## IdeficsImageProcessor
[[autodoc]] IdeficsImageProcessor [[autodoc]] IdeficsImageProcessor
......
...@@ -3862,6 +3862,15 @@ else: ...@@ -3862,6 +3862,15 @@ else:
"TFHubertPreTrainedModel", "TFHubertPreTrainedModel",
] ]
) )
_import_structure["models.idefics"].extend(
[
"TFIdeficsForVisionText2Text",
"TFIdeficsModel",
"TFIdeficsPreTrainedModel",
]
)
_import_structure["models.layoutlm"].extend( _import_structure["models.layoutlm"].extend(
[ [
"TFLayoutLMForMaskedLM", "TFLayoutLMForMaskedLM",
...@@ -7905,6 +7914,11 @@ if TYPE_CHECKING: ...@@ -7905,6 +7914,11 @@ if TYPE_CHECKING:
TFHubertModel, TFHubertModel,
TFHubertPreTrainedModel, TFHubertPreTrainedModel,
) )
from .models.idefics import (
TFIdeficsForVisionText2Text,
TFIdeficsModel,
TFIdeficsPreTrainedModel,
)
from .models.layoutlm import ( from .models.layoutlm import (
TFLayoutLMForMaskedLM, TFLayoutLMForMaskedLM,
TFLayoutLMForQuestionAnswering, TFLayoutLMForQuestionAnswering,
......
...@@ -58,6 +58,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict( ...@@ -58,6 +58,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict(
("gptj", "TFGPTJModel"), ("gptj", "TFGPTJModel"),
("groupvit", "TFGroupViTModel"), ("groupvit", "TFGroupViTModel"),
("hubert", "TFHubertModel"), ("hubert", "TFHubertModel"),
("idefics", "TFIdeficsModel"),
("layoutlm", "TFLayoutLMModel"), ("layoutlm", "TFLayoutLMModel"),
("layoutlmv3", "TFLayoutLMv3Model"), ("layoutlmv3", "TFLayoutLMv3Model"),
("led", "TFLEDModel"), ("led", "TFLEDModel"),
...@@ -112,6 +113,7 @@ TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( ...@@ -112,6 +113,7 @@ TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
("funnel", "TFFunnelForPreTraining"), ("funnel", "TFFunnelForPreTraining"),
("gpt-sw3", "TFGPT2LMHeadModel"), ("gpt-sw3", "TFGPT2LMHeadModel"),
("gpt2", "TFGPT2LMHeadModel"), ("gpt2", "TFGPT2LMHeadModel"),
("idefics", "TFIdeficsForVisionText2Text"),
("layoutlm", "TFLayoutLMForMaskedLM"), ("layoutlm", "TFLayoutLMForMaskedLM"),
("lxmert", "TFLxmertForPreTraining"), ("lxmert", "TFLxmertForPreTraining"),
("mobilebert", "TFMobileBertForPreTraining"), ("mobilebert", "TFMobileBertForPreTraining"),
......
...@@ -13,7 +13,13 @@ ...@@ -13,7 +13,13 @@
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_tf_available,
is_torch_available,
is_vision_available,
)
_import_structure = {"configuration_idefics": ["IdeficsConfig"]} _import_structure = {"configuration_idefics": ["IdeficsConfig"]}
...@@ -39,6 +45,17 @@ else: ...@@ -39,6 +45,17 @@ else:
] ]
_import_structure["processing_idefics"] = ["IdeficsProcessor"] _import_structure["processing_idefics"] = ["IdeficsProcessor"]
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_idefics"] = [
"TFIdeficsForVisionText2Text",
"TFIdeficsModel",
"TFIdeficsPreTrainedModel",
]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_idefics import IdeficsConfig from .configuration_idefics import IdeficsConfig
...@@ -64,6 +81,17 @@ if TYPE_CHECKING: ...@@ -64,6 +81,17 @@ if TYPE_CHECKING:
) )
from .processing_idefics import IdeficsProcessor from .processing_idefics import IdeficsProcessor
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_idefics import (
TFIdeficsForVisionText2Text,
TFIdeficsModel,
TFIdeficsPreTrainedModel,
)
else: else:
import sys import sys
......
...@@ -92,8 +92,9 @@ class IdeficsImageProcessor(BaseImageProcessor): ...@@ -92,8 +92,9 @@ class IdeficsImageProcessor(BaseImageProcessor):
image_mean: Optional[Union[float, List[float]]] = None, image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None,
transform: Callable = None, transform: Callable = None,
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
**kwargs, **kwargs,
) -> TensorType.PYTORCH: ) -> TensorType:
""" """
Preprocess a batch of images. Preprocess a batch of images.
...@@ -162,7 +163,6 @@ class IdeficsImageProcessor(BaseImageProcessor): ...@@ -162,7 +163,6 @@ class IdeficsImageProcessor(BaseImageProcessor):
images = [self.rescale(image=image, scale=1 / 255) for image in images] images = [self.rescale(image=image, scale=1 / 255) for image in images]
images = [self.normalize(x, mean=image_mean, std=image_std) for x in images] images = [self.normalize(x, mean=image_mean, std=image_std) for x in images]
images = [to_channel_dimension_format(x, ChannelDimension.FIRST) for x in images] images = [to_channel_dimension_format(x, ChannelDimension.FIRST) for x in images]
# TODO: this converts to torch tensors - switch to convert_to_tensors once it becomes available images = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)["pixel_values"]
images = BatchFeature(data={"pixel_values": images}, tensor_type=TensorType.PYTORCH)["pixel_values"]
return images return images
This diff is collapsed.
# This code was adapted from https://github.com/lucidrains/flamingo-pytorch licensed under the MIT License.
#
# MIT License
#
# Copyright (c) 2020 The Google AI Language Team Authors, The HuggingFace Inc. team and github/lonePatient
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""
Generic interface to various configurations of the Perceiver Resampler, that simply takes in a series of (potentially
time-indexed) contextual embeddings, and "resamples" (compresses) them down to a pre-specified number of latents! Note
that the Perceiver in general resamples based solely off the *long-range* context; there's a nice opportunity here to
prime the Perceiver Resampler with say a single layer's worth of language embeddings (the target domain), and use that
to softly "retrieve & compress" what we need --> this would be a novel contribution we should explore.
References:
- DeepMind's Flamingo: https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model
- Code borrowed w/ love from: https://github.com/lucidrains/flamingo-pytorch
"""
from typing import Optional, Tuple
import tensorflow as tf
from ...modeling_tf_utils import shape_list
from .configuration_idefics import IdeficsConfig
class TFIdeficsPerceiverResampler(tf.keras.layers.Layer):
def __init__(
self, config: IdeficsConfig, embed_dim: int, depth: int, n_heads: int, head_dim: int, n_latents: int, **kwargs
) -> None:
"""
Instantiates a Perceiver Resampler that operates over a sequence of embeddings (say from a ResNet or ViT or
MAE) of a given dimension, performs `depth` blocks of cross-attention with a fixed `n_latents` inputs, then
returns a Tensor of shape [bsz, n_latents, embed_dim]. :param embed_dim: Dimensionality of embeddings being fed
to the Perceiver Resampler (also dimensionality of latent embeddings *returned* by the Perceiver Resampler.
Could be e.g., VIT embed_dim, ResNet pool dim, and so on.
Args:
config (`IdeficsConfig`): config object
embed_dim (`int`): The size of each embedding vector
depth (`int`): Depth of the Perceiver Resampler (Transformer w/ cross attention). Should be shallow (< 3).
n_heads (`int`): Number of heads in each Transformer block (for multi-headed self-attention).
head_dim (`int`): Dimensionality of each head projection in the Transformer block.
n_latents (`int`):
Number of latent embeddings to resample ("compress") the input sequence to (usually < 128).
"""
super().__init__(**kwargs)
self.embed_dim, self.n_heads, self.head_dim, self.n_latents = embed_dim, n_heads, head_dim, n_latents
self.qk_layer_norms = config.perceiver_config.qk_layer_norms_perceiver
self.intermediate_dim = (
self.embed_dim * 4
if not hasattr(config.vision_config, "embed_dim")
else config.vision_config.embed_dim * 4
)
# Create Transformer Blocks
self.blocks = []
for i in range(depth):
self.blocks.append(
[
TFIdeficsPerceiverAttention(
self.embed_dim, self.n_heads, self.head_dim, self.qk_layer_norms, name=f"blocks.{i}.0"
),
TFIdeficsMLP(self.intermediate_dim, config, name=f"blocks.{i}.1"),
]
)
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm")
def build(self, input_shape):
# Create Latents for Perceiver
self.latents = self.add_weight(
shape=(self.n_latents, self.embed_dim), initializer="random_normal", trainable=True, name="latents"
)
super().build(input_shape)
def call(self, context: tf.Tensor) -> tf.Tensor:
"""Resample arbitrary length context & *compress* down to self.n_latents latent embeddings"""
# tf.repeat(self.latents, "seq embed -> bsz seq embed", bsz=context.shape[0])
latents = tf.expand_dims(self.latents, axis=0)
latents = tf.tile(latents, [tf.shape(context)[0], 1, 1])
# Feed through Perceiver Attention blocks...
for attn, ff in self.blocks:
latents = attn(context, latents) + latents
latents = ff(latents) + latents
return self.layer_norm(latents)
class TFIdeficsPerceiverAttention(tf.keras.layers.Layer):
def __init__(self, embed_dim: int, n_heads: int, head_dim: int, qk_layer_norms: bool, **kwargs) -> None:
"""Perceiver Cross-Attention Module --> let long-form inputs be `context`, resampled embeddings be `latents`"""
super().__init__(**kwargs)
self.embed_dim, self.n_heads, self.head_dim = embed_dim, n_heads, head_dim
self.qk_layer_norms = qk_layer_norms
# Normalization & Scaling
self.context_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="context_layer_norm")
self.latents_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="latents_layer_norm")
if self.qk_layer_norms:
self.q_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="q_layer_norm")
self.k_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="k_layer_norm")
self.qk_scale = self.head_dim**-0.5
# Q, K, V Projection (no bias -- detail from Perceiver/Flamingo Papers).
self.q_proj = tf.keras.layers.Dense(self.n_heads * self.head_dim, use_bias=False, name="q_proj")
self.k_proj = tf.keras.layers.Dense(self.n_heads * self.head_dim, use_bias=False, name="k_proj")
self.v_proj = tf.keras.layers.Dense(self.n_heads * self.head_dim, use_bias=False, name="v_proj")
self.output_proj = tf.keras.layers.Dense(embed_dim, use_bias=False, name="output_proj")
def call(self, context: tf.Tensor, latents: tf.Tensor) -> tf.Tensor:
"""
Runs Perceiver Self-Attention, with special (context, latents) appended along the `seq` dimension!
Args:
context (`tf.Tensor`):
Tensor of shape `[bsz, seq, embed_dim]` representing long-form context to resample.
latents (`tf.Tensor`):
Tensor of shape `[bsz, n_latents, embed_dim]` representing fixed length latents to compress to.
Returns:
`tf.Tensor`: Tensor of shape `[bsz, n_latents, embed_dim]` representing attention over latents w/ cross
from context.
"""
context = self.context_layer_norm(context)
latents = self.latents_layer_norm(latents)
batch_size, seq_length, embed_dim = shape_list(context)
# Query, Key, Value Projections --> Note that in Flamingo, latents are *concatenated* with context prior to attn!
# Note: This results in queries w/ `seq = n_latents`, and keys, values with `seq = len(context) + n_latents`
q = self.q_proj(latents)
k = self.k_proj(tf.concat([context, latents], axis=-2))
v = self.v_proj(tf.concat([context, latents], axis=-2))
# Multiheaded Self-Attention w/ stable softmax (subtract per-row max -- `amax` -- before softmax call)
# =>> `attn` should be a 2D matrix of shape [n_latents x (context + n_latents)]
q, k, v = [
tf.transpose(tf.reshape(x, (batch_size, x.shape[1], self.n_heads, self.head_dim)), perm=[0, 2, 1, 3])
for x in (q, k, v)
]
if self.qk_layer_norms:
q = self.q_layer_norm(q)
k = self.k_layer_norm(k)
scores = tf.einsum("... i d, ... j d -> ... i j", q * self.qk_scale, k)
stabilized_scores = scores - tf.reduce_max(scores, axis=-1, keepdims=True)
attn = tf.nn.softmax(stabilized_scores, axis=-1)
# Attend & project back to output...
resampled = tf.einsum("... i j, ... j d -> ... i d", attn, v)
return self.output_proj(
tf.reshape(tf.transpose(resampled, perm=[0, 2, 1, 3]), (batch_size, -1, self.n_heads * self.head_dim))
)
class TFIdeficsMLP(tf.keras.layers.Layer):
def __init__(self, intermediate_size, config: IdeficsConfig, **kwargs):
"""Simple MLP block with intermediate_size and embedding size"""
super().__init__(**kwargs)
self.embed_dim = config.vision_config.embed_dim
self.ln = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="ln")
self.fc = tf.keras.layers.Dense(intermediate_size, use_bias=False, name="fc")
self.act = tf.keras.layers.ReLU(name="act")
self.c_proj = tf.keras.layers.Dense(self.embed_dim, use_bias=False, name="c_proj")
def call(self, hidden_states: Optional[Tuple[tf.Tensor]]) -> tf.Tensor:
hidden_states = self.ln(hidden_states)
hidden_states = self.fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.c_proj(hidden_states)
return hidden_states
...@@ -22,34 +22,53 @@ from urllib.parse import urlparse ...@@ -22,34 +22,53 @@ from urllib.parse import urlparse
from ...feature_extraction_utils import BatchFeature from ...feature_extraction_utils import BatchFeature
from ...processing_utils import ProcessorMixin from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, TextInput, TruncationStrategy from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, TextInput, TruncationStrategy
from ...utils import TensorType, is_torch_available from ...utils import is_tf_available, is_torch_available
if is_torch_available(): if is_torch_available():
import torch import torch
if is_tf_available():
import tensorflow as tf
IMAGE_TOKEN = "<image>" IMAGE_TOKEN = "<image>"
# copied from m4.training.packing # copied from m4.training.packing
def incremental_to_binary_attention_mask(incremental_mask, num_classes=-1): def incremental_to_binary_attention_mask(incremental_mask, return_tensors, num_classes=-1):
# This function converts: [-1, 0, 1] => [[0, 0], [1, 0], [0, 1]] # Set elements >= num_classes to -1
# If any of images index are more than num_classes, set them to -1.
# Words after the max number of images allowed have been seen don't attend on anything
if num_classes != -1: if num_classes != -1:
if return_tensors == "pt":
incremental_mask[incremental_mask >= num_classes] = -1 incremental_mask[incremental_mask >= num_classes] = -1
elif return_tensors == "tf":
incremental_mask = tf.where(incremental_mask >= num_classes, -1, incremental_mask)
# Create mask for negative values
if return_tensors == "pt":
negatives = incremental_mask == -1 negatives = incremental_mask == -1
incremental_mask[negatives] = 0 incremental_mask[negatives] = 0
attn_mask = torch.nn.functional.one_hot(incremental_mask, num_classes=num_classes) attn_mask = torch.nn.functional.one_hot(incremental_mask, num_classes=num_classes)
attn_mask[negatives, :] = 0 attn_mask[negatives, :] = 0
elif return_tensors == "tf":
negatives = tf.equal(incremental_mask, -1)
incremental_mask = tf.where(negatives, 0, incremental_mask)
attn_mask = tf.one_hot(incremental_mask, depth=num_classes)
# Reshape 'negatives' to add an extra dimension, making it [batch_size, seq_length, 1]
negatives_expanded = tf.expand_dims(negatives, -1)
attn_mask = tf.where(negatives_expanded, tf.zeros_like(attn_mask), attn_mask)
return attn_mask return attn_mask
# copied from m4.training.packing # copied from m4.training.packing
def image_attention_mask_for_packed_input_ids(input_ids, tokenizer): def image_attention_mask_for_packed_input_ids(input_ids, tokenizer, return_tensors):
if return_tensors == "pt":
return image_attention_mask_for_packed_input_ids_pt(input_ids, tokenizer)
elif return_tensors == "tf":
return image_attention_mask_for_packed_input_ids_tf(input_ids, tokenizer)
def image_attention_mask_for_packed_input_ids_pt(input_ids, tokenizer):
image_attention_mask = torch.full_like(input_ids, fill_value=-1) image_attention_mask = torch.full_like(input_ids, fill_value=-1)
next_image_attention_mask = torch.full_like(input_ids, fill_value=-1) next_image_attention_mask = torch.full_like(input_ids, fill_value=-1)
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
...@@ -96,6 +115,39 @@ def image_attention_mask_for_packed_input_ids(input_ids, tokenizer): ...@@ -96,6 +115,39 @@ def image_attention_mask_for_packed_input_ids(input_ids, tokenizer):
return image_attention_mask, next_image_attention_mask return image_attention_mask, next_image_attention_mask
def image_attention_mask_for_packed_input_ids_tf(input_ids, tokenizer):
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
eod_token_id = tokenizer.eos_token_id
batch_size = tf.shape(input_ids)[0]
image_attention_mask = tf.fill(tf.shape(input_ids), -1)
next_image_attention_mask = tf.fill(tf.shape(input_ids), -1)
for batch_idx in range(batch_size):
count = -1
seen_eod = False
seq_length = tf.shape(input_ids)[1]
for idx in range(seq_length - 1, -1, -1):
token_id = input_ids[batch_idx, idx].numpy()
if token_id == image_token_id:
count += 1
indices = [[batch_idx, idx]]
updates = [count]
image_attention_mask = tf.tensor_scatter_nd_update(image_attention_mask, indices, updates)
next_image_attention_mask = tf.tensor_scatter_nd_update(next_image_attention_mask, indices, updates)
elif token_id == eod_token_id and not seen_eod:
seen_eod = True
count = 0
indices = [[batch_idx, idx]]
updates = [count]
next_image_attention_mask = tf.tensor_scatter_nd_update(next_image_attention_mask, indices, updates)
if seen_eod and token_id != eod_token_id:
indices = [[batch_idx, idx]]
updates = [-1]
next_image_attention_mask = tf.tensor_scatter_nd_update(next_image_attention_mask, indices, updates)
return image_attention_mask, next_image_attention_mask
def is_url(string): def is_url(string):
"""Checks if the passed string contains a valid url and nothing else. e.g. if space is included it's immediately """Checks if the passed string contains a valid url and nothing else. e.g. if space is included it's immediately
invalidated the url""" invalidated the url"""
...@@ -156,7 +208,7 @@ class IdeficsProcessor(ProcessorMixin): ...@@ -156,7 +208,7 @@ class IdeficsProcessor(ProcessorMixin):
add_eos_token=False, add_eos_token=False,
add_end_of_utterance_token=None, add_end_of_utterance_token=None,
debug=False, debug=False,
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, return_tensors="pt",
) -> BatchEncoding: ) -> BatchEncoding:
"""This method takes batched or non-batched prompts made of text and images and converts them into prompts that """This method takes batched or non-batched prompts made of text and images and converts them into prompts that
the model was trained on and prepares the image pixel values for the model to process. the model was trained on and prepares the image pixel values for the model to process.
...@@ -268,7 +320,6 @@ class IdeficsProcessor(ProcessorMixin): ...@@ -268,7 +320,6 @@ class IdeficsProcessor(ProcessorMixin):
# if the value isn't overriden by the user, check if the tokenizer was trained with this token and then use it # if the value isn't overriden by the user, check if the tokenizer was trained with this token and then use it
if add_end_of_utterance_token is None: if add_end_of_utterance_token is None:
add_end_of_utterance_token = self.tokenizer_was_trained_with_end_of_utterance_token add_end_of_utterance_token = self.tokenizer_was_trained_with_end_of_utterance_token
# turn non-batched prompts into batched # turn non-batched prompts into batched
if not any(isinstance(i, list) for i in prompts): if not any(isinstance(i, list) for i in prompts):
prompts = [prompts] prompts = [prompts]
...@@ -322,7 +373,7 @@ class IdeficsProcessor(ProcessorMixin): ...@@ -322,7 +373,7 @@ class IdeficsProcessor(ProcessorMixin):
if debug is True: if debug is True:
print(f"{full_text=}") print(f"{full_text=}")
image_objects = self.image_processor(image_objects, transform=transform) image_objects = self.image_processor(image_objects, transform=transform, return_tensors=return_tensors)
all_prompts.append(full_text) all_prompts.append(full_text)
all_images.append(image_objects) all_images.append(image_objects)
...@@ -345,39 +396,72 @@ class IdeficsProcessor(ProcessorMixin): ...@@ -345,39 +396,72 @@ class IdeficsProcessor(ProcessorMixin):
output_input_ids = [] output_input_ids = []
output_images = [] output_images = []
output_attention_masks = [] output_attention_masks = []
for text, attention_mask, images in zip(all_texts, all_attention_masks, all_images): for text, attention_mask, images in zip(all_texts, all_attention_masks, all_images):
padded_input_ids = text padded_input_ids = text
image_count = padded_input_ids.count(self.image_token_id) image_count = padded_input_ids.count(self.image_token_id)
local_max_num_images = min(image_count, max_num_images) local_max_num_images = min(image_count, max_num_images)
current_images = images[:local_max_num_images] current_images = images[:local_max_num_images]
if len(current_images) > 0: if len(current_images) > 0:
if return_tensors == "pt":
padded_image_tensor = torch.zeros(max_num_images, *current_images.size()[1:]) padded_image_tensor = torch.zeros(max_num_images, *current_images.size()[1:])
padded_image_tensor[: current_images.size(0)] = current_images padded_image_tensor[: current_images.size(0)] = current_images
elif return_tensors == "tf":
# Assuming current_images is a TensorFlow tensor
# Get the shape of current_images, excluding the first dimension
image_shape = tf.shape(current_images)[1:]
# Create a shape for the padded_image_tensor
padded_shape = tf.concat([[max_num_images], image_shape], axis=0)
# Create the padded_image_tensor of zeros
padded_image_tensor = tf.zeros(padded_shape, dtype=current_images.dtype)
# Get the number of images (assuming current_images has shape [num_images, height, width, channels])
num_images = tf.shape(current_images)[0]
# Update the padded_image_tensor with the values from current_images
indices = tf.reshape(tf.range(num_images), (-1, 1))
updates = current_images
padded_image_tensor = tf.tensor_scatter_nd_update(padded_image_tensor, indices, updates)
else: else:
if return_tensors == "pt":
padded_image_tensor = torch.zeros(max_num_images, *self.default_image_dims) padded_image_tensor = torch.zeros(max_num_images, *self.default_image_dims)
elif return_tensors == "tf":
padded_image_tensor = tf.zeros((max_num_images, *self.default_image_dims))
output_images.append(padded_image_tensor) output_images.append(padded_image_tensor)
if return_tensors == "pt":
output_input_ids.append(torch.tensor(padded_input_ids)) output_input_ids.append(torch.tensor(padded_input_ids))
output_attention_masks.append(torch.tensor(attention_mask)) output_attention_masks.append(torch.tensor(attention_mask))
elif return_tensors == "tf":
output_input_ids.append(tf.convert_to_tensor(padded_input_ids, dtype=tf.int32))
output_attention_masks.append(attention_mask)
if return_tensors == "pt":
output_input_ids = torch.stack(output_input_ids) output_input_ids = torch.stack(output_input_ids)
output_images = torch.stack(output_images) output_images = torch.stack(output_images)
output_attention_masks = torch.stack(output_attention_masks) output_attention_masks = torch.stack(output_attention_masks)
elif return_tensors == "tf":
output_input_ids = tf.stack(output_input_ids)
output_images = tf.stack(output_images)
output_attention_masks = tf.stack(output_attention_masks)
if at_least_one_image: if at_least_one_image:
image_attention_mask, _ = image_attention_mask_for_packed_input_ids(output_input_ids, self.tokenizer) image_attention_mask, _ = image_attention_mask_for_packed_input_ids(
output_input_ids, self.tokenizer, return_tensors
)
image_attention_mask = incremental_to_binary_attention_mask( image_attention_mask = incremental_to_binary_attention_mask(
image_attention_mask, num_classes=max_num_images image_attention_mask, return_tensors, num_classes=max_num_images
) )
else: else:
# in full language mode we set the image mask to all-0s # in full language mode we set the image mask to all-0s
if return_tensors == "pt":
image_attention_mask = torch.zeros( image_attention_mask = torch.zeros(
output_input_ids.shape[0], output_input_ids.shape[1], 1, dtype=torch.bool output_input_ids.shape[0], output_input_ids.shape[1], 1, dtype=torch.bool
) )
elif return_tensors == "tf":
image_attention_mask = tf.zeros(
(output_input_ids.shape[0], output_input_ids.shape[1], 1), dtype=tf.bool
)
return BatchFeature( return BatchFeature(
data={ data={
"input_ids": output_input_ids, "input_ids": output_input_ids,
......
This diff is collapsed.
...@@ -104,6 +104,33 @@ def functional_layernorm(inputs, weight, bias, epsilon=1e-5, axis=-1): ...@@ -104,6 +104,33 @@ def functional_layernorm(inputs, weight, bias, epsilon=1e-5, axis=-1):
return outputs return outputs
def scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale: float = None
):
"""TF equivalent for torch's nn.functional.scaled_dot_product_attention"""
if dropout_p != 0.0:
raise ValueError(
"Dropout is not supported in this implementation - file an issue "
"with Transformers and ping @Rocketknight1 if you need it for a port!"
)
if is_causal and attn_mask is not None:
raise ValueError("You cannot specify an attn_mask and is_causal at the same time!")
if is_causal:
attn_mask = tf.ones((tf.shape(query)[-2], tf.shape(key)[-2]), dtype=tf.int32)
attn_mask = tf.experimental.numpy.tril(attn_mask, k=0)
if attn_mask is not None and (attn_mask.dtype.is_integer or attn_mask.dtype.is_bool):
# Convert boolean mask to a negative logit bias
attn_mask = tf.where(attn_mask > 0, tf.cast(0.0, query.dtype), tf.cast(-1000.0, query.dtype))
logits = tf.einsum("...qd, ...kd -> ...qk", query, key)
if scale is None:
scale = tf.cast(tf.shape(key)[-1], logits.dtype) ** -0.5
logits *= scale # scale by 1/sqrt(key_dim)
if attn_mask is not None:
logits += attn_mask
probs = tf.nn.softmax(logits)
return probs @ value
def flatten(input, start_dim=0, end_dim=-1): def flatten(input, start_dim=0, end_dim=-1):
# Replicates the behavior of torch.flatten in TF # Replicates the behavior of torch.flatten in TF
......
...@@ -1542,6 +1542,27 @@ class TFHubertPreTrainedModel(metaclass=DummyObject): ...@@ -1542,6 +1542,27 @@ class TFHubertPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["tf"]) requires_backends(self, ["tf"])
class TFIdeficsForVisionText2Text(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFIdeficsModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFIdeficsPreTrainedModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFLayoutLMForMaskedLM(metaclass=DummyObject): class TFLayoutLMForMaskedLM(metaclass=DummyObject):
_backends = ["tf"] _backends = ["tf"]
......
...@@ -152,7 +152,7 @@ class IdeficsImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): ...@@ -152,7 +152,7 @@ class IdeficsImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
# they both do the same # they both do the same
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False) image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False)
image_processor = self.image_processing_class(**self.image_processor_dict) image_processor = self.image_processing_class(**self.image_processor_dict, return_tensors="pt")
print(image_inputs) print(image_inputs)
...@@ -181,8 +181,8 @@ class IdeficsImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): ...@@ -181,8 +181,8 @@ class IdeficsImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
] ]
) )
pixel_values_transform_implied = image_processor(image_inputs, transform=None) pixel_values_transform_implied = image_processor(image_inputs, transform=None, return_tensors="pt")
pixel_values_transform_supplied = image_processor(image_inputs, transform=transform) pixel_values_transform_supplied = image_processor(image_inputs, transform=transform, return_tensors="pt")
torch.testing.assert_close(pixel_values_transform_implied, pixel_values_transform_supplied, rtol=0.0, atol=0.0) torch.testing.assert_close(pixel_values_transform_implied, pixel_values_transform_supplied, rtol=0.0, atol=0.0)
......
...@@ -21,6 +21,7 @@ from parameterized import parameterized ...@@ -21,6 +21,7 @@ from parameterized import parameterized
from transformers import BitsAndBytesConfig, IdeficsConfig, is_torch_available, is_vision_available from transformers import BitsAndBytesConfig, IdeficsConfig, is_torch_available, is_vision_available
from transformers.testing_utils import ( from transformers.testing_utils import (
TestCasePlus, TestCasePlus,
is_pt_tf_cross_test,
require_bitsandbytes, require_bitsandbytes,
require_torch, require_torch,
require_torch_sdpa, require_torch_sdpa,
...@@ -559,6 +560,11 @@ class IdeficsModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase) ...@@ -559,6 +560,11 @@ class IdeficsModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
check_hidden_states_output(inputs_dict, config, model_class) check_hidden_states_output(inputs_dict, config, model_class)
@is_pt_tf_cross_test
def test_pt_tf_model_equivalence(self, allow_missing_keys=False):
self.has_attentions = False
super().test_pt_tf_model_equivalence(allow_missing_keys=allow_missing_keys)
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
model_name = "HuggingFaceM4/idefics-9b" model_name = "HuggingFaceM4/idefics-9b"
......
This diff is collapsed.
...@@ -41,7 +41,7 @@ class IdeficsProcessorTest(TestCasePlus): ...@@ -41,7 +41,7 @@ class IdeficsProcessorTest(TestCasePlus):
self.checkpoint_path = self.get_auto_remove_tmp_dir() self.checkpoint_path = self.get_auto_remove_tmp_dir()
image_processor = IdeficsImageProcessor() image_processor = IdeficsImageProcessor(return_tensors="pt")
tokenizer = LlamaTokenizerFast.from_pretrained("HuggingFaceM4/tiny-random-idefics") tokenizer = LlamaTokenizerFast.from_pretrained("HuggingFaceM4/tiny-random-idefics")
processor = IdeficsProcessor(image_processor, tokenizer) processor = IdeficsProcessor(image_processor, tokenizer)
...@@ -132,7 +132,7 @@ class IdeficsProcessorTest(TestCasePlus): ...@@ -132,7 +132,7 @@ class IdeficsProcessorTest(TestCasePlus):
image_processor = self.get_image_processor() image_processor = self.get_image_processor()
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
processor = IdeficsProcessor(tokenizer=tokenizer, image_processor=image_processor) processor = IdeficsProcessor(tokenizer=tokenizer, image_processor=image_processor, return_tensors="pt")
predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]] predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]]
...@@ -145,7 +145,7 @@ class IdeficsProcessorTest(TestCasePlus): ...@@ -145,7 +145,7 @@ class IdeficsProcessorTest(TestCasePlus):
image_processor = self.get_image_processor() image_processor = self.get_image_processor()
tokenizer = self.get_tokenizer(padding_side="right") tokenizer = self.get_tokenizer(padding_side="right")
processor = IdeficsProcessor(tokenizer=tokenizer, image_processor=image_processor) processor = IdeficsProcessor(tokenizer=tokenizer, image_processor=image_processor, return_tensors="pt")
predicted_tokens = [ predicted_tokens = [
"<s> Describe this image.\nAssistant:<unk><unk><unk><unk><unk><unk><unk><unk><unk>", "<s> Describe this image.\nAssistant:<unk><unk><unk><unk><unk><unk><unk><unk><unk>",
...@@ -156,8 +156,9 @@ class IdeficsProcessorTest(TestCasePlus): ...@@ -156,8 +156,9 @@ class IdeficsProcessorTest(TestCasePlus):
([1] * 10) + ([0] * 10), ([1] * 10) + ([0] * 10),
] ]
prompts = [[prompt] for prompt in self.prepare_prompts()[2]] prompts = [[prompt] for prompt in self.prepare_prompts()[2]]
max_length = processor(prompts, padding="max_length", truncation=True, max_length=20)
longest = processor(prompts, padding="longest", truncation=True, max_length=30) max_length = processor(prompts, padding="max_length", truncation=True, max_length=20, return_tensors="pt")
longest = processor(prompts, padding="longest", truncation=True, max_length=30, return_tensors="pt")
decoded_max_length = processor.tokenizer.decode(max_length["input_ids"][-1]) decoded_max_length = processor.tokenizer.decode(max_length["input_ids"][-1])
decoded_longest = processor.tokenizer.decode(longest["input_ids"][-1]) decoded_longest = processor.tokenizer.decode(longest["input_ids"][-1])
...@@ -203,7 +204,7 @@ class IdeficsProcessorTest(TestCasePlus): ...@@ -203,7 +204,7 @@ class IdeficsProcessorTest(TestCasePlus):
processor = IdeficsProcessor(tokenizer=tokenizer, image_processor=image_processor) processor = IdeficsProcessor(tokenizer=tokenizer, image_processor=image_processor)
prompts = self.prepare_prompts() prompts = self.prepare_prompts()
inputs = processor(prompts, padding="longest") inputs = processor(prompts, padding="longest", return_tensors="pt")
# For now the processor supports only ['pixel_values', 'input_ids', 'attention_mask'] # For now the processor supports only ['pixel_values', 'input_ids', 'attention_mask']
self.assertSetEqual(set(inputs.keys()), set(self.input_keys)) self.assertSetEqual(set(inputs.keys()), set(self.input_keys))
...@@ -380,7 +380,9 @@ class TFModelTesterMixin: ...@@ -380,7 +380,9 @@ class TFModelTesterMixin:
main_layer = main_layer_class(config) main_layer = main_layer_class(config)
symbolic_inputs = { symbolic_inputs = {
name: keras.Input(tensor.shape[1:], dtype=tensor.dtype) for name, tensor in inputs_dict.items() name: keras.Input(tensor.shape[1:], dtype=tensor.dtype)
for name, tensor in inputs_dict.items()
if tf.is_tensor(tensor)
} }
model = keras.Model(symbolic_inputs, outputs=main_layer(symbolic_inputs)) model = keras.Model(symbolic_inputs, outputs=main_layer(symbolic_inputs))
...@@ -1689,7 +1691,11 @@ class TFModelTesterMixin: ...@@ -1689,7 +1691,11 @@ class TFModelTesterMixin:
tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class, return_labels=True) tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
if "labels" not in tf_inputs_dict: if "labels" not in tf_inputs_dict:
return # This model isn't giving us labels after all, don't try training with it return # This model isn't giving us labels after all, don't try training with it
tf_inputs_dict = {key: val for key, val in tf_inputs_dict.items() if "head_mask" not in key} tf_inputs_dict = {
key: val
for key, val in tf_inputs_dict.items()
if "head_mask" not in key and isinstance(val, tf.Tensor)
}
tf_inputs_dict["extra_unwanted_column"] = list(tf_inputs_dict.values())[0] # Use a random other tensor tf_inputs_dict["extra_unwanted_column"] = list(tf_inputs_dict.values())[0] # Use a random other tensor
input_dataset = Dataset.from_dict(tf_inputs_dict) input_dataset = Dataset.from_dict(tf_inputs_dict)
tf_dataset = model.prepare_tf_dataset( tf_dataset = model.prepare_tf_dataset(
...@@ -1853,8 +1859,8 @@ def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=None): ...@@ -1853,8 +1859,8 @@ def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=None):
def random_attention_mask(shape, rng=None, name=None, dtype=None): def random_attention_mask(shape, rng=None, name=None, dtype=None):
attn_mask = ids_tensor(shape, vocab_size=2, rng=None, name=None, dtype=dtype) attn_mask = ids_tensor(shape, vocab_size=2, rng=None, name=None, dtype=dtype)
# make sure that at least one token is attended to for each batch # Mark the first token as 1 (matches behaviour of PyTorch/Flax function)
attn_mask = tf.concat([attn_mask[:, :-1], tf.ones_like(attn_mask[:, -1:], dtype=dtype)], axis=-1) attn_mask = tf.concat([tf.ones_like(attn_mask[:, :1]), attn_mask[:, 1:]], axis=1)
return attn_mask return attn_mask
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment