Unverified Commit 576e2823 authored by JP's avatar JP Committed by GitHub
Browse files

Add descriptive docstring to WhisperTimeStampLogitsProcessor (#25642)



* adding in logit examples for Whisper processor

* adding in updated logits processor for Whisper

* adding in cleaned version of  logits processor for Whisper

* adding docstrings for whisper processor

* making sure the formatting is correct

* adding logits after doc builder

* Update src/transformers/generation/logits_process.py

Adding in suggested fix to the LogitProcessor description.
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/generation/logits_process.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/generation/logits_process.py

Removing tip per suggestion.
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/generation/logits_process.py

Removing redundant code per suggestion.
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* adding in revised version

* adding in version with timestamp examples

* Update src/transformers/generation/logits_process.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* enhanced paragraph on behavior of processor

* fixing doc quality issue

* removing the word poem from example

* adding in updated docstring

* adding in new version of file after doc-builder

---------
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent fc142bd7
......@@ -1457,8 +1457,15 @@ class ForceTokensLogitsProcessor(LogitsProcessor):
class WhisperTimeStampLogitsProcessor(LogitsProcessor):
r"""
Whisper specific Processor. This processor can be used to force a list of tokens. The processor will set their log
probs to `inf` so that they are sampled at their corresponding index.
[`LogitsProcessor`] that modifies the logits for the generation of timestamps in the transcription. When the input
tokens are at a specific threshold, the processor sets the scores to negative infinity. The processor makes sure
that timestamp tokens appear in pairs, by masking out the logits that would break this pairing pattern. This is
done to maintain the consistency and structure of generated timestamps. It also ensures that when the predicted
probability of sampling any of the timestamp token is greater than any individual non-timestamp token, those
non-timestamp logits are set to negative infinity. This is done to ensure the generation of timestamps over other
potential tokens.
See [the paper](https://arxiv.org/abs/2212.04356) for more information.
......@@ -1472,6 +1479,34 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
max_initial_timestamp_index (`int`, *optional*, defaults to 1):
Used to set the maximum value of the initial timestamp. This is used to prevent the model from
predicting timestamps that are too far in the future.
Examples:
``` python
>>> import torch
>>> from transformers import AutoProcessor, WhisperForConditionalGeneration,GenerationConfig
>>> from datasets import load_dataset
>>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = processor(ds[3]["audio"]["array"], return_tensors="pt")
>>> input_features = inputs.input_features
>>> #Displaying timestamps
>>> generated_ids = model.generate(inputs=input_features, return_timestamps=True)
>>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
>>> print("Transcription:", transcription)
Transcription: <|startoftranscript|><|0.00|> He has grave doubts whether Sir Frederick Layton's work is really Greek after all, and can<|6.44|><|6.44|> discover in it but little of rocky Ithaca.<|9.44|><|endoftext|>
>>> #No timestamps & change EOS:
>>> #This allows the user to select a specific token to terminate the sequence on, in this case it's the word "can"(460)
>>> model.generation_config.eos_token_id = 460
>>> generated_ids = model.generate(inputs=input_features,return_timestamps=False)
>>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
>>> print("Transcription:", transcription)
Transcription: He has grave doubts whether Sir Frederick Layton's work is really Greek after all and can
```
"""
def __init__(self, generate_config): # support for the kwargs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment