"vscode:/vscode.git/clone" did not exist on "936ab7bae5e040ec58994cb722dd587b9ab26581"
Unverified Commit 8881f38a authored by Dong-Yong Lee's avatar Dong-Yong Lee Committed by GitHub
Browse files

Fix beam search when using model parallel (#24969)



* Fix GPTNeoX beam search when using parallelize

* Fix beam search idx device when using model parallel

* remove onnx related stuff
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* fix: move test_beam_search_on_multi_gpu to GenerationTesterMixin

* fix: add right item to _no_split_modules of MegaPreTrainedModel

* fix: add num_beams within parallelized beam_search test
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 0dd06c3f
......@@ -976,5 +976,7 @@ class Speech2Text2ForCausalLM(Speech2Text2PreTrainedModel):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
......@@ -2525,7 +2525,9 @@ class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
......
......@@ -1016,5 +1016,7 @@ class TrOCRForCausalLM(TrOCRPreTrainedModel):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
......@@ -1331,7 +1331,9 @@ class UMT5ForConditionalGeneration(UMT5PreTrainedModel):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
......
......@@ -1794,7 +1794,9 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_precision=0.02):
......
......@@ -881,5 +881,7 @@ class XGLMForCausalLM(XGLMPreTrainedModel):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
......@@ -2095,7 +2095,8 @@ class XLMProphetNetForConditionalGeneration(XLMProphetNetPreTrainedModel):
for layer_past in past_key_values:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
+ layer_past[2:],
)
return reordered_past
......@@ -2340,7 +2341,9 @@ class XLMProphetNetForCausalLM(XLMProphetNetPreTrainedModel):
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
......
......@@ -1020,7 +1020,9 @@ class XLMRobertaForCausalLM(XLMRobertaPreTrainedModel):
def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
......
......@@ -979,7 +979,9 @@ class XLMRobertaXLForCausalLM(XLMRobertaXLPreTrainedModel):
def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
......
......@@ -1128,7 +1128,9 @@ class XmodForCausalLM(XmodPreTrainedModel):
def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
......
......@@ -1180,7 +1180,7 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],)
reordered_past += (tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) + layer_past[2:],)
return reordered_past
class {{cookiecutter.camelcase_modelname}}ClassificationHead(nn.Module):
......@@ -2898,7 +2898,7 @@ class {{cookiecutter.camelcase_modelname}}ForConditionalGeneration({{cookiecutte
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
reordered_past += (tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),)
return reordered_past
......@@ -3335,6 +3335,6 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
reordered_past += (tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),)
return reordered_past
{% endif -%}
......@@ -15,13 +15,14 @@
import inspect
import tempfile
import unittest
import warnings
import numpy as np
from transformers import is_torch_available, pipeline
from transformers.testing_utils import require_torch, slow, torch_device
from transformers.testing_utils import require_accelerate, require_torch, require_torch_multi_gpu, slow, torch_device
from ..test_modeling_common import floats_tensor, ids_tensor
from .test_framework_agnostic import GenerationIntegrationTestsMixin
......@@ -1017,6 +1018,27 @@ class GenerationTesterMixin:
output, input_ids, model.config, use_cache=True, num_return_sequences=beam_scorer.num_beams
)
@require_accelerate
@require_torch_multi_gpu
def test_model_parallel_beam_search(self):
for model_class in self.all_generative_model_classes:
if model_class._no_split_modules is None:
continue
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
model = model_class(config).eval()
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir)
new_model = model_class.from_pretrained(tmp_dir, device_map="auto")
new_model.generate(
input_ids,
attention_mask=attention_mask,
max_length=max_length,
num_beams=2,
)
def test_beam_sample_generate(self):
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
......
......@@ -2482,34 +2482,6 @@ class ModelTesterMixin:
for value_, parallel_value_ in zip(value, parallel_value):
self.assertTrue(torch.allclose(value_, parallel_value_.to("cpu"), atol=1e-7))
@require_torch_multi_gpu
def test_model_parallel_beam_search(self):
if not self.test_model_parallel:
return
all_generative_and_parallelizable_model_classes = tuple(
set(self.all_generative_model_classes).intersection(self.all_parallelizable_model_classes)
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in all_generative_and_parallelizable_model_classes:
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
def cast_to_device(dictionary, device):
output = {}
for k, v in dictionary.items():
if isinstance(v, torch.Tensor):
output[k] = v.to(device)
else:
output[k] = v
return output
model.parallelize()
model.generate(**cast_to_device(inputs_dict, "cuda:0"), num_beams=2)
def check_device_map_is_respected(self, model, device_map):
for param_name, param in model.named_parameters():
# Find device in device_map
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment