extension.py 3.63 KB
Newer Older
1
import os
moto's avatar
moto committed
2
3
4
import platform
import subprocess
from pathlib import Path
moto's avatar
moto committed
5
import distutils.sysconfig
moto's avatar
moto committed
6

moto's avatar
moto committed
7
8
from setuptools import Extension
from setuptools.command.build_ext import build_ext
9
import torch
moto's avatar
moto committed
10
11
12

__all__ = [
    'get_ext_modules',
moto's avatar
moto committed
13
    'CMakeBuild',
moto's avatar
moto committed
14
15
16
17
]

_THIS_DIR = Path(__file__).parent.resolve()
_ROOT_DIR = _THIS_DIR.parent.parent.resolve()
moto's avatar
moto committed
18
_TORCHAUDIO_DIR = _ROOT_DIR / 'torchaudio'
moto's avatar
moto committed
19

20

21
22
23
24
def _get_build(var, default=False):
    if var not in os.environ:
        return default

25
    val = os.environ.get(var, '0')
26
27
28
29
30
31
    trues = ['1', 'true', 'TRUE', 'on', 'ON', 'yes', 'YES']
    falses = ['0', 'false', 'FALSE', 'off', 'OFF', 'no', 'NO']
    if val in trues:
        return True
    if val not in falses:
        print(
32
            f'WARNING: Unexpected environment variable value `{var}={val}`. '
33
34
35
36
            f'Expected one of {trues + falses}')
    return False


37
_BUILD_SOX = _get_build("BUILD_SOX")
38
_BUILD_KALDI = _get_build("BUILD_KALDI", True)
39
_BUILD_TRANSDUCER = _get_build("BUILD_TRANSDUCER")
moto's avatar
moto committed
40
41


moto's avatar
moto committed
42
43
44
45
def get_ext_modules():
    if platform.system() == 'Windows':
        return None
    return [Extension(name='torchaudio._torchaudio', sources=[])]
moto's avatar
moto committed
46
47


moto's avatar
moto committed
48
49
50
51
52
53
54
55
56
# Based off of
# https://github.com/pybind/cmake_example/blob/580c5fd29d4651db99d8874714b07c0c49a53f8a/setup.py
class CMakeBuild(build_ext):
    def run(self):
        try:
            subprocess.check_output(['cmake', '--version'])
        except OSError:
            raise RuntimeError("CMake is not available.")
        super().run()
moto's avatar
moto committed
57
58

    def build_extension(self, ext):
moto's avatar
moto committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
        extdir = os.path.abspath(
            os.path.dirname(self.get_ext_fullpath(ext.name)))

        # required for auto-detection of auxiliary "native" libs
        if not extdir.endswith(os.path.sep):
            extdir += os.path.sep

        cfg = "Debug" if self.debug else "Release"

        cmake_args = [
            f"-DCMAKE_BUILD_TYPE={cfg}",
            f"-DCMAKE_PREFIX_PATH={torch.utils.cmake_prefix_path}",
            f"-DCMAKE_INSTALL_PREFIX={extdir}",
            '-DCMAKE_VERBOSE_MAKEFILE=ON',
            f"-DPython_INCLUDE_DIR={distutils.sysconfig.get_python_inc()}",
            f"-DBUILD_SOX:BOOL={'ON' if _BUILD_SOX else 'OFF'}",
75
            f"-DBUILD_KALDI:BOOL={'ON' if _BUILD_KALDI else 'OFF'}",
moto's avatar
moto committed
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
            f"-DBUILD_TRANSDUCER:BOOL={'ON' if _BUILD_TRANSDUCER else 'OFF'}",
            "-DBUILD_TORCHAUDIO_PYTHON_EXTENSION:BOOL=ON",
            "-DBUILD_LIBTORCHAUDIO:BOOL=OFF",
        ]
        build_args = [
            '--target', 'install'
        ]

        # Default to Ninja
        if 'CMAKE_GENERATOR' not in os.environ:
            cmake_args += ["-GNinja"]

        # Set CMAKE_BUILD_PARALLEL_LEVEL to control the parallel build level
        # across all generators.
        if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ:
            # self.parallel is a Python 3 only way to set parallel jobs by hand
            # using -j in the build_ext call, not supported by pip or PyPA-build.
            if hasattr(self, "parallel") and self.parallel:
                # CMake 3.12+ only.
                build_args += ["-j{}".format(self.parallel)]

        if not os.path.exists(self.build_temp):
            os.makedirs(self.build_temp)

        subprocess.check_call(
            ["cmake", str(_ROOT_DIR)] + cmake_args, cwd=self.build_temp)
        subprocess.check_call(
            ["cmake", "--build", "."] + build_args, cwd=self.build_temp)

    def get_ext_filename(self, fullname):
        ext_filename = super().get_ext_filename(fullname)
        ext_filename_parts = ext_filename.split('.')
        without_abi = ext_filename_parts[:-2] + ext_filename_parts[-1:]
        ext_filename = '.'.join(without_abi)
        return ext_filename