Unverified Commit a192f61e authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Change the chunk_iter function to handle (#16730)

* Change the chunk_iter function to handle

the subtle cases where the last chunk gets ignored since all the
data is in the `left_strided` data.

We need to remove the right striding on the previous item.

* Remove commented line.
parent cc034f72
...@@ -58,9 +58,8 @@ def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right): ...@@ -58,9 +58,8 @@ def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right):
chunk = inputs[i : i + chunk_len] chunk = inputs[i : i + chunk_len]
processed = feature_extractor(chunk, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt") processed = feature_extractor(chunk, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt")
_stride_left = 0 if i == 0 else stride_left _stride_left = 0 if i == 0 else stride_left
is_last = i + step >= inputs_len is_last = i + step + stride_left >= inputs_len
_stride_right = 0 if is_last else stride_right _stride_right = 0 if is_last else stride_right
if chunk.shape[0] > _stride_left: if chunk.shape[0] > _stride_left:
yield {"is_last": is_last, "stride": (chunk.shape[0], _stride_left, _stride_right), **processed} yield {"is_last": is_last, "stride": (chunk.shape[0], _stride_left, _stride_right), **processed}
......
...@@ -653,6 +653,17 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel ...@@ -653,6 +653,17 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
self.assertEqual([o["input_values"].shape for o in outs], [(1, 80), (1, 20)]) self.assertEqual([o["input_values"].shape for o in outs], [(1, 80), (1, 20)])
self.assertEqual([o["is_last"] for o in outs], [False, True]) self.assertEqual([o["is_last"] for o in outs], [False, True])
# one chunk since first is also last, because it contains only data
# in the right strided part we just mark that part as non stride
# This test is specifically crafted to trigger a bug if next chunk
# would be ignored by the fact that all the data would be
# contained in the strided left data.
outs = list(chunk_iter(inputs, feature_extractor, 105, 5, 5))
self.assertEqual(len(outs), 1)
self.assertEqual([o["stride"] for o in outs], [(100, 0, 0)])
self.assertEqual([o["input_values"].shape for o in outs], [(1, 100)])
self.assertEqual([o["is_last"] for o in outs], [True])
@require_torch @require_torch
def test_chunk_iterator_stride(self): def test_chunk_iterator_stride(self):
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h") feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment