setup.py 5.44 KB
Newer Older
Soumith Chintala's avatar
Soumith Chintala committed
1
#!/usr/bin/env python
2
import os
3
import re
4
import sys
moto's avatar
moto committed
5
import shutil
6
import subprocess
moto's avatar
moto committed
7
from pathlib import Path
Soumith Chintala's avatar
Soumith Chintala committed
8
from setuptools import setup, find_packages
9
import distutils.command.clean
10

11
import torch
moto's avatar
moto committed
12
from tools import setup_helpers
13

moto's avatar
moto committed
14
ROOT_DIR = Path(__file__).parent.resolve()
15
16


17
18
19
20
21
22
23
def _run_cmd(cmd):
    try:
        return subprocess.check_output(cmd, cwd=ROOT_DIR).decode('ascii').strip()
    except Exception:
        return None


moto's avatar
moto committed
24
25
26
27
28
29
30
def _get_version(sha):
    version = '0.11.0a0'
    if os.getenv('BUILD_VERSION'):
        version = os.getenv('BUILD_VERSION')
    elif sha is not None:
        version += '+' + sha[:7]
    return version
31
32


moto's avatar
moto committed
33
34
35
36
37
38
def _make_version_file(version, sha):
    sha = 'Unknown' if sha is None else sha
    version_path = ROOT_DIR / 'torchaudio' / 'version.py'
    with open(version_path, 'w') as f:
        f.write(f"__version__ = '{version}'\n")
        f.write(f"git_version = '{sha}'\n")
39

40

moto's avatar
moto committed
41
42
43
44
def _get_pytorch_version():
    if 'PYTORCH_VERSION' in os.environ:
        return f"torch=={os.environ['PYTORCH_VERSION']}"
    return 'torch'
45

moto's avatar
moto committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63

class clean(distutils.command.clean.clean):
    def run(self):
        # Run default behavior first
        distutils.command.clean.clean.run(self)

        # Remove torchaudio extension
        for path in (ROOT_DIR / 'torchaudio').glob('**/*.so'):
            print(f'removing \'{path}\'')
            path.unlink()
        # Remove build directory
        build_dirs = [
            ROOT_DIR / 'build',
        ]
        for path in build_dirs:
            if path.exists():
                print(f'removing \'{path}\' (and everything under it)')
                shutil.rmtree(str(path), ignore_errors=True)
peterjc123's avatar
peterjc123 committed
64

65

moto's avatar
moto committed
66
def _get_packages(branch_name, tag):
67
68
69
70
71
72
73
74
75
76
    exclude = [
        "build*",
        "test*",
        "torchaudio.csrc*",
        "third_party*",
        "tools*",
    ]
    exclude_prototype = False
    if branch_name is not None and branch_name.startswith('release/'):
        exclude_prototype = True
moto's avatar
moto committed
77
    if tag is not None and re.match(r'v[\d.]+(-rc\d+)?', tag):
78
79
80
81
82
83
84
        exclude_prototype = True
    if exclude_prototype:
        print('Excluding torchaudio.prototype from the package.')
        exclude.append("torchaudio.prototype")
    return find_packages(exclude=exclude)


85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def _init_submodule():
    print(' --- Initializing submodules')
    try:
        subprocess.check_call(['git', 'submodule', 'init'])
        subprocess.check_call(['git', 'submodule', 'update'])
    except Exception:
        print(' --- Submodule initalization failed')
        print('Please run:\n\tgit submodule update --init --recursive')
        sys.exit(1)
    print(' --- Initialized submodule')


def _parse_sox_sources():
    sox_dir = ROOT_DIR / 'third_party' / 'sox'
    cmake_file = sox_dir / 'CMakeLists.txt'
    archive_dir = sox_dir / 'archives'
    archive_dir.mkdir(exist_ok=True)
    with open(cmake_file, 'r') as file_:
        for line in file_:
            match = re.match(r'^\s*URL\s+(https:\/\/.+)$', line)
            if match:
                url = match.group(1)
                path = archive_dir / os.path.basename(url)
                yield path, url


def _fetch_sox_archives():
    for dest, url in _parse_sox_sources():
        if not dest.exists():
            print(f' --- Fetching {os.path.basename(dest)}')
            torch.hub.download_url_to_file(url, dest, progress=False)


def _fetch_third_party_libraries():
    if not (ROOT_DIR / 'third_party' / 'kaldi' / 'submodule' / 'CMakeLists.txt').exists():
        _init_submodule()
    if os.name != 'nt':
        _fetch_sox_archives()


moto's avatar
moto committed
125
126
127
128
129
130
131
132
133
134
135
136
137
def _main():
    sha = _run_cmd(['git', 'rev-parse', 'HEAD'])
    branch = _run_cmd(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
    tag = _run_cmd(['git', 'describe', '--tags', '--exact-match', '@'])
    print('-- Git branch:', branch)
    print('-- Git SHA:', sha)
    print('-- Git tag:', tag)
    pytorch_package_dep = _get_pytorch_version()
    print('-- PyTorch dependency:', pytorch_package_dep)
    version = _get_version(sha)
    print('-- Building version', version)

    _make_version_file(version, sha)
138
    _fetch_third_party_libraries()
moto's avatar
moto committed
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176

    setup(
        name="torchaudio",
        version=version,
        description="An audio package for PyTorch",
        url="https://github.com/pytorch/audio",
        author="Soumith Chintala, David Pollack, Sean Naren, Peter Goldsborough",
        author_email="soumith@pytorch.org",
        classifiers=[
            "Environment :: Plugins",
            "Intended Audience :: Developers",
            "Intended Audience :: Science/Research",
            "License :: OSI Approved :: BSD License",
            "Operating System :: MacOS :: MacOS X",
            "Operating System :: Microsoft :: Windows",
            "Operating System :: POSIX",
            "Programming Language :: C++",
            "Programming Language :: Python :: 3.6",
            "Programming Language :: Python :: 3.7",
            "Programming Language :: Python :: 3.8",
            "Programming Language :: Python :: 3.9",
            "Programming Language :: Python :: Implementation :: CPython",
            "Topic :: Multimedia :: Sound/Audio",
            "Topic :: Scientific/Engineering :: Artificial Intelligence"
        ],
        packages=_get_packages(branch, tag),
        ext_modules=setup_helpers.get_ext_modules(),
        cmdclass={
            'build_ext': setup_helpers.CMakeBuild,
            'clean': clean,
        },
        install_requires=[pytorch_package_dep],
        zip_safe=False,
    )


if __name__ == '__main__':
    _main()