Unverified Commit 5603fad2 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Revert "add attention_mask and position_ids in assisted model" (#27523)

* Revert "add attention_mask and position_ids in assisted model (#26892)"

This reverts commit 184f60dc.

* more debug
parent 4989e73e
...@@ -4504,6 +4504,11 @@ class GenerationMixin: ...@@ -4504,6 +4504,11 @@ class GenerationMixin:
else: else:
num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens
# check if assistant model accepts encoder_outputs
assistant_accepts_encoder_outputs = "encoder_outputs" in set(
inspect.signature(assistant_model.forward).parameters.keys()
)
# init values # init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
...@@ -4546,6 +4551,15 @@ class GenerationMixin: ...@@ -4546,6 +4551,15 @@ class GenerationMixin:
# other auxiliary variables # other auxiliary variables
max_len = stopping_criteria[0].max_length max_len = stopping_criteria[0].max_length
assistant_kv_indexing = (
1
if "bloom" in assistant_model.__class__.__name__.lower()
or (
assistant_model.config.architectures is not None
and "bloom" in assistant_model.config.architectures[0].lower()
)
else 0
)
this_peer_finished = False # used by synced_gpus only this_peer_finished = False # used by synced_gpus only
while True: while True:
...@@ -4566,28 +4580,42 @@ class GenerationMixin: ...@@ -4566,28 +4580,42 @@ class GenerationMixin:
# `.generate()` call if we decide to add `past_key_values` as a possible output of generate, as we # `.generate()` call if we decide to add `past_key_values` as a possible output of generate, as we
# need access to the assistant cache to secure strong speedups. # need access to the assistant cache to secure strong speedups.
candidate_input_ids = input_ids candidate_input_ids = input_ids
assistant_attention_mask = model_kwargs.get("attention_mask", None)
assistant_decoder_attention_mask = model_kwargs.get("decoder_attention_mask", None)
assistant_encoder_outputs = (model_kwargs.get("assistant_encoder_outputs", None),)
for _ in range(int(num_assistant_tokens)): for _ in range(int(num_assistant_tokens)):
# 1.1. use the assistant model to obtain the next candidate logits # 1.1. use the assistant model to obtain the next candidate logits
assistant_inputs = assistant_model.prepare_inputs_for_generation( if "assistant_past_key_values" in model_kwargs:
candidate_input_ids, prev_seq_len = model_kwargs["assistant_past_key_values"][0][assistant_kv_indexing].shape[-2]
attention_mask=assistant_attention_mask, # `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model)
decoder_attention_mask=assistant_decoder_attention_mask, new_token_len = candidate_input_ids.shape[1] - prev_seq_len
encoder_outputs=assistant_encoder_outputs, assist_inputs = candidate_input_ids[:, -new_token_len:]
past_key_values=model_kwargs.get("assistant_past_key_values", None), # TODO (joao): make it compatible with models that use unconventional fwd pass logic, like blip2
) if assistant_model.config.is_encoder_decoder:
if assistant_inputs.get("past_key_values", None) is not None: assistant_model_outputs = assistant_model(
decoder_input_ids=assist_inputs,
past_key_values=model_kwargs["assistant_past_key_values"],
encoder_outputs=model_kwargs["assistant_encoder_outputs"],
)
else:
encoder_kwargs = {}
if assistant_accepts_encoder_outputs and "assistant_encoder_outputs" in model_kwargs:
encoder_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"]
assistant_model_outputs = assistant_model(
assist_inputs, past_key_values=model_kwargs["assistant_past_key_values"], **encoder_kwargs
)
else:
if assistant_model.config.is_encoder_decoder: if assistant_model.config.is_encoder_decoder:
input_ids_len = assistant_inputs["decoder_input_ids"].shape[-1] assistant_model_outputs = assistant_model(
decoder_input_ids=candidate_input_ids,
encoder_outputs=model_kwargs["assistant_encoder_outputs"],
)
else: else:
input_ids_len = assistant_inputs["input_ids"].shape[-1] encoder_kwargs = {}
if input_ids_len not in (1, 2): if assistant_accepts_encoder_outputs and "assistant_encoder_outputs" in model_kwargs:
raise ValueError("The length of the input ids in assistant inputs should be 1 or 2") encoder_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"]
assistant_model_outputs = assistant_model(**assistant_inputs) assistant_model_outputs = assistant_model(candidate_input_ids, **encoder_kwargs)
# 1.2. greedily select the next candidate token # 1.2. greedily select the next candidate token
model_kwargs["assistant_past_key_values"] = assistant_model_outputs.past_key_values model_kwargs["assistant_past_key_values"] = assistant_model_outputs.past_key_values
...@@ -4595,31 +4623,8 @@ class GenerationMixin: ...@@ -4595,31 +4623,8 @@ class GenerationMixin:
assistant_model_outputs.logits[:, -1, :] = logits_processor( assistant_model_outputs.logits[:, -1, :] = logits_processor(
candidate_input_ids, assistant_model_outputs.logits[:, -1, :] candidate_input_ids, assistant_model_outputs.logits[:, -1, :]
) )
new_token = assistant_model_outputs.logits[:, -1, :].argmax(dim=-1) new_token = assistant_model_outputs.logits[:, -1, :].argmax(dim=-1)
candidate_input_ids = torch.cat((candidate_input_ids, new_token[:, None]), dim=-1) candidate_input_ids = torch.cat((candidate_input_ids, new_token[:, None]), dim=-1)
if assistant_model.config.is_encoder_decoder and assistant_decoder_attention_mask is not None:
assistant_decoder_attention_mask = torch.cat(
(
assistant_decoder_attention_mask,
torch.ones(
[1, 1],
dtype=assistant_decoder_attention_mask.dtype,
device=assistant_decoder_attention_mask.device,
),
),
dim=-1,
)
elif not assistant_model.config.is_encoder_decoder and assistant_attention_mask is not None:
assistant_attention_mask = torch.cat(
(
assistant_attention_mask,
torch.ones(
[1, 1], dtype=assistant_attention_mask.dtype, device=assistant_attention_mask.device
),
),
dim=-1,
)
# 1.3. stop assistant generation on EOS # 1.3. stop assistant generation on EOS
if eos_token_id_tensor is not None: if eos_token_id_tensor is not None:
...@@ -4755,13 +4760,6 @@ class GenerationMixin: ...@@ -4755,13 +4760,6 @@ class GenerationMixin:
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
) )
# Update attention_mask for the assistant's next round of generations
if n_matches > 0 and model_kwargs.get("attention_mask", None) is not None:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], n_matches))], dim=-1
)
# if eos_token was found in one sentence, set sentence to finished # if eos_token was found in one sentence, set sentence to finished
if eos_token_id_tensor is not None: if eos_token_id_tensor is not None:
unfinished_sequences = unfinished_sequences.mul( unfinished_sequences = unfinished_sequences.mul(
......
...@@ -18,6 +18,7 @@ import copy ...@@ -18,6 +18,7 @@ import copy
import inspect import inspect
import os import os
import tempfile import tempfile
import time
import unittest import unittest
import numpy as np import numpy as np
...@@ -1736,6 +1737,102 @@ class WhisperModelIntegrationTests(unittest.TestCase): ...@@ -1736,6 +1737,102 @@ class WhisperModelIntegrationTests(unittest.TestCase):
self.assertTrue(prompt in text) self.assertTrue(prompt in text)
@slow
@require_torch_gpu
def test_speculative_decoding_distil(self):
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "openai/whisper-large-v2"
model = WhisperForConditionalGeneration.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(torch_device)
processor = WhisperProcessor.from_pretrained(model_id)
assistant_model_id = "distil-whisper/distil-large-v2"
assistant_model = WhisperForCausalLM.from_pretrained(
assistant_model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
assistant_model.to(torch_device)
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]
input_features = processor(sample["array"], return_tensors="pt").input_features.to("cuda").to(torch.float16)
# warm up assisted decoding
_ = model.generate(input_features, assistant_model=assistant_model)
# warm up non-assisted decoding
_ = model.generate(input_features)
# assisted decoding
start_time = time.time()
tokens = model.generate(input_features, assistant_model=assistant_model)
total_time_assist = time.time() - start_time
transcription_ass = processor.batch_decode(tokens, skip_special_tokens=True)
# non-assisted decoding
start_time = time.time()
tokens = model.generate(input_features)
total_time_non_assist = time.time() - start_time
transcription_non_ass = processor.batch_decode(tokens, skip_special_tokens=True)
assert transcription_ass == transcription_non_ass
assert transcription_ass == [
" Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel."
]
assert total_time_non_assist > total_time_assist, "Make sure that assistant decoding is faster"
@slow
@require_torch_gpu
def test_speculative_decoding_non_distil(self):
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "openai/whisper-large-v2"
model = WhisperForConditionalGeneration.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(torch_device)
processor = WhisperProcessor.from_pretrained(model_id)
assistant_model_id = "openai/whisper-tiny"
assistant_model = WhisperForConditionalGeneration.from_pretrained(
assistant_model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
assistant_model.to(torch_device)
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]
input_features = processor(sample["array"], return_tensors="pt").input_features.to("cuda").to(torch.float16)
# warm up assisted decoding
_ = model.generate(input_features, assistant_model=assistant_model)
# warm up non-assisted decoding
_ = model.generate(input_features)
# assisted decoding
start_time = time.time()
tokens = model.generate(input_features, assistant_model=assistant_model)
total_time_assist = time.time() - start_time
transcription_ass = processor.batch_decode(tokens, skip_special_tokens=True)
# non-assisted decoding
start_time = time.time()
tokens = model.generate(input_features)
total_time_non_assist = time.time() - start_time
transcription_non_ass = processor.batch_decode(tokens, skip_special_tokens=True)
assert transcription_ass == transcription_non_ass
assert transcription_ass == [
" Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel."
]
assert total_time_non_assist > total_time_assist, "Make sure that assistant decoding is faster"
def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None): def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None):
if head_mask is None: if head_mask is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment