setup.py 2.92 KB
Newer Older
1
2
3
4
5
6
import os
import shutil
import zipfile
from pathlib import Path

from setuptools import setup
7
8
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
root = Path(__file__).parent.resolve()


def get_version():
    with open(root / "pyproject.toml") as f:
        for line in f:
            if line.startswith("version"):
                return line.split("=")[1].strip().strip('"')


def rename_wheel():
    if not os.environ.get("CUDA_VERSION"):
        return
    cuda_version = os.environ["CUDA_VERSION"].replace(".", "")
    base_version = get_version()

    wheel_dir = Path("dist")
    old_wheel = next(wheel_dir.glob("*.whl"))
    tmp_dir = wheel_dir / "tmp"
    tmp_dir.mkdir(exist_ok=True)

    with zipfile.ZipFile(old_wheel, "r") as zip_ref:
        zip_ref.extractall(tmp_dir)

    old_info = tmp_dir / f"sgl_kernel-{base_version}.dist-info"
34
    new_info = tmp_dir / f"sgl_kernel-{base_version}.post0+cu{cuda_version}.dist-info"
35
36
    old_info.rename(new_info)

37
38
39
    platform = "manylinux2014_x86_64"
    new_wheel = wheel_dir / old_wheel.name.replace("linux_x86_64", platform)
    new_wheel = wheel_dir / new_wheel.name.replace(
40
        base_version, f"{base_version}.post0+cu{cuda_version}"
41
    )
42

43
44
45
46
47
48
49
50
51
    with zipfile.ZipFile(new_wheel, "w", zipfile.ZIP_DEFLATED) as new_zip:
        for file_path in tmp_dir.rglob("*"):
            if file_path.is_file():
                new_zip.write(file_path, file_path.relative_to(tmp_dir))

    old_wheel.unlink()
    shutil.rmtree(tmp_dir)


Yineng Zhang's avatar
Yineng Zhang committed
52
53
54
55
56
57
58
59
60
def update_wheel_platform_tag():
    wheel_dir = Path("dist")
    old_wheel = next(wheel_dir.glob("*.whl"))
    new_wheel = wheel_dir / old_wheel.name.replace(
        "linux_x86_64", "manylinux2014_x86_64"
    )
    old_wheel.rename(new_wheel)


Ke Bao's avatar
Ke Bao committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
nvcc_flags = [
    "-O3",
    "-Xcompiler",
    "-fPIC",
    "-gencode=arch=compute_75,code=sm_75",
    "-gencode=arch=compute_80,code=sm_80",
    "-gencode=arch=compute_89,code=sm_89",
    "-gencode=arch=compute_90,code=sm_90",
    "-U__CUDA_NO_HALF_OPERATORS__",
    "-U__CUDA_NO_HALF2_OPERATORS__",
]
cxx_flags = ["-O3"]
libraries = ["c10", "torch", "torch_python"]
extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib"]
ext_modules = [
    CUDAExtension(
        name="sgl_kernel.ops._kernels",
        sources=[
            "src/sgl-kernel/csrc/warp_reduce_kernel.cu",
            "src/sgl-kernel/csrc/trt_reduce_internal.cu",
            "src/sgl-kernel/csrc/trt_reduce_kernel.cu",
            "src/sgl-kernel/csrc/moe_align_kernel.cu",
            "src/sgl-kernel/csrc/sgl_kernel_ops.cu",
        ],
        extra_compile_args={
            "nvcc": nvcc_flags,
            "cxx": cxx_flags,
        },
        libraries=libraries,
        extra_link_args=extra_link_args,
    ),
]

94
95
setup(
    name="sgl-kernel",
96
97
    version=get_version(),
    packages=["sgl_kernel"],
98
    package_dir={"": "src"},
Ke Bao's avatar
Ke Bao committed
99
    ext_modules=ext_modules,
100
101
102
    cmdclass={"build_ext": BuildExtension},
    install_requires=["torch"],
)
Yineng Zhang's avatar
Yineng Zhang committed
103
104

update_wheel_platform_tag()