Unverified Commit 466af1a3 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

OPT/BioGPT: Improved attention mask shape exception (#23270)

parent 21741e8c
...@@ -546,6 +546,12 @@ class BioGptModel(BioGptPreTrainedModel): ...@@ -546,6 +546,12 @@ class BioGptModel(BioGptPreTrainedModel):
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.bool, device=inputs_embeds.device) attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.bool, device=inputs_embeds.device)
elif attention_mask.shape[1] != past_key_values_length + input_shape[1]:
raise ValueError(
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
f"{past_key_values_length + input_shape[1]} (sum of the lengths of current and past inputs)"
)
# embed positions # embed positions
positions = self.embed_positions(attention_mask, past_key_values_length) positions = self.embed_positions(attention_mask, past_key_values_length)
......
...@@ -642,6 +642,11 @@ class OPTDecoder(OPTPreTrainedModel): ...@@ -642,6 +642,11 @@ class OPTDecoder(OPTPreTrainedModel):
# embed positions # embed positions
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
elif attention_mask.shape[1] != mask_seq_length:
raise ValueError(
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
f"{mask_seq_length} (sum of the lengths of current and past inputs)"
)
causal_attention_mask = self._prepare_decoder_attention_mask( causal_attention_mask = self._prepare_decoder_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length attention_mask, input_shape, inputs_embeds, past_key_values_length
) )
......
...@@ -645,6 +645,15 @@ class TFOPTDecoder(tf.keras.layers.Layer): ...@@ -645,6 +645,15 @@ class TFOPTDecoder(tf.keras.layers.Layer):
if attention_mask is None: if attention_mask is None:
attention_mask = tf.ones(inputs_embeds.shape[:2], dtype=tf.bool) attention_mask = tf.ones(inputs_embeds.shape[:2], dtype=tf.bool)
else:
tf.debugging.assert_equal(
attention_mask.shape[1],
past_key_values_length + input_shape[1],
message=(
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
f"{past_key_values_length + input_shape[1]} (sum of the lengths of current and past inputs)"
),
)
pos_embeds = self.embed_positions(attention_mask, past_key_values_length) pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment