"vscode:/vscode.git/clone" did not exist on "68fa1e855bef7d77e227686543787d8e2c4463fc"
Unverified Commit 8881f38a authored by Dong-Yong Lee's avatar Dong-Yong Lee Committed by GitHub
Browse files

Fix beam search when using model parallel (#24969)



* Fix GPTNeoX beam search when using parallelize

* Fix beam search idx device when using model parallel

* remove onnx related stuff
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* fix: move test_beam_search_on_multi_gpu to GenerationTesterMixin

* fix: add right item to _no_split_modules of MegaPreTrainedModel

* fix: add num_beams within parallelized beam_search test
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 0dd06c3f
...@@ -976,5 +976,7 @@ class Speech2Text2ForCausalLM(Speech2Text2PreTrainedModel): ...@@ -976,5 +976,7 @@ class Speech2Text2ForCausalLM(Speech2Text2PreTrainedModel):
def _reorder_cache(past_key_values, beam_idx): def _reorder_cache(past_key_values, beam_idx):
reordered_past = () reordered_past = ()
for layer_past in past_key_values: for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past return reordered_past
...@@ -2525,7 +2525,9 @@ class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel): ...@@ -2525,7 +2525,9 @@ class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel):
def _reorder_cache(past_key_values, beam_idx): def _reorder_cache(past_key_values, beam_idx):
reordered_past = () reordered_past = ()
for layer_past in past_key_values: for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past return reordered_past
......
...@@ -1016,5 +1016,7 @@ class TrOCRForCausalLM(TrOCRPreTrainedModel): ...@@ -1016,5 +1016,7 @@ class TrOCRForCausalLM(TrOCRPreTrainedModel):
def _reorder_cache(past_key_values, beam_idx): def _reorder_cache(past_key_values, beam_idx):
reordered_past = () reordered_past = ()
for layer_past in past_key_values: for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past return reordered_past
...@@ -1331,7 +1331,9 @@ class UMT5ForConditionalGeneration(UMT5PreTrainedModel): ...@@ -1331,7 +1331,9 @@ class UMT5ForConditionalGeneration(UMT5PreTrainedModel):
def _reorder_cache(past_key_values, beam_idx): def _reorder_cache(past_key_values, beam_idx):
reordered_past = () reordered_past = ()
for layer_past in past_key_values: for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past return reordered_past
......
...@@ -1794,7 +1794,9 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): ...@@ -1794,7 +1794,9 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
def _reorder_cache(past_key_values, beam_idx): def _reorder_cache(past_key_values, beam_idx):
reordered_past = () reordered_past = ()
for layer_past in past_key_values: for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past return reordered_past
def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_precision=0.02): def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_precision=0.02):
......
...@@ -881,5 +881,7 @@ class XGLMForCausalLM(XGLMPreTrainedModel): ...@@ -881,5 +881,7 @@ class XGLMForCausalLM(XGLMPreTrainedModel):
def _reorder_cache(past_key_values, beam_idx): def _reorder_cache(past_key_values, beam_idx):
reordered_past = () reordered_past = ()
for layer_past in past_key_values: for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past return reordered_past
...@@ -2095,7 +2095,8 @@ class XLMProphetNetForConditionalGeneration(XLMProphetNetPreTrainedModel): ...@@ -2095,7 +2095,8 @@ class XLMProphetNetForConditionalGeneration(XLMProphetNetPreTrainedModel):
for layer_past in past_key_values: for layer_past in past_key_values:
# cached cross_attention states don't have to be reordered -> they are always the same # cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += ( reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
+ layer_past[2:],
) )
return reordered_past return reordered_past
...@@ -2340,7 +2341,9 @@ class XLMProphetNetForCausalLM(XLMProphetNetPreTrainedModel): ...@@ -2340,7 +2341,9 @@ class XLMProphetNetForCausalLM(XLMProphetNetPreTrainedModel):
def _reorder_cache(past_key_values, beam_idx): def _reorder_cache(past_key_values, beam_idx):
reordered_past = () reordered_past = ()
for layer_past in past_key_values: for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past return reordered_past
......
...@@ -1020,7 +1020,9 @@ class XLMRobertaForCausalLM(XLMRobertaPreTrainedModel): ...@@ -1020,7 +1020,9 @@ class XLMRobertaForCausalLM(XLMRobertaPreTrainedModel):
def _reorder_cache(self, past_key_values, beam_idx): def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = () reordered_past = ()
for layer_past in past_key_values: for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past return reordered_past
......
...@@ -979,7 +979,9 @@ class XLMRobertaXLForCausalLM(XLMRobertaXLPreTrainedModel): ...@@ -979,7 +979,9 @@ class XLMRobertaXLForCausalLM(XLMRobertaXLPreTrainedModel):
def _reorder_cache(self, past_key_values, beam_idx): def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = () reordered_past = ()
for layer_past in past_key_values: for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past return reordered_past
......
...@@ -1128,7 +1128,9 @@ class XmodForCausalLM(XmodPreTrainedModel): ...@@ -1128,7 +1128,9 @@ class XmodForCausalLM(XmodPreTrainedModel):
def _reorder_cache(self, past_key_values, beam_idx): def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = () reordered_past = ()
for layer_past in past_key_values: for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past return reordered_past
......
...@@ -1180,7 +1180,7 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m ...@@ -1180,7 +1180,7 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
def _reorder_cache(self, past_key_values, beam_idx): def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = () reordered_past = ()
for layer_past in past_key_values: for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],) reordered_past += (tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) + layer_past[2:],)
return reordered_past return reordered_past
class {{cookiecutter.camelcase_modelname}}ClassificationHead(nn.Module): class {{cookiecutter.camelcase_modelname}}ClassificationHead(nn.Module):
...@@ -2898,7 +2898,7 @@ class {{cookiecutter.camelcase_modelname}}ForConditionalGeneration({{cookiecutte ...@@ -2898,7 +2898,7 @@ class {{cookiecutter.camelcase_modelname}}ForConditionalGeneration({{cookiecutte
def _reorder_cache(past_key_values, beam_idx): def _reorder_cache(past_key_values, beam_idx):
reordered_past = () reordered_past = ()
for layer_past in past_key_values: for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) reordered_past += (tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),)
return reordered_past return reordered_past
...@@ -3335,6 +3335,6 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m ...@@ -3335,6 +3335,6 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
def _reorder_cache(past_key_values, beam_idx): def _reorder_cache(past_key_values, beam_idx):
reordered_past = () reordered_past = ()
for layer_past in past_key_values: for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) reordered_past += (tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),)
return reordered_past return reordered_past
{% endif -%} {% endif -%}
...@@ -15,13 +15,14 @@ ...@@ -15,13 +15,14 @@
import inspect import inspect
import tempfile
import unittest import unittest
import warnings import warnings
import numpy as np import numpy as np
from transformers import is_torch_available, pipeline from transformers import is_torch_available, pipeline
from transformers.testing_utils import require_torch, slow, torch_device from transformers.testing_utils import require_accelerate, require_torch, require_torch_multi_gpu, slow, torch_device
from ..test_modeling_common import floats_tensor, ids_tensor from ..test_modeling_common import floats_tensor, ids_tensor
from .test_framework_agnostic import GenerationIntegrationTestsMixin from .test_framework_agnostic import GenerationIntegrationTestsMixin
...@@ -1017,6 +1018,27 @@ class GenerationTesterMixin: ...@@ -1017,6 +1018,27 @@ class GenerationTesterMixin:
output, input_ids, model.config, use_cache=True, num_return_sequences=beam_scorer.num_beams output, input_ids, model.config, use_cache=True, num_return_sequences=beam_scorer.num_beams
) )
@require_accelerate
@require_torch_multi_gpu
def test_model_parallel_beam_search(self):
for model_class in self.all_generative_model_classes:
if model_class._no_split_modules is None:
continue
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
model = model_class(config).eval()
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir)
new_model = model_class.from_pretrained(tmp_dir, device_map="auto")
new_model.generate(
input_ids,
attention_mask=attention_mask,
max_length=max_length,
num_beams=2,
)
def test_beam_sample_generate(self): def test_beam_sample_generate(self):
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
......
...@@ -2482,34 +2482,6 @@ class ModelTesterMixin: ...@@ -2482,34 +2482,6 @@ class ModelTesterMixin:
for value_, parallel_value_ in zip(value, parallel_value): for value_, parallel_value_ in zip(value, parallel_value):
self.assertTrue(torch.allclose(value_, parallel_value_.to("cpu"), atol=1e-7)) self.assertTrue(torch.allclose(value_, parallel_value_.to("cpu"), atol=1e-7))
@require_torch_multi_gpu
def test_model_parallel_beam_search(self):
if not self.test_model_parallel:
return
all_generative_and_parallelizable_model_classes = tuple(
set(self.all_generative_model_classes).intersection(self.all_parallelizable_model_classes)
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in all_generative_and_parallelizable_model_classes:
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
def cast_to_device(dictionary, device):
output = {}
for k, v in dictionary.items():
if isinstance(v, torch.Tensor):
output[k] = v.to(device)
else:
output[k] = v
return output
model.parallelize()
model.generate(**cast_to_device(inputs_dict, "cuda:0"), num_beams=2)
def check_device_map_is_respected(self, model, device_map): def check_device_map_is_respected(self, model, device_map):
for param_name, param in model.named_parameters(): for param_name, param in model.named_parameters():
# Find device in device_map # Find device in device_map
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment