Unverified Commit 08abdabd authored by TobiasNorlund's avatar TobiasNorlund Committed by GitHub
Browse files

Fixed beam search generation for GPT2 and T5 (#9219)

parent 161a6461
......@@ -156,7 +156,7 @@ class GenerationMixin:
if is_encoder_decoder:
assert encoder_outputs is not None
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
0, expanded_return_idx
0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device)
)
model_kwargs["encoder_outputs"] = encoder_outputs
return input_ids, model_kwargs
......@@ -226,7 +226,7 @@ class GenerationMixin:
For custom re-ordering of :obj:`past_key_values` or :obj:`mems`, the function should be implemented in
subclasses of :class:`~transformers.PreTrainedModel`.
"""
return tuple(layer_past.index_select(1, beam_idx) for layer_past in past)
return tuple(layer_past.index_select(1, beam_idx.to(layer_past.device)) for layer_past in past)
def _get_logits_warper(
self, top_k: int = None, top_p: float = None, temperature: float = None, num_beams: int = None
......
......@@ -1166,6 +1166,34 @@ class ModelTesterMixin:
for value_, parallel_value_ in zip(value, parallel_value):
self.assertTrue(torch.allclose(value_, parallel_value_.to("cpu"), atol=1e-7))
@require_torch_multi_gpu
def test_model_parallel_beam_search(self):
if not self.test_model_parallel:
return
all_generative_and_parallelizable_model_classes = tuple(
set(self.all_generative_model_classes).intersection(self.all_parallelizable_model_classes)
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in all_generative_and_parallelizable_model_classes:
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
def cast_to_device(dictionary, device):
output = {}
for k, v in dictionary.items():
if isinstance(v, torch.Tensor):
output[k] = v.to(device)
else:
output[k] = v
return output
model.parallelize()
model.generate(**cast_to_device(inputs_dict, "cuda:0"), num_beams=2)
global_rng = random.Random()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment