Unverified Commit 8878eb1b authored by Ramiro Leal-Cavazos's avatar Ramiro Leal-Cavazos Committed by GitHub
Browse files

Remove unnecessary `view`s of `position_ids` (#26059)

* Remove unnecessary `view` of `position_ids` in `modeling_llama`

When `position_ids` is `None`, its value is generated using
`torch.arange`, which creates a tensor of size `(seq_length +
past_key_values_length) - past_key_values_length = seq_length`. The
tensor is then unsqueezed, resulting in a tensor of shape `(1,
seq_length)`. This means that the last `view` to a tensor of shape
`(-1, seq_length)` is a no-op.

This commit removes the unnecessary view.

* Remove no-op `view` of `position_ids` in rest of transformer models
parent 75a33d60
...@@ -475,9 +475,6 @@ class CodeGenModel(CodeGenPreTrainedModel): ...@@ -475,9 +475,6 @@ class CodeGenModel(CodeGenPreTrainedModel):
if token_type_ids is not None: if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1]) token_type_ids = token_type_ids.view(-1, input_shape[-1])
if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1]).long()
if past_key_values is None: if past_key_values is None:
past_length = 0 past_length = 0
past_key_values = tuple([None] * len(self.h)) past_key_values = tuple([None] * len(self.h))
...@@ -486,7 +483,7 @@ class CodeGenModel(CodeGenPreTrainedModel): ...@@ -486,7 +483,7 @@ class CodeGenModel(CodeGenPreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) position_ids = position_ids.unsqueeze(0)
# Attention mask. # Attention mask.
if attention_mask is not None: if attention_mask is not None:
......
...@@ -416,7 +416,7 @@ class CTRLModel(CTRLPreTrainedModel): ...@@ -416,7 +416,7 @@ class CTRLModel(CTRLPreTrainedModel):
past_length = past_key_values[0][0].size(-2) past_length = past_key_values[0][0].size(-2)
if position_ids is None: if position_ids is None:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) position_ids = position_ids.unsqueeze(0)
# Attention mask. # Attention mask.
if attention_mask is not None: if attention_mask is not None:
...@@ -447,7 +447,6 @@ class CTRLModel(CTRLPreTrainedModel): ...@@ -447,7 +447,6 @@ class CTRLModel(CTRLPreTrainedModel):
token_type_embeds *= np.sqrt(self.d_model_size) token_type_embeds *= np.sqrt(self.d_model_size)
else: else:
token_type_embeds = 0 token_type_embeds = 0
position_ids = position_ids.view(-1, input_shape[-1])
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.w(input_ids) inputs_embeds = self.w(input_ids)
......
...@@ -544,8 +544,6 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel): ...@@ -544,8 +544,6 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
if token_type_ids is not None: if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1]) token_type_ids = token_type_ids.view(-1, input_shape[-1])
if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1])
if past_key_values is None: if past_key_values is None:
past_length = 0 past_length = 0
...@@ -554,7 +552,7 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel): ...@@ -554,7 +552,7 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
past_length = past_key_values[0][0].size(-2) past_length = past_key_values[0][0].size(-2)
if position_ids is None: if position_ids is None:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) position_ids = position_ids.unsqueeze(0)
# GPT2Attention mask. # GPT2Attention mask.
if attention_mask is not None: if attention_mask is not None:
......
...@@ -630,9 +630,7 @@ class OpenLlamaModel(OpenLlamaPreTrainedModel): ...@@ -630,9 +630,7 @@ class OpenLlamaModel(OpenLlamaPreTrainedModel):
position_ids = torch.arange( position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
) )
position_ids = position_ids.unsqueeze(0).view(-1, seq_length) position_ids = position_ids.unsqueeze(0)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
......
...@@ -1128,9 +1128,7 @@ class FalconModel(FalconPreTrainedModel): ...@@ -1128,9 +1128,7 @@ class FalconModel(FalconPreTrainedModel):
position_ids = torch.arange( position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
) )
position_ids = position_ids.unsqueeze(0).view(-1, seq_length) position_ids = position_ids.unsqueeze(0)
else:
position_ids = position_ids.view(-1, seq_length).long()
causal_mask = self._prepare_attn_mask( causal_mask = self._prepare_attn_mask(
attention_mask, attention_mask,
......
...@@ -790,8 +790,6 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -790,8 +790,6 @@ class GPT2Model(GPT2PreTrainedModel):
if token_type_ids is not None: if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1]) token_type_ids = token_type_ids.view(-1, input_shape[-1])
if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1])
if past_key_values is None: if past_key_values is None:
past_length = 0 past_length = 0
...@@ -800,7 +798,7 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -800,7 +798,7 @@ class GPT2Model(GPT2PreTrainedModel):
past_length = past_key_values[0][0].size(-2) past_length = past_key_values[0][0].size(-2)
if position_ids is None: if position_ids is None:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) position_ids = position_ids.unsqueeze(0)
# GPT2Attention mask. # GPT2Attention mask.
if attention_mask is not None: if attention_mask is not None:
......
...@@ -577,8 +577,6 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel): ...@@ -577,8 +577,6 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
if token_type_ids is not None: if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1]) token_type_ids = token_type_ids.view(-1, input_shape[-1])
if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1])
if past_key_values is None: if past_key_values is None:
past_length = 0 past_length = 0
...@@ -594,7 +592,7 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel): ...@@ -594,7 +592,7 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
position_ids = position_ids[:, past_length : input_shape[-1] + past_length :] position_ids = position_ids[:, past_length : input_shape[-1] + past_length :]
elif position_ids is None: elif position_ids is None:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) position_ids = position_ids.unsqueeze(0)
# Self-attention mask. # Self-attention mask.
query_length = input_shape[-1] query_length = input_shape[-1]
......
...@@ -539,8 +539,6 @@ class GPTNeoModel(GPTNeoPreTrainedModel): ...@@ -539,8 +539,6 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
if token_type_ids is not None: if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1]) token_type_ids = token_type_ids.view(-1, input_shape[-1])
if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1])
if past_key_values is None: if past_key_values is None:
past_length = 0 past_length = 0
...@@ -550,7 +548,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel): ...@@ -550,7 +548,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) position_ids = position_ids.unsqueeze(0)
# Attention mask. # Attention mask.
if attention_mask is not None: if attention_mask is not None:
......
...@@ -596,9 +596,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel): ...@@ -596,9 +596,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
if position_ids is None: if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device) position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length) position_ids = position_ids.unsqueeze(0)
else:
position_ids = position_ids.view(-1, seq_length).long()
# Attention mask. # Attention mask.
if attention_mask is not None: if attention_mask is not None:
......
...@@ -592,9 +592,6 @@ class GPTJModel(GPTJPreTrainedModel): ...@@ -592,9 +592,6 @@ class GPTJModel(GPTJPreTrainedModel):
if token_type_ids is not None: if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1]) token_type_ids = token_type_ids.view(-1, input_shape[-1])
if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1]).long()
if past_key_values is None: if past_key_values is None:
past_length = 0 past_length = 0
past_key_values = tuple([None] * len(self.h)) past_key_values = tuple([None] * len(self.h))
...@@ -603,7 +600,7 @@ class GPTJModel(GPTJPreTrainedModel): ...@@ -603,7 +600,7 @@ class GPTJModel(GPTJPreTrainedModel):
if position_ids is None: if position_ids is None:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) position_ids = position_ids.unsqueeze(0)
# Attention mask. # Attention mask.
if attention_mask is not None: if attention_mask is not None:
......
...@@ -1208,9 +1208,7 @@ class IdeficsModel(IdeficsPreTrainedModel): ...@@ -1208,9 +1208,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
position_ids = torch.arange( position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
) )
position_ids = position_ids.unsqueeze(0).view(-1, seq_length) position_ids = position_ids.unsqueeze(0)
else:
position_ids = position_ids.view(-1, seq_length).long()
no_images = False no_images = False
if (pixel_values, image_encoder_embeddings, perceiver_embeddings).count(None) != 2: if (pixel_values, image_encoder_embeddings, perceiver_embeddings).count(None) != 2:
......
...@@ -729,8 +729,6 @@ class ImageGPTModel(ImageGPTPreTrainedModel): ...@@ -729,8 +729,6 @@ class ImageGPTModel(ImageGPTPreTrainedModel):
if token_type_ids is not None: if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1]) token_type_ids = token_type_ids.view(-1, input_shape[-1])
if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1])
if past_key_values is None: if past_key_values is None:
past_length = 0 past_length = 0
...@@ -739,7 +737,7 @@ class ImageGPTModel(ImageGPTPreTrainedModel): ...@@ -739,7 +737,7 @@ class ImageGPTModel(ImageGPTPreTrainedModel):
past_length = past_key_values[0][0].size(-2) past_length = past_key_values[0][0].size(-2)
if position_ids is None: if position_ids is None:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) position_ids = position_ids.unsqueeze(0)
# ImageGPTAttention mask. # ImageGPTAttention mask.
if attention_mask is not None: if attention_mask is not None:
......
...@@ -867,9 +867,7 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -867,9 +867,7 @@ class LlamaModel(LlamaPreTrainedModel):
position_ids = torch.arange( position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
) )
position_ids = position_ids.unsqueeze(0).view(-1, seq_length) position_ids = position_ids.unsqueeze(0)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
......
...@@ -638,9 +638,7 @@ class PersimmonModel(PersimmonPreTrainedModel): ...@@ -638,9 +638,7 @@ class PersimmonModel(PersimmonPreTrainedModel):
position_ids = torch.arange( position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
) )
position_ids = position_ids.unsqueeze(0).view(-1, seq_length) position_ids = position_ids.unsqueeze(0)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
......
...@@ -623,9 +623,7 @@ class XGLMModel(XGLMPreTrainedModel): ...@@ -623,9 +623,7 @@ class XGLMModel(XGLMPreTrainedModel):
dtype=torch.long, dtype=torch.long,
device=input_ids.device if input_ids is not None else inputs_embeds.device, device=input_ids.device if input_ids is not None else inputs_embeds.device,
) )
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) position_ids = position_ids.unsqueeze(0)
else:
position_ids = position_ids.view(-1, input_shape[-1])
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment