extension.py 4.19 KB
Newer Older
1
import os
moto's avatar
moto committed
2
3
4
5
import platform
import subprocess
from pathlib import Path

6
import torch
moto's avatar
moto committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from torch.utils.cpp_extension import (
    CppExtension,
    BuildExtension as TorchBuildExtension
)

__all__ = [
    'get_ext_modules',
    'BuildExtension',
]

_THIS_DIR = Path(__file__).parent.resolve()
_ROOT_DIR = _THIS_DIR.parent.parent.resolve()
_CSRC_DIR = _ROOT_DIR / 'torchaudio' / 'csrc'
_TP_BASE_DIR = _ROOT_DIR / 'third_party'
moto's avatar
moto committed
21
_TP_INSTALL_DIR = _TP_BASE_DIR / 'install'
moto's avatar
moto committed
22

23

24
25
def _get_build(var):
    val = os.environ.get(var, '0')
26
27
28
29
30
31
    trues = ['1', 'true', 'TRUE', 'on', 'ON', 'yes', 'YES']
    falses = ['0', 'false', 'FALSE', 'off', 'OFF', 'no', 'NO']
    if val in trues:
        return True
    if val not in falses:
        print(
32
            f'WARNING: Unexpected environment variable value `{var}={val}`. '
33
34
35
36
            f'Expected one of {trues + falses}')
    return False


37
38
_BUILD_SOX = _get_build("BUILD_SOX")
_BUILD_TRANSDUCER = _get_build("BUILD_TRANSDUCER")
moto's avatar
moto committed
39
40
41
42
43
44
45
46


def _get_eca(debug):
    eca = []
    if debug:
        eca += ["-O0", "-g"]
    else:
        eca += ["-O3"]
47
48
    if _BUILD_TRANSDUCER:
        eca += ['-DBUILD_TRANSDUCER']
moto's avatar
moto committed
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
    return eca


def _get_ela(debug):
    ela = []
    if debug:
        if platform.system() == "Windows":
            ela += ["/DEBUG:FULL"]
        else:
            ela += ["-O0", "-g"]
    else:
        ela += ["-O3"]
    return ela


def _get_srcs():
moto's avatar
moto committed
65
66
67
68
69
    srcs = [_CSRC_DIR / 'pybind.cpp']
    srcs += list(_CSRC_DIR.glob('sox/**/*.cpp'))
    if _BUILD_TRANSDUCER:
        srcs += [_CSRC_DIR / 'transducer.cpp']
    return [str(path) for path in srcs]
moto's avatar
moto committed
70
71
72
73
74
75


def _get_include_dirs():
    dirs = [
        str(_ROOT_DIR),
    ]
moto's avatar
moto committed
76
    if _BUILD_SOX or _BUILD_TRANSDUCER:
moto's avatar
moto committed
77
78
79
80
81
        dirs.append(str(_TP_INSTALL_DIR / 'include'))
    return dirs


def _get_extra_objects():
moto's avatar
moto committed
82
    libs = []
83
    if _BUILD_SOX:
moto's avatar
moto committed
84
85
86
        # NOTE: The order of the library listed bellow matters.
        #
        # (the most important thing is that dependencies come after a library
moto's avatar
moto committed
87
88
        # e.g., sox comes first, flac/vorbis comes before ogg, and
        # vorbisenc/vorbisfile comes before vorbis
moto's avatar
moto committed
89
        libs += [
moto's avatar
moto committed
90
91
92
93
            'libsox.a',
            'libmad.a',
            'libFLAC.a',
            'libmp3lame.a',
moto's avatar
moto committed
94
95
            'libopusfile.a',
            'libopus.a',
moto's avatar
moto committed
96
97
98
99
            'libvorbisenc.a',
            'libvorbisfile.a',
            'libvorbis.a',
            'libogg.a',
100
101
            'libopencore-amrnb.a',
            'libopencore-amrwb.a',
moto's avatar
moto committed
102
        ]
103
    if _BUILD_TRANSDUCER:
moto's avatar
moto committed
104
105
106
        libs += ['libwarprnnt.a']

    return [str(_TP_INSTALL_DIR / 'lib' / lib) for lib in libs]
moto's avatar
moto committed
107
108
109


def _get_libraries():
110
    return [] if _BUILD_SOX else ['sox']
moto's avatar
moto committed
111
112


113
114
115
116
117
118
119
120
def _get_cxx11_abi():
    try:
        value = int(torch._C._GLIBCXX_USE_CXX11_ABI)
    except ImportError:
        value = 0
    return f'-D_GLIBCXX_USE_CXX11_ABI={value}'


moto's avatar
moto committed
121
122
def _build_third_party(base_build_dir):
    build_dir = os.path.join(base_build_dir, 'third_party')
moto's avatar
moto committed
123
    os.makedirs(build_dir, exist_ok=True)
moto's avatar
moto committed
124
    subprocess.run(
moto's avatar
moto committed
125
126
        args=[
            'cmake',
127
128
            f"-DCMAKE_CXX_FLAGS='{_get_cxx11_abi()}'",
            '-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON',
moto's avatar
moto committed
129
130
131
132
            f'-DCMAKE_INSTALL_PREFIX={_TP_INSTALL_DIR}',
            f'-DBUILD_SOX={"ON" if _BUILD_SOX else "OFF"}',
            f'-DBUILD_TRANSDUCER={"ON" if _BUILD_TRANSDUCER else "OFF"}',
            f'{_TP_BASE_DIR}'],
moto's avatar
moto committed
133
134
135
        cwd=build_dir,
        check=True,
    )
moto's avatar
moto committed
136
137
138
    command = ['cmake', '--build', '.']
    if _BUILD_TRANSDUCER:
        command += ['--target', 'install']
moto's avatar
moto committed
139
    subprocess.run(
moto's avatar
moto committed
140
        args=command,
moto's avatar
moto committed
141
        cwd=build_dir,
moto's avatar
moto committed
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
        check=True,
    )


_EXT_NAME = 'torchaudio._torchaudio'


def get_ext_modules(debug=False):
    if platform.system() == 'Windows':
        return None
    return [
        CppExtension(
            _EXT_NAME,
            _get_srcs(),
            libraries=_get_libraries(),
            include_dirs=_get_include_dirs(),
            extra_compile_args=_get_eca(debug),
            extra_objects=_get_extra_objects(),
            extra_link_args=_get_ela(debug),
        ),
    ]


class BuildExtension(TorchBuildExtension):
    def build_extension(self, ext):
167
        if ext.name == _EXT_NAME and _BUILD_SOX:
moto's avatar
moto committed
168
            _build_third_party(self.build_temp)
moto's avatar
moto committed
169
        super().build_extension(ext)