Unverified Commit b722a6be authored by Arthur's avatar Arthur Committed by GitHub
Browse files

Fix whisper for `pipeline` (#19482)

* update feature extractor params

* update attention mask handling

* fix doc and pipeline test

* add warning when skipping test

* add whisper translation and transcription test

* fix build doc test
parent df8faba4
...@@ -218,6 +218,7 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor): ...@@ -218,6 +218,7 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
return_attention_mask: Optional[bool] = None, return_attention_mask: Optional[bool] = None,
padding: Optional[str] = "max_length", padding: Optional[str] = "max_length",
max_length: Optional[int] = None, max_length: Optional[int] = None,
sampling_rate: Optional[int] = None,
**kwargs **kwargs
) -> BatchFeature: ) -> BatchFeature:
""" """
...@@ -261,6 +262,19 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor): ...@@ -261,6 +262,19 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
The value that is used to fill the padding values / vectors. The value that is used to fill the padding values / vectors.
""" """
if sampling_rate is not None:
if sampling_rate != self.sampling_rate:
raise ValueError(
f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
f" {self.sampling_rate}. Please make sure that the provided `raw_speech` input was sampled with"
f" {self.sampling_rate} and not {sampling_rate}."
)
else:
logger.warning(
"It is strongly recommended to pass the `sampling_rate` argument to this function. "
"Failing to do so can result in silent errors that might be hard to debug."
)
is_batched = bool( is_batched = bool(
isinstance(raw_speech, (list, tuple)) isinstance(raw_speech, (list, tuple))
and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list))) and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list)))
......
...@@ -31,13 +31,22 @@ from ...modeling_outputs import ( ...@@ -31,13 +31,22 @@ from ...modeling_outputs import (
Seq2SeqModelOutput, Seq2SeqModelOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_whisper import WhisperConfig from .configuration_whisper import WhisperConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "WhisperConfig" _CONFIG_FOR_DOC = "WhisperConfig"
_CHECKPOINT_FOR_DOC = "openai/whisper-tiny"
_PROCESSOR_FOR_DOC = "openai/whisper-tiny"
_EXPECTED_OUTPUT_SHAPE = [1, 2, 512]
WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST = [ WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST = [
...@@ -982,7 +991,14 @@ class WhisperModel(WhisperPreTrainedModel): ...@@ -982,7 +991,14 @@ class WhisperModel(WhisperPreTrainedModel):
return self.decoder return self.decoder
@add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) @add_code_sample_docstrings(
processor_class=_PROCESSOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=Seq2SeqModelOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_EXPECTED_OUTPUT_SHAPE,
modality="audio",
)
def forward( def forward(
self, self,
input_features=None, input_features=None,
...@@ -999,26 +1015,6 @@ class WhisperModel(WhisperPreTrainedModel): ...@@ -999,26 +1015,6 @@ class WhisperModel(WhisperPreTrainedModel):
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
): ):
r"""
Returns:
Example:
```python
>>> import torch
>>> from transformers import WhisperModel, WhisperFeatureExtractor
>>> from datasets import load_dataset
>>> model = WhisperModel.from_pretrained("openai/whisper-base")
>>> feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-base")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt")
>>> input_features = inputs.input_features
>>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
>>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
>>> list(last_hidden_state.shape)
[1, 2, 512]
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
......
...@@ -26,6 +26,8 @@ from transformers import ( ...@@ -26,6 +26,8 @@ from transformers import (
AutoTokenizer, AutoTokenizer,
Speech2TextForConditionalGeneration, Speech2TextForConditionalGeneration,
Wav2Vec2ForCTC, Wav2Vec2ForCTC,
WhisperForConditionalGeneration,
WhisperProcessor,
) )
from transformers.pipelines import AutomaticSpeechRecognitionPipeline, pipeline from transformers.pipelines import AutomaticSpeechRecognitionPipeline, pipeline
from transformers.pipelines.audio_utils import chunk_bytes_iter from transformers.pipelines.audio_utils import chunk_bytes_iter
...@@ -308,6 +310,52 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel ...@@ -308,6 +310,52 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
output = asr(data) output = asr(data)
self.assertEqual(output, {"text": "Un uomo disse all'universo: \"Signore, io esisto."}) self.assertEqual(output, {"text": "Un uomo disse all'universo: \"Signore, io esisto."})
@slow
@require_torch
@require_torchaudio
def test_simple_whisper_asr(self):
speech_recognizer = pipeline(
task="automatic-speech-recognition",
model="openai/whisper-tiny.en",
framework="pt",
)
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
filename = ds[0]["file"]
output = speech_recognizer(filename)
self.assertEqual(output, {"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to"})
@slow
@require_torch
@require_torchaudio
def test_simple_whisper_translation(self):
speech_recognizer = pipeline(
task="automatic-speech-recognition",
model="openai/whisper-large",
framework="pt",
)
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
filename = ds[40]["file"]
output = speech_recognizer(filename)
self.assertEqual(output, {"text": " A man said to the universe, Sir, I exist."})
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large")
tokenizer = AutoTokenizer.from_pretrained("openai/whisper-large")
feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-large")
speech_recognizer_2 = AutomaticSpeechRecognitionPipeline(
model=model, tokenizer=tokenizer, feature_extractor=feature_extractor
)
output_2 = speech_recognizer_2(filename)
self.assertEqual(output, output_2)
processor = WhisperProcessor(feature_extractor, tokenizer)
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(task="transcribe", language="it")
speech_translator = AutomaticSpeechRecognitionPipeline(
model=model, tokenizer=tokenizer, feature_extractor=feature_extractor
)
output_3 = speech_translator(filename)
self.assertEqual(output_3, {"text": " Un uomo ha detto allo universo, Sir, esiste."})
@slow @slow
@require_torch @require_torch
@require_torchaudio @require_torchaudio
......
...@@ -178,8 +178,16 @@ class ANY: ...@@ -178,8 +178,16 @@ class ANY:
class PipelineTestCaseMeta(type): class PipelineTestCaseMeta(type):
def __new__(mcs, name, bases, dct): def __new__(mcs, name, bases, dct):
def gen_test(ModelClass, checkpoint, tiny_config, tokenizer_class, feature_extractor_class): def gen_test(ModelClass, checkpoint, tiny_config, tokenizer_class, feature_extractor_class):
@skipIf(tiny_config is None, "TinyConfig does not exist") @skipIf(
@skipIf(checkpoint is None, "checkpoint does not exist") tiny_config is None,
"TinyConfig does not exist, make sure that you defined a `_CONFIG_FOR_DOC` variable in the modeling"
" file",
)
@skipIf(
checkpoint is None,
"checkpoint does not exist, make sure that you defined a `_CHECKPOINT_FOR_DOC` variable in the"
" modeling file",
)
def test(self): def test(self):
if ModelClass.__name__.endswith("ForCausalLM"): if ModelClass.__name__.endswith("ForCausalLM"):
tiny_config.is_encoder_decoder = False tiny_config.is_encoder_decoder = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment