Unverified Commit 554e7ada authored by jiqing-feng's avatar jiqing-feng Committed by GitHub
Browse files

check if position_ids exists before using it (#29306)


Co-authored-by: default avatarJoao Gante <joao@huggingface.co>
parent d3a4b475
...@@ -1168,7 +1168,9 @@ class GemmaForCausalLM(GemmaPreTrainedModel): ...@@ -1168,7 +1168,9 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
# TODO @gante we should only keep a `cache_position` in generate, and do +=1. # TODO @gante we should only keep a `cache_position` in generate, and do +=1.
# same goes for position ids. Could also help with continued generation. # same goes for position ids. Could also help with continued generation.
cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device) input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
position_ids = position_ids.contiguous() if position_ids is not None else None
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None: if inputs_embeds is not None and past_key_values is None:
...@@ -1181,7 +1183,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel): ...@@ -1181,7 +1183,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
model_inputs.update( model_inputs.update(
{ {
"position_ids": position_ids.contiguous(), "position_ids": position_ids,
"cache_position": cache_position, "cache_position": cache_position,
"past_key_values": past_key_values, "past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"), "use_cache": kwargs.get("use_cache"),
......
...@@ -1284,7 +1284,9 @@ class LlamaForCausalLM(LlamaPreTrainedModel): ...@@ -1284,7 +1284,9 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
# TODO @gante we should only keep a `cache_position` in generate, and do +=1. # TODO @gante we should only keep a `cache_position` in generate, and do +=1.
# same goes for position ids. Could also help with continued generation. # same goes for position ids. Could also help with continued generation.
cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device) input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
position_ids = position_ids.contiguous() if position_ids is not None else None
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None: if inputs_embeds is not None and past_key_values is None:
...@@ -1297,7 +1299,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel): ...@@ -1297,7 +1299,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
model_inputs.update( model_inputs.update(
{ {
"position_ids": position_ids.contiguous(), "position_ids": position_ids,
"cache_position": cache_position, "cache_position": cache_position,
"past_key_values": past_key_values, "past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"), "use_cache": kwargs.get("use_cache"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment