Unverified Commit 255257f3 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[Whisper] Refactor whisper (#21252)

* update whisper logit processor

* add generate for whisper

* remove part of the whisper specific code from pipeline

* update logit processes

* major update

* enforce first timestamp

* update generate

* add more tests

* update new decoding strategy

* Apply suggestions from code review

* update docstring

* fixup

* default config will not have multilingual ar

* update expected tokenizer size, see pull on the hub for whisper-tiny
parent f83135eb
...@@ -917,35 +917,34 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor): ...@@ -917,35 +917,34 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
probs to `inf` so that they are sampled at their corresponding index. probs to `inf` so that they are sampled at their corresponding index.
Args: Args:
begin_index (`int`, *optional*, defaults to 5 ): generate_config (`GenerateConfig`):
This indicates to the processor where the first tokens are generated. This is used to differentiate between The generate config used to generate the output. The following parameters are required:
the `prompt` tokens and the `generated` tokens. When generating with `WhisperForConditionalGeneration` the eos_token_id (`int`, *optional*, defaults to 50257):
`prompt` tokens are the first 4 tokens. The id of the *end-of-sequence* token.
eos_token_id (`int`, *optional*, defaults to 50257): no_timestamps_token_id (`int`, *optional*, defaults to 50363):
The id of the *end-of-sequence* token. The id of the `"<|notimestamps|>"` token.
no_timestamps_token_id (`int`, *optional*, defaults to 50363): max_initial_timestamp_index (`int`, *optional*, defaults to 1):
The id of the `"<|notimestamps|>"` token. Used to set the maximum value of the initial timestamp. This is used to prevent the model from
max_initial_timestamp (`int`, *optional*, defaults to 1): predicting timestamps that are too far in the future.
Used to set the maximum value of the initial timestamp. This is used to prevent the model from predicting
timestamps that are too far in the future.
""" """
def __init__( def __init__(self, generate_config): # support for the kwargs
self, self.eos_token_id = generate_config.eos_token_id
begin_index=5, self.no_timestamps_token_id = generate_config.no_timestamps_token_id
eos_token_id=50257, self.timestamp_begin = generate_config.no_timestamps_token_id + 1
no_timestamps_token_id=50363,
max_initial_timestamp=1, self.begin_index = len(generate_config.forced_decoder_ids) + 1
): if generate_config.forced_decoder_ids[-1][1] == self.no_timestamps_token_id:
self.eos_token_id = eos_token_id self.begin_index -= 1
self.no_timestamps_token_id = no_timestamps_token_id if generate_config.is_multilingual:
self.timestamp_begin = no_timestamps_token_id + 1 self.begin_index += 1
self.begin_index = begin_index self.max_initial_timestamp_index = generate_config.max_initial_timestamp_index
self.max_initial_timestamp_index = max_initial_timestamp
def __call__(self, input_ids, scores): def __call__(self, input_ids, scores):
# suppress <|notimestamps|> which is handled by without_timestamps # suppress <|notimestamps|> which is handled by without_timestamps
scores[:, self.no_timestamps_token_id] = -float("inf") scores[:, self.no_timestamps_token_id] = -float("inf")
if input_ids.shape[1] == self.begin_index:
scores[:, self.timestamp_begin] = 0
# timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly # timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly
for k in range(input_ids.shape[0]): for k in range(input_ids.shape[0]):
......
...@@ -25,6 +25,7 @@ from torch import nn ...@@ -25,6 +25,7 @@ from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...generation.logits_process import WhisperTimeStampLogitsProcessor
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
...@@ -1231,6 +1232,150 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): ...@@ -1231,6 +1232,150 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
encoder_attentions=outputs.encoder_attentions, encoder_attentions=outputs.encoder_attentions,
) )
def generate(
self,
inputs: Optional[torch.Tensor] = None,
generation_config=None,
logits_processor=None,
stopping_criteria=None,
prefix_allowed_tokens_fn=None,
synced_gpus=False,
return_timestamps=None,
task=None,
language=None,
is_multilingual=None,
**kwargs
):
"""
Generates sequences of token ids for models with a language modeling head.
<Tip warning={true}>
Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
model's default generation configuration. You can override any `generation_config` by passing the corresponding
parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
For an overview of generation strategies and code examples, check out the [following
guide](./generation_strategies).
</Tip>
Parameters:
inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of
`input_ids`, `input_values`, `input_features`, or `pixel_values`.
generation_config (`~generation.GenerationConfig`, *optional*):
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
passed to generate matching the attributes of `generation_config` will override them. If
`generation_config` is not provided, the default will be used, which had the following loading
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
default values, whose documentation should be checked to parameterize generation.
logits_processor (`LogitsProcessorList`, *optional*):
Custom logits processors that complement the default logits processors built from arguments and
generation config. If a logit processor is passed that is already created with the arguments or a
generation config an error is thrown. This feature is intended for advanced users.
stopping_criteria (`StoppingCriteriaList`, *optional*):
Custom stopping criteria that complement the default stopping criteria built from arguments and a
generation config. If a stopping criteria is passed that is already created with the arguments or a
generation config an error is thrown. This feature is intended for advanced users.
prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
If provided, this function constraints the beam search to allowed tokens only at each step. If not
provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
`input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
Retrieval](https://arxiv.org/abs/2010.00904).
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
return_timestamps (`bool`, *optional*):
Whether to return the timestamps with the text. This enables the `WhisperTimestampsLogitsProcessor`.
task (`bool`, *optional*):
Task to use for generation, either "translate" or "transcribe". The `model.config.forced_decoder_ids`
will be updated accordingly.
language (`bool`, *optional*):
Language token to use for generation, should be in the form `<|en|>`. You can find all the possible
language tokens in the `model.generation_config.lang_to_id` dictionary.
is_multilingual (`bool`, *optional*):
Whether or not the model is multilingual.
kwargs:
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
Return:
[`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
[`~utils.ModelOutput`] types are:
- [`~generation.GreedySearchDecoderOnlyOutput`],
- [`~generation.SampleDecoderOnlyOutput`],
- [`~generation.BeamSearchDecoderOnlyOutput`],
- [`~generation.BeamSampleDecoderOnlyOutput`]
If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
[`~utils.ModelOutput`] types are:
- [`~generation.GreedySearchEncoderDecoderOutput`],
- [`~generation.SampleEncoderDecoderOutput`],
- [`~generation.BeamSearchEncoderDecoderOutput`],
- [`~generation.BeamSampleEncoderDecoderOutput`]
"""
if generation_config is None:
generation_config = self.generation_config
if return_timestamps is not None:
generation_config.return_timestamps = return_timestamps
if task is not None:
generation_config.task = task
if is_multilingual is not None:
generation_config.is_multilingual = is_multilingual
if language is not None:
generation_config.language = language
forced_decoder_ids = []
if hasattr(generation_config, "is_multilingual") and generation_config.is_multilingual:
if hasattr(generation_config, "language"):
forced_decoder_ids.append((1, generation_config.lang_to_id[generation_config.language]))
else:
forced_decoder_ids.append((1, None))
if hasattr(generation_config, "task"):
forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task]))
else:
forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"]))
if (
hasattr(generation_config, "return_timestamps") and generation_config.return_timestamps
) or return_timestamps:
logits_processor = [WhisperTimeStampLogitsProcessor(generation_config)]
else:
if forced_decoder_ids and forced_decoder_ids[-1][0] != generation_config.no_timestamps_token_id:
idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1
forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id))
if len(forced_decoder_ids) > 0:
generation_config.forced_decoder_ids = forced_decoder_ids
return super().generate(
inputs,
generation_config,
logits_processor,
stopping_criteria,
prefix_allowed_tokens_fn,
synced_gpus,
**kwargs,
)
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, self,
decoder_input_ids, decoder_input_ids,
......
...@@ -493,6 +493,23 @@ class WhisperTokenizer(PreTrainedTokenizer): ...@@ -493,6 +493,23 @@ class WhisperTokenizer(PreTrainedTokenizer):
normalizer = EnglishTextNormalizer(self.english_spelling_normalizer) normalizer = EnglishTextNormalizer(self.english_spelling_normalizer)
return normalizer(text) return normalizer(text)
def _decode_with_timestamps(self, token_ids, time_precision=0.02) -> str:
"""
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. This method decodes
given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
"""
timestamp_begin = self.all_special_ids[-1] + 1
outputs = [[]]
for token in token_ids:
if token >= timestamp_begin:
timestamp = f"<|{(token - timestamp_begin) * time_precision:.2f}|>"
outputs.append(timestamp)
outputs.append([])
else:
outputs[-1].append(token)
outputs = [s if isinstance(s, str) else self.decode(s) for s in outputs]
return "".join(outputs)
def _compute_offsets(self, token_ids, time_precision=0.02): def _compute_offsets(self, token_ids, time_precision=0.02):
""" """
Compute offsets for a given tokenized input Compute offsets for a given tokenized input
...@@ -544,6 +561,7 @@ class WhisperTokenizer(PreTrainedTokenizer): ...@@ -544,6 +561,7 @@ class WhisperTokenizer(PreTrainedTokenizer):
clean_up_tokenization_spaces: bool = True, clean_up_tokenization_spaces: bool = True,
output_offsets: bool = False, output_offsets: bool = False,
time_precision=0.02, time_precision=0.02,
decode_with_timestamps: bool = False,
**kwargs **kwargs
) -> str: ) -> str:
""" """
...@@ -561,7 +579,11 @@ class WhisperTokenizer(PreTrainedTokenizer): ...@@ -561,7 +579,11 @@ class WhisperTokenizer(PreTrainedTokenizer):
Whether or not to clean up the tokenization spaces. Whether or not to clean up the tokenization spaces.
kwargs (additional keyword arguments, *optional*): kwargs (additional keyword arguments, *optional*):
Will be passed to the underlying model specific decode method. Will be passed to the underlying model specific decode method.
output_offsets (`bool`, *optional*, defaults to `False`):
Whether or not to output the offsets of the tokens. This should only be set if the model predicted
timestamps.
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
WHether or not to decode with timestamps included in the raw text.
Returns: Returns:
`str`: The decoded sentence. `str`: The decoded sentence.
""" """
...@@ -571,6 +593,8 @@ class WhisperTokenizer(PreTrainedTokenizer): ...@@ -571,6 +593,8 @@ class WhisperTokenizer(PreTrainedTokenizer):
clean_up_tokenization_spaces=clean_up_tokenization_spaces, clean_up_tokenization_spaces=clean_up_tokenization_spaces,
**kwargs, **kwargs,
) )
if decode_with_timestamps:
text = self._decode_with_timestamps(token_ids, time_precision=time_precision)
# retrieve offsets # retrieve offsets
if output_offsets: if output_offsets:
offsets = None offsets = None
......
...@@ -31,8 +31,6 @@ if TYPE_CHECKING: ...@@ -31,8 +31,6 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
if is_torch_available(): if is_torch_available():
from transformers.generation.logits_process import WhisperTimeStampLogitsProcessor
from ..models.auto.modeling_auto import MODEL_FOR_CTC_MAPPING, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING from ..models.auto.modeling_auto import MODEL_FOR_CTC_MAPPING, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
...@@ -413,13 +411,6 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): ...@@ -413,13 +411,6 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
if return_timestamps is not None: if return_timestamps is not None:
forward_params["return_timestamps"] = return_timestamps forward_params["return_timestamps"] = return_timestamps
postprocess_params["return_timestamps"] = return_timestamps postprocess_params["return_timestamps"] = return_timestamps
if self.model.config.model_type == "whisper":
# Whisper is highly specific, if we want timestamps, we need to
# force whisper to output timestamp tokens, which means we need
# to set this variable to prevent `no_timestamp_token` to be
# used in the decoder.
if "forced_decoder_ids" not in forward_params.get("generate_kwargs", {}):
forward_params["generate_kwargs"]["forced_decoder_ids"] = None
return preprocess_params, forward_params, postprocess_params return preprocess_params, forward_params, postprocess_params
...@@ -529,10 +520,11 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): ...@@ -529,10 +520,11 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
def _forward(self, model_inputs, return_timestamps=False, generate_kwargs=None): def _forward(self, model_inputs, return_timestamps=False, generate_kwargs=None):
if generate_kwargs is None: if generate_kwargs is None:
generate_kwargs = {} generate_kwargs = {}
if return_timestamps and self.type == "seq2seq_whisper":
generate_kwargs["return_timestamps"] = return_timestamps
is_last = model_inputs.pop("is_last") is_last = model_inputs.pop("is_last")
if self.type == "seq2seq": if self.type in {"seq2seq", "seq2seq_whisper"}:
encoder = self.model.get_encoder() encoder = self.model.get_encoder()
# Consume values so we can let extra information flow freely through # Consume values so we can let extra information flow freely through
# the pipeline (important for `partial` in microphone) # the pipeline (important for `partial` in microphone)
...@@ -557,16 +549,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): ...@@ -557,16 +549,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
**generate_kwargs, **generate_kwargs,
) )
out = {"tokens": tokens} out = {"tokens": tokens}
elif self.type == "seq2seq_whisper": if self.type == "seq2seq_whisper":
stride = model_inputs.pop("stride", None) stride = model_inputs.pop("stride", None)
tokens = self.model.generate( if stride is not None:
input_features=model_inputs.pop("input_features"), out["stride"] = stride
logits_processor=[WhisperTimeStampLogitsProcessor()] if return_timestamps else None,
**generate_kwargs,
)
out = {"tokens": tokens}
if stride is not None:
out["stride"] = stride
else: else:
stride = model_inputs.pop("stride", None) stride = model_inputs.pop("stride", None)
......
...@@ -59,7 +59,7 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -59,7 +59,7 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
self.assertEqual(len(vocab_keys), 50364) self.assertEqual(len(vocab_keys), 50364)
def test_vocab_size(self): def test_vocab_size(self):
self.assertEqual(self.get_tokenizer().vocab_size, 50257) self.assertEqual(self.get_tokenizer().vocab_size, 50258)
def test_full_tokenizer(self): def test_full_tokenizer(self):
tokenizer = WhisperTokenizer.from_pretrained(self.tmpdirname) tokenizer = WhisperTokenizer.from_pretrained(self.tmpdirname)
...@@ -265,7 +265,15 @@ class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase): ...@@ -265,7 +265,15 @@ class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
}, },
], ],
) )
# test `decode_with_offsets`
output = multilingual_tokenizer.decode(INPUT_TOKENS, decode_with_timestamps=True)
self.assertEqual(
output,
"<|startoftranscript|><|en|><|transcribe|><|0.00|> Lennils, pictures are a sort of upguards and atom"
" paintings, and Mason's exquisite idles<|7.20|><|7.20|> are as national as a jingo poem. Mr. Birkut"
" Foster's landscapes smile at one much in the<|15.16|><|15.16|> same way that Mr. Carker used to flash"
" his teeth. And Mr. John Colier gives his<|21.70|><|21.70|><|endoftext|>",
)
# test a single sequence with timestamps # test a single sequence with timestamps
# fmt: off # fmt: off
INPUT_TOKENS = [ INPUT_TOKENS = [
......
...@@ -28,7 +28,6 @@ from transformers import ( ...@@ -28,7 +28,6 @@ from transformers import (
Speech2TextForConditionalGeneration, Speech2TextForConditionalGeneration,
Wav2Vec2ForCTC, Wav2Vec2ForCTC,
WhisperForConditionalGeneration, WhisperForConditionalGeneration,
WhisperProcessor,
) )
from transformers.pipelines import AutomaticSpeechRecognitionPipeline, pipeline from transformers.pipelines import AutomaticSpeechRecognitionPipeline, pipeline
from transformers.pipelines.audio_utils import chunk_bytes_iter from transformers.pipelines.audio_utils import chunk_bytes_iter
...@@ -523,10 +522,6 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel ...@@ -523,10 +522,6 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
"chunks": [{"text": " A man said to the universe, Sir, I exist.", "timestamp": (0.0, 4.26)}], "chunks": [{"text": " A man said to the universe, Sir, I exist.", "timestamp": (0.0, 4.26)}],
}, },
) )
pipe = pipeline(
model="openai/whisper-small",
return_timestamps=True,
)
output = pipe(array, chunk_length_s=10) output = pipe(array, chunk_length_s=10)
self.assertDictEqual( self.assertDictEqual(
...@@ -687,6 +682,21 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel ...@@ -687,6 +682,21 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
output, output,
{"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."}, {"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."},
) )
output = speech_recognizer(filename, return_timestamps=True)
self.assertEqual(
output,
{
"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.",
"chunks": [
{
"text": (
" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."
),
"timestamp": (0.0, 5.44),
}
],
},
)
@slow @slow
@require_torch @require_torch
...@@ -712,10 +722,14 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel ...@@ -712,10 +722,14 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
output_2 = speech_recognizer_2(filename) output_2 = speech_recognizer_2(filename)
self.assertEqual(output, output_2) self.assertEqual(output, output_2)
processor = WhisperProcessor(feature_extractor, tokenizer) # either use generate_kwargs or set the model's generation_config
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(task="transcribe", language="it") # model.generation_config.task = "transcribe"
# model.generation_config.lang = "<|it|>"
speech_translator = AutomaticSpeechRecognitionPipeline( speech_translator = AutomaticSpeechRecognitionPipeline(
model=model, tokenizer=tokenizer, feature_extractor=feature_extractor model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
generate_kwargs={"task": "transcribe", "language": "<|it|>"},
) )
output_3 = speech_translator(filename) output_3 = speech_translator(filename)
self.assertEqual(output_3, {"text": " Un uomo ha detto all'universo, Sir, esiste."}) self.assertEqual(output_3, {"text": " Un uomo ha detto all'universo, Sir, esiste."})
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment