"docs/source/vscode:/vscode.git/clone" did not exist on "d583f1317be422128b7bb56984720387b3bbcb35"
Unverified Commit 83b26dd7 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[`generate`] fix breaking change for patch (#29976)

* fix bug and add tests

* nit

* otherway to get the cur len instead of attention mask

* more places where this might have been broken

* nit

* oups

* inputs_embeds vs input_embeds

* test generated outptus

* style

* nit

* fix

* skip failing biogpt
parent 096f3046
...@@ -3034,6 +3034,8 @@ class GenerationMixin: ...@@ -3034,6 +3034,8 @@ class GenerationMixin:
num_beams = beam_scorer.num_beams num_beams = beam_scorer.num_beams
batch_beam_size, cur_len = input_ids.shape batch_beam_size, cur_len = input_ids.shape
if "inputs_embeds" in model_kwargs:
cur_len = model_kwargs["inputs_embeds"].shape[1]
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
if num_beams * batch_size != batch_beam_size: if num_beams * batch_size != batch_beam_size:
...@@ -3437,6 +3439,8 @@ class GenerationMixin: ...@@ -3437,6 +3439,8 @@ class GenerationMixin:
num_beams = beam_scorer.num_beams num_beams = beam_scorer.num_beams
batch_beam_size, cur_len = input_ids.shape batch_beam_size, cur_len = input_ids.shape
if "inputs_embeds" in model_kwargs:
cur_len = model_kwargs["inputs_embeds"].shape[1]
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
# init attention / hidden states / scores tuples # init attention / hidden states / scores tuples
...@@ -3795,6 +3799,8 @@ class GenerationMixin: ...@@ -3795,6 +3799,8 @@ class GenerationMixin:
device = input_ids.device device = input_ids.device
batch_beam_size, cur_len = input_ids.shape batch_beam_size, cur_len = input_ids.shape
if "inputs_embeds" in model_kwargs:
cur_len = model_kwargs["inputs_embeds"].shape[1]
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
if return_dict_in_generate and output_scores: if return_dict_in_generate and output_scores:
...@@ -4211,6 +4217,8 @@ class GenerationMixin: ...@@ -4211,6 +4217,8 @@ class GenerationMixin:
num_beams = constrained_beam_scorer.num_beams num_beams = constrained_beam_scorer.num_beams
batch_beam_size, cur_len = input_ids.shape batch_beam_size, cur_len = input_ids.shape
if "inputs_embeds" in model_kwargs:
cur_len = model_kwargs["inputs_embeds"].shape[1]
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
if num_beams * batch_size != batch_beam_size: if num_beams * batch_size != batch_beam_size:
......
...@@ -717,6 +717,19 @@ class GenerationTesterMixin: ...@@ -717,6 +717,19 @@ class GenerationTesterMixin:
) )
self.assertTrue(output_generate.shape[-1] == max_length) self.assertTrue(output_generate.shape[-1] == max_length)
if "inputs_embeds" in set(inspect.signature(model.prepare_inputs_for_generation).parameters):
input_embeds = model.get_input_embeddings()(input_ids)
beam_kwargs.update({"inputs_embeds": input_embeds})
output_generate2 = self._beam_sample_generate(
model=model,
input_ids=None,
attention_mask=attention_mask,
max_length=max_length,
beam_kwargs=beam_kwargs,
logits_warper_kwargs=logits_warper_kwargs,
)
torch.testing.assert_close(output_generate[:, input_embeds.shape[1] :], output_generate2)
def test_beam_sample_generate_dict_output(self): def test_beam_sample_generate_dict_output(self):
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
......
...@@ -414,6 +414,10 @@ class BioGptModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix ...@@ -414,6 +414,10 @@ class BioGptModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
@unittest.skip("The `input_embeds` when fed don't produce the same results.")
def test_beam_sample_generate(self):
pass
@require_torch @require_torch
class BioGptModelIntegrationTest(unittest.TestCase): class BioGptModelIntegrationTest(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment