Unverified Commit e26c6f03 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix `Wav2Vec2` CI OOM (#24190)



fix
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 8f093fb7
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
from __future__ import annotations from __future__ import annotations
import copy import copy
import gc
import glob import glob
import inspect import inspect
import math import math
...@@ -709,6 +710,11 @@ class TFWav2Vec2UtilsTest(unittest.TestCase): ...@@ -709,6 +710,11 @@ class TFWav2Vec2UtilsTest(unittest.TestCase):
@require_tf @require_tf
@slow @slow
class TFWav2Vec2ModelIntegrationTest(unittest.TestCase): class TFWav2Vec2ModelIntegrationTest(unittest.TestCase):
def tearDown(self):
super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
def _load_datasamples(self, num_samples): def _load_datasamples(self, num_samples):
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# automatic decoding with librispeech # automatic decoding with librispeech
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
""" Testing suite for the PyTorch Wav2Vec2 model. """ """ Testing suite for the PyTorch Wav2Vec2 model. """
import gc
import math import math
import multiprocessing import multiprocessing
import os import os
...@@ -1374,6 +1375,12 @@ class Wav2Vec2UtilsTest(unittest.TestCase): ...@@ -1374,6 +1375,12 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
@require_soundfile @require_soundfile
@slow @slow
class Wav2Vec2ModelIntegrationTest(unittest.TestCase): class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
def tearDown(self):
super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
torch.cuda.empty_cache()
def _load_datasamples(self, num_samples): def _load_datasamples(self, num_samples):
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# automatic decoding with librispeech # automatic decoding with librispeech
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment