Unverified Commit cf9e7cb0 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF: embeddings out of bounds check factored into function (#23427)

parent 45e3d649
......@@ -36,7 +36,7 @@ from ...modeling_tf_utils import (
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list, stable_softmax
from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_whisper import WhisperConfig
......@@ -882,16 +882,7 @@ class TFWhisperDecoder(tf.keras.layers.Layer):
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if inputs_embeds is None:
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
tf.debugging.assert_less(
input_ids,
tf.cast(self.embed_tokens.input_dim, dtype=input_ids.dtype),
message=(
"input_ids must be smaller than the embedding layer's input dimension (got"
f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.input_dim})"
),
)
check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
inputs_embeds = self.embed_tokens(input_ids)
attention_mask = self._prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values_length)
......
......@@ -42,7 +42,7 @@ from ...modeling_tf_utils import (
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list, stable_softmax
from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
from ...utils import logging
from .configuration_xglm import XGLMConfig
......@@ -527,16 +527,7 @@ class TFXGLMMainLayer(tf.keras.layers.Layer):
position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
if inputs_embeds is None:
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
tf.debugging.assert_less(
input_ids,
tf.cast(self.embed_tokens.vocab_size, dtype=input_ids.dtype),
message=(
"input_ids must be smaller than the embedding layer's input dimension (got"
f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.vocab_size})"
),
)
check_embeddings_within_bounds(input_ids, self.embed_tokens.vocab_size)
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
attention_mask = self._prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values_length)
......
......@@ -45,7 +45,7 @@ from ...modeling_tf_utils import (
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list, stable_softmax
from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
from ...utils import (
MULTIPLE_CHOICE_DUMMY_INPUTS,
ModelOutput,
......@@ -440,16 +440,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
# embeddings
if inputs_embeds is None:
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
tf.debugging.assert_less(
input_ids,
tf.cast(self.embeddings.vocab_size, dtype=input_ids.dtype),
message=(
"input_ids must be smaller than the embedding layer's input dimension (got"
f" {tf.math.reduce_max(input_ids)} >= {self.embeddings.vocab_size})"
),
)
check_embeddings_within_bounds(input_ids, self.embeddings.vocab_size)
inputs_embeds = self.embeddings(input_ids)
tensor = inputs_embeds + tf.gather(self.position_embeddings, position_ids)
......
......@@ -46,7 +46,7 @@ from ...modeling_tf_utils import (
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list, stable_softmax
from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
from ...utils import (
DUMMY_INPUTS,
MULTIPLE_CHOICE_DUMMY_INPUTS,
......@@ -233,16 +233,7 @@ class TFXLMRobertaEmbeddings(tf.keras.layers.Layer):
assert not (input_ids is None and inputs_embeds is None)
if input_ids is not None:
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
tf.debugging.assert_less(
input_ids,
tf.cast(self.config.vocab_size, dtype=input_ids.dtype),
message=(
"input_ids must be smaller than the embedding layer's input dimension (got"
f" {tf.math.reduce_max(input_ids)} >= {self.config.vocab_size})"
),
)
check_embeddings_within_bounds(input_ids, self.config.vocab_size)
inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
input_shape = shape_list(inputs_embeds)[:-1]
......
......@@ -39,7 +39,7 @@ from ...modeling_tf_utils import (
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list, stable_softmax
from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
from ...utils import (
MULTIPLE_CHOICE_DUMMY_INPUTS,
ModelOutput,
......@@ -678,16 +678,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
if inputs_embeds is not None:
word_emb_k = inputs_embeds
else:
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
tf.debugging.assert_less(
input_ids,
tf.cast(self.word_embedding.vocab_size, dtype=input_ids.dtype),
message=(
"input_ids must be smaller than the embedding layer's input dimension (got"
f" {tf.math.reduce_max(input_ids)} >= {self.word_embedding.vocab_size})"
),
)
check_embeddings_within_bounds(input_ids, self.word_embedding.vocab_size)
word_emb_k = self.word_embedding(input_ids)
output_h = self.dropout(word_emb_k, training=training)
if target_mapping is not None:
......
......@@ -96,3 +96,23 @@ def invert_attention_mask(encoder_attention_mask: tf.Tensor) -> tf.Tensor:
) * encoder_extended_attention_mask.dtype.min
return encoder_extended_attention_mask
def check_embeddings_within_bounds(tensor: tf.Tensor, embed_dim: int, tensor_name: str = "input_ids") -> None:
"""
`tf.gather`, on which TF embedding layers are based, won't check positive out of bound indices on GPU, returning
zeros instead. This function adds a check against that dangerous silent behavior.
Args:
tensor (`tf.Tensor`): The tensor of indices to check.
embed_dim (`int`): The embedding dimension.
tensor_name (`str`, *optional*): The name of the tensor to use in the error message.
"""
tf.debugging.assert_less(
tensor,
tf.cast(embed_dim, dtype=tensor.dtype),
message=(
f"The maximum value of {tensor_name} ({tf.math.reduce_max(tensor)}) must be smaller than the embedding "
f"layer's input dimension ({embed_dim}). The likely cause is some problem at tokenization time."
),
)
......@@ -53,7 +53,7 @@ from ...modeling_tf_utils import (
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list, stable_softmax
from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
from ...utils import logging
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config
......@@ -126,16 +126,7 @@ class TF{{cookiecutter.camelcase_modelname}}Embeddings(tf.keras.layers.Layer):
assert not (input_ids is None and inputs_embeds is None)
if input_ids is not None:
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
tf.debugging.assert_less(
input_ids,
tf.cast(self.vocab_size, dtype=input_ids.dtype),
message=(
"input_ids must be smaller than the embedding layer's input dimension (got"
f" {tf.math.reduce_max(input_ids)} >= {self.vocab_size})"
),
)
check_embeddings_within_bounds(input_ids, self.vocab_size)
inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
input_shape = shape_list(inputs_embeds)[:-1]
......@@ -1670,7 +1661,7 @@ from ...modeling_tf_utils import (
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list, stable_softmax
from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
from ...utils import ContextManagers, logging
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config
......@@ -2311,16 +2302,7 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
if hasattr(self.embed_tokens, "load_weight_prefix"):
context.append(tf.name_scope(self.embed_tokens.load_weight_prefix + "/"))
with ContextManagers(context):
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
tf.debugging.assert_less(
input_ids,
tf.cast(self.embed_tokens.input_dim, dtype=input_ids.dtype),
message=(
"input_ids must be smaller than the embedding layer's input dimension (got"
f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.input_dim})"
),
)
check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
embed_pos = self.embed_positions(input_shape)
......@@ -2518,16 +2500,7 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
if hasattr(self.embed_tokens, "load_weight_prefix"):
context.append(tf.name_scope(self.embed_tokens.load_weight_prefix + "/"))
with ContextManagers(context):
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
tf.debugging.assert_less(
input_ids,
tf.cast(self.embed_tokens.input_dim, dtype=input_ids.dtype),
message=(
"input_ids must be smaller than the embedding layer's input dimension (got"
f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.input_dim})"
),
)
check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment