Unverified Commit 3c448374 authored by Caroline Chen's avatar Caroline Chen Committed by GitHub
Browse files

Make kaldi selective in build (#1342)

parent 5521f6c7
......@@ -18,7 +18,10 @@ _ROOT_DIR = _THIS_DIR.parent.parent.resolve()
_TORCHAUDIO_DIR = _ROOT_DIR / 'torchaudio'
def _get_build(var):
def _get_build(var, default=False):
if var not in os.environ:
return default
val = os.environ.get(var, '0')
trues = ['1', 'true', 'TRUE', 'on', 'ON', 'yes', 'YES']
falses = ['0', 'false', 'FALSE', 'off', 'OFF', 'no', 'NO']
......@@ -32,6 +35,7 @@ def _get_build(var):
_BUILD_SOX = _get_build("BUILD_SOX")
_BUILD_KALDI = _get_build("BUILD_KALDI", True)
_BUILD_TRANSDUCER = _get_build("BUILD_TRANSDUCER")
......@@ -68,7 +72,7 @@ class CMakeBuild(build_ext):
'-DCMAKE_VERBOSE_MAKEFILE=ON',
f"-DPython_INCLUDE_DIR={distutils.sysconfig.get_python_inc()}",
f"-DBUILD_SOX:BOOL={'ON' if _BUILD_SOX else 'OFF'}",
"-DBUILD_KALDI:BOOL=ON",
f"-DBUILD_KALDI:BOOL={'ON' if _BUILD_KALDI else 'OFF'}",
f"-DBUILD_TRANSDUCER:BOOL={'ON' if _BUILD_TRANSDUCER else 'OFF'}",
"-DBUILD_TORCHAUDIO_PYTHON_EXTENSION:BOOL=ON",
"-DBUILD_LIBTORCHAUDIO:BOOL=OFF",
......
......@@ -15,7 +15,7 @@ from .case_utils import (
skipIfNoCuda,
skipIfNoExec,
skipIfNoModule,
skipIfNoExtension,
skipIfNoKaldi,
skipIfNoSox,
skipIfNoSoxBackend,
)
......@@ -31,5 +31,5 @@ from .parameterized_utils import (
__all__ = ['get_asset_path', 'get_whitenoise', 'get_sinusoid', 'set_audio_backend',
'TempDirMixin', 'HttpServerMixin', 'TestBaseMixin', 'PytorchTestCase', 'TorchaudioTestCase',
'skipIfNoCuda', 'skipIfNoExec', 'skipIfNoModule', 'skipIfNoExtension', 'skipIfNoSox',
'skipIfNoCuda', 'skipIfNoExec', 'skipIfNoModule', 'skipIfNoKaldi', 'skipIfNoSox',
'skipIfNoSoxBackend', 'get_wav_data', 'normalize_wav', 'load_wav', 'save_wav', 'load_params']
......@@ -10,7 +10,8 @@ from torch.testing._internal.common_utils import TestCase as PytorchTestCase
import torchaudio
from torchaudio._internal.module_utils import (
is_module_available,
is_sox_available
is_sox_available,
is_kaldi_available
)
from .backend_utils import set_audio_backend
......@@ -99,11 +100,4 @@ skipIfNoSoxBackend = unittest.skipIf(
'sox' not in torchaudio.list_audio_backends(), 'Sox backend not available')
skipIfNoCuda = unittest.skipIf(not torch.cuda.is_available(), reason='CUDA not available')
skipIfNoSox = unittest.skipIf(not is_sox_available(), reason='Sox not available')
def skipIfNoExtension(test_item):
if is_module_available('torchaudio._torchaudio'):
return test_item
if 'TORCHAUDIO_TEST_FAIL_IF_NO_EXTENSION' in os.environ:
raise RuntimeError('torchaudio C++ extension is not available.')
return unittest.skip('torchaudio C++ extension is not available')(test_item)
skipIfNoKaldi = unittest.skipIf(not is_kaldi_available(), reason='Kaldi not available')
......@@ -206,7 +206,7 @@ class TestFunctional(common_utils.TorchaudioTestCase):
self.assert_batch_consistency(
F.vad, waveforms, sample_rate=sample_rate)
@common_utils.skipIfNoExtension
@common_utils.skipIfNoKaldi
def test_compute_kaldi_pitch(self):
sample_rate = 44100
n_channels = 2
......
......@@ -548,7 +548,7 @@ class Functional(common_utils.TestBaseMixin):
tensor = common_utils.get_whitenoise(sample_rate=44100)
self._assert_consistency(func, tensor)
@common_utils.skipIfNoExtension
@common_utils.skipIfNoKaldi
def test_compute_kaldi_pitch(self):
if self.dtype != torch.float32 or self.device != torch.device('cpu'):
raise unittest.SkipTest("Only float32, cpu is supported.")
......
......@@ -60,6 +60,23 @@ def deprecated(direction: str, version: Optional[str] = None):
return decorator
def is_kaldi_available():
return is_module_available('torchaudio._torchaudio') and torch.ops.torchaudio.is_kaldi_available()
def requires_kaldi():
if is_kaldi_available():
def decorator(func):
return func
else:
def decorator(func):
@wraps(func)
def wrapped(*args, **kwargs):
raise RuntimeError(f'{func.__module__}.{func.__name__} requires kaldi')
return wrapped
return decorator
def is_sox_available():
return is_module_available('torchaudio._torchaudio') and torch.ops.torchaudio.is_sox_available()
......
......@@ -88,6 +88,10 @@ if (BUILD_TORCHAUDIO_PYTHON_EXTENSION)
target_compile_definitions(_torchaudio PRIVATE INCLUDE_SOX)
endif()
if (BUILD_KALDI)
target_compile_definitions(_torchaudio PRIVATE INCLUDE_KALDI)
endif()
target_include_directories(
_torchaudio
PRIVATE
......
......@@ -12,10 +12,19 @@ bool is_sox_available() {
#endif
}
bool is_kaldi_available() {
#ifdef INCLUDE_KALDI
return true;
#else
return false;
#endif
}
} // namespace
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def("torchaudio::is_sox_available", &is_sox_available);
m.def("torchaudio::is_kaldi_available", &is_kaldi_available);
}
} // namespace torchaudio
......@@ -1114,6 +1114,7 @@ def apply_codec(
return augmented
@_mod_utils.requires_kaldi()
def compute_kaldi_pitch(
waveform: torch.Tensor,
sample_rate: float,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment