"examples/vscode:/vscode.git/clone" did not exist on "18643ff29a946c4d21b67d288e6da98bb0c1b169"
Unverified Commit 7f5d644e authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[pipeline] fix padding for 1-d tensors (#31776)



* [pipeline] fix padding for 1-d tensors

* add test

* make style

* Update tests/pipelines/test_pipelines_automatic_speech_recognition.py
Co-authored-by: default avatarKamil Akesbi <45195979+kamilakesbi@users.noreply.github.com>

* Update tests/pipelines/test_pipelines_automatic_speech_recognition.py

---------
Co-authored-by: default avatarKamil Akesbi <45195979+kamilakesbi@users.noreply.github.com>
parent 3fbaaaa6
...@@ -90,6 +90,9 @@ def _pad(items, key, padding_value, padding_side): ...@@ -90,6 +90,9 @@ def _pad(items, key, padding_value, padding_side):
# Others include `attention_mask` etc... # Others include `attention_mask` etc...
shape = items[0][key].shape shape = items[0][key].shape
dim = len(shape) dim = len(shape)
if dim == 1:
# We have a list of 1-dim torch tensors, which can be stacked without padding
return torch.cat([item[key] for item in items], dim=0)
if key in ["pixel_values", "image"]: if key in ["pixel_values", "image"]:
# This is probable image so padding shouldn't be necessary # This is probable image so padding shouldn't be necessary
# B, C, H, W # B, C, H, W
......
...@@ -549,6 +549,23 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase): ...@@ -549,6 +549,23 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
output = speech_recognizer([filename], chunk_length_s=5, batch_size=4) output = speech_recognizer([filename], chunk_length_s=5, batch_size=4)
self.assertEqual(output, [{"text": " A man said to the universe, Sir, I exist."}]) self.assertEqual(output, [{"text": " A man said to the universe, Sir, I exist."}])
@require_torch
@slow
def test_torch_whisper_batched(self):
speech_recognizer = pipeline(
task="automatic-speech-recognition",
model="openai/whisper-tiny",
framework="pt",
)
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation[:2]")
EXPECTED_OUTPUT = [
{"text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel."},
{"text": " Nor is Mr. Quilters' manner less interesting than his matter."},
]
output = speech_recognizer(ds["audio"], batch_size=2)
self.assertEqual(output, EXPECTED_OUTPUT)
@slow @slow
def test_find_longest_common_subsequence(self): def test_find_longest_common_subsequence(self):
max_source_positions = 1500 max_source_positions = 1500
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment