Unverified Commit 33196b45 authored by Fei Wang's avatar Fei Wang Committed by GitHub
Browse files

Fix LLaMa beam search when using parallelize (#24224)

* Fix LLaMa beam search when using parallelize

same issue as T5 #11717

* fix code format in modeling_llama.py

* fix format of _reorder_cache in modeling_llama.py
parent 7504be35
......@@ -762,7 +762,9 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment