Unverified Commit 1b5ab39c authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF: check embeddings range (#19102)

parent cf6308ef
...@@ -227,6 +227,16 @@ class TFLxmertEmbeddings(tf.keras.layers.Layer): ...@@ -227,6 +227,16 @@ class TFLxmertEmbeddings(tf.keras.layers.Layer):
assert not (input_ids is None and inputs_embeds is None) assert not (input_ids is None and inputs_embeds is None)
if input_ids is not None: if input_ids is not None:
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
tf.debugging.assert_less(
input_ids,
tf.cast(self.vocab_size, dtype=input_ids.dtype),
message=(
"input_ids must be smaller than the embedding layer's input dimension (got"
f" {tf.math.reduce_max(input_ids)} >= {self.vocab_size})"
),
)
inputs_embeds = tf.gather(params=self.weight, indices=input_ids) inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
input_shape = shape_list(inputs_embeds)[:-1] input_shape = shape_list(inputs_embeds)[:-1]
......
...@@ -772,6 +772,16 @@ class TFMarianEncoder(tf.keras.layers.Layer): ...@@ -772,6 +772,16 @@ class TFMarianEncoder(tf.keras.layers.Layer):
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None: if inputs_embeds is None:
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
tf.debugging.assert_less(
input_ids,
tf.cast(self.embed_tokens.vocab_size, dtype=input_ids.dtype),
message=(
"input_ids must be smaller than the embedding layer's input dimension (got"
f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.vocab_size})"
),
)
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
embed_pos = self.embed_positions(input_shape) embed_pos = self.embed_positions(input_shape)
...@@ -967,6 +977,16 @@ class TFMarianDecoder(tf.keras.layers.Layer): ...@@ -967,6 +977,16 @@ class TFMarianDecoder(tf.keras.layers.Layer):
positions = self.embed_positions(input_shape, position_ids=position_ids) positions = self.embed_positions(input_shape, position_ids=position_ids)
if inputs_embeds is None: if inputs_embeds is None:
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
tf.debugging.assert_less(
input_ids,
tf.cast(self.embed_tokens.vocab_size, dtype=input_ids.dtype),
message=(
"input_ids must be smaller than the embedding layer's input dimension (got"
f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.vocab_size})"
),
)
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
hidden_states = inputs_embeds hidden_states = inputs_embeds
......
...@@ -757,6 +757,16 @@ class TFMBartEncoder(tf.keras.layers.Layer): ...@@ -757,6 +757,16 @@ class TFMBartEncoder(tf.keras.layers.Layer):
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None: if inputs_embeds is None:
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
tf.debugging.assert_less(
input_ids,
tf.cast(self.embed_tokens.vocab_size, dtype=input_ids.dtype),
message=(
"input_ids must be smaller than the embedding layer's input dimension (got"
f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.vocab_size})"
),
)
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
embed_pos = self.embed_positions(input_shape) embed_pos = self.embed_positions(input_shape)
...@@ -959,6 +969,16 @@ class TFMBartDecoder(tf.keras.layers.Layer): ...@@ -959,6 +969,16 @@ class TFMBartDecoder(tf.keras.layers.Layer):
positions = self.embed_positions(input_shape, position_ids=position_ids) positions = self.embed_positions(input_shape, position_ids=position_ids)
if inputs_embeds is None: if inputs_embeds is None:
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
tf.debugging.assert_less(
input_ids,
tf.cast(self.embed_tokens.vocab_size, dtype=input_ids.dtype),
message=(
"input_ids must be smaller than the embedding layer's input dimension (got"
f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.vocab_size})"
),
)
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
hidden_states = inputs_embeds hidden_states = inputs_embeds
......
...@@ -214,6 +214,16 @@ class TFMobileBertEmbeddings(tf.keras.layers.Layer): ...@@ -214,6 +214,16 @@ class TFMobileBertEmbeddings(tf.keras.layers.Layer):
assert not (input_ids is None and inputs_embeds is None) assert not (input_ids is None and inputs_embeds is None)
if input_ids is not None: if input_ids is not None:
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
tf.debugging.assert_less(
input_ids,
tf.cast(self.vocab_size, dtype=input_ids.dtype),
message=(
"input_ids must be smaller than the embedding layer's input dimension (got"
f" {tf.math.reduce_max(input_ids)} >= {self.vocab_size})"
),
)
inputs_embeds = tf.gather(params=self.weight, indices=input_ids) inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
input_shape = shape_list(inputs_embeds)[:-1] input_shape = shape_list(inputs_embeds)[:-1]
......
...@@ -145,6 +145,16 @@ class TFMPNetEmbeddings(tf.keras.layers.Layer): ...@@ -145,6 +145,16 @@ class TFMPNetEmbeddings(tf.keras.layers.Layer):
assert not (input_ids is None and inputs_embeds is None) assert not (input_ids is None and inputs_embeds is None)
if input_ids is not None: if input_ids is not None:
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
tf.debugging.assert_less(
input_ids,
tf.cast(self.vocab_size, dtype=input_ids.dtype),
message=(
"input_ids must be smaller than the embedding layer's input dimension (got"
f" {tf.math.reduce_max(input_ids)} >= {self.vocab_size})"
),
)
inputs_embeds = tf.gather(params=self.weight, indices=input_ids) inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
input_shape = shape_list(inputs_embeds)[:-1] input_shape = shape_list(inputs_embeds)[:-1]
......
...@@ -298,10 +298,30 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer): ...@@ -298,10 +298,30 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]]) position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
if inputs_embeds is None: if inputs_embeds is None:
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
tf.debugging.assert_less(
input_ids,
tf.cast(self.vocab_size, dtype=input_ids.dtype),
message=(
"input_ids must be smaller than the embedding layer's input dimension (got"
f" {tf.math.reduce_max(input_ids)} >= {self.vocab_size})"
),
)
inputs_embeds = self.tokens_embed(input_ids, mode="embedding") inputs_embeds = self.tokens_embed(input_ids, mode="embedding")
position_embeds = tf.gather(self.positions_embed, position_ids) position_embeds = tf.gather(self.positions_embed, position_ids)
if token_type_ids is not None: if token_type_ids is not None:
token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]]) token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]])
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
tf.debugging.assert_less(
token_type_ids,
tf.cast(self.vocab_size, dtype=token_type_ids.dtype),
message=(
"token_type_ids must be smaller than the embedding layer's input dimension (got"
f" {tf.math.reduce_max(token_type_ids)} >= {self.vocab_size})"
),
)
token_type_embeds = self.tokens_embed(token_type_ids, mode="embedding") token_type_embeds = self.tokens_embed(token_type_ids, mode="embedding")
else: else:
token_type_embeds = 0 token_type_embeds = 0
......
...@@ -632,6 +632,16 @@ class TFOPTDecoder(tf.keras.layers.Layer): ...@@ -632,6 +632,16 @@ class TFOPTDecoder(tf.keras.layers.Layer):
past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0 past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0
if inputs_embeds is None: if inputs_embeds is None:
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
tf.debugging.assert_less(
input_ids,
tf.cast(self.embed_tokens.vocab_size, dtype=input_ids.dtype),
message=(
"input_ids must be smaller than the embedding layer's input dimension (got"
f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.vocab_size})"
),
)
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
if attention_mask is None: if attention_mask is None:
......
...@@ -775,6 +775,16 @@ class TFPegasusEncoder(tf.keras.layers.Layer): ...@@ -775,6 +775,16 @@ class TFPegasusEncoder(tf.keras.layers.Layer):
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None: if inputs_embeds is None:
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
tf.debugging.assert_less(
input_ids,
tf.cast(self.embed_tokens.vocab_size, dtype=input_ids.dtype),
message=(
"input_ids must be smaller than the embedding layer's input dimension (got"
f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.vocab_size})"
),
)
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
embed_pos = self.embed_positions(input_shape) embed_pos = self.embed_positions(input_shape)
...@@ -973,6 +983,16 @@ class TFPegasusDecoder(tf.keras.layers.Layer): ...@@ -973,6 +983,16 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
positions = self.embed_positions(input_shape, position_ids=position_ids) positions = self.embed_positions(input_shape, position_ids=position_ids)
if inputs_embeds is None: if inputs_embeds is None:
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
tf.debugging.assert_less(
input_ids,
tf.cast(self.embed_tokens.vocab_size, dtype=input_ids.dtype),
message=(
"input_ids must be smaller than the embedding layer's input dimension (got"
f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.vocab_size})"
),
)
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
hidden_states = inputs_embeds hidden_states = inputs_embeds
......
...@@ -124,6 +124,16 @@ class TFRemBertEmbeddings(tf.keras.layers.Layer): ...@@ -124,6 +124,16 @@ class TFRemBertEmbeddings(tf.keras.layers.Layer):
assert not (input_ids is None and inputs_embeds is None) assert not (input_ids is None and inputs_embeds is None)
if input_ids is not None: if input_ids is not None:
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
tf.debugging.assert_less(
input_ids,
tf.cast(self.vocab_size, dtype=input_ids.dtype),
message=(
"input_ids must be smaller than the embedding layer's input dimension (got"
f" {tf.math.reduce_max(input_ids)} >= {self.vocab_size})"
),
)
inputs_embeds = tf.gather(params=self.weight, indices=input_ids) inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
input_shape = shape_list(inputs_embeds)[:-1] input_shape = shape_list(inputs_embeds)[:-1]
......
...@@ -146,6 +146,16 @@ class TFRobertaEmbeddings(tf.keras.layers.Layer): ...@@ -146,6 +146,16 @@ class TFRobertaEmbeddings(tf.keras.layers.Layer):
assert not (input_ids is None and inputs_embeds is None) assert not (input_ids is None and inputs_embeds is None)
if input_ids is not None: if input_ids is not None:
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
tf.debugging.assert_less(
input_ids,
tf.cast(self.vocab_size, dtype=input_ids.dtype),
message=(
"input_ids must be smaller than the embedding layer's input dimension (got"
f" {tf.math.reduce_max(input_ids)} >= {self.vocab_size})"
),
)
inputs_embeds = tf.gather(params=self.weight, indices=input_ids) inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
input_shape = shape_list(inputs_embeds)[:-1] input_shape = shape_list(inputs_embeds)[:-1]
......
...@@ -177,6 +177,16 @@ class TFRoFormerEmbeddings(tf.keras.layers.Layer): ...@@ -177,6 +177,16 @@ class TFRoFormerEmbeddings(tf.keras.layers.Layer):
assert not (input_ids is None and inputs_embeds is None) assert not (input_ids is None and inputs_embeds is None)
if input_ids is not None: if input_ids is not None:
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
tf.debugging.assert_less(
input_ids,
tf.cast(self.vocab_size, dtype=input_ids.dtype),
message=(
"input_ids must be smaller than the embedding layer's input dimension (got"
f" {tf.math.reduce_max(input_ids)} >= {self.vocab_size})"
),
)
inputs_embeds = tf.gather(params=self.weight, indices=input_ids) inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
input_shape = shape_list(inputs_embeds)[:-1] input_shape = shape_list(inputs_embeds)[:-1]
......
...@@ -1022,6 +1022,16 @@ class TFSpeech2TextDecoder(tf.keras.layers.Layer): ...@@ -1022,6 +1022,16 @@ class TFSpeech2TextDecoder(tf.keras.layers.Layer):
past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0 past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0
if inputs_embeds is None: if inputs_embeds is None:
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
tf.debugging.assert_less(
input_ids,
tf.cast(self.embed_tokens.vocab_size, dtype=input_ids.dtype),
message=(
"input_ids must be smaller than the embedding layer's input dimension (got"
f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.vocab_size})"
),
)
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
else: else:
inputs_embeds = inputs_embeds inputs_embeds = inputs_embeds
......
...@@ -681,6 +681,16 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -681,6 +681,16 @@ class TFT5MainLayer(tf.keras.layers.Layer):
if inputs_embeds is None: if inputs_embeds is None:
assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings" assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings"
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
tf.debugging.assert_less(
input_ids,
tf.cast(self.embed_tokens.vocab_size, dtype=input_ids.dtype),
message=(
"input_ids must be smaller than the embedding layer's input dimension (got"
f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.vocab_size})"
),
)
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
batch_size, seq_length = input_shape batch_size, seq_length = input_shape
......
...@@ -533,6 +533,16 @@ class TFXGLMMainLayer(tf.keras.layers.Layer): ...@@ -533,6 +533,16 @@ class TFXGLMMainLayer(tf.keras.layers.Layer):
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if inputs_embeds is None: if inputs_embeds is None:
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
tf.debugging.assert_less(
input_ids,
tf.cast(self.embed_tokens.vocab_size, dtype=input_ids.dtype),
message=(
"input_ids must be smaller than the embedding layer's input dimension (got"
f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.vocab_size})"
),
)
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
attention_mask = self._prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values_length) attention_mask = self._prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values_length)
......
...@@ -440,6 +440,16 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -440,6 +440,16 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
# embeddings # embeddings
if inputs_embeds is None: if inputs_embeds is None:
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
tf.debugging.assert_less(
input_ids,
tf.cast(self.embeddings.vocab_size, dtype=input_ids.dtype),
message=(
"input_ids must be smaller than the embedding layer's input dimension (got"
f" {tf.math.reduce_max(input_ids)} >= {self.embeddings.vocab_size})"
),
)
inputs_embeds = self.embeddings(input_ids) inputs_embeds = self.embeddings(input_ids)
tensor = inputs_embeds + tf.gather(self.position_embeddings, position_ids) tensor = inputs_embeds + tf.gather(self.position_embeddings, position_ids)
......
...@@ -680,6 +680,16 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -680,6 +680,16 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
if inputs_embeds is not None: if inputs_embeds is not None:
word_emb_k = inputs_embeds word_emb_k = inputs_embeds
else: else:
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
tf.debugging.assert_less(
input_ids,
tf.cast(self.word_embedding.vocab_size, dtype=input_ids.dtype),
message=(
"input_ids must be smaller than the embedding layer's input dimension (got"
f" {tf.math.reduce_max(input_ids)} >= {self.word_embedding.vocab_size})"
),
)
word_emb_k = self.word_embedding(input_ids) word_emb_k = self.word_embedding(input_ids)
output_h = self.dropout(word_emb_k, training=training) output_h = self.dropout(word_emb_k, training=training)
if target_mapping is not None: if target_mapping is not None:
......
...@@ -127,6 +127,16 @@ class TF{{cookiecutter.camelcase_modelname}}Embeddings(tf.keras.layers.Layer): ...@@ -127,6 +127,16 @@ class TF{{cookiecutter.camelcase_modelname}}Embeddings(tf.keras.layers.Layer):
assert not (input_ids is None and inputs_embeds is None) assert not (input_ids is None and inputs_embeds is None)
if input_ids is not None: if input_ids is not None:
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
tf.debugging.assert_less(
input_ids,
tf.cast(self.vocab_size, dtype=input_ids.dtype),
message=(
"input_ids must be smaller than the embedding layer's input dimension (got"
f" {tf.math.reduce_max(input_ids)} >= {self.vocab_size})"
),
)
inputs_embeds = tf.gather(params=self.weight, indices=input_ids) inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
input_shape = shape_list(inputs_embeds)[:-1] input_shape = shape_list(inputs_embeds)[:-1]
...@@ -2305,6 +2315,16 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer): ...@@ -2305,6 +2315,16 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None: if inputs_embeds is None:
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
tf.debugging.assert_less(
input_ids,
tf.cast(self.embed_tokens.vocab_size, dtype=input_ids.dtype),
message=(
"input_ids must be smaller than the embedding layer's input dimension (got"
f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.vocab_size})"
),
)
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
embed_pos = self.embed_positions(input_shape) embed_pos = self.embed_positions(input_shape)
...@@ -2494,6 +2514,16 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer): ...@@ -2494,6 +2514,16 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
positions = self.embed_positions(input_shape, past_key_values_length) positions = self.embed_positions(input_shape, past_key_values_length)
if inputs_embeds is None: if inputs_embeds is None:
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
tf.debugging.assert_less(
input_ids,
tf.cast(self.embed_tokens.vocab_size, dtype=input_ids.dtype),
message=(
"input_ids must be smaller than the embedding layer's input dimension (got"
f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.vocab_size})"
),
)
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds hidden_states = inputs_embeds
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment