setup.py 5.79 KB
Newer Older
Soumith Chintala's avatar
Soumith Chintala committed
1
#!/usr/bin/env python
2
import distutils.command.clean
3
import os
4
import re
moto's avatar
moto committed
5
import shutil
6
import subprocess
moto's avatar
moto committed
7
from pathlib import Path
8

9
import torch
10
from setuptools import find_packages, setup
moto's avatar
moto committed
11
from tools import setup_helpers
12

moto's avatar
moto committed
13
ROOT_DIR = Path(__file__).parent.resolve()
14
15


16
17
def _run_cmd(cmd):
    try:
18
        return subprocess.check_output(cmd, cwd=ROOT_DIR, stderr=subprocess.DEVNULL).decode("ascii").strip()
19
20
21
22
    except Exception:
        return None


moto's avatar
moto committed
23
def _get_version(sha):
24
25
    with open(ROOT_DIR / "version.txt", "r") as f:
        version = f.read().strip()
26
27
    if os.getenv("BUILD_VERSION"):
        version = os.getenv("BUILD_VERSION")
moto's avatar
moto committed
28
    elif sha is not None:
29
        version += "+" + sha[:7]
moto's avatar
moto committed
30
    return version
31
32


zhanggzh's avatar
zhanggzh committed
33
def _make_version_file(version):
moto-meta's avatar
moto-meta committed
34
    version_path = ROOT_DIR / "src" / "torchaudio" / "version.py"
35
    with open(version_path, "w") as f:
moto's avatar
moto committed
36
        f.write(f"__version__ = '{version}'\n")
37

38

moto's avatar
moto committed
39
def _get_pytorch_version():
40
    if "PYTORCH_VERSION" in os.environ:
moto's avatar
moto committed
41
        return f"torch=={os.environ['PYTORCH_VERSION']}"
42
    return "torch"
43

moto's avatar
moto committed
44
45
46
47
48
49
50

class clean(distutils.command.clean.clean):
    def run(self):
        # Run default behavior first
        distutils.command.clean.clean.run(self)

        # Remove torchaudio extension
moto-meta's avatar
moto-meta committed
51
        for path in (ROOT_DIR / "src").glob("**/*.so"):
52
            print(f"removing '{path}'")
moto's avatar
moto committed
53
54
55
            path.unlink()
        # Remove build directory
        build_dirs = [
56
            ROOT_DIR / "build",
moto's avatar
moto committed
57
58
59
        ]
        for path in build_dirs:
            if path.exists():
60
                print(f"removing '{path}' (and everything under it)")
moto's avatar
moto committed
61
                shutil.rmtree(str(path), ignore_errors=True)
peterjc123's avatar
peterjc123 committed
62

63

zhanggzh's avatar
zhanggzh committed
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def _make_version_file(version):
    ROCM_PATH = os.getenv('ROCM_PATH')
    dtk_path = ROCM_PATH + '/.info/rocm_version'
    with open(dtk_path, 'r') as file:
        content = file.read().strip()
    dtk_version = content.replace('.', '')

    hcu_version = f"{version}+das.dtk{dtk_version}"

    version_path = ROOT_DIR / "packaging" / "torchaudio" /"version.py"
    with open(version_path, "w") as f:
        f.write(f"__version__ = '{version}'\n")
        f.write(f"__hcu_version__ = '{hcu_version}'\n")
    return hcu_version


moto's avatar
moto committed
80
def _parse_url(path):
81
    with open(path, "r") as file_:
moto's avatar
moto committed
82
        for line in file_:
83
            match = re.match(r"^\s*URL\s+(https:\/\/.+)$", line)
moto's avatar
moto committed
84
85
86
87
88
89
90
            if match:
                url = match.group(1)
                yield url


def _fetch_archives(src):
    for dest, url in src:
91
        if not dest.exists():
92
            print(f" --- Fetching {os.path.basename(dest)}")
93
94
95
            torch.hub.download_url_to_file(url, dest, progress=False)


zhanggzh's avatar
zhanggzh committed
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import re

def get_version(file_path: str = "version.txt", encoding: str = "utf-8") -> str | None:
    """
    从指定文件中读取版本号(默认文件为 version.txt)
    """
    try:
        with open(file_path, "r", encoding=encoding) as file:
            line = file.readline().strip()
            
            # 正则匹配语义化版本号格式(如 v1.2.3 或 4.5.6-beta)
            if re.match(r"^v?(?:\d+\.){2}\d+(-\w+)?$", line):
                return line
            else:
                print(f"[错误] 无效的版本号格式: {line}")
                return None

    except FileNotFoundError:
        print(f"[错误] 文件不存在: {file_path}")
    except UnicodeDecodeError:
        print(f"[错误] 编码不匹配,请尝试 encoding='gbk'")
    except Exception as e:
        print(f"[错误] 读取失败: {str(e)}")
    
    return None

moto's avatar
moto committed
122
def _main():
123
124
125
126
127
128
    sha = _run_cmd(["git", "rev-parse", "HEAD"])
    branch = _run_cmd(["git", "rev-parse", "--abbrev-ref", "HEAD"])
    tag = _run_cmd(["git", "describe", "--tags", "--exact-match", "@"])
    print("-- Git branch:", branch)
    print("-- Git SHA:", sha)
    print("-- Git tag:", tag)
moto's avatar
moto committed
129
    pytorch_package_dep = _get_pytorch_version()
130
    print("-- PyTorch dependency:", pytorch_package_dep)
zhanggzh's avatar
zhanggzh committed
131
    version= get_version()
132
    print("-- Building version", version)
moto's avatar
moto committed
133

zhanggzh's avatar
zhanggzh committed
134
    dcu_version = _make_version_file(version)
moto's avatar
moto committed
135

136
137
138
    with open("README.md") as f:
        long_description = f.read()

moto's avatar
moto committed
139
140
    setup(
        name="torchaudio",
zhanggzh's avatar
zhanggzh committed
141
        version=dcu_version,
moto's avatar
moto committed
142
        description="An audio package for PyTorch",
143
144
        long_description=long_description,
        long_description_content_type="text/markdown",
moto's avatar
moto committed
145
        url="https://github.com/pytorch/audio",
moto's avatar
moto committed
146
147
148
149
        author=(
            "Soumith Chintala, David Pollack, Sean Naren, Peter Goldsborough, "
            "Moto Hira, Caroline Chen, Jeff Hwang, Zhaoheng Ni, Xiaohui Zhang"
        ),
moto's avatar
moto committed
150
        author_email="soumith@pytorch.org",
151
152
        maintainer="Moto Hira, Caroline Chen, Jeff Hwang, Zhaoheng Ni, Xiaohui Zhang",
        maintainer_email="moto@meta.com",
moto's avatar
moto committed
153
154
155
156
157
158
159
160
161
162
163
        classifiers=[
            "Environment :: Plugins",
            "Intended Audience :: Developers",
            "Intended Audience :: Science/Research",
            "License :: OSI Approved :: BSD License",
            "Operating System :: MacOS :: MacOS X",
            "Operating System :: Microsoft :: Windows",
            "Operating System :: POSIX",
            "Programming Language :: C++",
            "Programming Language :: Python :: 3.8",
            "Programming Language :: Python :: 3.9",
Wei Wang's avatar
Wei Wang committed
164
            "Programming Language :: Python :: 3.10",
165
            "Programming Language :: Python :: 3.11",
moto's avatar
moto committed
166
167
            "Programming Language :: Python :: Implementation :: CPython",
            "Topic :: Multimedia :: Sound/Audio",
168
            "Topic :: Scientific/Engineering :: Artificial Intelligence",
moto's avatar
moto committed
169
        ],
170
        packages=find_packages(where="src"),
moto-meta's avatar
moto-meta committed
171
        package_dir={"": "src"},
moto's avatar
moto committed
172
173
        ext_modules=setup_helpers.get_ext_modules(),
        cmdclass={
174
175
            "build_ext": setup_helpers.CMakeBuild,
            "clean": clean,
moto's avatar
moto committed
176
177
178
179
180
181
        },
        install_requires=[pytorch_package_dep],
        zip_safe=False,
    )


182
if __name__ == "__main__":
moto's avatar
moto committed
183
    _main()