Unverified Commit bd3b599c authored by Oyvind Tafjord's avatar Oyvind Tafjord Committed by GitHub
Browse files

Fix T5 beam search using parallelize (#11717)

parent 218d552f
......@@ -1682,7 +1682,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
for layer_past_state in layer_past_states:
# need to set correct `past` for each of the four key / value states
reordered_layer_past_states = reordered_layer_past_states + (
layer_past_state.index_select(0, beam_idx),
layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
)
assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment