Unverified Commit b80b2218 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[ci-daily] Fix pipeline tests (#21257)

* use streaming dataset

* fix whisper's test

* add rescale argument to chunk_iter
parent 275ad9d8
...@@ -56,7 +56,7 @@ def rescale_stride(stride, ratio): ...@@ -56,7 +56,7 @@ def rescale_stride(stride, ratio):
return new_strides return new_strides
def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right, ratio, dtype=None): def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right, rescale=True, dtype=None):
inputs_len = inputs.shape[0] inputs_len = inputs.shape[0]
step = chunk_len - stride_left - stride_right step = chunk_len - stride_left - stride_right
for i in range(0, inputs_len, step): for i in range(0, inputs_len, step):
...@@ -68,9 +68,15 @@ def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right, ...@@ -68,9 +68,15 @@ def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right,
_stride_left = 0 if i == 0 else stride_left _stride_left = 0 if i == 0 else stride_left
is_last = i + step + stride_left >= inputs_len is_last = i + step + stride_left >= inputs_len
_stride_right = 0 if is_last else stride_right _stride_right = 0 if is_last else stride_right
chunk_len = chunk.shape[0] chunk_len = chunk.shape[0]
stride = (chunk_len, _stride_left, _stride_right) stride = (chunk_len, _stride_left, _stride_right)
if ratio != 1: if "input_features" in processed:
processed_len = processed["input_features"].shape[-1]
elif "input_values" in processed:
processed_len = processed["input_values"].shape[-1]
if processed_len != chunk.shape[-1] and rescale:
ratio = processed_len / chunk_len
stride = rescale_stride([stride], ratio)[0] stride = rescale_stride([stride], ratio)[0]
if chunk.shape[0] > _stride_left: if chunk.shape[0] > _stride_left:
yield {"is_last": is_last, "stride": stride, **processed} yield {"is_last": is_last, "stride": stride, **processed}
...@@ -101,10 +107,10 @@ def _find_timestamp_sequence(sequences, tokenizer, feature_extractor, max_source ...@@ -101,10 +107,10 @@ def _find_timestamp_sequence(sequences, tokenizer, feature_extractor, max_source
sequence = sequence[begin_idx:] sequence = sequence[begin_idx:]
timestamp_tokens = sequence >= timestamp_begin timestamp_tokens = sequence >= timestamp_begin
if seq_idx != 0 and sum(timestamp_tokens) > 0:
consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1 consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1
last_timestamp = np.where(timestamp_tokens)[0][-1] last_timestamp = np.where(timestamp_tokens)[0][-1]
consecutive = np.append(consecutive, last_timestamp) if last_timestamp not in consecutive else consecutive consecutive = np.append(consecutive, last_timestamp) if last_timestamp not in consecutive else consecutive
if seq_idx != 0:
time -= stride_left + stride_right time -= stride_left + stride_right
offset = int((time / feature_extractor.sampling_rate) / time_precision) offset = int((time / feature_extractor.sampling_rate) / time_precision)
overlap_time = int((stride_left / feature_extractor.sampling_rate) / time_precision) overlap_time = int((stride_left / feature_extractor.sampling_rate) / time_precision)
...@@ -400,13 +406,12 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): ...@@ -400,13 +406,12 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
" only 1 version" " only 1 version"
) )
forward_params["generate_kwargs"].update(generate_kwargs) forward_params["generate_kwargs"].update(generate_kwargs)
if return_timestamps is not None:
forward_params["generate_kwargs"]["return_timestamps"] = return_timestamps
postprocess_params = {} postprocess_params = {}
if decoder_kwargs is not None: if decoder_kwargs is not None:
postprocess_params["decoder_kwargs"] = decoder_kwargs postprocess_params["decoder_kwargs"] = decoder_kwargs
if return_timestamps is not None: if return_timestamps is not None:
forward_params["return_timestamps"] = return_timestamps
postprocess_params["return_timestamps"] = return_timestamps postprocess_params["return_timestamps"] = return_timestamps
if self.model.config.model_type == "whisper": if self.model.config.model_type == "whisper":
# Whisper is highly specific, if we want timestamps, we need to # Whisper is highly specific, if we want timestamps, we need to
...@@ -502,9 +507,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): ...@@ -502,9 +507,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
if chunk_len < stride_left + stride_right: if chunk_len < stride_left + stride_right:
raise ValueError("Chunk length must be superior to stride length") raise ValueError("Chunk length must be superior to stride length")
rescale = self.type != "seq2seq_whisper"
# make sure that # make sure that
for item in chunk_iter( for item in chunk_iter(
inputs, self.feature_extractor, chunk_len, stride_left, stride_right, align_to, self.torch_dtype inputs, self.feature_extractor, chunk_len, stride_left, stride_right, rescale, self.torch_dtype
): ):
yield item yield item
else: else:
...@@ -520,12 +526,11 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): ...@@ -520,12 +526,11 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
processed["stride"] = stride processed["stride"] = stride
yield {"is_last": True, **processed, **extra} yield {"is_last": True, **processed, **extra}
def _forward(self, model_inputs, generate_kwargs=None): def _forward(self, model_inputs, return_timestamps=False, generate_kwargs=None):
if generate_kwargs is None: if generate_kwargs is None:
generate_kwargs = {} generate_kwargs = {}
is_last = model_inputs.pop("is_last") is_last = model_inputs.pop("is_last")
return_timestamps = generate_kwargs.pop("return_timestamps", False)
if self.type == "seq2seq": if self.type == "seq2seq":
encoder = self.model.get_encoder() encoder = self.model.get_encoder()
...@@ -635,9 +640,9 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): ...@@ -635,9 +640,9 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
# Simply cast from pyctcdecode format to wav2vec2 format to leverage # Simply cast from pyctcdecode format to wav2vec2 format to leverage
# pre-existing code later # pre-existing code later
chunk_offset = beams[0][2] chunk_offset = beams[0][2]
word_offsets = [] offsets = []
for word, (start_offset, end_offset) in chunk_offset: for word, (start_offset, end_offset) in chunk_offset:
word_offsets.append({"word": word, "start_offset": start_offset, "end_offset": end_offset}) offsets.append({"word": word, "start_offset": start_offset, "end_offset": end_offset})
else: else:
skip_special_tokens = self.type != "ctc" skip_special_tokens = self.type != "ctc"
text = self.tokenizer.decode(items, skip_special_tokens=skip_special_tokens) text = self.tokenizer.decode(items, skip_special_tokens=skip_special_tokens)
......
...@@ -201,8 +201,9 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel ...@@ -201,8 +201,9 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
@require_torch @require_torch
@require_pyctcdecode @require_pyctcdecode
def test_large_model_pt_with_lm(self): def test_large_model_pt_with_lm(self):
dataset = load_dataset("Narsil/asr_dummy") dataset = load_dataset("Narsil/asr_dummy", streaming=True)
filename = dataset["test"][3]["file"] third_item = next(iter(dataset["test"].skip(3)))
filename = third_item["file"]
speech_recognizer = pipeline( speech_recognizer = pipeline(
task="automatic-speech-recognition", task="automatic-speech-recognition",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment