setup.py 2.32 KB
Newer Older
1
2
from pathlib import Path

lukec's avatar
lukec committed
3
from setuptools import find_packages, setup
4
5
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

6
7
8
9
10
11
12
13
14
15
root = Path(__file__).parent.resolve()


def get_version():
    with open(root / "pyproject.toml") as f:
        for line in f:
            if line.startswith("version"):
                return line.split("=")[1].strip().strip('"')


Yineng Zhang's avatar
Yineng Zhang committed
16
17
def update_wheel_platform_tag():
    wheel_dir = Path("dist")
18
19
20
21
22
23
    if wheel_dir.exists() and wheel_dir.is_dir():
        old_wheel = next(wheel_dir.glob("*.whl"))
        new_wheel = wheel_dir / old_wheel.name.replace(
            "linux_x86_64", "manylinux2014_x86_64"
        )
        old_wheel.rename(new_wheel)
Yineng Zhang's avatar
Yineng Zhang committed
24
25


26
27
28
29
cutlass = root / "3rdparty" / "cutlass"
include_dirs = [
    cutlass.resolve() / "include",
    cutlass.resolve() / "tools" / "util" / "include",
Ke Bao's avatar
Ke Bao committed
30
    root / "src" / "sgl-kernel" / "csrc",
31
]
Ke Bao's avatar
Ke Bao committed
32
nvcc_flags = [
33
    "-DNDEBUG",
Ke Bao's avatar
Ke Bao committed
34
35
36
37
38
39
40
    "-O3",
    "-Xcompiler",
    "-fPIC",
    "-gencode=arch=compute_75,code=sm_75",
    "-gencode=arch=compute_80,code=sm_80",
    "-gencode=arch=compute_89,code=sm_89",
    "-gencode=arch=compute_90,code=sm_90",
41
    "-gencode=arch=compute_90a,code=sm_90a",
Ke Bao's avatar
Ke Bao committed
42
43
44
45
    "-U__CUDA_NO_HALF_OPERATORS__",
    "-U__CUDA_NO_HALF2_OPERATORS__",
]
cxx_flags = ["-O3"]
46
libraries = ["c10", "torch", "torch_python", "cuda"]
47
extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", "-L/usr/lib/x86_64-linux-gnu"]
Ke Bao's avatar
Ke Bao committed
48
49
50
51
52
53
54
ext_modules = [
    CUDAExtension(
        name="sgl_kernel.ops._kernels",
        sources=[
            "src/sgl-kernel/csrc/trt_reduce_internal.cu",
            "src/sgl-kernel/csrc/trt_reduce_kernel.cu",
            "src/sgl-kernel/csrc/moe_align_kernel.cu",
Ke Bao's avatar
Ke Bao committed
55
            "src/sgl-kernel/csrc/int8_gemm_kernel.cu",
56
            "src/sgl-kernel/csrc/sampling_scaling_penalties.cu",
Ke Bao's avatar
Ke Bao committed
57
            "src/sgl-kernel/csrc/sgl_kernel_ops.cu",
58
            "src/sgl-kernel/csrc/rotary_embedding.cu",
Ke Bao's avatar
Ke Bao committed
59
        ],
60
        include_dirs=include_dirs,
Ke Bao's avatar
Ke Bao committed
61
62
63
64
65
66
67
68
69
        extra_compile_args={
            "nvcc": nvcc_flags,
            "cxx": cxx_flags,
        },
        libraries=libraries,
        extra_link_args=extra_link_args,
    ),
]

70
71
setup(
    name="sgl-kernel",
72
    version=get_version(),
lukec's avatar
lukec committed
73
    packages=find_packages(),
74
    package_dir={"": "src"},
Ke Bao's avatar
Ke Bao committed
75
    ext_modules=ext_modules,
76
77
78
    cmdclass={"build_ext": BuildExtension},
    install_requires=["torch"],
)
Yineng Zhang's avatar
Yineng Zhang committed
79
80

update_wheel_platform_tag()