setup.py 2.34 KB
Newer Older
1
2
3
4
5
6
import os
import shutil
import zipfile
from pathlib import Path

from setuptools import setup
7
8
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
root = Path(__file__).parent.resolve()


def get_version():
    with open(root / "pyproject.toml") as f:
        for line in f:
            if line.startswith("version"):
                return line.split("=")[1].strip().strip('"')


def rename_wheel():
    if not os.environ.get("CUDA_VERSION"):
        return
    cuda_version = os.environ["CUDA_VERSION"].replace(".", "")
    base_version = get_version()

    wheel_dir = Path("dist")
    old_wheel = next(wheel_dir.glob("*.whl"))
    tmp_dir = wheel_dir / "tmp"
    tmp_dir.mkdir(exist_ok=True)

    with zipfile.ZipFile(old_wheel, "r") as zip_ref:
        zip_ref.extractall(tmp_dir)

    old_info = tmp_dir / f"sgl_kernel-{base_version}.dist-info"
    new_info = tmp_dir / f"sgl_kernel-{base_version}+cu{cuda_version}.dist-info"
    old_info.rename(new_info)

37
38
39
40
    platform = "manylinux2014_x86_64"
    new_wheel = wheel_dir / old_wheel.name.replace("linux_x86_64", platform)
    new_wheel = wheel_dir / new_wheel.name.replace(
        base_version, f"{base_version}+cu{cuda_version}"
41
    )
42

43
44
45
46
47
48
49
50
51
    with zipfile.ZipFile(new_wheel, "w", zipfile.ZIP_DEFLATED) as new_zip:
        for file_path in tmp_dir.rglob("*"):
            if file_path.is_file():
                new_zip.write(file_path, file_path.relative_to(tmp_dir))

    old_wheel.unlink()
    shutil.rmtree(tmp_dir)


52
53
setup(
    name="sgl-kernel",
54
55
    version=get_version(),
    packages=["sgl_kernel"],
56
57
58
59
60
61
62
63
    package_dir={"": "src"},
    ext_modules=[
        CUDAExtension(
            "sgl_kernel.ops.warp_reduce_cuda",
            [
                "src/sgl-kernel/csrc/warp_reduce.cc",
                "src/sgl-kernel/csrc/warp_reduce_kernel.cu",
            ],
64
65
66
67
68
69
70
71
72
73
74
75
            extra_compile_args={
                "nvcc": [
                    "-O3",
                    "-Xcompiler",
                    "-fPIC",
                    "-gencode=arch=compute_75,code=sm_75",
                    "-gencode=arch=compute_80,code=sm_80",
                    "-gencode=arch=compute_89,code=sm_89",
                    "-gencode=arch=compute_90,code=sm_90",
                ],
                "cxx": ["-O3"],
            },
76
77
78
79
80
        )
    ],
    cmdclass={"build_ext": BuildExtension},
    install_requires=["torch"],
)
81
82

rename_wheel()