Unverified Commit 24588c67 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[M2M100, XGLM] fix create_position_ids_from_inputs_embeds (#15751)

parent f9582c20
......@@ -167,7 +167,7 @@ class M2M100SinusoidalPositionalEmbedding(nn.Module):
)
else:
bsz, seq_len = inputs_embeds.size()[:-1]
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds, past_key_values_length)
# expand embeddings if needed
max_pos = self.padding_idx + 1 + seq_len + past_key_values_length
......@@ -176,7 +176,7 @@ class M2M100SinusoidalPositionalEmbedding(nn.Module):
return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach()
def create_position_ids_from_inputs_embeds(self, inputs_embeds):
def create_position_ids_from_inputs_embeds(self, inputs_embeds, past_key_values_length):
"""
We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
......@@ -191,7 +191,7 @@ class M2M100SinusoidalPositionalEmbedding(nn.Module):
position_ids = torch.arange(
self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
)
return position_ids.unsqueeze(0).expand(input_shape).contiguous()
return position_ids.unsqueeze(0).expand(input_shape).contiguous() + past_key_values_length
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->M2M100
......
......@@ -211,7 +211,7 @@ class XGLMSinusoidalPositionalEmbedding(nn.Module):
)
else:
bsz, seq_len = inputs_embeds.size()[:-1]
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds, past_key_values_length)
# expand embeddings if needed
max_pos = self.padding_idx + 1 + seq_len + past_key_values_length
......@@ -220,7 +220,7 @@ class XGLMSinusoidalPositionalEmbedding(nn.Module):
return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach()
def create_position_ids_from_inputs_embeds(self, inputs_embeds):
def create_position_ids_from_inputs_embeds(self, inputs_embeds, past_key_values_length):
"""
We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
......@@ -235,7 +235,7 @@ class XGLMSinusoidalPositionalEmbedding(nn.Module):
position_ids = torch.arange(
self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
)
return position_ids.unsqueeze(0).expand(input_shape).contiguous()
return position_ids.unsqueeze(0).expand(input_shape).contiguous() + past_key_values_length
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->XGLM
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment