Unverified Commit 57699496 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix some tests using `"common_voice"` (#27147)



* Use mozilla-foundation/common_voice_11_0

* Update expected values

* Update expected values

* For test_word_time_stamp_integration

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 691fd8fd
...@@ -97,7 +97,7 @@ def _test_wav2vec2_with_lm_invalid_pool(in_queue, out_queue, timeout): ...@@ -97,7 +97,7 @@ def _test_wav2vec2_with_lm_invalid_pool(in_queue, out_queue, timeout):
try: try:
_ = in_queue.get(timeout=timeout) _ = in_queue.get(timeout=timeout)
ds = load_dataset("common_voice", "es", split="test", streaming=True) ds = load_dataset("mozilla-foundation/common_voice_11_0", "es", split="test", streaming=True)
sample = next(iter(ds)) sample = next(iter(ds))
resampled_audio = torchaudio.functional.resample( resampled_audio = torchaudio.functional.resample(
...@@ -119,7 +119,7 @@ def _test_wav2vec2_with_lm_invalid_pool(in_queue, out_queue, timeout): ...@@ -119,7 +119,7 @@ def _test_wav2vec2_with_lm_invalid_pool(in_queue, out_queue, timeout):
transcription = processor.batch_decode(logits.cpu().numpy(), pool).text transcription = processor.batch_decode(logits.cpu().numpy(), pool).text
unittest.TestCase().assertIn("Falling back to sequential decoding.", cl.out) unittest.TestCase().assertIn("Falling back to sequential decoding.", cl.out)
unittest.TestCase().assertEqual(transcription[0], "bien y qué regalo vas a abrir primero") unittest.TestCase().assertEqual(transcription[0], "habitan aguas poco profundas y rocosas")
# force batch_decode to internally create a spawn pool, which should trigger a warning if different than fork # force batch_decode to internally create a spawn pool, which should trigger a warning if different than fork
multiprocessing.set_start_method("spawn", force=True) multiprocessing.set_start_method("spawn", force=True)
...@@ -127,7 +127,7 @@ def _test_wav2vec2_with_lm_invalid_pool(in_queue, out_queue, timeout): ...@@ -127,7 +127,7 @@ def _test_wav2vec2_with_lm_invalid_pool(in_queue, out_queue, timeout):
transcription = processor.batch_decode(logits.cpu().numpy()).text transcription = processor.batch_decode(logits.cpu().numpy()).text
unittest.TestCase().assertIn("Falling back to sequential decoding.", cl.out) unittest.TestCase().assertIn("Falling back to sequential decoding.", cl.out)
unittest.TestCase().assertEqual(transcription[0], "bien y qué regalo vas a abrir primero") unittest.TestCase().assertEqual(transcription[0], "habitan aguas poco profundas y rocosas")
except Exception: except Exception:
error = f"{traceback.format_exc()}" error = f"{traceback.format_exc()}"
...@@ -1833,7 +1833,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase): ...@@ -1833,7 +1833,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
@require_pyctcdecode @require_pyctcdecode
@require_torchaudio @require_torchaudio
def test_wav2vec2_with_lm(self): def test_wav2vec2_with_lm(self):
ds = load_dataset("common_voice", "es", split="test", streaming=True) ds = load_dataset("mozilla-foundation/common_voice_11_0", "es", split="test", streaming=True)
sample = next(iter(ds)) sample = next(iter(ds))
resampled_audio = torchaudio.functional.resample( resampled_audio = torchaudio.functional.resample(
...@@ -1852,12 +1852,12 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase): ...@@ -1852,12 +1852,12 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
transcription = processor.batch_decode(logits.cpu().numpy()).text transcription = processor.batch_decode(logits.cpu().numpy()).text
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero") self.assertEqual(transcription[0], "habitan aguas poco profundas y rocosas")
@require_pyctcdecode @require_pyctcdecode
@require_torchaudio @require_torchaudio
def test_wav2vec2_with_lm_pool(self): def test_wav2vec2_with_lm_pool(self):
ds = load_dataset("common_voice", "es", split="test", streaming=True) ds = load_dataset("mozilla-foundation/common_voice_11_0", "es", split="test", streaming=True)
sample = next(iter(ds)) sample = next(iter(ds))
resampled_audio = torchaudio.functional.resample( resampled_audio = torchaudio.functional.resample(
...@@ -1878,7 +1878,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase): ...@@ -1878,7 +1878,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
with multiprocessing.get_context("fork").Pool(2) as pool: with multiprocessing.get_context("fork").Pool(2) as pool:
transcription = processor.batch_decode(logits.cpu().numpy(), pool).text transcription = processor.batch_decode(logits.cpu().numpy(), pool).text
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero") self.assertEqual(transcription[0], "habitan aguas poco profundas y rocosas")
# user-managed pool + num_processes should trigger a warning # user-managed pool + num_processes should trigger a warning
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl, multiprocessing.get_context("fork").Pool( with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl, multiprocessing.get_context("fork").Pool(
...@@ -1889,7 +1889,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase): ...@@ -1889,7 +1889,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
self.assertIn("num_process", cl.out) self.assertIn("num_process", cl.out)
self.assertIn("it will be ignored", cl.out) self.assertIn("it will be ignored", cl.out)
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero") self.assertEqual(transcription[0], "habitan aguas poco profundas y rocosas")
@require_pyctcdecode @require_pyctcdecode
@require_torchaudio @require_torchaudio
...@@ -1957,7 +1957,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase): ...@@ -1957,7 +1957,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
LANG_MAP = {"it": "ita", "es": "spa", "fr": "fra", "en": "eng"} LANG_MAP = {"it": "ita", "es": "spa", "fr": "fra", "en": "eng"}
def run_model(lang): def run_model(lang):
ds = load_dataset("common_voice", lang, split="test", streaming=True) ds = load_dataset("mozilla-foundation/common_voice_11_0", lang, split="test", streaming=True)
sample = next(iter(ds)) sample = next(iter(ds))
wav2vec2_lang = LANG_MAP[lang] wav2vec2_lang = LANG_MAP[lang]
...@@ -1982,10 +1982,10 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase): ...@@ -1982,10 +1982,10 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
return transcription return transcription
TRANSCRIPTIONS = { TRANSCRIPTIONS = {
"it": "mi hanno fatto un'offerta che non potevo proprio rifiutare", "it": "il libro ha suscitato molte polemiche a causa dei suoi contenuti",
"es": "bien y qué regalo vas a abrir primero", "es": "habitan aguas poco profundas y rocosas",
"fr": "un vrai travail intéressant va enfin être mené sur ce sujet", "fr": "ce dernier est volé tout au long de l'histoire romaine",
"en": "twas the time of day and olof spen slept during the summer", "en": "joe keton disapproved of films and buster also had reservations about the media",
} }
for lang in LANG_MAP.keys(): for lang in LANG_MAP.keys():
......
...@@ -434,7 +434,7 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase): ...@@ -434,7 +434,7 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
def test_word_time_stamp_integration(self): def test_word_time_stamp_integration(self):
import torch import torch
ds = load_dataset("common_voice", "en", split="train", streaming=True) ds = load_dataset("mozilla-foundation/common_voice_11_0", "en", split="train", streaming=True)
ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16_000)) ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16_000))
ds_iter = iter(ds) ds_iter = iter(ds)
sample = next(ds_iter) sample = next(ds_iter)
...@@ -442,7 +442,6 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase): ...@@ -442,7 +442,6 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
processor = AutoProcessor.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm") processor = AutoProcessor.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm")
model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm") model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm")
# compare to filename `common_voice_en_100038.mp3` of dataset viewer on https://huggingface.co/datasets/common_voice/viewer/en/train
input_values = processor(sample["audio"]["array"], return_tensors="pt").input_values input_values = processor(sample["audio"]["array"], return_tensors="pt").input_values
with torch.no_grad(): with torch.no_grad():
...@@ -461,6 +460,7 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase): ...@@ -461,6 +460,7 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
] ]
EXPECTED_TEXT = "WHY DOES MILISANDRA LOOK LIKE SHE WANTS TO CONSUME JOHN SNOW ON THE RIVER AT THE WALL" EXPECTED_TEXT = "WHY DOES MILISANDRA LOOK LIKE SHE WANTS TO CONSUME JOHN SNOW ON THE RIVER AT THE WALL"
EXPECTED_TEXT = "THE TRACK APPEARS ON THE COMPILATION ALBUM CRAFT FORKS"
# output words # output words
self.assertEqual(" ".join(self.get_from_offsets(word_time_stamps, "word")), EXPECTED_TEXT) self.assertEqual(" ".join(self.get_from_offsets(word_time_stamps, "word")), EXPECTED_TEXT)
...@@ -471,8 +471,8 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase): ...@@ -471,8 +471,8 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
end_times = torch.tensor(self.get_from_offsets(word_time_stamps, "end_time")) end_times = torch.tensor(self.get_from_offsets(word_time_stamps, "end_time"))
# fmt: off # fmt: off
expected_start_tensor = torch.tensor([1.4199, 1.6599, 2.2599, 3.0, 3.24, 3.5999, 3.7999, 4.0999, 4.26, 4.94, 5.28, 5.6599, 5.78, 5.94, 6.32, 6.5399, 6.6599]) expected_start_tensor = torch.tensor([0.6800, 0.8800, 1.1800, 1.8600, 1.9600, 2.1000, 3.0000, 3.5600, 3.9800])
expected_end_tensor = torch.tensor([1.5399, 1.8999, 2.9, 3.16, 3.5399, 3.72, 4.0199, 4.1799, 4.76, 5.1599, 5.5599, 5.6999, 5.86, 6.1999, 6.38, 6.6199, 6.94]) expected_end_tensor = torch.tensor([0.7800, 1.1000, 1.6600, 1.9200, 2.0400, 2.8000, 3.3000, 3.8800, 4.2800])
# fmt: on # fmt: on
self.assertTrue(torch.allclose(start_times, expected_start_tensor, atol=0.01)) self.assertTrue(torch.allclose(start_times, expected_start_tensor, atol=0.01))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment