Unverified Commit 04428160 authored by Pablo Montalvo's avatar Pablo Montalvo Committed by GitHub
Browse files

Fix generate with `inputs_embeds` as input (#32493)

* I think inputs_embeds has ndim == 3

* fix sequence length catch

* add generate test

* [run-slow]olmo, persimmon, gemma, gemma2, qwen2, llama

* skip whisper

* fix bart test

* more fixes
parent b01f9c48
...@@ -502,6 +502,11 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin ...@@ -502,6 +502,11 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_model_as_decoder(*config_and_inputs) self.model_tester.create_and_check_model_as_decoder(*config_and_inputs)
@unittest.skip(reason="Generate needs input ids")
def test_inputs_embeds_matches_input_ids_with_generate(self):
# generate only works with input ids for bertforcausalLM
pass
def test_model_as_decoder_with_default_input_mask(self): def test_model_as_decoder_with_default_input_mask(self):
# This regression test was failing with PyTorch < 1.3 # This regression test was failing with PyTorch < 1.3
( (
......
...@@ -4058,6 +4058,11 @@ class WhisperStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, ...@@ -4058,6 +4058,11 @@ class WhisperStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin,
# generate only works with input ids for whisper # generate only works with input ids for whisper
pass pass
@unittest.skip(reason="Generate needs input ids")
def test_inputs_embeds_matches_input_ids_with_generate(self):
# generate only works with input ids for whisper
pass
@unittest.skip(reason="Decoder can't keep attention grads") @unittest.skip(reason="Decoder can't keep attention grads")
def test_retain_grad_hidden_states_attentions(self): def test_retain_grad_hidden_states_attentions(self):
return return
......
...@@ -2819,6 +2819,53 @@ class ModelTesterMixin: ...@@ -2819,6 +2819,53 @@ class ModelTesterMixin:
)[0] )[0]
self.assertTrue(torch.allclose(out_embeds, out_ids)) self.assertTrue(torch.allclose(out_embeds, out_ids))
def test_inputs_embeds_matches_input_ids_with_generate(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
if model_class.__name__ not in get_values(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES):
continue
model = model_class(config)
model.to(torch_device)
model.eval()
model_forward_args = inspect.signature(model.forward).parameters
if "inputs_embeds" not in model_forward_args:
self.skipTest(reason="This model doesn't use `inputs_embeds`")
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
pad_token_id = config.pad_token_id if config.pad_token_id is not None else 1
wte = model.get_input_embeddings()
if not self.is_encoder_decoder:
input_ids = inputs["input_ids"]
# some models infer position ids/attn mask differently when input ids
# by check if pad_token let's make sure no padding is in input ids
not_pad_token_id = pad_token_id + 1 if max(0, pad_token_id - 1) == 0 else pad_token_id - 1
input_ids[input_ids == pad_token_id] = not_pad_token_id
del inputs["input_ids"]
inputs_embeds = wte(input_ids)
out_ids = model.generate(input_ids=input_ids, **inputs, max_new_tokens=2)[:, -2:]
out_embeds = model.generate(inputs_embeds=inputs_embeds, **inputs, max_new_tokens=2)
else:
encoder_input_ids = inputs["input_ids"]
decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
encoder_input_ids[encoder_input_ids == pad_token_id] = max(0, pad_token_id + 1)
decoder_input_ids[decoder_input_ids == pad_token_id] = max(0, pad_token_id + 1)
del inputs["input_ids"]
inputs.pop("decoder_input_ids", None)
inputs_embeds = wte(encoder_input_ids)
decoder_inputs_embeds = wte(decoder_input_ids)
out_ids = model.generate(
input_ids=encoder_input_ids, decoder_input_ids=decoder_input_ids, **inputs, max_new_tokens=2
)[:, -2:]
out_embeds = model.generate(
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
**inputs,
max_new_tokens=2,
)
self.assertTrue(torch.allclose(out_embeds, out_ids))
@require_torch_multi_gpu @require_torch_multi_gpu
def test_multi_gpu_data_parallel_forward(self): def test_multi_gpu_data_parallel_forward(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment