"tests/mobilebert/__init__.py" did not exist on "31c23bd5ee26425a67f92fc170789656379252a6"
Unverified Commit 9e5c28c5 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: assisted decoding now uses `generate` for the assistant (#28030)

generate refactor
parent dde6c427
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import copy import copy
import warnings import warnings
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import torch import torch
...@@ -28,7 +28,7 @@ if TYPE_CHECKING: ...@@ -28,7 +28,7 @@ if TYPE_CHECKING:
class CandidateGenerator: class CandidateGenerator:
"""Abstract base class for all candidate generators that can be applied during assisted generation.""" """Abstract base class for all candidate generators that can be applied during assisted generation."""
def get_candidates(self, input_ids: torch.LongTensor) -> torch.LongTensor: def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
""" """
Fetches the candidates to be tried for the current input. Fetches the candidates to be tried for the current input.
...@@ -37,8 +37,9 @@ class CandidateGenerator: ...@@ -37,8 +37,9 @@ class CandidateGenerator:
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
Return: Return:
`torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be assessed by `torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be
the model. assessed by the model and, optionally, a `torch.FloatTensor` of shape `(batch_size, candidate_length,
vocabulary_size)` containing the logits associated to each candidate.
""" """
raise NotImplementedError( raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can call `get_candidates`." f"{self.__class__} is an abstract class. Only classes inheriting this class can call `get_candidates`."
...@@ -152,7 +153,7 @@ class AssistedCandidateGenerator(CandidateGenerator): ...@@ -152,7 +153,7 @@ class AssistedCandidateGenerator(CandidateGenerator):
) )
self.logits_processor = logits_processor self.logits_processor = logits_processor
def get_candidates(self, input_ids: torch.LongTensor) -> torch.LongTensor: def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
""" """
Fetches the candidates to be tried for the current input. Fetches the candidates to be tried for the current input.
...@@ -161,7 +162,9 @@ class AssistedCandidateGenerator(CandidateGenerator): ...@@ -161,7 +162,9 @@ class AssistedCandidateGenerator(CandidateGenerator):
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
Return: Return:
`torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be tried. `torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be
assessed by the model and a `torch.FloatTensor` of shape `(batch_size, candidate_length,
vocabulary_size)` containing the logits associated to each candidate.
""" """
# 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length # 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length
# (which implicitly contains the number of accepted candidates from the previous round) # (which implicitly contains the number of accepted candidates from the previous round)
...@@ -179,51 +182,24 @@ class AssistedCandidateGenerator(CandidateGenerator): ...@@ -179,51 +182,24 @@ class AssistedCandidateGenerator(CandidateGenerator):
) )
self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, new_cur_len) self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, new_cur_len)
# 2. Forecast next N tokens using the assistant model. This `for` block can be replaced with a `.generate()` # 2. Forecast next N tokens using the assistant model.
# call if we decide to add `past_key_values` as a possible output of generate, as we need access to the assistant_generation_kwargs = {
# assistant cache to secure strong speedups. self.input_ids_key: input_ids,
candidate_input_ids = input_ids "do_sample": False,
for _ in range(int(self.num_assistant_tokens)): "num_beams": 1,
# 2.1 prepare assistant model inputs "max_new_tokens": int(self.num_assistant_tokens),
assistant_inputs = self.assistant_model.prepare_inputs_for_generation( "return_dict_in_generate": True,
candidate_input_ids, "output_scores": True,
**self.assistant_kwargs, }
) assistant_output = self.assistant_model.generate(**assistant_generation_kwargs, **self.assistant_kwargs)
# 2.2. check if the input ids length is correct # 3. Update variables for the next round of candidate generation
has_past_key_values = assistant_inputs.get("past_key_values", None) is not None self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values
if has_past_key_values and assistant_inputs[self.input_ids_key].shape[-1] not in (1, 2):
raise ValueError("The length of the input ids in assistant inputs should be 1 or 2") # 4. Prepare variables for output
candidate_logits = torch.stack(assistant_output.scores, dim=1)
# 2.3. use the assistant model to obtain the next candidate logits candidate_ids = assistant_output.sequences
assistant_model_outputs = self.assistant_model(**assistant_inputs) return candidate_ids, candidate_logits
# 2.4. greedily select the next candidate token
if len(self.logits_processor) > 0:
assistant_model_outputs.logits[:, -1, :] = self.logits_processor(
candidate_input_ids, assistant_model_outputs.logits[:, -1, :]
)
new_token = assistant_model_outputs.logits[:, -1, :].argmax(dim=-1)
candidate_input_ids = torch.cat((candidate_input_ids, new_token[:, None]), dim=-1)
# 2.5. update assistant model inputs
if self.assistant_kwargs.get(self.attention_key, None) is not None:
mask = self.assistant_kwargs[self.attention_key]
self.assistant_kwargs[self.attention_key] = torch.cat(
[mask, mask.new_ones((mask.shape[0], 1))], dim=-1
)
self.assistant_kwargs["past_key_values"] = assistant_model_outputs.past_key_values
# 2.6. stop assistant generation on EOS
if self.eos_token_id_tensor is not None:
last_assistant_token_is_eos = new_token.tile(self.eos_token_id_tensor.shape[0], 1)
last_assistant_token_is_eos = (
~last_assistant_token_is_eos.ne(self.eos_token_id_tensor.unsqueeze(1)).prod(dim=0).bool()
)
if last_assistant_token_is_eos:
break
return candidate_input_ids
def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int): def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int):
""" """
......
...@@ -4585,7 +4585,7 @@ class GenerationMixin: ...@@ -4585,7 +4585,7 @@ class GenerationMixin:
cur_len = input_ids.shape[-1] cur_len = input_ids.shape[-1]
# 1. Fetch candidate sequences from a `CandidateGenerator` # 1. Fetch candidate sequences from a `CandidateGenerator`
candidate_input_ids = candidate_generator.get_candidates(input_ids) candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids)
candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1] candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]
last_assistant_token_is_eos = ( last_assistant_token_is_eos = (
~candidate_input_ids[:, -1] ~candidate_input_ids[:, -1]
......
...@@ -3128,21 +3128,26 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi ...@@ -3128,21 +3128,26 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self.assertListEqual(outputs_assisted.tolist(), outputs_tti.tolist()) self.assertListEqual(outputs_assisted.tolist(), outputs_tti.tolist())
def test_model_kwarg_assisted_decoding_encoder_decoder(self): def test_model_kwarg_assisted_decoding_encoder_decoder(self):
"""
Tests that the following scenario is compatible with assisted generation:
1. encoder-decoder main model
2. encoder-decoder assistant model
3. both have a custom input
(e.g. Whisper)
"""
# PT-only test: TF doesn't support assisted decoding yet. # PT-only test: TF doesn't support assisted decoding yet.
# Bart subclass with a kwarg that distorts the output # Bart subclass with a kwarg that distorts the output
class FakeBart(BartForConditionalGeneration): class FakeBart(BartForConditionalGeneration):
def forward(self, input_ids, foo=False, **kwargs): def forward(self, input_ids, past_key_values, foo=False, **kwargs):
outs = super().forward(input_ids, **kwargs) outs = super().forward(input_ids, past_key_values=past_key_values, **kwargs)
if foo: if foo:
outs["logits"][:, :, :] = 0.0 outs["logits"][:, :, :] = 0.0
return outs return outs
def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs): def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs):
kwargs["encoder_outputs"] = encoder_outputs kwargs["encoder_outputs"] = encoder_outputs
inputs = super().prepare_inputs_for_generation(*args, **kwargs) inputs = super().prepare_inputs_for_generation(*args, **kwargs)
inputs["foo"] = foo inputs["foo"] = foo
return inputs return inputs
...@@ -3160,17 +3165,14 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi ...@@ -3160,17 +3165,14 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self.assertEqual(outputs_normal.shape, (1, 20)) self.assertEqual(outputs_normal.shape, (1, 20))
# Should be different with foo # Should be different with foo
outputs_foo = model.generate( outputs_foo = model.generate(input_ids, foo=True)
input_ids,
foo=True,
)
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
self.assertListEqual(outputs_foo.tolist(), outputs_normal.tolist()) self.assertListEqual(outputs_foo.tolist(), outputs_normal.tolist())
# Assistant model # Assistant model
assistant = AutoModelForSeq2SeqLM.from_pretrained( assistant = FakeBart.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to(
"hf-internal-testing/tiny-random-BartForConditionalGeneration" torch_device
).to(torch_device) )
# If assisted generation passes model_kwargs correctly, should be same as previous # If assisted generation passes model_kwargs correctly, should be same as previous
outputs_assisted = model.generate( outputs_assisted = model.generate(
...@@ -3192,25 +3194,43 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi ...@@ -3192,25 +3194,43 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist()) self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
def test_assisted_decoding_encoder_decoder_shared_encoder(self): def test_assisted_decoding_encoder_decoder_shared_encoder(self):
"""
Tests that the following scenario is compatible with assisted generation:
1. encoder-decoder main model
2. decoder-only assistant model
3. both have a custom input
(e.g. DistilWhisper)
"""
# PT-only test: TF doesn't support assisted decoding yet. # PT-only test: TF doesn't support assisted decoding yet.
# Bart subclass with a kwarg called foo that distorts the output # Bart subclass with a kwarg called foo that distorts the output
class FakeBart(BartForConditionalGeneration): class FakeBartSeq2Seq(BartForConditionalGeneration):
def forward(self, input_ids, foo=False, **kwargs): def forward(self, input_ids, foo=False, **kwargs):
outs = super().forward(input_ids, **kwargs) outs = super().forward(input_ids, **kwargs)
if foo: if foo:
outs["logits"][:, :, :] = 0.0 outs["logits"][:, :, :] = 0.0
return outs return outs
def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs): def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs):
kwargs["encoder_outputs"] = encoder_outputs kwargs["encoder_outputs"] = encoder_outputs
inputs = super().prepare_inputs_for_generation(*args, **kwargs) inputs = super().prepare_inputs_for_generation(*args, **kwargs)
inputs["foo"] = foo
return inputs
class FakeBartCausalLM(BartForCausalLM):
def forward(self, input_ids, attention_mask, past_key_values, foo=False, **kwargs):
outs = super().forward(input_ids, attention_mask, past_key_values=past_key_values, **kwargs)
if foo:
outs["logits"][:, :, :] = 0.0
return outs
def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs):
kwargs["encoder_outputs"] = encoder_outputs
inputs = super().prepare_inputs_for_generation(*args, **kwargs)
inputs["foo"] = foo inputs["foo"] = foo
return inputs return inputs
model = FakeBart.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to( model = FakeBartSeq2Seq.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to(
torch_device torch_device
) )
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration") tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration")
...@@ -3229,9 +3249,9 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi ...@@ -3229,9 +3249,9 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self.assertListEqual(outputs_foo.tolist(), outputs_normal.tolist()) self.assertListEqual(outputs_foo.tolist(), outputs_normal.tolist())
# Assistant model # Assistant model
assistant = BartForCausalLM.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to( assistant = FakeBartCausalLM.from_pretrained(
torch_device "hf-internal-testing/tiny-random-BartForConditionalGeneration"
) ).to(torch_device)
# If assisted generation passes model_kwargs correctly, should be same as previous # If assisted generation passes model_kwargs correctly, should be same as previous
outputs_assisted = model.generate( outputs_assisted = model.generate(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment