import os import platform import subprocess from pathlib import Path from torch.utils.cpp_extension import ( CppExtension, BuildExtension as TorchBuildExtension ) __all__ = [ 'get_ext_modules', 'BuildExtension', ] _THIS_DIR = Path(__file__).parent.resolve() _ROOT_DIR = _THIS_DIR.parent.parent.resolve() _CSRC_DIR = _ROOT_DIR / 'torchaudio' / 'csrc' _TP_BASE_DIR = _ROOT_DIR / 'third_party' _TP_INSTALL_DIR = _TP_BASE_DIR / 'build' def _get_build_sox(): val = os.environ.get('BUILD_SOX', '0') trues = ['1', 'true', 'TRUE', 'on', 'ON', 'yes', 'YES'] falses = ['0', 'false', 'FALSE', 'off', 'OFF', 'no', 'NO'] if val in trues: return True if val not in falses: print( f'WARNING: Unexpected environment variable value `BUILD_SOX={val}`. ' f'Expected one of {trues + falses}') return False _BUILD_SOX = _get_build_sox() def _get_eca(debug): eca = [] if debug: eca += ["-O0", "-g"] else: eca += ["-O3"] return eca def _get_ela(debug): ela = [] if debug: if platform.system() == "Windows": ela += ["/DEBUG:FULL"] else: ela += ["-O0", "-g"] else: ela += ["-O3"] return ela def _get_srcs(): return [str(p) for p in _CSRC_DIR.glob('**/*.cpp')] def _get_include_dirs(): dirs = [ str(_ROOT_DIR), ] if _BUILD_SOX: dirs.append(str(_TP_INSTALL_DIR / 'include')) return dirs def _get_extra_objects(): objs = [] if _BUILD_SOX: # NOTE: The order of the library listed bellow matters. # # (the most important thing is that dependencies come after a library # e.g., sox comes first) libs = ['libsox.a', 'libmad.a', 'libFLAC.a', 'libmp3lame.a'] for lib in libs: objs.append(str(_TP_INSTALL_DIR / 'lib' / lib)) return objs def _get_libraries(): return [] if _BUILD_SOX else ['sox'] def _build_codecs(): subprocess.run( args=[str(_THIS_DIR / 'build_third_party.sh')], check=True, ) def _configure_third_party(): _build_codecs() _EXT_NAME = 'torchaudio._torchaudio' def get_ext_modules(debug=False): if platform.system() == 'Windows': return None return [ CppExtension( _EXT_NAME, _get_srcs(), libraries=_get_libraries(), include_dirs=_get_include_dirs(), extra_compile_args=_get_eca(debug), extra_objects=_get_extra_objects(), extra_link_args=_get_ela(debug), ), ] class BuildExtension(TorchBuildExtension): def build_extension(self, ext): if ext.name == _EXT_NAME and _BUILD_SOX: _configure_third_party() super().build_extension(ext)