Unverified Commit d0c1ded5 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

remove `attention_mask` truncation in whisper (#20488)



* remove truncation

* For TFWhisper
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent de6d19ea
...@@ -772,11 +772,6 @@ class TFWhisperDecoder(tf.keras.layers.Layer): ...@@ -772,11 +772,6 @@ class TFWhisperDecoder(tf.keras.layers.Layer):
) )
if attention_mask is not None: if attention_mask is not None:
attention_mask = tf.cond(
tf.greater(tf.shape(attention_mask)[-1], seq_len) & tf.greater(seq_len, 0),
lambda: attention_mask[:, : seq_len + past_key_values_length],
lambda: attention_mask,
)
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, tgt_len=input_shape[-1]) expanded_attn_mask = _expand_mask(attention_mask, tgt_len=input_shape[-1])
combined_attention_mask = ( combined_attention_mask = (
......
...@@ -756,8 +756,6 @@ class WhisperDecoder(WhisperPreTrainedModel): ...@@ -756,8 +756,6 @@ class WhisperDecoder(WhisperPreTrainedModel):
).to(inputs_embeds.device) ).to(inputs_embeds.device)
if attention_mask is not None: if attention_mask is not None:
if attention_mask.shape[-1] > input_shape[-1] > 0:
attention_mask = attention_mask[:, : input_shape[-1] + past_key_values_length]
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
combined_attention_mask = ( combined_attention_mask = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment