"docs/git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "a9535377bf946dad4942712699dc3f702c581618"
Unverified Commit 842e99f1 authored by Matt's avatar Matt Committed by GitHub
Browse files

TF-OPT attention mask fixes (#25238)



* stash commit

* More OPT updates

* Update src/transformers/models/opt/modeling_tf_opt.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent f6301b9a
...@@ -61,17 +61,15 @@ _CAUSAL_LM_EXPECTED_OUTPUT = ( ...@@ -61,17 +61,15 @@ _CAUSAL_LM_EXPECTED_OUTPUT = (
LARGE_NEGATIVE = -1e8 LARGE_NEGATIVE = -1e8
# Copied from transformers.models.bart.modeling_tf_bart._make_causal_mask
def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: int = 0): def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: int = 0):
""" """
Make causal mask used for bi-directional self-attention. Make causal mask used for bi-directional self-attention.
""" """
bsz = input_ids_shape[0] bsz = input_ids_shape[0]
tgt_len = input_ids_shape[1] tgt_len = input_ids_shape[1]
mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE # We need triu with k = 1 but TF expects known compile-time dims for that, so we hack around it
mask_cond = tf.range(shape_list(mask)[-1]) mask = tf.fill((tgt_len, tgt_len), tf.cast(LARGE_NEGATIVE, tf.float32))
mask = tf.linalg.band_part(mask, 0, -1) - tf.linalg.band_part(mask, 0, 0)
mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask)
if past_key_values_length > 0: if past_key_values_length > 0:
mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1) mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1)
...@@ -93,7 +91,7 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None): ...@@ -93,7 +91,7 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
return (one_cst - expanded_mask) * LARGE_NEGATIVE return (one_cst - expanded_mask) * LARGE_NEGATIVE
class TFOPTLearnedPositionalEmbedding(TFSharedEmbeddings): class TFOPTLearnedPositionalEmbedding(tf.keras.layers.Embedding):
""" """
This module learns positional embeddings up to a fixed maximum size. This module learns positional embeddings up to a fixed maximum size.
""" """
...@@ -516,16 +514,22 @@ class TFOPTDecoder(tf.keras.layers.Layer): ...@@ -516,16 +514,22 @@ class TFOPTDecoder(tf.keras.layers.Layer):
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, past_key_values_length): def _prepare_decoder_attention_mask(self, attention_mask, input_shape, past_key_values_length):
# create causal mask # create causal mask
# # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None _, seq_length = input_shape
if input_shape[-1] > 1: tf.debugging.assert_equal(
combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length) seq_length + past_key_values_length,
else: shape_list(attention_mask)[1],
combined_attention_mask = _expand_mask( message="Attention mask shape should be (batch_size, seq_length + past_key_values_length)"
tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1] f" but is {shape_list(attention_mask)[1]} with input_ids shape {input_shape} and past length"
) f" {past_key_values_length}.",
)
if attention_mask is not None: expanded_attn_mask = _expand_mask(attention_mask, tgt_len=input_shape[-1])
combined_attention_mask = combined_attention_mask + _expand_mask(attention_mask, tgt_len=input_shape[-1]) if seq_length > 1:
combined_attention_mask = (
_make_causal_mask(input_shape, past_key_values_length=past_key_values_length) + expanded_attn_mask
)
else:
combined_attention_mask = expanded_attn_mask
return combined_attention_mask return combined_attention_mask
...@@ -615,17 +619,16 @@ class TFOPTDecoder(tf.keras.layers.Layer): ...@@ -615,17 +619,16 @@ class TFOPTDecoder(tf.keras.layers.Layer):
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
if attention_mask is None: if attention_mask is None:
attention_mask = tf.ones(inputs_embeds.shape[:2], dtype=tf.bool) attention_mask = tf.ones((input_shape[0], input_shape[1] + past_key_values_length), dtype=tf.bool)
else: else:
tf.debugging.assert_equal( tf.debugging.assert_equal(
tf.shape(attention_mask)[1], shape_list(attention_mask)[1],
past_key_values_length + input_shape[1], past_key_values_length + input_shape[1],
message=( message=(
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be " f"The provided attention mask has length {tf.shape(attention_mask)[1]}, but its length should be "
f"{past_key_values_length + input_shape[1]} (sum of the lengths of current and past inputs)" f"{past_key_values_length + input_shape[1]} (sum of the lengths of current and past inputs)"
), ),
) )
pos_embeds = self.embed_positions(attention_mask, past_key_values_length) pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
attention_mask = self._prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values_length) attention_mask = self._prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values_length)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment