Unverified Commit 9e5c28c5 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: assisted decoding now uses `generate` for the assistant (#28030)

generate refactor
parent dde6c427
......@@ -15,7 +15,7 @@
import copy
import warnings
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import torch
......@@ -28,7 +28,7 @@ if TYPE_CHECKING:
class CandidateGenerator:
"""Abstract base class for all candidate generators that can be applied during assisted generation."""
def get_candidates(self, input_ids: torch.LongTensor) -> torch.LongTensor:
def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
"""
Fetches the candidates to be tried for the current input.
......@@ -37,8 +37,9 @@ class CandidateGenerator:
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
Return:
`torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be assessed by
the model.
`torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be
assessed by the model and, optionally, a `torch.FloatTensor` of shape `(batch_size, candidate_length,
vocabulary_size)` containing the logits associated to each candidate.
"""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can call `get_candidates`."
......@@ -152,7 +153,7 @@ class AssistedCandidateGenerator(CandidateGenerator):
)
self.logits_processor = logits_processor
def get_candidates(self, input_ids: torch.LongTensor) -> torch.LongTensor:
def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
"""
Fetches the candidates to be tried for the current input.
......@@ -161,7 +162,9 @@ class AssistedCandidateGenerator(CandidateGenerator):
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
Return:
`torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be tried.
`torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be
assessed by the model and a `torch.FloatTensor` of shape `(batch_size, candidate_length,
vocabulary_size)` containing the logits associated to each candidate.
"""
# 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length
# (which implicitly contains the number of accepted candidates from the previous round)
......@@ -179,51 +182,24 @@ class AssistedCandidateGenerator(CandidateGenerator):
)
self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, new_cur_len)
# 2. Forecast next N tokens using the assistant model. This `for` block can be replaced with a `.generate()`
# call if we decide to add `past_key_values` as a possible output of generate, as we need access to the
# assistant cache to secure strong speedups.
candidate_input_ids = input_ids
for _ in range(int(self.num_assistant_tokens)):
# 2.1 prepare assistant model inputs
assistant_inputs = self.assistant_model.prepare_inputs_for_generation(
candidate_input_ids,
**self.assistant_kwargs,
)
# 2.2. check if the input ids length is correct
has_past_key_values = assistant_inputs.get("past_key_values", None) is not None
if has_past_key_values and assistant_inputs[self.input_ids_key].shape[-1] not in (1, 2):
raise ValueError("The length of the input ids in assistant inputs should be 1 or 2")
# 2.3. use the assistant model to obtain the next candidate logits
assistant_model_outputs = self.assistant_model(**assistant_inputs)
# 2.4. greedily select the next candidate token
if len(self.logits_processor) > 0:
assistant_model_outputs.logits[:, -1, :] = self.logits_processor(
candidate_input_ids, assistant_model_outputs.logits[:, -1, :]
)
new_token = assistant_model_outputs.logits[:, -1, :].argmax(dim=-1)
candidate_input_ids = torch.cat((candidate_input_ids, new_token[:, None]), dim=-1)
# 2.5. update assistant model inputs
if self.assistant_kwargs.get(self.attention_key, None) is not None:
mask = self.assistant_kwargs[self.attention_key]
self.assistant_kwargs[self.attention_key] = torch.cat(
[mask, mask.new_ones((mask.shape[0], 1))], dim=-1
)
self.assistant_kwargs["past_key_values"] = assistant_model_outputs.past_key_values
# 2.6. stop assistant generation on EOS
if self.eos_token_id_tensor is not None:
last_assistant_token_is_eos = new_token.tile(self.eos_token_id_tensor.shape[0], 1)
last_assistant_token_is_eos = (
~last_assistant_token_is_eos.ne(self.eos_token_id_tensor.unsqueeze(1)).prod(dim=0).bool()
)
if last_assistant_token_is_eos:
break
return candidate_input_ids
# 2. Forecast next N tokens using the assistant model.
assistant_generation_kwargs = {
self.input_ids_key: input_ids,
"do_sample": False,
"num_beams": 1,
"max_new_tokens": int(self.num_assistant_tokens),
"return_dict_in_generate": True,
"output_scores": True,
}
assistant_output = self.assistant_model.generate(**assistant_generation_kwargs, **self.assistant_kwargs)
# 3. Update variables for the next round of candidate generation
self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values
# 4. Prepare variables for output
candidate_logits = torch.stack(assistant_output.scores, dim=1)
candidate_ids = assistant_output.sequences
return candidate_ids, candidate_logits
def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int):
"""
......
......@@ -4585,7 +4585,7 @@ class GenerationMixin:
cur_len = input_ids.shape[-1]
# 1. Fetch candidate sequences from a `CandidateGenerator`
candidate_input_ids = candidate_generator.get_candidates(input_ids)
candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids)
candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]
last_assistant_token_is_eos = (
~candidate_input_ids[:, -1]
......
......@@ -3128,21 +3128,26 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self.assertListEqual(outputs_assisted.tolist(), outputs_tti.tolist())
def test_model_kwarg_assisted_decoding_encoder_decoder(self):
"""
Tests that the following scenario is compatible with assisted generation:
1. encoder-decoder main model
2. encoder-decoder assistant model
3. both have a custom input
(e.g. Whisper)
"""
# PT-only test: TF doesn't support assisted decoding yet.
# Bart subclass with a kwarg that distorts the output
class FakeBart(BartForConditionalGeneration):
def forward(self, input_ids, foo=False, **kwargs):
outs = super().forward(input_ids, **kwargs)
def forward(self, input_ids, past_key_values, foo=False, **kwargs):
outs = super().forward(input_ids, past_key_values=past_key_values, **kwargs)
if foo:
outs["logits"][:, :, :] = 0.0
return outs
def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs):
kwargs["encoder_outputs"] = encoder_outputs
inputs = super().prepare_inputs_for_generation(*args, **kwargs)
inputs["foo"] = foo
return inputs
......@@ -3160,17 +3165,14 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self.assertEqual(outputs_normal.shape, (1, 20))
# Should be different with foo
outputs_foo = model.generate(
input_ids,
foo=True,
)
outputs_foo = model.generate(input_ids, foo=True)
with self.assertRaises(AssertionError):
self.assertListEqual(outputs_foo.tolist(), outputs_normal.tolist())
# Assistant model
assistant = AutoModelForSeq2SeqLM.from_pretrained(
"hf-internal-testing/tiny-random-BartForConditionalGeneration"
).to(torch_device)
assistant = FakeBart.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to(
torch_device
)
# If assisted generation passes model_kwargs correctly, should be same as previous
outputs_assisted = model.generate(
......@@ -3192,25 +3194,43 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
def test_assisted_decoding_encoder_decoder_shared_encoder(self):
"""
Tests that the following scenario is compatible with assisted generation:
1. encoder-decoder main model
2. decoder-only assistant model
3. both have a custom input
(e.g. DistilWhisper)
"""
# PT-only test: TF doesn't support assisted decoding yet.
# Bart subclass with a kwarg called foo that distorts the output
class FakeBart(BartForConditionalGeneration):
class FakeBartSeq2Seq(BartForConditionalGeneration):
def forward(self, input_ids, foo=False, **kwargs):
outs = super().forward(input_ids, **kwargs)
if foo:
outs["logits"][:, :, :] = 0.0
return outs
def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs):
kwargs["encoder_outputs"] = encoder_outputs
inputs = super().prepare_inputs_for_generation(*args, **kwargs)
inputs["foo"] = foo
return inputs
class FakeBartCausalLM(BartForCausalLM):
def forward(self, input_ids, attention_mask, past_key_values, foo=False, **kwargs):
outs = super().forward(input_ids, attention_mask, past_key_values=past_key_values, **kwargs)
if foo:
outs["logits"][:, :, :] = 0.0
return outs
def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs):
kwargs["encoder_outputs"] = encoder_outputs
inputs = super().prepare_inputs_for_generation(*args, **kwargs)
inputs["foo"] = foo
return inputs
model = FakeBart.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to(
model = FakeBartSeq2Seq.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to(
torch_device
)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration")
......@@ -3229,9 +3249,9 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self.assertListEqual(outputs_foo.tolist(), outputs_normal.tolist())
# Assistant model
assistant = BartForCausalLM.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to(
torch_device
)
assistant = FakeBartCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-BartForConditionalGeneration"
).to(torch_device)
# If assisted generation passes model_kwargs correctly, should be same as previous
outputs_assisted = model.generate(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment