Unverified Commit 36f183eb authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[ASR Pipeline] Fix init with timestamps (#25438)

* [ASR Pipeline] Fix init

* refactor test

* change default kwarg setting

* only perform checks if we have to

* override init

* move pre/forward/post checks to sanitize
parent 6bca43bb
...@@ -17,19 +17,24 @@ from typing import TYPE_CHECKING, Dict, Optional, Union ...@@ -17,19 +17,24 @@ from typing import TYPE_CHECKING, Dict, Optional, Union
import numpy as np import numpy as np
import requests import requests
from ..modelcard import ModelCard
from ..tokenization_utils import PreTrainedTokenizer
from ..utils import is_torch_available, is_torchaudio_available, logging from ..utils import is_torch_available, is_torchaudio_available, logging
from .audio_utils import ffmpeg_read from .audio_utils import ffmpeg_read
from .base import ChunkPipeline from .base import ArgumentHandler, ChunkPipeline, infer_framework_load_model
if TYPE_CHECKING: if TYPE_CHECKING:
from pyctcdecode import BeamSearchDecoderCTC from pyctcdecode import BeamSearchDecoderCTC
from ..feature_extraction_sequence_utils import SequenceFeatureExtractor from ..feature_extraction_sequence_utils import SequenceFeatureExtractor
from ..modeling_utils import PreTrainedModel
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
if is_torch_available(): if is_torch_available():
import torch
from ..models.auto.modeling_auto import MODEL_FOR_CTC_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES from ..models.auto.modeling_auto import MODEL_FOR_CTC_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
...@@ -194,14 +199,78 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): ...@@ -194,14 +199,78 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
def __init__( def __init__(
self, self,
feature_extractor: Union["SequenceFeatureExtractor", str], model: "PreTrainedModel",
*, feature_extractor: Union["SequenceFeatureExtractor", str] = None,
tokenizer: Optional[PreTrainedTokenizer] = None,
decoder: Optional[Union["BeamSearchDecoderCTC", str]] = None, decoder: Optional[Union["BeamSearchDecoderCTC", str]] = None,
modelcard: Optional[ModelCard] = None,
framework: Optional[str] = None,
task: str = "",
args_parser: ArgumentHandler = None,
device: Union[int, "torch.device"] = None,
torch_dtype: Optional[Union[str, "torch.dtype"]] = None,
binary_output: bool = False,
**kwargs, **kwargs,
): ):
super().__init__(**kwargs) if framework is None:
framework, model = infer_framework_load_model(model, config=model.config)
self.task = task
self.model = model
self.tokenizer = tokenizer
self.feature_extractor = feature_extractor self.feature_extractor = feature_extractor
self.modelcard = modelcard
self.framework = framework
# `accelerate` device map
hf_device_map = getattr(self.model, "hf_device_map", None)
if hf_device_map is not None and device is not None:
raise ValueError(
"The model has been loaded with `accelerate` and therefore cannot be moved to a specific device. Please "
"discard the `device` argument when creating your pipeline object."
)
if self.framework == "tf":
raise ValueError("The AutomaticSpeechRecognitionPipeline is only available in PyTorch.")
# We shouldn't call `model.to()` for models loaded with accelerate
if device is not None and not (isinstance(device, int) and device < 0):
self.model.to(device)
if device is None:
if hf_device_map is not None:
# Take the first device used by `accelerate`.
device = next(iter(hf_device_map.values()))
else:
device = -1
if is_torch_available() and self.framework == "pt":
if isinstance(device, torch.device):
self.device = device
elif isinstance(device, str):
self.device = torch.device(device)
elif device < 0:
self.device = torch.device("cpu")
else:
self.device = torch.device(f"cuda:{device}")
else:
self.device = device if device is not None else -1
self.torch_dtype = torch_dtype
self.binary_output = binary_output
# Update config and generation_config with task specific parameters
task_specific_params = self.model.config.task_specific_params
if task_specific_params is not None and task in task_specific_params:
self.model.config.update(task_specific_params.get(task))
if self.model.can_generate():
self.model.generation_config.update(**task_specific_params.get(task))
self.call_count = 0
self._batch_size = kwargs.pop("batch_size", None)
self._num_workers = kwargs.pop("num_workers", None)
# set the model type so we can check we have the right pre- and post-processing parameters
if self.model.config.model_type == "whisper": if self.model.config.model_type == "whisper":
self.type = "seq2seq_whisper" self.type = "seq2seq_whisper"
elif self.model.__class__.__name__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES.values(): elif self.model.__class__.__name__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES.values():
...@@ -216,8 +285,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): ...@@ -216,8 +285,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
else: else:
self.type = "ctc" self.type = "ctc"
if self.framework == "tf": self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters(**kwargs)
raise ValueError("The AutomaticSpeechRecognitionPipeline is only available in PyTorch.")
mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES.copy() mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES.copy()
mapping.update(MODEL_FOR_CTC_MAPPING_NAMES) mapping.update(MODEL_FOR_CTC_MAPPING_NAMES)
...@@ -301,11 +369,16 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): ...@@ -301,11 +369,16 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
# No parameters on this pipeline right now # No parameters on this pipeline right now
preprocess_params = {} preprocess_params = {}
if chunk_length_s is not None: if chunk_length_s is not None:
if self.type == "seq2seq" and not ignore_warning:
logger.warning(
"Using `chunk_length_s` is very experimental with seq2seq models. The results will not necessarily"
" be entirely accurate and will have caveats. More information:"
" https://github.com/huggingface/transformers/pull/20104. Ignore this warning with pipeline(...,"
" ignore_warning=True)"
)
preprocess_params["chunk_length_s"] = chunk_length_s preprocess_params["chunk_length_s"] = chunk_length_s
if stride_length_s is not None: if stride_length_s is not None:
preprocess_params["stride_length_s"] = stride_length_s preprocess_params["stride_length_s"] = stride_length_s
if ignore_warning is not None:
preprocess_params["ignore_warning"] = ignore_warning
forward_params = defaultdict(dict) forward_params = defaultdict(dict)
if max_new_tokens is not None: if max_new_tokens is not None:
...@@ -322,6 +395,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): ...@@ -322,6 +395,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
if decoder_kwargs is not None: if decoder_kwargs is not None:
postprocess_params["decoder_kwargs"] = decoder_kwargs postprocess_params["decoder_kwargs"] = decoder_kwargs
if return_timestamps is not None: if return_timestamps is not None:
# Check whether we have a valid setting for return_timestamps and throw an error before we perform a forward pass
if self.type == "seq2seq" and return_timestamps: if self.type == "seq2seq" and return_timestamps:
raise ValueError("We cannot return_timestamps yet on non-CTC models apart from Whisper!") raise ValueError("We cannot return_timestamps yet on non-CTC models apart from Whisper!")
if self.type == "ctc_with_lm" and return_timestamps != "word": if self.type == "ctc_with_lm" and return_timestamps != "word":
...@@ -339,11 +413,13 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): ...@@ -339,11 +413,13 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
forward_params["return_timestamps"] = return_timestamps forward_params["return_timestamps"] = return_timestamps
postprocess_params["return_timestamps"] = return_timestamps postprocess_params["return_timestamps"] = return_timestamps
if return_language is not None: if return_language is not None:
if self.type != "seq2seq_whisper":
raise ValueError("Only Whisper can return language for now.")
postprocess_params["return_language"] = return_language postprocess_params["return_language"] = return_language
return preprocess_params, forward_params, postprocess_params return preprocess_params, forward_params, postprocess_params
def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None, ignore_warning=False): def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
if isinstance(inputs, str): if isinstance(inputs, str):
if inputs.startswith("http://") or inputs.startswith("https://"): if inputs.startswith("http://") or inputs.startswith("https://"):
# We need to actually check for a real protocol, otherwise it's impossible to use a local file # We need to actually check for a real protocol, otherwise it's impossible to use a local file
...@@ -378,8 +454,6 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): ...@@ -378,8 +454,6 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
extra = inputs extra = inputs
inputs = _inputs inputs = _inputs
if in_sampling_rate != self.feature_extractor.sampling_rate: if in_sampling_rate != self.feature_extractor.sampling_rate:
import torch
if is_torchaudio_available(): if is_torchaudio_available():
from torchaudio import functional as F from torchaudio import functional as F
else: else:
...@@ -409,14 +483,6 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): ...@@ -409,14 +483,6 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline") raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline")
if chunk_length_s: if chunk_length_s:
if self.type == "seq2seq" and not ignore_warning:
logger.warning(
"Using `chunk_length_s` is very experimental with seq2seq models. The results will not necessarily"
" be entirely accurate and will have caveats. More information:"
" https://github.com/huggingface/transformers/pull/20104. Ignore this warning with pipeline(...,"
" ignore_warning=True)"
)
self._preprocess_params["ignore_warning"] = True
if stride_length_s is None: if stride_length_s is None:
stride_length_s = chunk_length_s / 6 stride_length_s = chunk_length_s / 6
...@@ -456,6 +522,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): ...@@ -456,6 +522,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
def _forward(self, model_inputs, return_timestamps=False, generate_kwargs=None): def _forward(self, model_inputs, return_timestamps=False, generate_kwargs=None):
if generate_kwargs is None: if generate_kwargs is None:
generate_kwargs = {} generate_kwargs = {}
if return_timestamps and self.type == "seq2seq_whisper": if return_timestamps and self.type == "seq2seq_whisper":
generate_kwargs["return_timestamps"] = return_timestamps generate_kwargs["return_timestamps"] = return_timestamps
if return_timestamps == "word": if return_timestamps == "word":
...@@ -525,9 +592,6 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): ...@@ -525,9 +592,6 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
# Optional return types # Optional return types
optional = {} optional = {}
if return_language is not None and self.type != "seq2seq_whisper":
raise ValueError("Only whisper can return language for now.")
final_items = [] final_items = []
key = "logits" if self.type == "ctc_with_lm" else "tokens" key = "logits" if self.type == "ctc_with_lm" else "tokens"
stride = None stride = None
......
...@@ -343,6 +343,58 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase): ...@@ -343,6 +343,58 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
) )
# fmt: on # fmt: on
@require_torch
def test_return_timestamps_in_init(self):
# segment-level timestamps are accepted
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
tokenizer = AutoTokenizer.from_pretrained("openai/whisper-tiny")
feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-tiny")
dummy_speech = np.ones(100)
pipe = pipeline(
task="automatic-speech-recognition",
model=model,
feature_extractor=feature_extractor,
tokenizer=tokenizer,
chunk_length_s=8,
stride_length_s=1,
return_timestamps=True,
)
_ = pipe(dummy_speech)
# word-level timestamps are accepted
pipe = pipeline(
task="automatic-speech-recognition",
model=model,
feature_extractor=feature_extractor,
tokenizer=tokenizer,
chunk_length_s=8,
stride_length_s=1,
return_timestamps="word",
)
_ = pipe(dummy_speech)
# char-level timestamps are not accepted
with self.assertRaisesRegex(
ValueError,
"^Whisper cannot return `char` timestamps, only word level or segment level timestamps. "
"Use `return_timestamps='word'` or `return_timestamps=True` respectively.$",
):
pipe = pipeline(
task="automatic-speech-recognition",
model=model,
feature_extractor=feature_extractor,
tokenizer=tokenizer,
chunk_length_s=8,
stride_length_s=1,
return_timestamps="char",
)
_ = pipe(dummy_speech)
@require_torch @require_torch
@slow @slow
def test_torch_whisper(self): def test_torch_whisper(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment