"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "cad1b1192b3ad98808d24f898e28fb56f78720d3"
Unverified Commit 25c451e5 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Adding chunking for whisper (all seq2seq actually). Very crude matching algorithm. (#20104)

* Very crude matching algorithm.

* Fixing tests.

* Removing comments

* Adding warning + fix short matches.

* Cleanup tests.

* Quality.

* Less noisy.

* Fixup.
parent 938cb047
......@@ -30,7 +30,7 @@ if is_torch_available():
from ..models.auto.modeling_auto import MODEL_FOR_CTC_MAPPING, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
def rescale_stride(tokens_or_logits, stride, ratio):
def rescale_stride(stride, ratio):
"""
Rescales the stride values from audio space to tokens/logits space.
......@@ -60,8 +60,43 @@ def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right):
_stride_left = 0 if i == 0 else stride_left
is_last = i + step + stride_left >= inputs_len
_stride_right = 0 if is_last else stride_right
if "input_features" in processed:
processed_len = processed["input_features"].shape[-1]
elif "input_values" in processed:
processed_len = processed["input_values"].shape[-1]
chunk_len = chunk.shape[0]
stride = (chunk_len, _stride_left, _stride_right)
if processed_len != chunk.shape[-1]:
ratio = processed_len / chunk_len
stride = rescale_stride([stride], ratio)[0]
if chunk.shape[0] > _stride_left:
yield {"is_last": is_last, "stride": (chunk.shape[0], _stride_left, _stride_right), **processed}
yield {"is_last": is_last, "stride": stride, **processed}
def _find_longest_common_sequence(sequences, tokenizer):
# TODO Use a faster algorithm this can probably be done in O(n)
# using suffix array.
# It might be tedious to do because of fault tolerance.
# We actually have a really good property which is that the total sequence
# MUST be those subsequences in order.
# Also the algorithm should be more tolerant to errors.
sequence = [tok_id for tok_id in sequences[0][0].tolist() if tok_id not in tokenizer.all_special_ids]
for new_seq in sequences[1:]:
new_sequence = [tok_id for tok_id in new_seq[0].tolist() if tok_id not in tokenizer.all_special_ids]
index = 0
max_ = 0.0
for i in range(1, len(new_sequence) + 1):
# epsilon to favor long perfect matches
eps = i / 10000.0
matches = np.sum(np.array(sequence[-i:]) == np.array(new_sequence[:i]))
matching = matches / i + eps
if matches > 1 and matching > max_:
index = i
max_ = matching
sequence.extend(new_sequence[index:])
return np.array(sequence)
class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
......@@ -188,6 +223,8 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
preprocess_params["chunk_length_s"] = kwargs["chunk_length_s"]
if "stride_length_s" in kwargs:
preprocess_params["stride_length_s"] = kwargs["stride_length_s"]
if "ignore_warning" in kwargs:
preprocess_params["ignore_warning"] = kwargs["ignore_warning"]
postprocess_params = {}
if "decoder_kwargs" in kwargs:
......@@ -197,7 +234,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
return preprocess_params, {}, postprocess_params
def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None, ignore_warning=False):
if isinstance(inputs, str):
with open(inputs, "rb") as f:
inputs = f.read()
......@@ -249,10 +286,14 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline")
if chunk_length_s:
if self.type not in {"ctc", "ctc_with_lm"}:
raise ValueError(
"`chunk_length_s` is only valid for CTC models, use other chunking options for other models"
if self.type == "seq2seq" and not ignore_warning:
logger.warning(
"Using `chunk_length_s` is very experimental with seq2seq models. The results will not necessarily"
" be entirely accurate and will have caveats. More information:"
" https://github.com/huggingface/transformers/pull/20104. Ignore this warning with pipeline(...,"
" ignore_warning=True)"
)
self._preprocess_params["ignore_warning"] = True
if stride_length_s is None:
stride_length_s = chunk_length_s / 6
......@@ -262,7 +303,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
# XXX: Carefuly, this variable will not exist in `seq2seq` setting.
# Currently chunking is not possible at this level for `seq2seq` so
# it's ok.
align_to = self.model.config.inputs_to_logits_ratio
align_to = getattr(self.model.config, "inputs_to_logits_ratio", 1)
chunk_len = int(round(chunk_length_s * self.feature_extractor.sampling_rate / align_to) * align_to)
stride_left = int(round(stride_length_s[0] * self.feature_extractor.sampling_rate / align_to) * align_to)
stride_right = int(round(stride_length_s[1] * self.feature_extractor.sampling_rate / align_to) * align_to)
......@@ -329,9 +370,9 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
# the pieces are to be concatenated.
ratio = 1 / self.model.config.inputs_to_logits_ratio
if isinstance(stride, tuple):
out["stride"] = rescale_stride(logits, [stride], ratio)[0]
out["stride"] = rescale_stride([stride], ratio)[0]
else:
out["stride"] = rescale_stride(logits, stride, ratio)
out["stride"] = rescale_stride(stride, ratio)
# Leftover
extra = model_inputs
return {"is_last": is_last, **out, **extra}
......@@ -347,10 +388,11 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
final_items = []
key = "logits" if self.type == "ctc_with_lm" else "tokens"
stride = None
for outputs in model_outputs:
items = outputs[key].numpy()
stride = outputs.pop("stride", None)
if stride is not None:
if stride is not None and self.type in {"ctc", "ctc_with_lm"}:
total_n, left, right = stride
# Total_n might be < logits.shape[1]
# because of padding, that's why
......@@ -359,8 +401,11 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
right_n = total_n - right
items = items[:, left:right_n]
final_items.append(items)
items = np.concatenate(final_items, axis=1)
items = items.squeeze(0)
if stride and self.type == "seq2seq":
items = _find_longest_common_sequence(final_items, self.tokenizer)
else:
items = np.concatenate(final_items, axis=1)
items = items.squeeze(0)
if self.type == "ctc_with_lm":
if decoder_kwargs is None:
decoder_kwargs = {}
......
......@@ -144,12 +144,8 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
waveform = np.tile(np.arange(1000, dtype=np.float32), 34)
output = speech_recognizer(waveform)
self.assertEqual(output, {"text": "(Applaudissements)"})
with self.assertRaises(ValueError) as v:
_ = speech_recognizer(waveform, chunk_length_s=10)
self.assertEqual(
str(v.exception),
"`chunk_length_s` is only valid for CTC models, use other chunking options for other models",
)
output = speech_recognizer(waveform, chunk_length_s=10)
self.assertEqual(output, {"text": "(Applaudissements)"})
# Non CTC models cannot use return_timestamps
with self.assertRaises(ValueError) as v:
......@@ -261,6 +257,22 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
output = speech_recognizer(filename)
self.assertEqual(output, {"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST"})
@require_torch
@slow
def test_torch_whisper(self):
speech_recognizer = pipeline(
task="automatic-speech-recognition",
model="openai/whisper-tiny",
framework="pt",
)
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
filename = ds[40]["file"]
output = speech_recognizer(filename)
self.assertEqual(output, {"text": " A man said to the universe, Sir, I exist."})
output = speech_recognizer([filename], chunk_length_s=5, batch_size=4)
self.assertEqual(output, [{"text": " A man said to the universe, Sir, I exist."}])
@require_torch
@slow
def test_torch_speech_encoder_decoder(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment