"vscode:/vscode.git/clone" did not exist on "320a622ec4d098f2da5d097930f4031517e7327b"
Commit e25b8ba0 authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

fix position embedding computation

parent cb44e483
......@@ -136,7 +136,7 @@ class StableSpeechSinusoidalPositionalEmbedding(nn.Module):
@torch.no_grad()
def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
bsz, codebooks, seq_len = input_ids.size()
bsz, seq_len, _ = input_ids.size()
# Create the position ids from the input token ids.
position_ids = (torch.arange(seq_len) + past_key_values_length).to(input_ids.device)
# expand embeddings if needed
......@@ -754,7 +754,7 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel):
)
# embed positions
positions = self.embed_positions(input, past_key_values_length)
positions = self.embed_positions(inputs_embeds, past_key_values_length)
hidden_states = inputs_embeds + positions.to(inputs_embeds.device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment