setup.py 5.68 KB
Newer Older
Soumith Chintala's avatar
Soumith Chintala committed
1
#!/usr/bin/env python
2
import distutils.command.clean
3
import os
4
import re
moto's avatar
moto committed
5
import shutil
6
import subprocess
moto's avatar
moto committed
7
from pathlib import Path
8

9
import torch
flyingdown's avatar
flyingdown committed
10
from torch.utils.cpp_extension import ROCM_HOME
11
from setuptools import find_packages, setup
moto's avatar
moto committed
12
from tools import setup_helpers
13

moto's avatar
moto committed
14
ROOT_DIR = Path(__file__).parent.resolve()
15
16


flyingdown's avatar
flyingdown committed
17
def _run_cmd(cmd, shell=False):
18
    try:
flyingdown's avatar
flyingdown committed
19
        return subprocess.check_output(cmd, cwd=ROOT_DIR, stderr=subprocess.DEVNULL, shell=shell).decode("ascii").strip()
20
21
22
23
    except Exception:
        return None


mayp777's avatar
UPDATE  
mayp777 committed
24

moto's avatar
moto committed
25
def _get_version(sha):
26
27
    with open(ROOT_DIR / "version.txt", "r") as f:
        version = f.read().strip()
28
29
    if os.getenv("BUILD_VERSION"):
        version = os.getenv("BUILD_VERSION")
moto's avatar
moto committed
30
    elif sha is not None:
31
        version += "+das" + "." + "opt2" 
moto's avatar
moto committed
32
    return version
33
34


moto's avatar
moto committed
35
def _make_version_file(version, sha):
36
    sha = "Unknown" if sha is None else sha
flyingdown's avatar
flyingdown committed
37
38
    abi = _run_cmd(["echo '#include <string>' | gcc -x c++ -E -dM - | fgrep _GLIBCXX_USE_CXX11_ABI | awk '{print $3}'"], shell=True)
    dtk = _run_cmd(["cat", os.path.join(ROCM_HOME, '.info/rocm_version')])
mayp777's avatar
UPDATE  
mayp777 committed
39
    dtk = ''.join(dtk.split('.')[:2]) + "2"
flyingdown's avatar
flyingdown committed
40
    torch_version = torch.__version__
mayp777's avatar
UPDATE  
mayp777 committed
41
42
    dcu_version = f"{version}.dtk{dtk}"

43
    version_path = ROOT_DIR / "torchaudio" / "version.py"
mayp777's avatar
UPDATE  
mayp777 committed
44
    version_write = version[:-9]
45
    with open(version_path, "w") as f:
mayp777's avatar
UPDATE  
mayp777 committed
46
        f.write(f"__version__ = '{version_write}'\n")
moto's avatar
moto committed
47
        f.write(f"git_version = '{sha}'\n")
flyingdown's avatar
flyingdown committed
48
49
50
51
52
53
        f.write(f"abi = 'abi{abi}'\n")
        f.write(f"dtk = '{dtk}'\n")
        f.write(f"torch_version = '{torch_version}'\n")
        f.write(f"dcu_version = '{dcu_version}'\n")
    return dcu_version

54

moto's avatar
moto committed
55
def _get_pytorch_version():
56
    if "PYTORCH_VERSION" in os.environ:
moto's avatar
moto committed
57
        return f"torch=={os.environ['PYTORCH_VERSION']}"
58
    return "torch"
59

moto's avatar
moto committed
60
61
62
63
64
65
66

class clean(distutils.command.clean.clean):
    def run(self):
        # Run default behavior first
        distutils.command.clean.clean.run(self)

        # Remove torchaudio extension
67
68
        for path in (ROOT_DIR / "torchaudio").glob("**/*.so"):
            print(f"removing '{path}'")
moto's avatar
moto committed
69
70
71
            path.unlink()
        # Remove build directory
        build_dirs = [
72
            ROOT_DIR / "build",
moto's avatar
moto committed
73
74
75
        ]
        for path in build_dirs:
            if path.exists():
76
                print(f"removing '{path}' (and everything under it)")
moto's avatar
moto committed
77
                shutil.rmtree(str(path), ignore_errors=True)
peterjc123's avatar
peterjc123 committed
78

79

moto's avatar
moto committed
80
def _get_packages(branch_name, tag):
81
82
83
84
85
86
87
88
    exclude = [
        "build*",
        "test*",
        "torchaudio.csrc*",
        "third_party*",
        "tools*",
    ]
    exclude_prototype = False
89
    if branch_name is not None and branch_name.startswith("release/"):
90
        exclude_prototype = True
91
    if tag is not None and re.match(r"v[\d.]+(-rc\d+)?", tag):
92
93
        exclude_prototype = True
    if exclude_prototype:
94
        print("Excluding torchaudio.prototype from the package.")
95
        exclude.append("torchaudio.prototype*")
96
97
98
    return find_packages(exclude=exclude)


moto's avatar
moto committed
99
def _parse_url(path):
100
    with open(path, "r") as file_:
moto's avatar
moto committed
101
        for line in file_:
102
            match = re.match(r"^\s*URL\s+(https:\/\/.+)$", line)
moto's avatar
moto committed
103
104
105
106
107
108
109
            if match:
                url = match.group(1)
                yield url


def _fetch_archives(src):
    for dest, url in src:
110
        if not dest.exists():
111
            print(f" --- Fetching {os.path.basename(dest)}")
112
113
114
            torch.hub.download_url_to_file(url, dest, progress=False)


moto's avatar
moto committed
115
def _main():
116
117
118
119
120
121
    sha = _run_cmd(["git", "rev-parse", "HEAD"])
    branch = _run_cmd(["git", "rev-parse", "--abbrev-ref", "HEAD"])
    tag = _run_cmd(["git", "describe", "--tags", "--exact-match", "@"])
    print("-- Git branch:", branch)
    print("-- Git SHA:", sha)
    print("-- Git tag:", tag)
moto's avatar
moto committed
122
    pytorch_package_dep = _get_pytorch_version()
123
    print("-- PyTorch dependency:", pytorch_package_dep)
moto's avatar
moto committed
124
    version = _get_version(sha)
125
    print("-- Building version", version)
moto's avatar
moto committed
126

flyingdown's avatar
flyingdown committed
127
    dcu_version = _make_version_file(version, sha)
mayp777's avatar
UPDATE  
mayp777 committed
128
129
130

    with open("README.md") as f:
        long_description = f.read()
moto's avatar
moto committed
131
132
133

    setup(
        name="torchaudio",
flyingdown's avatar
flyingdown committed
134
        version=dcu_version,
moto's avatar
moto committed
135
        description="An audio package for PyTorch",
mayp777's avatar
UPDATE  
mayp777 committed
136
137
        long_description=long_description,
        long_description_content_type="text/markdown",
moto's avatar
moto committed
138
        url="https://github.com/pytorch/audio",
mayp777's avatar
UPDATE  
mayp777 committed
139
140
141
142
        author=(
            "Soumith Chintala, David Pollack, Sean Naren, Peter Goldsborough, "
            "Moto Hira, Caroline Chen, Jeff Hwang, Zhaoheng Ni, Xiaohui Zhang"
        ),
moto's avatar
moto committed
143
        author_email="soumith@pytorch.org",
144
145
        maintainer="Moto Hira, Caroline Chen, Jeff Hwang, Zhaoheng Ni, Xiaohui Zhang",
        maintainer_email="moto@meta.com",
moto's avatar
moto committed
146
147
148
149
150
151
152
153
154
155
156
        classifiers=[
            "Environment :: Plugins",
            "Intended Audience :: Developers",
            "Intended Audience :: Science/Research",
            "License :: OSI Approved :: BSD License",
            "Operating System :: MacOS :: MacOS X",
            "Operating System :: Microsoft :: Windows",
            "Operating System :: POSIX",
            "Programming Language :: C++",
            "Programming Language :: Python :: 3.8",
            "Programming Language :: Python :: 3.9",
mayp777's avatar
UPDATE  
mayp777 committed
157
158
            "Programming Language :: Python :: 3.10",
            "Programming Language :: Python :: 3.11",
moto's avatar
moto committed
159
160
            "Programming Language :: Python :: Implementation :: CPython",
            "Topic :: Multimedia :: Sound/Audio",
161
            "Topic :: Scientific/Engineering :: Artificial Intelligence",
moto's avatar
moto committed
162
163
164
165
        ],
        packages=_get_packages(branch, tag),
        ext_modules=setup_helpers.get_ext_modules(),
        cmdclass={
166
167
            "build_ext": setup_helpers.CMakeBuild,
            "clean": clean,
moto's avatar
moto committed
168
169
170
171
172
173
        },
        install_requires=[pytorch_package_dep],
        zip_safe=False,
    )


174
if __name__ == "__main__":
moto's avatar
moto committed
175
    _main()