"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "4c7e8d09008ea4e46dd09dccfbd518bb2b792e75"
Unverified Commit cbecf121 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix env. variable type issue in testing (#21609)



* fix env issue

* fix env issue

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 5987e0ab
...@@ -1728,7 +1728,7 @@ def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, d ...@@ -1728,7 +1728,7 @@ def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, d
return decorator return decorator
def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=600): def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None):
""" """
To run a test in a subprocess. In particular, this can avoid (GPU) memory issue. To run a test in a subprocess. In particular, this can avoid (GPU) memory issue.
...@@ -1739,9 +1739,12 @@ def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=600): ...@@ -1739,9 +1739,12 @@ def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=600):
The function implementing the actual testing logic. The function implementing the actual testing logic.
inputs (`dict`, *optional*, defaults to `None`): inputs (`dict`, *optional*, defaults to `None`):
The inputs that will be passed to `target_func` through an (input) queue. The inputs that will be passed to `target_func` through an (input) queue.
timeout (`int`, *optional*, defaults to 600): timeout (`int`, *optional*, defaults to `None`):
The timeout (in seconds) that will be passed to the input and output queues. The timeout (in seconds) that will be passed to the input and output queues. If not specified, the env.
variable `PYTEST_TIMEOUT` will be checked. If still `None`, its value will be set to `600`.
""" """
if timeout is None:
timeout = int(os.environ.get("PYTEST_TIMEOUT", 600))
start_methohd = "spawn" start_methohd = "spawn"
ctx = multiprocessing.get_context(start_methohd) ctx = multiprocessing.get_context(start_methohd)
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
import inspect import inspect
import math import math
import multiprocessing import multiprocessing
import os
import traceback import traceback
import unittest import unittest
...@@ -637,7 +636,4 @@ class FlaxWav2Vec2ModelIntegrationTest(unittest.TestCase): ...@@ -637,7 +636,4 @@ class FlaxWav2Vec2ModelIntegrationTest(unittest.TestCase):
@require_pyctcdecode @require_pyctcdecode
@require_librosa @require_librosa
def test_wav2vec2_with_lm_invalid_pool(self): def test_wav2vec2_with_lm_invalid_pool(self):
timeout = os.environ.get("PYTEST_TIMEOUT", 600) run_test_in_subprocess(test_case=self, target_func=_test_wav2vec2_with_lm_invalid_pool, inputs=None)
run_test_in_subprocess(
test_case=self, target_func=_test_wav2vec2_with_lm_invalid_pool, inputs=None, timeout=timeout
)
...@@ -19,7 +19,6 @@ import glob ...@@ -19,7 +19,6 @@ import glob
import inspect import inspect
import math import math
import multiprocessing import multiprocessing
import os
import traceback import traceback
import unittest import unittest
...@@ -682,7 +681,4 @@ class TFWav2Vec2ModelIntegrationTest(unittest.TestCase): ...@@ -682,7 +681,4 @@ class TFWav2Vec2ModelIntegrationTest(unittest.TestCase):
@require_pyctcdecode @require_pyctcdecode
@require_librosa @require_librosa
def test_wav2vec2_with_lm_invalid_pool(self): def test_wav2vec2_with_lm_invalid_pool(self):
timeout = os.environ.get("PYTEST_TIMEOUT", 600) run_test_in_subprocess(test_case=self, target_func=_test_wav2vec2_with_lm_invalid_pool, inputs=None)
run_test_in_subprocess(
test_case=self, target_func=_test_wav2vec2_with_lm_invalid_pool, inputs=None, timeout=timeout
)
...@@ -1713,10 +1713,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase): ...@@ -1713,10 +1713,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
@require_pyctcdecode @require_pyctcdecode
@require_torchaudio @require_torchaudio
def test_wav2vec2_with_lm_invalid_pool(self): def test_wav2vec2_with_lm_invalid_pool(self):
timeout = os.environ.get("PYTEST_TIMEOUT", 600) run_test_in_subprocess(test_case=self, target_func=_test_wav2vec2_with_lm_invalid_pool, inputs=None)
run_test_in_subprocess(
test_case=self, target_func=_test_wav2vec2_with_lm_invalid_pool, inputs=None, timeout=timeout
)
def test_inference_diarization(self): def test_inference_diarization(self):
model = Wav2Vec2ForAudioFrameClassification.from_pretrained("anton-l/wav2vec2-base-superb-sd").to(torch_device) model = Wav2Vec2ForAudioFrameClassification.from_pretrained("anton-l/wav2vec2-base-superb-sd").to(torch_device)
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
""" Testing suite for the TensorFlow Whisper model. """ """ Testing suite for the TensorFlow Whisper model. """
import inspect import inspect
import os
import tempfile import tempfile
import traceback import traceback
import unittest import unittest
...@@ -891,10 +890,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase): ...@@ -891,10 +890,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
@slow @slow
def test_large_logits_librispeech(self): def test_large_logits_librispeech(self):
timeout = os.environ.get("PYTEST_TIMEOUT", 600) run_test_in_subprocess(test_case=self, target_func=_test_large_logits_librispeech, inputs=None)
run_test_in_subprocess(
test_case=self, target_func=_test_large_logits_librispeech, inputs=None, timeout=timeout
)
@slow @slow
def test_tiny_en_generation(self): def test_tiny_en_generation(self):
...@@ -959,22 +955,15 @@ class TFWhisperModelIntegrationTests(unittest.TestCase): ...@@ -959,22 +955,15 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
@slow @slow
def test_large_generation(self): def test_large_generation(self):
timeout = os.environ.get("PYTEST_TIMEOUT", 600) run_test_in_subprocess(test_case=self, target_func=_test_large_generation, inputs=None)
run_test_in_subprocess(test_case=self, target_func=_test_large_generation, inputs=None, timeout=timeout)
@slow @slow
def test_large_generation_multilingual(self): def test_large_generation_multilingual(self):
timeout = os.environ.get("PYTEST_TIMEOUT", 600) run_test_in_subprocess(test_case=self, target_func=_test_large_generation_multilingual, inputs=None)
run_test_in_subprocess(
test_case=self, target_func=_test_large_generation_multilingual, inputs=None, timeout=timeout
)
@slow @slow
def test_large_batched_generation(self): def test_large_batched_generation(self):
timeout = os.environ.get("PYTEST_TIMEOUT", 600) run_test_in_subprocess(test_case=self, target_func=_test_large_batched_generation, inputs=None)
run_test_in_subprocess(
test_case=self, target_func=_test_large_batched_generation, inputs=None, timeout=timeout
)
@slow @slow
def test_tiny_en_batched_generation(self): def test_tiny_en_batched_generation(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment