Unverified Commit e26c6f03 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix `Wav2Vec2` CI OOM (#24190)



fix
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 8f093fb7
......@@ -17,6 +17,7 @@
from __future__ import annotations
import copy
import gc
import glob
import inspect
import math
......@@ -709,6 +710,11 @@ class TFWav2Vec2UtilsTest(unittest.TestCase):
@require_tf
@slow
class TFWav2Vec2ModelIntegrationTest(unittest.TestCase):
def tearDown(self):
super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
def _load_datasamples(self, num_samples):
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# automatic decoding with librispeech
......
......@@ -14,6 +14,7 @@
# limitations under the License.
""" Testing suite for the PyTorch Wav2Vec2 model. """
import gc
import math
import multiprocessing
import os
......@@ -1374,6 +1375,12 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
@require_soundfile
@slow
class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
def tearDown(self):
super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
torch.cuda.empty_cache()
def _load_datasamples(self, num_samples):
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# automatic decoding with librispeech
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment