"awq/git@developer.sourcefind.cn:OpenDAS/autoawq.git" did not exist on "48be2ee2531e5cdd8d9b2fe5c825cdaf95601908"
Unverified Commit f25a9332 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Generation] Allow `inputs_embeds` as an input (#14443)

* up

* finalize

* finalize

* finish

* Update src/transformers/generation_utils.py

* apply feedback
parent 0490b988
......@@ -392,15 +392,26 @@ class GenerationMixin:
return torch.ones((1, 1), dtype=torch.long, device=self.device) * bos_token_id
def _prepare_attention_mask_for_generation(
self, input_ids: torch.Tensor, pad_token_id: int, eos_token_id: int
self,
input_ids: torch.Tensor,
pad_token_id: int,
eos_token_id: int,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.LongTensor:
# First if `inputs_embeds` are given, but no `attention_mask` assume that full attention_mask is used
if inputs_embeds is not None:
return torch.ones((inputs_embeds.shape[0], inputs_embeds.shape[1]), dtype=torch.long)
# Otherwise, use `input_ids`
is_pad_token_in_inputs_ids = (pad_token_id is not None) and (pad_token_id in input_ids)
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (
(eos_token_id is not None) and (pad_token_id != eos_token_id)
)
if is_pad_token_in_inputs_ids and is_pad_token_not_equal_to_eos_token_id:
return input_ids.ne(pad_token_id).long()
return input_ids.new_ones(input_ids.shape, dtype=torch.long)
else:
return input_ids.new_ones(input_ids.shape, dtype=torch.long)
def _prepare_encoder_decoder_kwargs_for_generation(
self, input_ids: torch.LongTensor, model_kwargs
......@@ -417,12 +428,11 @@ class GenerationMixin:
return model_kwargs
def _prepare_decoder_input_ids_for_generation(
self, input_ids: torch.LongTensor, decoder_start_token_id: int = None, bos_token_id: int = None
self, batch_size: int, decoder_start_token_id: int = None, bos_token_id: int = None
) -> torch.LongTensor:
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
decoder_input_ids = (
torch.ones((input_ids.shape[0], 1), dtype=torch.long, device=input_ids.device) * decoder_start_token_id
)
decoder_input_ids = torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * decoder_start_token_id
return decoder_input_ids
def _get_pad_token_id(self, pad_token_id: int = None, eos_token_id: int = None) -> int:
......@@ -890,8 +900,9 @@ class GenerationMixin:
if model_kwargs.get("attention_mask", None) is None:
# init `attention_mask` depending on `pad_token_id`
inputs_embeds = model_kwargs.get("inputs_embeds", None)
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
input_ids, pad_token_id, eos_token_id
input_ids, pad_token_id, eos_token_id, inputs_embeds
)
# special case if pad_token_id is not defined
......@@ -910,12 +921,17 @@ class GenerationMixin:
if "decoder_input_ids" in model_kwargs:
input_ids = model_kwargs.pop("decoder_input_ids")
else:
# if word embeddings are provided directly, infere the batch size from it
batch_size = input_ids.shape[0] if input_ids is not None else model_kwargs["inputs_embeds"].shape[0]
input_ids = self._prepare_decoder_input_ids_for_generation(
input_ids, decoder_start_token_id=decoder_start_token_id, bos_token_id=bos_token_id
batch_size, decoder_start_token_id=decoder_start_token_id, bos_token_id=bos_token_id
)
if "encoder_outputs" not in model_kwargs or not isinstance(model_kwargs["encoder_outputs"], ModelOutput):
raise ValueError("Make sure that `model_kwargs` include `encoder_outputs` of type `ModelOutput`.")
else:
if "inputs_embeds" in model_kwargs and input_ids is None:
raise ValueError("For decoder-only generation, one must pass `input_ids`.")
# if `max_new_tokens` is passed, but not `max_length` -> set `max_length = max_new_tokens`
if max_length is None and max_new_tokens is not None:
......
......@@ -1388,15 +1388,17 @@ class GenerationIntegrationTests(unittest.TestCase):
def test_max_length_backward_compat_greedy(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
torch_device
)
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
max_length = 20
input_ids = input_ids.expand(2, -1)
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
input_ids,
input_ids.shape[0],
decoder_start_token_id=bart_model.config.decoder_start_token_id,
bos_token_id=bart_model.config.bos_token_id,
)
......@@ -1412,15 +1414,17 @@ class GenerationIntegrationTests(unittest.TestCase):
def test_max_length_backward_compat_sample(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
torch_device
)
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
max_length = 20
input_ids = input_ids.expand(2, -1)
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
input_ids,
input_ids.shape[0],
decoder_start_token_id=bart_model.config.decoder_start_token_id,
bos_token_id=bart_model.config.bos_token_id,
)
......@@ -1436,8 +1440,10 @@ class GenerationIntegrationTests(unittest.TestCase):
def test_max_length_backward_compat_beam_search(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
torch_device
)
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
batch_size = 1
......@@ -1447,7 +1453,7 @@ class GenerationIntegrationTests(unittest.TestCase):
input_ids = input_ids.expand(2, -1)
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
input_ids,
input_ids.shape[0],
decoder_start_token_id=bart_model.config.decoder_start_token_id,
bos_token_id=bart_model.config.bos_token_id,
)
......@@ -1464,8 +1470,10 @@ class GenerationIntegrationTests(unittest.TestCase):
def test_max_length_backward_compat_group_beam_search(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
torch_device
)
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
batch_size = 1
......@@ -1477,7 +1485,7 @@ class GenerationIntegrationTests(unittest.TestCase):
input_ids = input_ids.expand(6, -1)
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
input_ids,
input_ids.shape[0],
decoder_start_token_id=bart_model.config.decoder_start_token_id,
bos_token_id=bart_model.config.bos_token_id,
)
......@@ -1496,8 +1504,10 @@ class GenerationIntegrationTests(unittest.TestCase):
def test_max_length_warning_if_different(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
torch_device
)
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
batch_size = 1
......@@ -1513,7 +1523,7 @@ class GenerationIntegrationTests(unittest.TestCase):
input_ids = input_ids.expand(6, -1)
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
input_ids,
input_ids.shape[0],
decoder_start_token_id=bart_model.config.decoder_start_token_id,
bos_token_id=bart_model.config.bos_token_id,
)
......@@ -1577,8 +1587,10 @@ class GenerationIntegrationTests(unittest.TestCase):
def test_beam_search_warning_if_max_length_is_passed(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
torch_device
)
batch_size = 1
num_beams = 3
......@@ -1587,6 +1599,9 @@ class GenerationIntegrationTests(unittest.TestCase):
input_ids = input_ids.expand(num_beams, -1)
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
# pretend decoder_input_ids correspond to first encoder input id
decoder_input_ids = input_ids[:, :1]
stopping_criteria_max_length = 18
stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=stopping_criteria_max_length)])
......@@ -1599,7 +1614,7 @@ class GenerationIntegrationTests(unittest.TestCase):
)
generated_ids = bart_model.beam_search(
input_ids,
decoder_input_ids,
num_beams=num_beams,
stopping_criteria=stopping_criteria,
beam_scorer=beam_scorer,
......@@ -1613,7 +1628,7 @@ class GenerationIntegrationTests(unittest.TestCase):
)
generated_ids_no_max_len = bart_model.beam_search(
input_ids,
decoder_input_ids,
num_beams=num_beams,
stopping_criteria=stopping_criteria,
beam_scorer=beam_scorer_no_max_len,
......@@ -1625,14 +1640,17 @@ class GenerationIntegrationTests(unittest.TestCase):
def test_max_new_tokens_encoder_decoder(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
torch_device
)
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
self.assertEqual(list(input_ids.shape), [1, 15])
self.assertEqual(list(input_ids.shape), [1, 29])
max_new_tokens = 3
bart_model.config.max_length = 20
bart_model.config.eos_token_id = None
# Encoder decoder call
outputs = bart_model.generate(input_ids, max_new_tokens=max_new_tokens)
......@@ -1641,8 +1659,8 @@ class GenerationIntegrationTests(unittest.TestCase):
# Decoder only call
outputs = bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=max_new_tokens)
# 15 + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 18])
# 29 + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 32])
# Encoder decoder call > 20
outputs = bart_model.generate(max_new_tokens=max_new_tokens + 20)
......@@ -1680,3 +1698,29 @@ class GenerationIntegrationTests(unittest.TestCase):
# max_new_tokens and max_length serve the same purpose and should not be used together.
with self.assertWarns(UserWarning):
gpt2_model.generate(decoder_input_ids=input_ids, max_new_tokens=10, max_length=20)
def test_encoder_decoder_generate_with_inputs_embeds(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart", max_length=5).to(
torch_device
)
model.config.eos_token_id = None
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
inputs_embeds = model.get_input_embeddings()(input_ids)
output_sequences = model.generate(inputs_embeds=inputs_embeds)
# make sure model generated correctly until `max_length`
self.assertEqual(output_sequences.shape, (1, 5))
def test_decoder_generate_with_inputs_embeds(self):
article = """I need input_ids to generate"""
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2", max_length=5).to(torch_device)
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
inputs_embeds = model.get_input_embeddings()(input_ids)
# cannot generate from `inputs_embeds` for decoder only
with self.assertRaises(ValueError):
model.generate(inputs_embeds=inputs_embeds)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment