Commit e25b8ba0 authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

fix position embedding computation

parent cb44e483
...@@ -136,7 +136,7 @@ class StableSpeechSinusoidalPositionalEmbedding(nn.Module): ...@@ -136,7 +136,7 @@ class StableSpeechSinusoidalPositionalEmbedding(nn.Module):
@torch.no_grad() @torch.no_grad()
def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
bsz, codebooks, seq_len = input_ids.size() bsz, seq_len, _ = input_ids.size()
# Create the position ids from the input token ids. # Create the position ids from the input token ids.
position_ids = (torch.arange(seq_len) + past_key_values_length).to(input_ids.device) position_ids = (torch.arange(seq_len) + past_key_values_length).to(input_ids.device)
# expand embeddings if needed # expand embeddings if needed
...@@ -754,7 +754,7 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel): ...@@ -754,7 +754,7 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel):
) )
# embed positions # embed positions
positions = self.embed_positions(input, past_key_values_length) positions = self.embed_positions(inputs_embeds, past_key_values_length)
hidden_states = inputs_embeds + positions.to(inputs_embeds.device) hidden_states = inputs_embeds + positions.to(inputs_embeds.device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment