Unverified Commit 90ca8d36 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[TF Led] Fix wrong decoder attention mask behavior (#9601)

* fix tf led

* remove loop file
parent 85788bae
......@@ -1862,7 +1862,6 @@ class TFLEDDecoder(tf.keras.layers.Layer):
hidden_states = inputs["inputs_embeds"]
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
else:
......@@ -1870,20 +1869,9 @@ class TFLEDDecoder(tf.keras.layers.Layer):
tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
)
if inputs["attention_mask"] is None and inputs["input_ids"] is not None and input_shape[-1] > 1:
inputs["attention_mask"] = tf.cast(
tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), inputs["input_ids"].dtype
)
inputs["attention_mask"] = tf.concat(
[
tf.ones((input_shape[0], past_key_values_length), dtype=inputs["attention_mask"].dtype),
inputs["attention_mask"],
],
axis=-1,
)
else:
inputs["attention_mask"] = tf.ones(
(input_shape[0], input_shape[1] + past_key_values_length), dtype=tf.int32
if inputs["attention_mask"] is not None and input_shape[-1] > 1:
combined_attention_mask = combined_attention_mask + _expand_mask(
inputs["attention_mask"], tgt_len=input_shape[-1]
)
if inputs["encoder_hidden_states"] is not None and inputs["encoder_attention_mask"] is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment