Unverified Commit 842e99f1 authored by Matt's avatar Matt Committed by GitHub
Browse files

TF-OPT attention mask fixes (#25238)



* stash commit

* More OPT updates

* Update src/transformers/models/opt/modeling_tf_opt.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent f6301b9a
......@@ -61,17 +61,15 @@ _CAUSAL_LM_EXPECTED_OUTPUT = (
LARGE_NEGATIVE = -1e8
# Copied from transformers.models.bart.modeling_tf_bart._make_causal_mask
def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: int = 0):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz = input_ids_shape[0]
tgt_len = input_ids_shape[1]
mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE
mask_cond = tf.range(shape_list(mask)[-1])
mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask)
# We need triu with k = 1 but TF expects known compile-time dims for that, so we hack around it
mask = tf.fill((tgt_len, tgt_len), tf.cast(LARGE_NEGATIVE, tf.float32))
mask = tf.linalg.band_part(mask, 0, -1) - tf.linalg.band_part(mask, 0, 0)
if past_key_values_length > 0:
mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1)
......@@ -93,7 +91,7 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
return (one_cst - expanded_mask) * LARGE_NEGATIVE
class TFOPTLearnedPositionalEmbedding(TFSharedEmbeddings):
class TFOPTLearnedPositionalEmbedding(tf.keras.layers.Embedding):
"""
This module learns positional embeddings up to a fixed maximum size.
"""
......@@ -516,16 +514,22 @@ class TFOPTDecoder(tf.keras.layers.Layer):
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, past_key_values_length):
# create causal mask
# # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
else:
combined_attention_mask = _expand_mask(
tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
)
_, seq_length = input_shape
tf.debugging.assert_equal(
seq_length + past_key_values_length,
shape_list(attention_mask)[1],
message="Attention mask shape should be (batch_size, seq_length + past_key_values_length)"
f" but is {shape_list(attention_mask)[1]} with input_ids shape {input_shape} and past length"
f" {past_key_values_length}.",
)
if attention_mask is not None:
combined_attention_mask = combined_attention_mask + _expand_mask(attention_mask, tgt_len=input_shape[-1])
expanded_attn_mask = _expand_mask(attention_mask, tgt_len=input_shape[-1])
if seq_length > 1:
combined_attention_mask = (
_make_causal_mask(input_shape, past_key_values_length=past_key_values_length) + expanded_attn_mask
)
else:
combined_attention_mask = expanded_attn_mask
return combined_attention_mask
......@@ -615,17 +619,16 @@ class TFOPTDecoder(tf.keras.layers.Layer):
inputs_embeds = self.embed_tokens(input_ids)
if attention_mask is None:
attention_mask = tf.ones(inputs_embeds.shape[:2], dtype=tf.bool)
attention_mask = tf.ones((input_shape[0], input_shape[1] + past_key_values_length), dtype=tf.bool)
else:
tf.debugging.assert_equal(
tf.shape(attention_mask)[1],
shape_list(attention_mask)[1],
past_key_values_length + input_shape[1],
message=(
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
f"The provided attention mask has length {tf.shape(attention_mask)[1]}, but its length should be "
f"{past_key_values_length + input_shape[1]} (sum of the lengths of current and past inputs)"
),
)
pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
attention_mask = self._prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values_length)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment