extension.py 3.44 KB
Newer Older
1
import os
moto's avatar
moto committed
2
3
4
import platform
import subprocess
from pathlib import Path
moto's avatar
moto committed
5
import distutils.sysconfig
moto's avatar
moto committed
6

moto's avatar
moto committed
7
8
from setuptools import Extension
from setuptools.command.build_ext import build_ext
9
import torch
moto's avatar
moto committed
10
11
12

__all__ = [
    'get_ext_modules',
moto's avatar
moto committed
13
    'CMakeBuild',
moto's avatar
moto committed
14
15
16
17
]

_THIS_DIR = Path(__file__).parent.resolve()
_ROOT_DIR = _THIS_DIR.parent.parent.resolve()
moto's avatar
moto committed
18
_TORCHAUDIO_DIR = _ROOT_DIR / 'torchaudio'
moto's avatar
moto committed
19

20

21
22
def _get_build(var):
    val = os.environ.get(var, '0')
23
24
25
26
27
28
    trues = ['1', 'true', 'TRUE', 'on', 'ON', 'yes', 'YES']
    falses = ['0', 'false', 'FALSE', 'off', 'OFF', 'no', 'NO']
    if val in trues:
        return True
    if val not in falses:
        print(
29
            f'WARNING: Unexpected environment variable value `{var}={val}`. '
30
31
32
33
            f'Expected one of {trues + falses}')
    return False


34
35
_BUILD_SOX = _get_build("BUILD_SOX")
_BUILD_TRANSDUCER = _get_build("BUILD_TRANSDUCER")
moto's avatar
moto committed
36
37


moto's avatar
moto committed
38
39
40
41
def get_ext_modules():
    if platform.system() == 'Windows':
        return None
    return [Extension(name='torchaudio._torchaudio', sources=[])]
moto's avatar
moto committed
42
43


moto's avatar
moto committed
44
45
46
47
48
49
50
51
52
# Based off of
# https://github.com/pybind/cmake_example/blob/580c5fd29d4651db99d8874714b07c0c49a53f8a/setup.py
class CMakeBuild(build_ext):
    def run(self):
        try:
            subprocess.check_output(['cmake', '--version'])
        except OSError:
            raise RuntimeError("CMake is not available.")
        super().run()
moto's avatar
moto committed
53
54

    def build_extension(self, ext):
moto's avatar
moto committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
        extdir = os.path.abspath(
            os.path.dirname(self.get_ext_fullpath(ext.name)))

        # required for auto-detection of auxiliary "native" libs
        if not extdir.endswith(os.path.sep):
            extdir += os.path.sep

        cfg = "Debug" if self.debug else "Release"

        cmake_args = [
            f"-DCMAKE_BUILD_TYPE={cfg}",
            f"-DCMAKE_PREFIX_PATH={torch.utils.cmake_prefix_path}",
            f"-DCMAKE_INSTALL_PREFIX={extdir}",
            '-DCMAKE_VERBOSE_MAKEFILE=ON',
            f"-DPython_INCLUDE_DIR={distutils.sysconfig.get_python_inc()}",
            f"-DBUILD_SOX:BOOL={'ON' if _BUILD_SOX else 'OFF'}",
            f"-DBUILD_TRANSDUCER:BOOL={'ON' if _BUILD_TRANSDUCER else 'OFF'}",
            "-DBUILD_TORCHAUDIO_PYTHON_EXTENSION:BOOL=ON",
            "-DBUILD_LIBTORCHAUDIO:BOOL=OFF",
        ]
        build_args = [
            '--target', 'install'
        ]

        # Default to Ninja
        if 'CMAKE_GENERATOR' not in os.environ:
            cmake_args += ["-GNinja"]

        # Set CMAKE_BUILD_PARALLEL_LEVEL to control the parallel build level
        # across all generators.
        if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ:
            # self.parallel is a Python 3 only way to set parallel jobs by hand
            # using -j in the build_ext call, not supported by pip or PyPA-build.
            if hasattr(self, "parallel") and self.parallel:
                # CMake 3.12+ only.
                build_args += ["-j{}".format(self.parallel)]

        if not os.path.exists(self.build_temp):
            os.makedirs(self.build_temp)

        subprocess.check_call(
            ["cmake", str(_ROOT_DIR)] + cmake_args, cwd=self.build_temp)
        subprocess.check_call(
            ["cmake", "--build", "."] + build_args, cwd=self.build_temp)

    def get_ext_filename(self, fullname):
        ext_filename = super().get_ext_filename(fullname)
        ext_filename_parts = ext_filename.split('.')
        without_abi = ext_filename_parts[:-2] + ext_filename_parts[-1:]
        ext_filename = '.'.join(without_abi)
        return ext_filename