Unverified Commit 77412343 authored by Andrei Filatov's avatar Andrei Filatov Committed by GitHub
Browse files

fixed whisper positional encoding (#23167)

parent 1b9c352e
...@@ -128,7 +128,7 @@ class TFWhisperPositionalEmbedding(tf.keras.layers.Layer): ...@@ -128,7 +128,7 @@ class TFWhisperPositionalEmbedding(tf.keras.layers.Layer):
def call(self, input_ids, past_key_values_length=0): def call(self, input_ids, past_key_values_length=0):
past_key_values_length = tf.cast(past_key_values_length, tf.int32) past_key_values_length = tf.cast(past_key_values_length, tf.int32)
gather_indices = tf.range(tf.shape(input_ids)[-1], delta=1) + past_key_values_length gather_indices = tf.range(tf.shape(input_ids)[1], delta=1) + past_key_values_length
return tf.gather(self.weight, gather_indices) return tf.gather(self.weight, gather_indices)
......
...@@ -226,7 +226,7 @@ class WhisperPositionalEmbedding(nn.Embedding): ...@@ -226,7 +226,7 @@ class WhisperPositionalEmbedding(nn.Embedding):
super().__init__(num_positions, embedding_dim) super().__init__(num_positions, embedding_dim)
def forward(self, input_ids, past_key_values_length=0): def forward(self, input_ids, past_key_values_length=0):
return self.weight[past_key_values_length : past_key_values_length + input_ids.shape[-1]] return self.weight[past_key_values_length : past_key_values_length + input_ids.shape[1]]
class WhisperAttention(nn.Module): class WhisperAttention(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment