Unverified Commit ea118ae2 authored by Antonio Carlos Falcão Petri's avatar Antonio Carlos Falcão Petri Committed by GitHub
Browse files

Fix bug in Wav2Vec2's GPU tests (#19803)

* Fix tests when running on GPU

* Fix tests that require mp.set_start_method
parent f1e42bc5
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
import inspect import inspect
import math import math
import multiprocessing import multiprocessing
import os
import traceback
import unittest import unittest
import numpy as np import numpy as np
...@@ -31,6 +33,7 @@ from transformers.testing_utils import ( ...@@ -31,6 +33,7 @@ from transformers.testing_utils import (
require_librosa, require_librosa,
require_pyctcdecode, require_pyctcdecode,
require_soundfile, require_soundfile,
run_test_in_subprocess,
slow, slow,
) )
...@@ -54,6 +57,7 @@ if is_flax_available(): ...@@ -54,6 +57,7 @@ if is_flax_available():
if is_pyctcdecode_available(): if is_pyctcdecode_available():
import pyctcdecode.decoder
from transformers import Wav2Vec2ProcessorWithLM from transformers import Wav2Vec2ProcessorWithLM
from transformers.models.wav2vec2_with_lm import processing_wav2vec2_with_lm from transformers.models.wav2vec2_with_lm import processing_wav2vec2_with_lm
...@@ -62,6 +66,46 @@ if is_librosa_available(): ...@@ -62,6 +66,46 @@ if is_librosa_available():
import librosa import librosa
def _test_wav2vec2_with_lm_invalid_pool(in_queue, out_queue, timeout):
error = None
try:
_ = in_queue.get(timeout=timeout)
ds = load_dataset("common_voice", "es", split="test", streaming=True)
sample = next(iter(ds))
resampled_audio = librosa.resample(sample["audio"]["array"], 48_000, 16_000)
model = FlaxWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
input_values = processor(resampled_audio, return_tensors="np").input_values
logits = model(input_values).logits
# use a spawn pool, which should trigger a warning if different than fork
with CaptureLogger(pyctcdecode.decoder.logger) as cl, multiprocessing.get_context("spawn").Pool(1) as pool:
transcription = processor.batch_decode(np.array(logits), pool).text
unittest.TestCase().assertIn("Falling back to sequential decoding.", cl.out)
unittest.TestCase().assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
# force batch_decode to internally create a spawn pool, which should trigger a warning if different than fork
multiprocessing.set_start_method("spawn", force=True)
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl:
transcription = processor.batch_decode(np.array(logits)).text
unittest.TestCase().assertIn("Falling back to sequential decoding.", cl.out)
unittest.TestCase().assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
except Exception:
error = f"{traceback.format_exc()}"
results = {"error": error}
out_queue.put(results, timeout=timeout)
out_queue.join()
class FlaxWav2Vec2ModelTester: class FlaxWav2Vec2ModelTester:
def __init__( def __init__(
self, self,
...@@ -575,7 +619,7 @@ class FlaxWav2Vec2ModelIntegrationTest(unittest.TestCase): ...@@ -575,7 +619,7 @@ class FlaxWav2Vec2ModelIntegrationTest(unittest.TestCase):
# test user-managed pool # test user-managed pool
with multiprocessing.get_context("fork").Pool(2) as pool: with multiprocessing.get_context("fork").Pool(2) as pool:
transcription = processor.batch_decode(logits.numpy(), pool).text transcription = processor.batch_decode(np.array(logits), pool).text
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero") self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
...@@ -583,7 +627,7 @@ class FlaxWav2Vec2ModelIntegrationTest(unittest.TestCase): ...@@ -583,7 +627,7 @@ class FlaxWav2Vec2ModelIntegrationTest(unittest.TestCase):
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl, multiprocessing.get_context("fork").Pool( with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl, multiprocessing.get_context("fork").Pool(
2 2
) as pool: ) as pool:
transcription = processor.batch_decode(logits.numpy(), pool, num_processes=2).text transcription = processor.batch_decode(np.array(logits), pool, num_processes=2).text
self.assertIn("num_process", cl.out) self.assertIn("num_process", cl.out)
self.assertIn("it will be ignored", cl.out) self.assertIn("it will be ignored", cl.out)
...@@ -593,22 +637,7 @@ class FlaxWav2Vec2ModelIntegrationTest(unittest.TestCase): ...@@ -593,22 +637,7 @@ class FlaxWav2Vec2ModelIntegrationTest(unittest.TestCase):
@require_pyctcdecode @require_pyctcdecode
@require_librosa @require_librosa
def test_wav2vec2_with_lm_invalid_pool(self): def test_wav2vec2_with_lm_invalid_pool(self):
ds = load_dataset("common_voice", "es", split="test", streaming=True) timeout = os.environ.get("PYTEST_TIMEOUT", 600)
sample = next(iter(ds)) run_test_in_subprocess(
test_case=self, target_func=_test_wav2vec2_with_lm_invalid_pool, inputs=None, timeout=timeout
resampled_audio = librosa.resample(sample["audio"]["array"], 48_000, 16_000) )
model = FlaxWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
input_values = processor(resampled_audio, return_tensors="np").input_values
logits = model(input_values).logits
# change default start method, which should trigger a warning if different than fork
multiprocessing.set_start_method("spawn")
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl:
transcription = processor.batch_decode(logits.numpy()).text
self.assertIn("Falling back to sequential decoding.", cl.out)
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
...@@ -19,6 +19,8 @@ import glob ...@@ -19,6 +19,8 @@ import glob
import inspect import inspect
import math import math
import multiprocessing import multiprocessing
import os
import traceback
import unittest import unittest
import numpy as np import numpy as np
...@@ -27,7 +29,15 @@ from datasets import load_dataset ...@@ -27,7 +29,15 @@ from datasets import load_dataset
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from transformers import Wav2Vec2Config, is_tf_available from transformers import Wav2Vec2Config, is_tf_available
from transformers.testing_utils import CaptureLogger, is_flaky, require_librosa, require_pyctcdecode, require_tf, slow from transformers.testing_utils import (
CaptureLogger,
is_flaky,
require_librosa,
require_pyctcdecode,
require_tf,
run_test_in_subprocess,
slow,
)
from transformers.utils import is_librosa_available, is_pyctcdecode_available from transformers.utils import is_librosa_available, is_pyctcdecode_available
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
...@@ -42,6 +52,7 @@ if is_tf_available(): ...@@ -42,6 +52,7 @@ if is_tf_available():
if is_pyctcdecode_available(): if is_pyctcdecode_available():
import pyctcdecode.decoder
from transformers import Wav2Vec2ProcessorWithLM from transformers import Wav2Vec2ProcessorWithLM
from transformers.models.wav2vec2_with_lm import processing_wav2vec2_with_lm from transformers.models.wav2vec2_with_lm import processing_wav2vec2_with_lm
...@@ -50,6 +61,45 @@ if is_librosa_available(): ...@@ -50,6 +61,45 @@ if is_librosa_available():
import librosa import librosa
def _test_wav2vec2_with_lm_invalid_pool(in_queue, out_queue, timeout):
error = None
try:
_ = in_queue.get(timeout=timeout)
downloaded_folder = snapshot_download("patrickvonplaten/common_voice_es_sample")
file_path = glob.glob(downloaded_folder + "/*")[0]
sample = librosa.load(file_path, sr=16_000)[0]
model = TFWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
input_values = processor(sample, return_tensors="tf").input_values
logits = model(input_values).logits
# use a spawn pool, which should trigger a warning if different than fork
with CaptureLogger(pyctcdecode.decoder.logger) as cl, multiprocessing.get_context("spawn").Pool(1) as pool:
transcription = processor.batch_decode(logits.numpy(), pool).text
unittest.TestCase().assertIn("Falling back to sequential decoding.", cl.out)
unittest.TestCase().assertEqual(transcription[0], "el libro ha sido escrito por cervantes")
# force batch_decode to internally create a spawn pool, which should trigger a warning if different than fork
multiprocessing.set_start_method("spawn", force=True)
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl:
transcription = processor.batch_decode(logits.numpy()).text
unittest.TestCase().assertIn("Falling back to sequential decoding.", cl.out)
unittest.TestCase().assertEqual(transcription[0], "el libro ha sido escrito por cervantes")
except Exception:
error = f"{traceback.format_exc()}"
results = {"error": error}
out_queue.put(results, timeout=timeout)
out_queue.join()
@require_tf @require_tf
class TFWav2Vec2ModelTester: class TFWav2Vec2ModelTester:
def __init__( def __init__(
...@@ -627,21 +677,7 @@ class TFWav2Vec2ModelIntegrationTest(unittest.TestCase): ...@@ -627,21 +677,7 @@ class TFWav2Vec2ModelIntegrationTest(unittest.TestCase):
@require_pyctcdecode @require_pyctcdecode
@require_librosa @require_librosa
def test_wav2vec2_with_lm_invalid_pool(self): def test_wav2vec2_with_lm_invalid_pool(self):
downloaded_folder = snapshot_download("patrickvonplaten/common_voice_es_sample") timeout = os.environ.get("PYTEST_TIMEOUT", 600)
file_path = glob.glob(downloaded_folder + "/*")[0] run_test_in_subprocess(
sample = librosa.load(file_path, sr=16_000)[0] test_case=self, target_func=_test_wav2vec2_with_lm_invalid_pool, inputs=None, timeout=timeout
)
model = TFWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
input_values = processor(sample, return_tensors="tf").input_values
logits = model(input_values).logits
# change default start method, which should trigger a warning if different than fork
multiprocessing.set_start_method("spawn")
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl:
transcription = processor.batch_decode(logits.numpy()).text
self.assertIn("Falling back to sequential decoding.", cl.out)
self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes")
...@@ -19,6 +19,7 @@ import multiprocessing ...@@ -19,6 +19,7 @@ import multiprocessing
import os import os
import pickle import pickle
import tempfile import tempfile
import traceback
import unittest import unittest
import numpy as np import numpy as np
...@@ -34,6 +35,7 @@ from transformers.testing_utils import ( ...@@ -34,6 +35,7 @@ from transformers.testing_utils import (
require_soundfile, require_soundfile,
require_torch, require_torch,
require_torchaudio, require_torchaudio,
run_test_in_subprocess,
slow, slow,
torch_device, torch_device,
) )
...@@ -75,6 +77,7 @@ if is_torchaudio_available(): ...@@ -75,6 +77,7 @@ if is_torchaudio_available():
if is_pyctcdecode_available(): if is_pyctcdecode_available():
import pyctcdecode.decoder
from transformers import Wav2Vec2ProcessorWithLM from transformers import Wav2Vec2ProcessorWithLM
from transformers.models.wav2vec2_with_lm import processing_wav2vec2_with_lm from transformers.models.wav2vec2_with_lm import processing_wav2vec2_with_lm
...@@ -83,6 +86,51 @@ if is_torch_fx_available(): ...@@ -83,6 +86,51 @@ if is_torch_fx_available():
from transformers.utils.fx import symbolic_trace from transformers.utils.fx import symbolic_trace
def _test_wav2vec2_with_lm_invalid_pool(in_queue, out_queue, timeout):
error = None
try:
_ = in_queue.get(timeout=timeout)
ds = load_dataset("common_voice", "es", split="test", streaming=True)
sample = next(iter(ds))
resampled_audio = torchaudio.functional.resample(
torch.tensor(sample["audio"]["array"]), 48_000, 16_000
).numpy()
model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm").to(
torch_device
)
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
input_values = processor(resampled_audio, return_tensors="pt").input_values
with torch.no_grad():
logits = model(input_values.to(torch_device)).logits
# use a spawn pool, which should trigger a warning if different than fork
with CaptureLogger(pyctcdecode.decoder.logger) as cl, multiprocessing.get_context("spawn").Pool(1) as pool:
transcription = processor.batch_decode(logits.cpu().numpy(), pool).text
unittest.TestCase().assertIn("Falling back to sequential decoding.", cl.out)
unittest.TestCase().assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
# force batch_decode to internally create a spawn pool, which should trigger a warning if different than fork
multiprocessing.set_start_method("spawn", force=True)
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl:
transcription = processor.batch_decode(logits.cpu().numpy()).text
unittest.TestCase().assertIn("Falling back to sequential decoding.", cl.out)
unittest.TestCase().assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
except Exception:
error = f"{traceback.format_exc()}"
results = {"error": error}
out_queue.put(results, timeout=timeout)
out_queue.join()
class Wav2Vec2ModelTester: class Wav2Vec2ModelTester:
def __init__( def __init__(
self, self,
...@@ -1636,7 +1684,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase): ...@@ -1636,7 +1684,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
# test user-managed pool # test user-managed pool
with multiprocessing.get_context("fork").Pool(2) as pool: with multiprocessing.get_context("fork").Pool(2) as pool:
transcription = processor.batch_decode(logits.numpy(), pool).text transcription = processor.batch_decode(logits.cpu().numpy(), pool).text
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero") self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
...@@ -1644,7 +1692,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase): ...@@ -1644,7 +1692,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl, multiprocessing.get_context("fork").Pool( with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl, multiprocessing.get_context("fork").Pool(
2 2
) as pool: ) as pool:
transcription = processor.batch_decode(logits.numpy(), pool, num_processes=2).text transcription = processor.batch_decode(logits.cpu().numpy(), pool, num_processes=2).text
self.assertIn("num_process", cl.out) self.assertIn("num_process", cl.out)
self.assertIn("it will be ignored", cl.out) self.assertIn("it will be ignored", cl.out)
...@@ -1654,30 +1702,10 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase): ...@@ -1654,30 +1702,10 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
@require_pyctcdecode @require_pyctcdecode
@require_torchaudio @require_torchaudio
def test_wav2vec2_with_lm_invalid_pool(self): def test_wav2vec2_with_lm_invalid_pool(self):
ds = load_dataset("common_voice", "es", split="test", streaming=True) timeout = os.environ.get("PYTEST_TIMEOUT", 600)
sample = next(iter(ds)) run_test_in_subprocess(
test_case=self, target_func=_test_wav2vec2_with_lm_invalid_pool, inputs=None, timeout=timeout
resampled_audio = torchaudio.functional.resample(
torch.tensor(sample["audio"]["array"]), 48_000, 16_000
).numpy()
model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm").to(
torch_device
) )
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
input_values = processor(resampled_audio, return_tensors="pt").input_values
with torch.no_grad():
logits = model(input_values.to(torch_device)).logits
# change default start method, which should trigger a warning if different than fork
multiprocessing.set_start_method("spawn")
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl:
transcription = processor.batch_decode(logits.numpy()).text
self.assertIn("Falling back to sequential decoding.", cl.out)
self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes")
def test_inference_diarization(self): def test_inference_diarization(self):
model = Wav2Vec2ForAudioFrameClassification.from_pretrained("anton-l/wav2vec2-base-superb-sd").to(torch_device) model = Wav2Vec2ForAudioFrameClassification.from_pretrained("anton-l/wav2vec2-base-superb-sd").to(torch_device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment