Unverified Commit 24588c67 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[M2M100, XGLM] fix create_position_ids_from_inputs_embeds (#15751)

parent f9582c20
...@@ -167,7 +167,7 @@ class M2M100SinusoidalPositionalEmbedding(nn.Module): ...@@ -167,7 +167,7 @@ class M2M100SinusoidalPositionalEmbedding(nn.Module):
) )
else: else:
bsz, seq_len = inputs_embeds.size()[:-1] bsz, seq_len = inputs_embeds.size()[:-1]
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds, past_key_values_length)
# expand embeddings if needed # expand embeddings if needed
max_pos = self.padding_idx + 1 + seq_len + past_key_values_length max_pos = self.padding_idx + 1 + seq_len + past_key_values_length
...@@ -176,7 +176,7 @@ class M2M100SinusoidalPositionalEmbedding(nn.Module): ...@@ -176,7 +176,7 @@ class M2M100SinusoidalPositionalEmbedding(nn.Module):
return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach() return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach()
def create_position_ids_from_inputs_embeds(self, inputs_embeds): def create_position_ids_from_inputs_embeds(self, inputs_embeds, past_key_values_length):
""" """
We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids. We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
...@@ -191,7 +191,7 @@ class M2M100SinusoidalPositionalEmbedding(nn.Module): ...@@ -191,7 +191,7 @@ class M2M100SinusoidalPositionalEmbedding(nn.Module):
position_ids = torch.arange( position_ids = torch.arange(
self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
) )
return position_ids.unsqueeze(0).expand(input_shape).contiguous() return position_ids.unsqueeze(0).expand(input_shape).contiguous() + past_key_values_length
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->M2M100 # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->M2M100
......
...@@ -211,7 +211,7 @@ class XGLMSinusoidalPositionalEmbedding(nn.Module): ...@@ -211,7 +211,7 @@ class XGLMSinusoidalPositionalEmbedding(nn.Module):
) )
else: else:
bsz, seq_len = inputs_embeds.size()[:-1] bsz, seq_len = inputs_embeds.size()[:-1]
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds, past_key_values_length)
# expand embeddings if needed # expand embeddings if needed
max_pos = self.padding_idx + 1 + seq_len + past_key_values_length max_pos = self.padding_idx + 1 + seq_len + past_key_values_length
...@@ -220,7 +220,7 @@ class XGLMSinusoidalPositionalEmbedding(nn.Module): ...@@ -220,7 +220,7 @@ class XGLMSinusoidalPositionalEmbedding(nn.Module):
return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach() return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach()
def create_position_ids_from_inputs_embeds(self, inputs_embeds): def create_position_ids_from_inputs_embeds(self, inputs_embeds, past_key_values_length):
""" """
We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids. We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
...@@ -235,7 +235,7 @@ class XGLMSinusoidalPositionalEmbedding(nn.Module): ...@@ -235,7 +235,7 @@ class XGLMSinusoidalPositionalEmbedding(nn.Module):
position_ids = torch.arange( position_ids = torch.arange(
self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
) )
return position_ids.unsqueeze(0).expand(input_shape).contiguous() return position_ids.unsqueeze(0).expand(input_shape).contiguous() + past_key_values_length
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->XGLM # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->XGLM
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment