Commit da212020 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Move env util (#3499)

Summary: Pull Request resolved: https://github.com/pytorch/audio/pull/3499

Differential Revision: D47803654

Pulled By: mthrok

fbshipit-source-id: 2b916fa66d84c91c01b4dfe6dd5ee3501159f451
parent f082e6c1
...@@ -11,7 +11,7 @@ from itertools import zip_longest ...@@ -11,7 +11,7 @@ from itertools import zip_longest
import torch import torch
import torchaudio import torchaudio
from torch.testing._internal.common_utils import TestCase as PytorchTestCase from torch.testing._internal.common_utils import TestCase as PytorchTestCase
from torchaudio._internal.module_utils import is_module_available from torchaudio._internal.module_utils import eval_env, is_module_available
from torchaudio.utils.ffmpeg_utils import get_video_decoders, get_video_encoders from torchaudio.utils.ffmpeg_utils import get_video_decoders, get_video_encoders
from .backend_utils import set_audio_backend from .backend_utils import set_audio_backend
...@@ -143,24 +143,6 @@ def is_cuda_ctc_decoder_available(): ...@@ -143,24 +143,6 @@ def is_cuda_ctc_decoder_available():
return _IS_CUDA_CTC_DECODER_AVAILABLE return _IS_CUDA_CTC_DECODER_AVAILABLE
def _eval_env(var, default):
if var not in os.environ:
return default
val = os.environ.get(var, "0")
trues = ["1", "true", "TRUE", "on", "ON", "yes", "YES"]
falses = ["0", "false", "FALSE", "off", "OFF", "no", "NO"]
if val in trues:
return True
if val not in falses:
# fmt: off
raise RuntimeError(
f"Unexpected environment variable value `{var}={val}`. "
f"Expected one of {trues + falses}")
# fmt: on
return False
def _fail(reason): def _fail(reason):
def deco(test_item): def deco(test_item):
if isinstance(test_item, type): if isinstance(test_item, type):
...@@ -185,7 +167,7 @@ def _pass(test_item): ...@@ -185,7 +167,7 @@ def _pass(test_item):
return test_item return test_item
_IN_CI = _eval_env("CI", default=False) _IN_CI = eval_env("CI", default=False)
def _skipIf(condition, reason, key): def _skipIf(condition, reason, key):
...@@ -195,7 +177,7 @@ def _skipIf(condition, reason, key): ...@@ -195,7 +177,7 @@ def _skipIf(condition, reason, key):
# In CI, default to fail, so as to prevent accidental skip. # In CI, default to fail, so as to prevent accidental skip.
# In other env, default to skip # In other env, default to skip
var = f"TORCHAUDIO_TEST_ALLOW_SKIP_IF_{key}" var = f"TORCHAUDIO_TEST_ALLOW_SKIP_IF_{key}"
skip_allowed = _eval_env(var, default=not _IN_CI) skip_allowed = eval_env(var, default=not _IN_CI)
if skip_allowed: if skip_allowed:
return unittest.skip(reason) return unittest.skip(reason)
return _fail(f"{reason} But the test cannot be skipped. (CI={_IN_CI}, {var}={skip_allowed}.)") return _fail(f"{reason} But the test cannot be skipped. (CI={_IN_CI}, {var}={skip_allowed}.)")
...@@ -268,7 +250,7 @@ skipIfNoCuCtcDecoder = _skipIf( ...@@ -268,7 +250,7 @@ skipIfNoCuCtcDecoder = _skipIf(
key="NO_CUCTC_DECODER", key="NO_CUCTC_DECODER",
) )
skipIfRocm = _skipIf( skipIfRocm = _skipIf(
_eval_env("TORCHAUDIO_TEST_WITH_ROCM", default=False), eval_env("TORCHAUDIO_TEST_WITH_ROCM", default=False),
reason="The test doesn't currently work on the ROCm stack.", reason="The test doesn't currently work on the ROCm stack.",
key="ON_ROCM", key="ON_ROCM",
) )
......
import importlib.util import importlib.util
import os
import warnings import warnings
from functools import wraps from functools import wraps
from typing import Optional from typing import Optional
def eval_env(var, default):
"""Check if environment varable has True-y value"""
if var not in os.environ:
return default
val = os.environ.get(var, "0")
trues = ["1", "true", "TRUE", "on", "ON", "yes", "YES"]
falses = ["0", "false", "FALSE", "off", "OFF", "no", "NO"]
if val in trues:
return True
if val not in falses:
# fmt: off
raise RuntimeError(
f"Unexpected environment variable value `{var}={val}`. "
f"Expected one of {trues + falses}")
# fmt: on
return False
def is_module_available(*modules: str) -> bool: def is_module_available(*modules: str) -> bool:
r"""Returns if a top-level module with :attr:`name` exists *without** r"""Returns if a top-level module with :attr:`name` exists *without**
importing it. This is generally safer than try-catch block around a importing it. This is generally safer than try-catch block around a
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment