Unverified Commit 279008ad authored by Connor Henderson's avatar Connor Henderson Committed by GitHub
Browse files

fix: Change is_last chunk calc and add conditional break in chunk_iter (#21612)

* fix: Change is_last chunk calc and add conditional break

* format fix

* account for 0 and full stride_rights, add comment

* add new test

* make style

* update slow whisper asr test timestamps

* use nested_simplify on output and round timestamp to hundreths place
parent 4446b6b0
...@@ -56,14 +56,15 @@ def rescale_stride(stride, ratio): ...@@ -56,14 +56,15 @@ def rescale_stride(stride, ratio):
def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right, rescale=True, dtype=None): def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right, rescale=True, dtype=None):
inputs_len = inputs.shape[0] inputs_len = inputs.shape[0]
step = chunk_len - stride_left - stride_right step = chunk_len - stride_left - stride_right
for i in range(0, inputs_len, step): for chunk_start_idx in range(0, inputs_len, step):
# add start and end paddings to the chunk chunk_end_idx = chunk_start_idx + chunk_len
chunk = inputs[i : i + chunk_len] chunk = inputs[chunk_start_idx:chunk_end_idx]
processed = feature_extractor(chunk, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt") processed = feature_extractor(chunk, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt")
if dtype is not None: if dtype is not None:
processed = processed.to(dtype=dtype) processed = processed.to(dtype=dtype)
_stride_left = 0 if i == 0 else stride_left _stride_left = 0 if chunk_start_idx == 0 else stride_left
is_last = i + step + stride_left >= inputs_len # all right strides must be full, otherwise it is the last item
is_last = chunk_end_idx > inputs_len if stride_right > 0 else chunk_end_idx >= inputs_len
_stride_right = 0 if is_last else stride_right _stride_right = 0 if is_last else stride_right
chunk_len = chunk.shape[0] chunk_len = chunk.shape[0]
...@@ -77,6 +78,8 @@ def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right, ...@@ -77,6 +78,8 @@ def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right,
stride = rescale_stride([stride], ratio)[0] stride = rescale_stride([stride], ratio)[0]
if chunk.shape[0] > _stride_left: if chunk.shape[0] > _stride_left:
yield {"is_last": is_last, "stride": stride, **processed} yield {"is_last": is_last, "stride": stride, **processed}
if is_last:
break
def _find_timestamp_sequence(sequences, tokenizer, feature_extractor, max_source_positions): def _find_timestamp_sequence(sequences, tokenizer, feature_extractor, max_source_positions):
......
...@@ -526,7 +526,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel ...@@ -526,7 +526,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
output = pipe(array, chunk_length_s=10) output = pipe(array, chunk_length_s=10)
self.assertDictEqual( self.assertDictEqual(
output, nested_simplify(output),
{ {
"chunks": [ "chunks": [
{"text": " A man said to the universe, Sir, I exist.", "timestamp": (0.0, 5.5)}, {"text": " A man said to the universe, Sir, I exist.", "timestamp": (0.0, 5.5)},
...@@ -548,11 +548,11 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel ...@@ -548,11 +548,11 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
}, },
{ {
"text": " the thousands of spectators, retrievality is not worth thinking about.", "text": " the thousands of spectators, retrievality is not worth thinking about.",
"timestamp": (19.6, 24.98), "timestamp": (19.6, 26.66),
}, },
{ {
"text": " His instant panic was followed by a small, sharp blow high on his chest.", "text": " His instant panic was followed by a small, sharp blow high on his chest.",
"timestamp": (24.98, 30.98), "timestamp": (26.66, 31.06),
}, },
], ],
"text": ( "text": (
...@@ -1110,6 +1110,11 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel ...@@ -1110,6 +1110,11 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
self.assertEqual([o["stride"] for o in outs], [(90, 0, 0), (30, 20, 0)]) self.assertEqual([o["stride"] for o in outs], [(90, 0, 0), (30, 20, 0)])
self.assertEqual([o["input_values"].shape for o in outs], [(1, 90), (1, 30)]) self.assertEqual([o["input_values"].shape for o in outs], [(1, 90), (1, 30)])
outs = list(chunk_iter(inputs, feature_extractor, 36, 6, 6, ratio))
self.assertEqual(len(outs), 4)
self.assertEqual([o["stride"] for o in outs], [(36, 0, 6), (36, 6, 6), (36, 6, 6), (28, 6, 0)])
self.assertEqual([o["input_values"].shape for o in outs], [(1, 36), (1, 36), (1, 36), (1, 28)])
inputs = torch.LongTensor([i % 2 for i in range(100)]) inputs = torch.LongTensor([i % 2 for i in range(100)])
input_values = feature_extractor(inputs, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt")[ input_values = feature_extractor(inputs, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt")[
"input_values" "input_values"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment