Unverified Commit 8881f38a authored by Dong-Yong Lee's avatar Dong-Yong Lee Committed by GitHub
Browse files

Fix beam search when using model parallel (#24969)



* Fix GPTNeoX beam search when using parallelize

* Fix beam search idx device when using model parallel

* remove onnx related stuff
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* fix: move test_beam_search_on_multi_gpu to GenerationTesterMixin

* fix: add right item to _no_split_modules of MegaPreTrainedModel

* fix: add num_beams within parallelized beam_search test
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 0dd06c3f
......@@ -961,7 +961,9 @@ class MarkupLMModel(MarkupLMPreTrainedModel):
def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
......
......@@ -1431,7 +1431,8 @@ class MBartForConditionalGeneration(MBartPreTrainedModel):
for layer_past in past_key_values:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
+ layer_past[2:],
)
return reordered_past
......@@ -1904,5 +1905,7 @@ class MBartForCausalLM(MBartPreTrainedModel):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
......@@ -1341,7 +1341,7 @@ class MegaPreTrainedModel(PreTrainedModel):
config_class = MegaConfig
base_model_prefix = "mega"
supports_gradient_checkpointing = False
_no_split_modules = []
_no_split_modules = ["MegaMovingAverageGatedAttention"]
def _init_weights(self, module):
"""Initialize the weights"""
......@@ -1802,7 +1802,9 @@ class MegaForCausalLM(MegaPreTrainedModel):
def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
......
......@@ -1260,7 +1260,9 @@ class MegatronBertForCausalLM(MegatronBertPreTrainedModel):
def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
......
......@@ -1595,7 +1595,8 @@ class MvpForConditionalGeneration(MvpPreTrainedModel):
for layer_past in past_key_values:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
+ layer_past[2:],
)
return reordered_past
......@@ -2066,5 +2067,7 @@ class MvpForCausalLM(MvpPreTrainedModel):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
......@@ -1826,5 +1826,7 @@ class NllbMoeForConditionalGeneration(NllbMoePreTrainedModel):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
......@@ -1003,7 +1003,9 @@ class OPTForCausalLM(OPTPreTrainedModel):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
......
......@@ -1489,7 +1489,8 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel):
for layer_past in past_key_values:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
+ layer_past[2:],
)
return reordered_past
......@@ -1731,5 +1732,7 @@ class PegasusForCausalLM(PegasusPreTrainedModel):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
......@@ -1691,7 +1691,8 @@ class PegasusXForConditionalGeneration(PegasusXPreTrainedModel):
for layer_past in past_key_values:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
+ layer_past[2:],
)
return reordered_past
......
......@@ -1402,7 +1402,8 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel):
for layer_past in past_key_values:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
+ layer_past[2:],
)
return reordered_past
......@@ -1751,5 +1752,7 @@ class PLBartForCausalLM(PLBartPreTrainedModel):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
......@@ -2068,7 +2068,8 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
for layer_past in past_key_values:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
+ layer_past[2:],
)
return reordered_past
......@@ -2312,7 +2313,9 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
......
......@@ -1160,7 +1160,9 @@ class QDQBertLMHeadModel(QDQBertPreTrainedModel):
def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
......
......@@ -1213,7 +1213,9 @@ class RagTokenForGeneration(RagPreTrainedModel):
reordered_past = ()
for layer_past in past_key_values:
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
reordered_past += (tuple(_reorder_stacked(past_state, beam_idx) for past_state in layer_past),)
reordered_past += (
tuple(_reorder_stacked(past_state, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
......
......@@ -2300,12 +2300,12 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel):
for layer_past in past_key_values:
# buckets
if layer_past[0] is not None:
reord_buckets = layer_past[0].index_select(0, beam_idx)
reord_buckets = layer_past[0].index_select(0, beam_idx.to(layer_past[0].device))
else:
reord_buckets = None
# hidden states
reord_hidden_states = layer_past[1].index_select(0, beam_idx)
reord_hidden_states = layer_past[1].index_select(0, beam_idx.to(layer_past[1].device))
reord_past_buckets_states.append((reord_buckets, reord_hidden_states))
return reord_past_buckets_states
......
......@@ -1157,7 +1157,8 @@ class RemBertForCausalLM(RemBertPreTrainedModel):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
+ layer_past[2:],
)
return reordered_past
......
......@@ -1016,7 +1016,9 @@ class RobertaForCausalLM(RobertaPreTrainedModel):
def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
......
......@@ -1023,7 +1023,9 @@ class RobertaPreLayerNormForCausalLM(RobertaPreLayerNormPreTrainedModel):
def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
......
......@@ -1580,7 +1580,9 @@ class RoCBertForCausalLM(RoCBertPreTrainedModel):
def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
......
......@@ -1188,7 +1188,8 @@ class RoFormerForCausalLM(RoFormerPreTrainedModel):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
+ layer_past[2:],
)
return reordered_past
......
......@@ -1418,5 +1418,7 @@ class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment