Unverified Commit 8881f38a authored by Dong-Yong Lee's avatar Dong-Yong Lee Committed by GitHub
Browse files

Fix beam search when using model parallel (#24969)



* Fix GPTNeoX beam search when using parallelize

* Fix beam search idx device when using model parallel

* remove onnx related stuff
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* fix: move test_beam_search_on_multi_gpu to GenerationTesterMixin

* fix: add right item to _no_split_modules of MegaPreTrainedModel

* fix: add num_beams within parallelized beam_search test
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 0dd06c3f
...@@ -1467,7 +1467,8 @@ class BartForConditionalGeneration(BartPreTrainedModel): ...@@ -1467,7 +1467,8 @@ class BartForConditionalGeneration(BartPreTrainedModel):
for layer_past in past_key_values: for layer_past in past_key_values:
# cached cross_attention states don't have to be reordered -> they are always the same # cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += ( reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
+ layer_past[2:],
) )
return reordered_past return reordered_past
...@@ -1946,5 +1947,7 @@ class BartForCausalLM(BartPreTrainedModel): ...@@ -1946,5 +1947,7 @@ class BartForCausalLM(BartPreTrainedModel):
def _reorder_cache(past_key_values, beam_idx): def _reorder_cache(past_key_values, beam_idx):
reordered_past = () reordered_past = ()
for layer_past in past_key_values: for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past return reordered_past
...@@ -1294,7 +1294,9 @@ class BertLMHeadModel(BertPreTrainedModel): ...@@ -1294,7 +1294,9 @@ class BertLMHeadModel(BertPreTrainedModel):
def _reorder_cache(self, past_key_values, beam_idx): def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = () reordered_past = ()
for layer_past in past_key_values: for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past return reordered_past
......
...@@ -1002,5 +1002,7 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel): ...@@ -1002,5 +1002,7 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel):
def _reorder_cache(self, past_key_values, beam_idx): def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = () reordered_past = ()
for layer_past in past_key_values: for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past return reordered_past
...@@ -2638,7 +2638,8 @@ class BigBirdForCausalLM(BigBirdPreTrainedModel): ...@@ -2638,7 +2638,8 @@ class BigBirdForCausalLM(BigBirdPreTrainedModel):
reordered_past = () reordered_past = ()
for layer_past in past_key_values: for layer_past in past_key_values:
reordered_past += ( reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
+ layer_past[2:],
) )
return reordered_past return reordered_past
......
...@@ -2651,7 +2651,8 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel): ...@@ -2651,7 +2651,8 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel):
for layer_past in past_key_values: for layer_past in past_key_values:
# cached cross_attention states don't have to be reordered -> they are always the same # cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += ( reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
+ layer_past[2:],
) )
return reordered_past return reordered_past
...@@ -3121,5 +3122,7 @@ class BigBirdPegasusForCausalLM(BigBirdPegasusPreTrainedModel): ...@@ -3121,5 +3122,7 @@ class BigBirdPegasusForCausalLM(BigBirdPegasusPreTrainedModel):
def _reorder_cache(past_key_values, beam_idx): def _reorder_cache(past_key_values, beam_idx):
reordered_past = () reordered_past = ()
for layer_past in past_key_values: for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past return reordered_past
...@@ -752,7 +752,9 @@ class BioGptForCausalLM(BioGptPreTrainedModel): ...@@ -752,7 +752,9 @@ class BioGptForCausalLM(BioGptPreTrainedModel):
def _reorder_cache(past_key_values, beam_idx): def _reorder_cache(past_key_values, beam_idx):
reordered_past = () reordered_past = ()
for layer_past in past_key_values: for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past return reordered_past
......
...@@ -1412,7 +1412,8 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel): ...@@ -1412,7 +1412,8 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel):
for layer_past in past_key_values: for layer_past in past_key_values:
# cached cross_attention states don't have to be reordered -> they are always the same # cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += ( reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
+ layer_past[2:],
) )
return reordered_past return reordered_past
...@@ -1634,5 +1635,7 @@ class BlenderbotForCausalLM(BlenderbotPreTrainedModel): ...@@ -1634,5 +1635,7 @@ class BlenderbotForCausalLM(BlenderbotPreTrainedModel):
def _reorder_cache(past_key_values, beam_idx): def _reorder_cache(past_key_values, beam_idx):
reordered_past = () reordered_past = ()
for layer_past in past_key_values: for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past return reordered_past
...@@ -1379,7 +1379,8 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel): ...@@ -1379,7 +1379,8 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel):
for layer_past in past_key_values: for layer_past in past_key_values:
# cached cross_attention states don't have to be reordered -> they are always the same # cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += ( reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
+ layer_past[2:],
) )
return reordered_past return reordered_past
...@@ -1601,5 +1602,7 @@ class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel): ...@@ -1601,5 +1602,7 @@ class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel):
def _reorder_cache(past_key_values, beam_idx): def _reorder_cache(past_key_values, beam_idx):
reordered_past = () reordered_past = ()
for layer_past in past_key_values: for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past return reordered_past
...@@ -934,5 +934,7 @@ class BlipTextLMHeadModel(BlipTextPreTrainedModel): ...@@ -934,5 +934,7 @@ class BlipTextLMHeadModel(BlipTextPreTrainedModel):
def _reorder_cache(self, past_key_values, beam_idx): def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = () reordered_past = ()
for layer_past in past_key_values: for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past return reordered_past
...@@ -1551,7 +1551,9 @@ class CamembertForCausalLM(CamembertPreTrainedModel): ...@@ -1551,7 +1551,9 @@ class CamembertForCausalLM(CamembertPreTrainedModel):
def _reorder_cache(self, past_key_values, beam_idx): def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = () reordered_past = ()
for layer_past in past_key_values: for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past return reordered_past
......
...@@ -1018,7 +1018,9 @@ class Data2VecTextForCausalLM(Data2VecTextPreTrainedModel): ...@@ -1018,7 +1018,9 @@ class Data2VecTextForCausalLM(Data2VecTextPreTrainedModel):
def _reorder_cache(self, past_key_values, beam_idx): def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = () reordered_past = ()
for layer_past in past_key_values: for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past return reordered_past
......
...@@ -881,7 +881,9 @@ class OpenLlamaForCausalLM(OpenLlamaPreTrainedModel): ...@@ -881,7 +881,9 @@ class OpenLlamaForCausalLM(OpenLlamaPreTrainedModel):
def _reorder_cache(past_key_values, beam_idx): def _reorder_cache(past_key_values, beam_idx):
reordered_past = () reordered_past = ()
for layer_past in past_key_values: for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past return reordered_past
......
...@@ -1677,5 +1677,7 @@ class ElectraForCausalLM(ElectraPreTrainedModel): ...@@ -1677,5 +1677,7 @@ class ElectraForCausalLM(ElectraPreTrainedModel):
def _reorder_cache(self, past_key_values, beam_idx): def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = () reordered_past = ()
for layer_past in past_key_values: for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past return reordered_past
...@@ -1236,7 +1236,9 @@ class ErnieForCausalLM(ErniePreTrainedModel): ...@@ -1236,7 +1236,9 @@ class ErnieForCausalLM(ErniePreTrainedModel):
def _reorder_cache(self, past_key_values, beam_idx): def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = () reordered_past = ()
for layer_past in past_key_values: for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past return reordered_past
......
...@@ -1568,5 +1568,7 @@ class GitForCausalLM(GitPreTrainedModel): ...@@ -1568,5 +1568,7 @@ class GitForCausalLM(GitPreTrainedModel):
def _reorder_cache(self, past_key_values, beam_idx): def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = () reordered_past = ()
for layer_past in past_key_values: for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past return reordered_past
...@@ -839,7 +839,8 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel): ...@@ -839,7 +839,8 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel):
reordered_past = () reordered_past = ()
for layer_past in past_key_values: for layer_past in past_key_values:
reordered_past += ( reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
+ layer_past[2:],
) )
return reordered_past return reordered_past
......
...@@ -723,6 +723,7 @@ class GPTNeoXJapaneseForCausalLM(GPTNeoXJapanesePreTrainedModel): ...@@ -723,6 +723,7 @@ class GPTNeoXJapaneseForCausalLM(GPTNeoXJapanesePreTrainedModel):
reordered_past = () reordered_past = ()
for layer_past in past_key_values: for layer_past in past_key_values:
reordered_past += ( reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
+ layer_past[2:],
) )
return reordered_past return reordered_past
...@@ -2509,7 +2509,8 @@ class LEDForConditionalGeneration(LEDPreTrainedModel): ...@@ -2509,7 +2509,8 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
for layer_past in past_key_values: for layer_past in past_key_values:
# cached cross_attention states don't have to be reordered -> they are always the same # cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += ( reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
+ layer_past[2:],
) )
return reordered_past return reordered_past
......
...@@ -1385,5 +1385,7 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel): ...@@ -1385,5 +1385,7 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
def _reorder_cache(past_key_values, beam_idx): def _reorder_cache(past_key_values, beam_idx):
reordered_past = () reordered_past = ()
for layer_past in past_key_values: for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past return reordered_past
...@@ -1532,7 +1532,8 @@ class MarianMTModel(MarianPreTrainedModel): ...@@ -1532,7 +1532,8 @@ class MarianMTModel(MarianPreTrainedModel):
for layer_past in past_key_values: for layer_past in past_key_values:
# cached cross_attention states don't have to be reordered -> they are always the same # cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += ( reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
+ layer_past[2:],
) )
return reordered_past return reordered_past
...@@ -1752,5 +1753,7 @@ class MarianForCausalLM(MarianPreTrainedModel): ...@@ -1752,5 +1753,7 @@ class MarianForCausalLM(MarianPreTrainedModel):
def _reorder_cache(past_key_values, beam_idx): def _reorder_cache(past_key_values, beam_idx):
reordered_past = () reordered_past = ()
for layer_past in past_key_values: for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past return reordered_past
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment