Unverified Commit 38f95d18 authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

Large audio chunking for the existing ASR pipeline (#14896)



* Naive ASR chunking

* Fixing batching for ASR.
Co-authored-by: default avatarNicolas Patry <patry.nicolas@protonmail.com>
parent d33dc796
...@@ -18,7 +18,7 @@ import numpy as np ...@@ -18,7 +18,7 @@ import numpy as np
from ..file_utils import is_torch_available from ..file_utils import is_torch_available
from ..utils import logging from ..utils import logging
from .base import Pipeline from .base import ChunkPipeline
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -66,7 +66,7 @@ def ffmpeg_read(bpayload: bytes, sampling_rate: int) -> np.array: ...@@ -66,7 +66,7 @@ def ffmpeg_read(bpayload: bytes, sampling_rate: int) -> np.array:
return audio return audio
class AutomaticSpeechRecognitionPipeline(Pipeline): class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
""" """
Pipeline that aims at extracting spoken text contained within some audio. Pipeline that aims at extracting spoken text contained within some audio.
...@@ -85,8 +85,13 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): ...@@ -85,8 +85,13 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
tokenizer ([`PreTrainedTokenizer`]): tokenizer ([`PreTrainedTokenizer`]):
The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from
[`PreTrainedTokenizer`]. [`PreTrainedTokenizer`].
modelcard (`str` or [`ModelCard`], *optional*): chunk_length_ms (`int`, *optional*, defaults to 0):
Model card attributed to the model for this pipeline. The input length for in each chunk. If `0` then chunking is disabled (default). Only available for CTC
models.
stride_length_ms (`int`, *optional*, defaults to `chunk_length_ms / 6`):
The length of stride on the left and right of each chunk. Used only with `chunk_length_ms > 0`. This
enables the model to *see* more context and infer letters better than without this context but the
pipeline discards the stride bits at the end to make the final reconstitution as perfect as possible.
framework (`str`, *optional*): framework (`str`, *optional*):
The framework to use, either `"pt"` for PyTorch or `"tf"` for TensorFlow. The specified framework must The framework to use, either `"pt"` for PyTorch or `"tf"` for TensorFlow. The specified framework must
be installed. be installed.
...@@ -133,9 +138,14 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): ...@@ -133,9 +138,14 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
def _sanitize_parameters(self, **kwargs): def _sanitize_parameters(self, **kwargs):
# No parameters on this pipeline right now # No parameters on this pipeline right now
return {}, {}, {} preprocess_params = {}
if "chunk_length_ms" in kwargs:
def preprocess(self, inputs): preprocess_params["chunk_length_ms"] = kwargs["chunk_length_ms"]
if "stride_length_ms" in kwargs:
preprocess_params["stride_length_ms"] = kwargs["stride_length_ms"]
return preprocess_params, {}, {}
def preprocess(self, inputs, chunk_length_ms=0, stride_length_ms=None):
if isinstance(inputs, str): if isinstance(inputs, str):
with open(inputs, "rb") as f: with open(inputs, "rb") as f:
inputs = f.read() inputs = f.read()
...@@ -148,13 +158,48 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): ...@@ -148,13 +158,48 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
if len(inputs.shape) != 1: if len(inputs.shape) != 1:
raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline") raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline")
if chunk_length_ms:
if stride_length_ms is None:
stride_length_ms = chunk_length_ms // 6
inputs_len = len(inputs)
chunk_len = chunk_length_ms * self.feature_extractor.sampling_rate // 1000
stride_len = stride_length_ms * self.feature_extractor.sampling_rate // 1000
# Redefine chunk_len to useful chunk length
# Not the size
# chunk_len = chunk_len - 2 * stride_len
if self.model.__class__ not in MODEL_FOR_CTC_MAPPING.values():
raise ValueError(
"`chunk_length_ms` is only valid for CTC models, use other chunking options for other models"
)
if chunk_len < stride_len:
raise ValueError("Chunk length must be superior to stride length")
# make sure that
step = chunk_len
for i in range(0, inputs_len, step):
# add start and end paddings to the chunk
start = 0 if i - stride_len < 0 else i - stride_len
stop = inputs_len if i + chunk_len + stride_len > inputs_len else i + chunk_len + stride_len
chunk = inputs[start:stop]
processed = self.feature_extractor(
chunk, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
)
stride_left = i - start
stride_right = max(stop - (i + chunk_len), 0)
is_last = i + step > inputs_len
yield {"is_last": is_last, "stride": (stop - start, stride_left, stride_right), **processed}
else:
processed = self.feature_extractor( processed = self.feature_extractor(
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt" inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
) )
return processed yield {"is_last": True, **processed}
def _forward(self, model_inputs): def _forward(self, model_inputs):
model_class = self.model.__class__ model_class = self.model.__class__
is_last = model_inputs.pop("is_last")
if model_class in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.values(): if model_class in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.values():
encoder = self.model.get_encoder() encoder = self.model.get_encoder()
# we need to pass `processed.get("attention_mask")` here since audio encoder # we need to pass `processed.get("attention_mask")` here since audio encoder
...@@ -164,17 +209,33 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): ...@@ -164,17 +209,33 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
tokens = self.model.generate( tokens = self.model.generate(
encoder_outputs=encoder(**model_inputs), attention_mask=model_inputs.get("attention_mask") encoder_outputs=encoder(**model_inputs), attention_mask=model_inputs.get("attention_mask")
) )
tokens = tokens.squeeze(0)
elif model_class in MODEL_FOR_CTC_MAPPING.values(): elif model_class in MODEL_FOR_CTC_MAPPING.values():
stride = model_inputs.pop("stride", None)
outputs = self.model(**model_inputs) outputs = self.model(**model_inputs)
tokens = outputs.logits.squeeze(0).argmax(dim=-1) tokens = outputs.logits.argmax(dim=-1)
if stride is not None:
if isinstance(stride, tuple):
stride = [stride]
max_token_n = tokens.shape[-1]
max_input_n = max(input_n for input_n, _, _ in stride)
ratio = max_token_n / max_input_n
for i, (input_n, left, right) in enumerate(stride):
token_n = int(input_n * ratio) + 1
left_token = int(left / input_n * token_n)
right_token = int((input_n - right) / input_n * token_n) + 1
tokens[i, :left_token] = self.tokenizer.pad_token_id
tokens[i, right_token:] = self.tokenizer.pad_token_id
else: else:
logger.warning("This is an unknown class, treating it as CTC.") logger.warning("This is an unknown class, treating it as CTC.")
outputs = self.model(**model_inputs) outputs = self.model(**model_inputs)
tokens = outputs.logits.squeeze(0).argmax(dim=-1) tokens = outputs.logits.argmax(dim=-1)
return tokens return {"tokens": tokens, "is_last": is_last}
def postprocess(self, model_outputs): def postprocess(self, model_outputs):
skip_special_tokens = False if "CTC" in self.tokenizer.__class__.__name__ else True skip_special_tokens = False if "CTC" in self.tokenizer.__class__.__name__ else True
recognized_string = self.tokenizer.decode(model_outputs, skip_special_tokens=skip_special_tokens) tokens = np.concatenate([outputs["tokens"].numpy() for outputs in model_outputs], axis=-1)
tokens = tokens.squeeze(0)
recognized_string = self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
return {"text": recognized_string} return {"text": recognized_string}
...@@ -232,3 +232,30 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel ...@@ -232,3 +232,30 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
filename = ds[40]["file"] filename = ds[40]["file"]
output = speech_recognizer(filename) output = speech_recognizer(filename)
self.assertEqual(output, {"text": "a man said to the universe sir i exist"}) self.assertEqual(output, {"text": "a man said to the universe sir i exist"})
@require_torch
@slow
def test_chunking(self):
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
tokenizer = AutoTokenizer.from_pretrained("facebook/wav2vec2-base-960h")
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
speech_recognizer = pipeline(
task="automatic-speech-recognition",
model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
framework="pt",
chunk_length_ms=10_000,
)
from datasets import load_dataset
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
audio = ds[40]["audio"]["array"]
n_repeats = 100
audio = np.tile(audio, n_repeats)
output = speech_recognizer([audio], batch_size=2)
expected_text = "A MAN SAID TO THE UNIVERSE SIR I EXIST " * n_repeats
expected = [{"text": expected_text.strip()}]
self.assertEqual(output, expected)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment