Unverified Commit 3c448374 authored by Caroline Chen's avatar Caroline Chen Committed by GitHub
Browse files

Make kaldi selective in build (#1342)

parent 5521f6c7
...@@ -18,7 +18,10 @@ _ROOT_DIR = _THIS_DIR.parent.parent.resolve() ...@@ -18,7 +18,10 @@ _ROOT_DIR = _THIS_DIR.parent.parent.resolve()
_TORCHAUDIO_DIR = _ROOT_DIR / 'torchaudio' _TORCHAUDIO_DIR = _ROOT_DIR / 'torchaudio'
def _get_build(var): def _get_build(var, default=False):
if var not in os.environ:
return default
val = os.environ.get(var, '0') val = os.environ.get(var, '0')
trues = ['1', 'true', 'TRUE', 'on', 'ON', 'yes', 'YES'] trues = ['1', 'true', 'TRUE', 'on', 'ON', 'yes', 'YES']
falses = ['0', 'false', 'FALSE', 'off', 'OFF', 'no', 'NO'] falses = ['0', 'false', 'FALSE', 'off', 'OFF', 'no', 'NO']
...@@ -32,6 +35,7 @@ def _get_build(var): ...@@ -32,6 +35,7 @@ def _get_build(var):
_BUILD_SOX = _get_build("BUILD_SOX") _BUILD_SOX = _get_build("BUILD_SOX")
_BUILD_KALDI = _get_build("BUILD_KALDI", True)
_BUILD_TRANSDUCER = _get_build("BUILD_TRANSDUCER") _BUILD_TRANSDUCER = _get_build("BUILD_TRANSDUCER")
...@@ -68,7 +72,7 @@ class CMakeBuild(build_ext): ...@@ -68,7 +72,7 @@ class CMakeBuild(build_ext):
'-DCMAKE_VERBOSE_MAKEFILE=ON', '-DCMAKE_VERBOSE_MAKEFILE=ON',
f"-DPython_INCLUDE_DIR={distutils.sysconfig.get_python_inc()}", f"-DPython_INCLUDE_DIR={distutils.sysconfig.get_python_inc()}",
f"-DBUILD_SOX:BOOL={'ON' if _BUILD_SOX else 'OFF'}", f"-DBUILD_SOX:BOOL={'ON' if _BUILD_SOX else 'OFF'}",
"-DBUILD_KALDI:BOOL=ON", f"-DBUILD_KALDI:BOOL={'ON' if _BUILD_KALDI else 'OFF'}",
f"-DBUILD_TRANSDUCER:BOOL={'ON' if _BUILD_TRANSDUCER else 'OFF'}", f"-DBUILD_TRANSDUCER:BOOL={'ON' if _BUILD_TRANSDUCER else 'OFF'}",
"-DBUILD_TORCHAUDIO_PYTHON_EXTENSION:BOOL=ON", "-DBUILD_TORCHAUDIO_PYTHON_EXTENSION:BOOL=ON",
"-DBUILD_LIBTORCHAUDIO:BOOL=OFF", "-DBUILD_LIBTORCHAUDIO:BOOL=OFF",
......
...@@ -15,7 +15,7 @@ from .case_utils import ( ...@@ -15,7 +15,7 @@ from .case_utils import (
skipIfNoCuda, skipIfNoCuda,
skipIfNoExec, skipIfNoExec,
skipIfNoModule, skipIfNoModule,
skipIfNoExtension, skipIfNoKaldi,
skipIfNoSox, skipIfNoSox,
skipIfNoSoxBackend, skipIfNoSoxBackend,
) )
...@@ -31,5 +31,5 @@ from .parameterized_utils import ( ...@@ -31,5 +31,5 @@ from .parameterized_utils import (
__all__ = ['get_asset_path', 'get_whitenoise', 'get_sinusoid', 'set_audio_backend', __all__ = ['get_asset_path', 'get_whitenoise', 'get_sinusoid', 'set_audio_backend',
'TempDirMixin', 'HttpServerMixin', 'TestBaseMixin', 'PytorchTestCase', 'TorchaudioTestCase', 'TempDirMixin', 'HttpServerMixin', 'TestBaseMixin', 'PytorchTestCase', 'TorchaudioTestCase',
'skipIfNoCuda', 'skipIfNoExec', 'skipIfNoModule', 'skipIfNoExtension', 'skipIfNoSox', 'skipIfNoCuda', 'skipIfNoExec', 'skipIfNoModule', 'skipIfNoKaldi', 'skipIfNoSox',
'skipIfNoSoxBackend', 'get_wav_data', 'normalize_wav', 'load_wav', 'save_wav', 'load_params'] 'skipIfNoSoxBackend', 'get_wav_data', 'normalize_wav', 'load_wav', 'save_wav', 'load_params']
...@@ -10,7 +10,8 @@ from torch.testing._internal.common_utils import TestCase as PytorchTestCase ...@@ -10,7 +10,8 @@ from torch.testing._internal.common_utils import TestCase as PytorchTestCase
import torchaudio import torchaudio
from torchaudio._internal.module_utils import ( from torchaudio._internal.module_utils import (
is_module_available, is_module_available,
is_sox_available is_sox_available,
is_kaldi_available
) )
from .backend_utils import set_audio_backend from .backend_utils import set_audio_backend
...@@ -99,11 +100,4 @@ skipIfNoSoxBackend = unittest.skipIf( ...@@ -99,11 +100,4 @@ skipIfNoSoxBackend = unittest.skipIf(
'sox' not in torchaudio.list_audio_backends(), 'Sox backend not available') 'sox' not in torchaudio.list_audio_backends(), 'Sox backend not available')
skipIfNoCuda = unittest.skipIf(not torch.cuda.is_available(), reason='CUDA not available') skipIfNoCuda = unittest.skipIf(not torch.cuda.is_available(), reason='CUDA not available')
skipIfNoSox = unittest.skipIf(not is_sox_available(), reason='Sox not available') skipIfNoSox = unittest.skipIf(not is_sox_available(), reason='Sox not available')
skipIfNoKaldi = unittest.skipIf(not is_kaldi_available(), reason='Kaldi not available')
def skipIfNoExtension(test_item):
if is_module_available('torchaudio._torchaudio'):
return test_item
if 'TORCHAUDIO_TEST_FAIL_IF_NO_EXTENSION' in os.environ:
raise RuntimeError('torchaudio C++ extension is not available.')
return unittest.skip('torchaudio C++ extension is not available')(test_item)
...@@ -206,7 +206,7 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -206,7 +206,7 @@ class TestFunctional(common_utils.TorchaudioTestCase):
self.assert_batch_consistency( self.assert_batch_consistency(
F.vad, waveforms, sample_rate=sample_rate) F.vad, waveforms, sample_rate=sample_rate)
@common_utils.skipIfNoExtension @common_utils.skipIfNoKaldi
def test_compute_kaldi_pitch(self): def test_compute_kaldi_pitch(self):
sample_rate = 44100 sample_rate = 44100
n_channels = 2 n_channels = 2
......
...@@ -548,7 +548,7 @@ class Functional(common_utils.TestBaseMixin): ...@@ -548,7 +548,7 @@ class Functional(common_utils.TestBaseMixin):
tensor = common_utils.get_whitenoise(sample_rate=44100) tensor = common_utils.get_whitenoise(sample_rate=44100)
self._assert_consistency(func, tensor) self._assert_consistency(func, tensor)
@common_utils.skipIfNoExtension @common_utils.skipIfNoKaldi
def test_compute_kaldi_pitch(self): def test_compute_kaldi_pitch(self):
if self.dtype != torch.float32 or self.device != torch.device('cpu'): if self.dtype != torch.float32 or self.device != torch.device('cpu'):
raise unittest.SkipTest("Only float32, cpu is supported.") raise unittest.SkipTest("Only float32, cpu is supported.")
......
...@@ -60,6 +60,23 @@ def deprecated(direction: str, version: Optional[str] = None): ...@@ -60,6 +60,23 @@ def deprecated(direction: str, version: Optional[str] = None):
return decorator return decorator
def is_kaldi_available():
return is_module_available('torchaudio._torchaudio') and torch.ops.torchaudio.is_kaldi_available()
def requires_kaldi():
if is_kaldi_available():
def decorator(func):
return func
else:
def decorator(func):
@wraps(func)
def wrapped(*args, **kwargs):
raise RuntimeError(f'{func.__module__}.{func.__name__} requires kaldi')
return wrapped
return decorator
def is_sox_available(): def is_sox_available():
return is_module_available('torchaudio._torchaudio') and torch.ops.torchaudio.is_sox_available() return is_module_available('torchaudio._torchaudio') and torch.ops.torchaudio.is_sox_available()
......
...@@ -88,6 +88,10 @@ if (BUILD_TORCHAUDIO_PYTHON_EXTENSION) ...@@ -88,6 +88,10 @@ if (BUILD_TORCHAUDIO_PYTHON_EXTENSION)
target_compile_definitions(_torchaudio PRIVATE INCLUDE_SOX) target_compile_definitions(_torchaudio PRIVATE INCLUDE_SOX)
endif() endif()
if (BUILD_KALDI)
target_compile_definitions(_torchaudio PRIVATE INCLUDE_KALDI)
endif()
target_include_directories( target_include_directories(
_torchaudio _torchaudio
PRIVATE PRIVATE
......
...@@ -12,10 +12,19 @@ bool is_sox_available() { ...@@ -12,10 +12,19 @@ bool is_sox_available() {
#endif #endif
} }
bool is_kaldi_available() {
#ifdef INCLUDE_KALDI
return true;
#else
return false;
#endif
}
} // namespace } // namespace
TORCH_LIBRARY_FRAGMENT(torchaudio, m) { TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def("torchaudio::is_sox_available", &is_sox_available); m.def("torchaudio::is_sox_available", &is_sox_available);
m.def("torchaudio::is_kaldi_available", &is_kaldi_available);
} }
} // namespace torchaudio } // namespace torchaudio
...@@ -1114,6 +1114,7 @@ def apply_codec( ...@@ -1114,6 +1114,7 @@ def apply_codec(
return augmented return augmented
@_mod_utils.requires_kaldi()
def compute_kaldi_pitch( def compute_kaldi_pitch(
waveform: torch.Tensor, waveform: torch.Tensor,
sample_rate: float, sample_rate: float,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment