extension.py 3.57 KB
Newer Older
1
import os
moto's avatar
moto committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import platform
import subprocess
from pathlib import Path

from torch.utils.cpp_extension import (
    CppExtension,
    BuildExtension as TorchBuildExtension
)

__all__ = [
    'get_ext_modules',
    'BuildExtension',
]

_THIS_DIR = Path(__file__).parent.resolve()
_ROOT_DIR = _THIS_DIR.parent.parent.resolve()
_CSRC_DIR = _ROOT_DIR / 'torchaudio' / 'csrc'
_TP_BASE_DIR = _ROOT_DIR / 'third_party'
moto's avatar
moto committed
20
_TP_INSTALL_DIR = _TP_BASE_DIR / 'install'
moto's avatar
moto committed
21

22

23
24
def _get_build(var):
    val = os.environ.get(var, '0')
25
26
27
28
29
30
    trues = ['1', 'true', 'TRUE', 'on', 'ON', 'yes', 'YES']
    falses = ['0', 'false', 'FALSE', 'off', 'OFF', 'no', 'NO']
    if val in trues:
        return True
    if val not in falses:
        print(
31
            f'WARNING: Unexpected environment variable value `{var}={val}`. '
32
33
34
35
            f'Expected one of {trues + falses}')
    return False


36
37
_BUILD_SOX = _get_build("BUILD_SOX")
_BUILD_TRANSDUCER = _get_build("BUILD_TRANSDUCER")
moto's avatar
moto committed
38
39
40
41
42
43
44
45


def _get_eca(debug):
    eca = []
    if debug:
        eca += ["-O0", "-g"]
    else:
        eca += ["-O3"]
46
47
    if _BUILD_TRANSDUCER:
        eca += ['-DBUILD_TRANSDUCER']
moto's avatar
moto committed
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    return eca


def _get_ela(debug):
    ela = []
    if debug:
        if platform.system() == "Windows":
            ela += ["/DEBUG:FULL"]
        else:
            ela += ["-O0", "-g"]
    else:
        ela += ["-O3"]
    return ela


def _get_srcs():
    return [str(p) for p in _CSRC_DIR.glob('**/*.cpp')]


def _get_include_dirs():
    dirs = [
        str(_ROOT_DIR),
    ]
71
    if _BUILD_SOX:
moto's avatar
moto committed
72
        dirs.append(str(_TP_INSTALL_DIR / 'include'))
73
74
    if _BUILD_TRANSDUCER:
        dirs.append(str(_TP_BASE_DIR / 'transducer' / 'submodule' / 'include'))
moto's avatar
moto committed
75
76
77
78
79
    return dirs


def _get_extra_objects():
    objs = []
80
    if _BUILD_SOX:
moto's avatar
moto committed
81
82
83
        # NOTE: The order of the library listed bellow matters.
        #
        # (the most important thing is that dependencies come after a library
moto's avatar
moto committed
84
85
86
87
88
89
90
        # e.g., sox comes first, flac/vorbis comes before ogg, and
        # vorbisenc/vorbisfile comes before vorbis
        libs = [
            'libsox.a',
            'libmad.a',
            'libFLAC.a',
            'libmp3lame.a',
moto's avatar
moto committed
91
92
            'libopusfile.a',
            'libopus.a',
moto's avatar
moto committed
93
94
95
96
            'libvorbisenc.a',
            'libvorbisfile.a',
            'libvorbis.a',
            'libogg.a',
97
98
            'libopencore-amrnb.a',
            'libopencore-amrwb.a',
moto's avatar
moto committed
99
        ]
moto's avatar
moto committed
100
101
        for lib in libs:
            objs.append(str(_TP_INSTALL_DIR / 'lib' / lib))
102
103
    if _BUILD_TRANSDUCER:
        objs.append(str(_TP_BASE_DIR / 'build' / 'transducer' / 'libwarprnnt.a'))
moto's avatar
moto committed
104
105
106
107
    return objs


def _get_libraries():
108
    return [] if _BUILD_SOX else ['sox']
moto's avatar
moto committed
109
110


moto's avatar
moto committed
111
112
113
def _build_third_party():
    build_dir = str(_TP_BASE_DIR / 'build')
    os.makedirs(build_dir, exist_ok=True)
moto's avatar
moto committed
114
    subprocess.run(
moto's avatar
moto committed
115
116
117
118
119
120
121
        args=['cmake', '..'],
        cwd=build_dir,
        check=True,
    )
    subprocess.run(
        args=['cmake', '--build', '.'],
        cwd=build_dir,
moto's avatar
moto committed
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
        check=True,
    )


_EXT_NAME = 'torchaudio._torchaudio'


def get_ext_modules(debug=False):
    if platform.system() == 'Windows':
        return None
    return [
        CppExtension(
            _EXT_NAME,
            _get_srcs(),
            libraries=_get_libraries(),
            include_dirs=_get_include_dirs(),
            extra_compile_args=_get_eca(debug),
            extra_objects=_get_extra_objects(),
            extra_link_args=_get_ela(debug),
        ),
    ]


class BuildExtension(TorchBuildExtension):
    def build_extension(self, ext):
147
        if ext.name == _EXT_NAME and _BUILD_SOX:
moto's avatar
moto committed
148
            _build_third_party()
moto's avatar
moto committed
149
        super().build_extension(ext)