Unverified Commit 34368421 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Run some TF Whisper tests in subprocesses to avoid GPU OOM (#19772)



* Run some TF Whisper tests in subprocesses to avoid GPU OOM
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent e0b825a8
...@@ -17,6 +17,7 @@ import contextlib ...@@ -17,6 +17,7 @@ import contextlib
import functools import functools
import inspect import inspect
import logging import logging
import multiprocessing
import os import os
import re import re
import shlex import shlex
...@@ -1672,3 +1673,43 @@ def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None): ...@@ -1672,3 +1673,43 @@ def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None):
return wrapper return wrapper
return decorator return decorator
def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=600):
"""
To run a test in a subprocess. In particular, this can avoid (GPU) memory issue.
Args:
test_case (`unittest.TestCase`):
The test that will run `target_func`.
target_func (`Callable`):
The function implementing the actual testing logic.
inputs (`dict`, *optional*, defaults to `None`):
The inputs that will be passed to `target_func` through an (input) queue.
timeout (`int`, *optional*, defaults to 600):
The timeout (in seconds) that will be passed to the input and output queues.
"""
start_methohd = "spawn"
ctx = multiprocessing.get_context(start_methohd)
input_queue = ctx.Queue(1)
output_queue = ctx.JoinableQueue(1)
# We can't send `unittest.TestCase` to the child, otherwise we get issues regarding pickle.
input_queue.put(inputs, timeout=timeout)
process = ctx.Process(target=target_func, args=(input_queue, output_queue, timeout))
process.start()
# Kill the child process if we can't get outputs from it in time: otherwise, the hanging subprocess prevents
# the test to exit properly.
try:
results = output_queue.get(timeout=timeout)
output_queue.task_done()
except Exception as e:
process.terminate()
test_case.fail(e)
process.join(timeout=timeout)
if results["error"] is not None:
test_case.fail(f'{results["error"]}')
...@@ -15,13 +15,15 @@ ...@@ -15,13 +15,15 @@
""" Testing suite for the TensorFlow Whisper model. """ """ Testing suite for the TensorFlow Whisper model. """
import inspect import inspect
import os
import tempfile import tempfile
import traceback
import unittest import unittest
import numpy as np import numpy as np
from transformers import WhisperConfig, WhisperFeatureExtractor, WhisperProcessor from transformers import WhisperConfig, WhisperFeatureExtractor, WhisperProcessor
from transformers.testing_utils import is_tf_available, require_tf, require_tokenizers, slow from transformers.testing_utils import is_tf_available, require_tf, require_tokenizers, run_test_in_subprocess, slow
from transformers.utils import cached_property from transformers.utils import cached_property
from transformers.utils.import_utils import is_datasets_available from transformers.utils.import_utils import is_datasets_available
...@@ -626,6 +628,184 @@ class TFWhisperModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -626,6 +628,184 @@ class TFWhisperModelTest(TFModelTesterMixin, unittest.TestCase):
self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids)) self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids))
def _load_datasamples(num_samples):
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# automatic decoding with librispeech
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
return [x["array"] for x in speech_samples]
def _test_large_logits_librispeech(in_queue, out_queue, timeout):
error = None
try:
_ = in_queue.get(timeout=timeout)
set_seed(0)
model = TFWhisperModel.from_pretrained("openai/whisper-large")
input_speech = _load_datasamples(1)
processor = WhisperProcessor.from_pretrained("openai/whisper-large")
processed_inputs = processor(audio=input_speech, text="This part of the speech", return_tensors="tf")
input_features = processed_inputs.input_features
labels = processed_inputs.labels
logits = model(
input_features,
decoder_input_ids=labels,
output_hidden_states=False,
output_attentions=False,
use_cache=False,
)
logits = logits.last_hidden_state @ tf.transpose(model.model.decoder.embed_tokens.weights[0])
# fmt: off
EXPECTED_LOGITS = tf.convert_to_tensor(
[
2.1382, 0.9381, 4.4671, 3.5589, 2.4022, 3.8576, -0.6521, 2.5472,
1.8301, 1.9957, 2.3432, 1.4678, 0.5459, 2.2597, 1.5179, 2.5357,
1.1624, 0.6194, 1.0757, 1.8259, 2.4076, 1.6601, 2.3503, 1.3376,
1.9891, 1.8635, 3.8931, 5.3699, 4.4772, 3.9184
]
)
# fmt: on
unittest.TestCase().assertTrue(np.allclose(logits[0, 0, :30], EXPECTED_LOGITS, atol=1e-4))
except Exception:
error = f"{traceback.format_exc()}"
results = {"error": error}
out_queue.put(results, timeout=timeout)
out_queue.join()
def _test_large_generation(in_queue, out_queue, timeout):
error = None
try:
_ = in_queue.get(timeout=timeout)
set_seed(0)
processor = WhisperProcessor.from_pretrained("openai/whisper-large")
model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-large")
input_speech = _load_datasamples(1)
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe")
generated_ids = model.generate(input_features, do_sample=False, max_length=20)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
EXPECTED_TRANSCRIPT = " Mr. Quilter is the apostle of the middle classes and we are glad"
unittest.TestCase().assertEqual(transcript, EXPECTED_TRANSCRIPT)
except Exception:
error = f"{traceback.format_exc()}"
results = {"error": error}
out_queue.put(results, timeout=timeout)
out_queue.join()
def _test_large_generation_multilingual(in_queue, out_queue, timeout):
error = None
try:
_ = in_queue.get(timeout=timeout)
set_seed(0)
processor = WhisperProcessor.from_pretrained("openai/whisper-large")
model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-large")
ds = load_dataset("common_voice", "ja", split="test", streaming=True)
ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16_000))
input_speech = next(iter(ds))["audio"]["array"]
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="transcribe")
generated_ids = model.generate(input_features, do_sample=False, max_length=20)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
EXPECTED_TRANSCRIPT = "木村さんに電話を貸してもらいました"
unittest.TestCase().assertEqual(transcript, EXPECTED_TRANSCRIPT)
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe")
generated_ids = model.generate(
input_features,
do_sample=False,
max_length=20,
)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
EXPECTED_TRANSCRIPT = " Kimura san ni denwa wo kaite moraimashita"
unittest.TestCase().assertEqual(transcript, EXPECTED_TRANSCRIPT)
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="translate")
generated_ids = model.generate(input_features, do_sample=False, max_length=20)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
EXPECTED_TRANSCRIPT = " I borrowed a phone from Kimura san"
unittest.TestCase().assertEqual(transcript, EXPECTED_TRANSCRIPT)
except Exception:
error = f"{traceback.format_exc()}"
results = {"error": error}
out_queue.put(results, timeout=timeout)
out_queue.join()
def _test_large_batched_generation(in_queue, out_queue, timeout):
error = None
try:
_ = in_queue.get(timeout=timeout)
set_seed(0)
processor = WhisperProcessor.from_pretrained("openai/whisper-large")
model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-large")
input_speech = _load_datasamples(4)
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
generated_ids_1 = model.generate(input_features[0:2], max_length=20)
generated_ids_2 = model.generate(input_features[2:4], max_length=20)
generated_ids = np.concatenate([generated_ids_1, generated_ids_2])
# fmt: off
EXPECTED_LOGITS = tf.convert_to_tensor(
[
[50258, 50358, 50363, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404, 281],
[50258, 50358, 50363, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50257, 50257],
[50258, 50358, 50363, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256],
[50258, 50358, 50363, 634, 575, 12525, 22618, 1968, 6144, 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, 439, 11]
]
)
# fmt: on
unittest.TestCase().assertTrue(np.allclose(generated_ids, EXPECTED_LOGITS))
# fmt: off
EXPECTED_TRANSCRIPT = [
' Mr. Quilter is the apostle of the middle classes and we are glad to',
" Nor is Mr. Quilter's manner less interesting than his matter.",
" He tells us that at this festive season of the year, with Christmas and roast beef",
" He has grave doubts whether Sir Frederick Layton's work is really Greek after all,"
]
# fmt: on
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)
unittest.TestCase().assertListEqual(transcript, EXPECTED_TRANSCRIPT)
except Exception:
error = f"{traceback.format_exc()}"
results = {"error": error}
out_queue.put(results, timeout=timeout)
out_queue.join()
@require_tf @require_tf
@require_tokenizers @require_tokenizers
class TFWhisperModelIntegrationTests(unittest.TestCase): class TFWhisperModelIntegrationTests(unittest.TestCase):
...@@ -634,12 +814,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase): ...@@ -634,12 +814,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
return WhisperProcessor.from_pretrained("openai/whisper-base") return WhisperProcessor.from_pretrained("openai/whisper-base")
def _load_datasamples(self, num_samples): def _load_datasamples(self, num_samples):
return _load_datasamples(num_samples)
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# automatic decoding with librispeech
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
return [x["array"] for x in speech_samples]
@slow @slow
def test_tiny_logits_librispeech(self): def test_tiny_logits_librispeech(self):
...@@ -719,40 +894,11 @@ class TFWhisperModelIntegrationTests(unittest.TestCase): ...@@ -719,40 +894,11 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
@slow @slow
def test_large_logits_librispeech(self): def test_large_logits_librispeech(self):
set_seed(0) timeout = os.environ.get("PYTEST_TIMEOUT", 600)
run_test_in_subprocess(
model = TFWhisperModel.from_pretrained("openai/whisper-large") test_case=self, target_func=_test_large_logits_librispeech, inputs=None, timeout=timeout
input_speech = self._load_datasamples(1)
processor = WhisperProcessor.from_pretrained("openai/whisper-large")
processed_inputs = processor(audio=input_speech, text="This part of the speech", return_tensors="tf")
input_features = processed_inputs.input_features
labels = processed_inputs.labels
logits = model(
input_features,
decoder_input_ids=labels,
output_hidden_states=False,
output_attentions=False,
use_cache=False,
) )
logits = logits.last_hidden_state @ tf.transpose(model.model.decoder.embed_tokens.weights[0])
# fmt: off
EXPECTED_LOGITS = tf.convert_to_tensor(
[
2.1382, 0.9381, 4.4671, 3.5589, 2.4022, 3.8576, -0.6521, 2.5472,
1.8301, 1.9957, 2.3432, 1.4678, 0.5459, 2.2597, 1.5179, 2.5357,
1.1624, 0.6194, 1.0757, 1.8259, 2.4076, 1.6601, 2.3503, 1.3376,
1.9891, 1.8635, 3.8931, 5.3699, 4.4772, 3.9184
]
)
# fmt: on
self.assertTrue(np.allclose(logits[0, 0, :30], EXPECTED_LOGITS, atol=1e-4))
@slow @slow
def test_tiny_en_generation(self): def test_tiny_en_generation(self):
set_seed(0) set_seed(0)
...@@ -816,90 +962,22 @@ class TFWhisperModelIntegrationTests(unittest.TestCase): ...@@ -816,90 +962,22 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
@slow @slow
def test_large_generation(self): def test_large_generation(self):
set_seed(0) timeout = os.environ.get("PYTEST_TIMEOUT", 600)
processor = WhisperProcessor.from_pretrained("openai/whisper-large") run_test_in_subprocess(test_case=self, target_func=_test_large_generation, inputs=None, timeout=timeout)
model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-large")
input_speech = self._load_datasamples(1)
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe")
generated_ids = model.generate(input_features, do_sample=False, max_length=20)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
EXPECTED_TRANSCRIPT = " Mr. Quilter is the apostle of the middle classes and we are glad"
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
@slow @slow
def test_large_generation_multilingual(self): def test_large_generation_multilingual(self):
set_seed(0) timeout = os.environ.get("PYTEST_TIMEOUT", 600)
processor = WhisperProcessor.from_pretrained("openai/whisper-large") run_test_in_subprocess(
model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-large") test_case=self, target_func=_test_large_generation_multilingual, inputs=None, timeout=timeout
ds = load_dataset("common_voice", "ja", split="test", streaming=True)
ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16_000))
input_speech = next(iter(ds))["audio"]["array"]
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="transcribe")
generated_ids = model.generate(input_features, do_sample=False, max_length=20)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
EXPECTED_TRANSCRIPT = "木村さんに電話を貸してもらいました"
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe")
generated_ids = model.generate(
input_features,
do_sample=False,
max_length=20,
) )
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
EXPECTED_TRANSCRIPT = " Kimura san ni denwa wo kaite moraimashita"
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="translate")
generated_ids = model.generate(input_features, do_sample=False, max_length=20)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
EXPECTED_TRANSCRIPT = " I borrowed a phone from Kimura san"
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
@slow @slow
def test_large_batched_generation(self): def test_large_batched_generation(self):
set_seed(0) timeout = os.environ.get("PYTEST_TIMEOUT", 600)
processor = WhisperProcessor.from_pretrained("openai/whisper-large") run_test_in_subprocess(
model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-large") test_case=self, target_func=_test_large_batched_generation, inputs=None, timeout=timeout
input_speech = self._load_datasamples(4)
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
generated_ids = model.generate(input_features, max_length=20)
# fmt: off
EXPECTED_LOGITS = tf.convert_to_tensor(
[
[50258, 50358, 50363, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404, 281],
[50258, 50358, 50363, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50257, 50257],
[50258, 50358, 50363, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256],
[50258, 50358, 50363, 634, 575, 12525, 22618, 1968, 6144, 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, 439, 11]
]
) )
# fmt: on
self.assertTrue(np.allclose(generated_ids, EXPECTED_LOGITS))
# fmt: off
EXPECTED_TRANSCRIPT = [
' Mr. Quilter is the apostle of the middle classes and we are glad to',
" Nor is Mr. Quilter's manner less interesting than his matter.",
" He tells us that at this festive season of the year, with Christmas and roast beef",
" He has grave doubts whether Sir Frederick Layton's work is really Greek after all,"
]
# fmt: on
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)
self.assertListEqual(transcript, EXPECTED_TRANSCRIPT)
@slow @slow
def test_tiny_en_batched_generation(self): def test_tiny_en_batched_generation(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment