"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "b88e0e016db8cc9dc3e21db1bcc2615a1717ddcb"
Unverified Commit 2da82bb4 authored by ValeKnappich's avatar ValeKnappich Committed by GitHub
Browse files

fix past_key_values in GPTNeoXForCausalLM.prepare_inputs_for_generation (#20621)

* fix past_key_values in GPTNeoXForCausalLM.prepare_inputs_for_generation

* fix formatting
parent 852e7eba
......@@ -697,7 +697,11 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel):
if past and past[0] is not None:
input_ids = input_ids[:, -1:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past or model_kwargs.get("past_key_values"),
}
def _reorder_cache(self, past, beam_idx):
reordered_past = ()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment