Unverified Commit 36f183eb authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[ASR Pipeline] Fix init with timestamps (#25438)

* [ASR Pipeline] Fix init

* refactor test

* change default kwarg setting

* only perform checks if we have to

* override init

* move pre/forward/post checks to sanitize
parent 6bca43bb
......@@ -17,19 +17,24 @@ from typing import TYPE_CHECKING, Dict, Optional, Union
import numpy as np
import requests
from ..modelcard import ModelCard
from ..tokenization_utils import PreTrainedTokenizer
from ..utils import is_torch_available, is_torchaudio_available, logging
from .audio_utils import ffmpeg_read
from .base import ChunkPipeline
from .base import ArgumentHandler, ChunkPipeline, infer_framework_load_model
if TYPE_CHECKING:
from pyctcdecode import BeamSearchDecoderCTC
from ..feature_extraction_sequence_utils import SequenceFeatureExtractor
from ..modeling_utils import PreTrainedModel
logger = logging.get_logger(__name__)
if is_torch_available():
import torch
from ..models.auto.modeling_auto import MODEL_FOR_CTC_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
......@@ -194,14 +199,78 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
def __init__(
self,
feature_extractor: Union["SequenceFeatureExtractor", str],
*,
model: "PreTrainedModel",
feature_extractor: Union["SequenceFeatureExtractor", str] = None,
tokenizer: Optional[PreTrainedTokenizer] = None,
decoder: Optional[Union["BeamSearchDecoderCTC", str]] = None,
modelcard: Optional[ModelCard] = None,
framework: Optional[str] = None,
task: str = "",
args_parser: ArgumentHandler = None,
device: Union[int, "torch.device"] = None,
torch_dtype: Optional[Union[str, "torch.dtype"]] = None,
binary_output: bool = False,
**kwargs,
):
super().__init__(**kwargs)
if framework is None:
framework, model = infer_framework_load_model(model, config=model.config)
self.task = task
self.model = model
self.tokenizer = tokenizer
self.feature_extractor = feature_extractor
self.modelcard = modelcard
self.framework = framework
# `accelerate` device map
hf_device_map = getattr(self.model, "hf_device_map", None)
if hf_device_map is not None and device is not None:
raise ValueError(
"The model has been loaded with `accelerate` and therefore cannot be moved to a specific device. Please "
"discard the `device` argument when creating your pipeline object."
)
if self.framework == "tf":
raise ValueError("The AutomaticSpeechRecognitionPipeline is only available in PyTorch.")
# We shouldn't call `model.to()` for models loaded with accelerate
if device is not None and not (isinstance(device, int) and device < 0):
self.model.to(device)
if device is None:
if hf_device_map is not None:
# Take the first device used by `accelerate`.
device = next(iter(hf_device_map.values()))
else:
device = -1
if is_torch_available() and self.framework == "pt":
if isinstance(device, torch.device):
self.device = device
elif isinstance(device, str):
self.device = torch.device(device)
elif device < 0:
self.device = torch.device("cpu")
else:
self.device = torch.device(f"cuda:{device}")
else:
self.device = device if device is not None else -1
self.torch_dtype = torch_dtype
self.binary_output = binary_output
# Update config and generation_config with task specific parameters
task_specific_params = self.model.config.task_specific_params
if task_specific_params is not None and task in task_specific_params:
self.model.config.update(task_specific_params.get(task))
if self.model.can_generate():
self.model.generation_config.update(**task_specific_params.get(task))
self.call_count = 0
self._batch_size = kwargs.pop("batch_size", None)
self._num_workers = kwargs.pop("num_workers", None)
# set the model type so we can check we have the right pre- and post-processing parameters
if self.model.config.model_type == "whisper":
self.type = "seq2seq_whisper"
elif self.model.__class__.__name__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES.values():
......@@ -216,8 +285,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
else:
self.type = "ctc"
if self.framework == "tf":
raise ValueError("The AutomaticSpeechRecognitionPipeline is only available in PyTorch.")
self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters(**kwargs)
mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES.copy()
mapping.update(MODEL_FOR_CTC_MAPPING_NAMES)
......@@ -301,11 +369,16 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
# No parameters on this pipeline right now
preprocess_params = {}
if chunk_length_s is not None:
if self.type == "seq2seq" and not ignore_warning:
logger.warning(
"Using `chunk_length_s` is very experimental with seq2seq models. The results will not necessarily"
" be entirely accurate and will have caveats. More information:"
" https://github.com/huggingface/transformers/pull/20104. Ignore this warning with pipeline(...,"
" ignore_warning=True)"
)
preprocess_params["chunk_length_s"] = chunk_length_s
if stride_length_s is not None:
preprocess_params["stride_length_s"] = stride_length_s
if ignore_warning is not None:
preprocess_params["ignore_warning"] = ignore_warning
forward_params = defaultdict(dict)
if max_new_tokens is not None:
......@@ -322,6 +395,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
if decoder_kwargs is not None:
postprocess_params["decoder_kwargs"] = decoder_kwargs
if return_timestamps is not None:
# Check whether we have a valid setting for return_timestamps and throw an error before we perform a forward pass
if self.type == "seq2seq" and return_timestamps:
raise ValueError("We cannot return_timestamps yet on non-CTC models apart from Whisper!")
if self.type == "ctc_with_lm" and return_timestamps != "word":
......@@ -339,11 +413,13 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
forward_params["return_timestamps"] = return_timestamps
postprocess_params["return_timestamps"] = return_timestamps
if return_language is not None:
if self.type != "seq2seq_whisper":
raise ValueError("Only Whisper can return language for now.")
postprocess_params["return_language"] = return_language
return preprocess_params, forward_params, postprocess_params
def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None, ignore_warning=False):
def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
if isinstance(inputs, str):
if inputs.startswith("http://") or inputs.startswith("https://"):
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
......@@ -378,8 +454,6 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
extra = inputs
inputs = _inputs
if in_sampling_rate != self.feature_extractor.sampling_rate:
import torch
if is_torchaudio_available():
from torchaudio import functional as F
else:
......@@ -409,14 +483,6 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline")
if chunk_length_s:
if self.type == "seq2seq" and not ignore_warning:
logger.warning(
"Using `chunk_length_s` is very experimental with seq2seq models. The results will not necessarily"
" be entirely accurate and will have caveats. More information:"
" https://github.com/huggingface/transformers/pull/20104. Ignore this warning with pipeline(...,"
" ignore_warning=True)"
)
self._preprocess_params["ignore_warning"] = True
if stride_length_s is None:
stride_length_s = chunk_length_s / 6
......@@ -456,6 +522,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
def _forward(self, model_inputs, return_timestamps=False, generate_kwargs=None):
if generate_kwargs is None:
generate_kwargs = {}
if return_timestamps and self.type == "seq2seq_whisper":
generate_kwargs["return_timestamps"] = return_timestamps
if return_timestamps == "word":
......@@ -525,9 +592,6 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
# Optional return types
optional = {}
if return_language is not None and self.type != "seq2seq_whisper":
raise ValueError("Only whisper can return language for now.")
final_items = []
key = "logits" if self.type == "ctc_with_lm" else "tokens"
stride = None
......
......@@ -343,6 +343,58 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
)
# fmt: on
@require_torch
def test_return_timestamps_in_init(self):
# segment-level timestamps are accepted
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
tokenizer = AutoTokenizer.from_pretrained("openai/whisper-tiny")
feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-tiny")
dummy_speech = np.ones(100)
pipe = pipeline(
task="automatic-speech-recognition",
model=model,
feature_extractor=feature_extractor,
tokenizer=tokenizer,
chunk_length_s=8,
stride_length_s=1,
return_timestamps=True,
)
_ = pipe(dummy_speech)
# word-level timestamps are accepted
pipe = pipeline(
task="automatic-speech-recognition",
model=model,
feature_extractor=feature_extractor,
tokenizer=tokenizer,
chunk_length_s=8,
stride_length_s=1,
return_timestamps="word",
)
_ = pipe(dummy_speech)
# char-level timestamps are not accepted
with self.assertRaisesRegex(
ValueError,
"^Whisper cannot return `char` timestamps, only word level or segment level timestamps. "
"Use `return_timestamps='word'` or `return_timestamps=True` respectively.$",
):
pipe = pipeline(
task="automatic-speech-recognition",
model=model,
feature_extractor=feature_extractor,
tokenizer=tokenizer,
chunk_length_s=8,
stride_length_s=1,
return_timestamps="char",
)
_ = pipe(dummy_speech)
@require_torch
@slow
def test_torch_whisper(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment