"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "38580455dea435acd4a261e788d237d3421d65b2"
Unverified Commit 77412343 authored by Andrei Filatov's avatar Andrei Filatov Committed by GitHub
Browse files

fixed whisper positional encoding (#23167)

parent 1b9c352e
......@@ -128,7 +128,7 @@ class TFWhisperPositionalEmbedding(tf.keras.layers.Layer):
def call(self, input_ids, past_key_values_length=0):
past_key_values_length = tf.cast(past_key_values_length, tf.int32)
gather_indices = tf.range(tf.shape(input_ids)[-1], delta=1) + past_key_values_length
gather_indices = tf.range(tf.shape(input_ids)[1], delta=1) + past_key_values_length
return tf.gather(self.weight, gather_indices)
......
......@@ -226,7 +226,7 @@ class WhisperPositionalEmbedding(nn.Embedding):
super().__init__(num_positions, embedding_dim)
def forward(self, input_ids, past_key_values_length=0):
return self.weight[past_key_values_length : past_key_values_length + input_ids.shape[-1]]
return self.weight[past_key_values_length : past_key_values_length + input_ids.shape[1]]
class WhisperAttention(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment