Unverified Commit 89087597 authored by Joel Lamy-Poirier's avatar Joel Lamy-Poirier Committed by GitHub
Browse files

Indexing fix for gpt_bigcode (#22737)

Fix indexing
parent 7ade6ef7
......@@ -538,7 +538,7 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
past_key_values: Optional[Union[List[torch.Tensor], int]] = None,
past_key_values: Optional[List[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
......@@ -584,7 +584,7 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
past_length = 0
past_key_values = tuple([None] * len(self.h))
else:
past_length = past_key_values[0][0].size(-2)
past_length = past_key_values[0].size(-2)
if attention_mask is not None and len(attention_mask.shape) == 2 and position_ids is None:
# create position_ids on the fly for batch generation
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment