Unverified Commit 3c2383b1 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: general test for decoder-only generation from `inputs_embeds` (#25687)


Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 656e17f6
...@@ -1750,6 +1750,56 @@ class GenerationTesterMixin: ...@@ -1750,6 +1750,56 @@ class GenerationTesterMixin:
past_kv[i][1].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim) past_kv[i][1].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
) )
def test_generate_from_inputs_embeds_decoder_only(self):
# When supported, tests that the decoder model can generate from `inputs_embeds` instead of `input_ids`
# if fails, you should probably update the `prepare_inputs_for_generation` function
for model_class in self.all_generative_model_classes:
config, input_ids, _, _ = self._get_input_ids_and_config()
# Ignore:
# a) eos (to always output 20 tokens) and pad (so we don't try to infer the attn mask from the input_ids,
# which would cause a mismatch),
config.pad_token_id = config.eos_token_id = -1
# b) embedding scaling, the scaling factor applied after embeding from input_ids (requires knowledge of the
# variable that holds the scaling factor, which is model-dependent)
if hasattr(config, "scale_embedding"):
config.scale_embedding = False
# This test is for decoder-only models (encoder-decoder models have native input embeddings support in the
# decoder)
if config.is_encoder_decoder:
continue
# Skip models without explicit support
model = model_class(config).to(torch_device).eval()
if "inputs_embeds" not in inspect.signature(model.prepare_inputs_for_generation).parameters.keys():
continue
# Traditional way of generating text
outputs_from_ids = model.generate(input_ids)
self.assertEqual(outputs_from_ids.shape, (2, 20))
# Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output)
inputs_embeds = model.get_input_embeddings()(input_ids)
outputs_from_embeds = model.generate(input_ids, inputs_embeds=inputs_embeds)
self.assertListEqual(outputs_from_ids.tolist(), outputs_from_embeds.tolist())
# But if we pass different inputs_embeds, we should get different outputs
torch.manual_seed(0)
random_embeds = torch.rand_like(inputs_embeds)
outputs_from_rand_embeds = model.generate(input_ids, inputs_embeds=random_embeds)
with self.assertRaises(AssertionError):
self.assertListEqual(outputs_from_rand_embeds.tolist(), outputs_from_embeds.tolist())
# input_ids is not a required input -- if we don't pass it, the newly generated tokens will be the same
outputs_from_embeds_wo_ids = model.generate(
inputs_embeds=inputs_embeds, max_new_tokens=20 - inputs_embeds.shape[1]
)
self.assertListEqual(
outputs_from_embeds[:, inputs_embeds.shape[1] :].tolist(),
outputs_from_embeds_wo_ids[:, 1:].tolist(),
)
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1): def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
batch_size, seq_length = input_ids.shape batch_size, seq_length = input_ids.shape
num_sequences_in_output = batch_size * num_return_sequences num_sequences_in_output = batch_size * num_return_sequences
...@@ -2773,42 +2823,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi ...@@ -2773,42 +2823,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs) generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
self.assertTrue(expectation == len(generated_tokens[0])) self.assertTrue(expectation == len(generated_tokens[0]))
def test_generate_from_inputs_embeds_decoder_only(self):
# PT-only test: TF doesn't have a model with support to generate from input embeds (yet ;))
# Note: the model must support generation from input embeddings
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
model.config.pad_token_id = tokenizer.eos_token_id
text = "Hello world"
tokenized_inputs = tokenizer([text, text], return_tensors="pt")
input_ids = tokenized_inputs.input_ids.to(torch_device)
# Traditional way of generating text
outputs_from_ids = model.generate(input_ids)
self.assertEqual(outputs_from_ids.shape, (2, 20))
# Same thing, but from input embeddings
inputs_embeds = model.transformer.wte(input_ids)
outputs_from_embeds = model.generate(input_ids, inputs_embeds=inputs_embeds)
self.assertListEqual(outputs_from_ids.tolist(), outputs_from_embeds.tolist())
# But if we pass different inputs_embeds, we should get different outputs
torch.manual_seed(0)
random_embeds = torch.rand_like(inputs_embeds)
outputs_from_rand_embeds = model.generate(input_ids, inputs_embeds=random_embeds)
with self.assertRaises(AssertionError):
self.assertListEqual(outputs_from_rand_embeds.tolist(), outputs_from_embeds.tolist())
# input_ids is not a required input -- if we don't pass it, the newly generated tokens will be the same
outputs_from_embeds_wo_ids = model.generate(
inputs_embeds=inputs_embeds, max_new_tokens=20 - inputs_embeds.shape[1]
)
self.assertListEqual(
outputs_from_embeds[:, inputs_embeds.shape[1] :].tolist(),
outputs_from_embeds_wo_ids[:, 1:].tolist(),
)
def test_model_kwarg_encoder_signature_filtering(self): def test_model_kwarg_encoder_signature_filtering(self):
# Has TF equivalent: ample use of framework-specific code # Has TF equivalent: ample use of framework-specific code
bart_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") bart_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment