"test/vscode:/vscode.git/clone" did not exist on "5c8541b77dd399b276c919c2a8969e5b973191f0"
setup.py 3.64 KB
Newer Older
1
2
from pathlib import Path

3
import torch
lukec's avatar
lukec committed
4
from setuptools import find_packages, setup
5
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
6
from version import __version__
7

8
9
10
root = Path(__file__).parent.resolve()


Yineng Zhang's avatar
Yineng Zhang committed
11
12
def update_wheel_platform_tag():
    wheel_dir = Path("dist")
13
14
15
16
17
18
    if wheel_dir.exists() and wheel_dir.is_dir():
        old_wheel = next(wheel_dir.glob("*.whl"))
        new_wheel = wheel_dir / old_wheel.name.replace(
            "linux_x86_64", "manylinux2014_x86_64"
        )
        old_wheel.rename(new_wheel)
Yineng Zhang's avatar
Yineng Zhang committed
19
20


21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def get_cuda_version():
    if torch.version.cuda:
        return tuple(map(int, torch.version.cuda.split(".")))
    return (0, 0)


def get_device_sm():
    if torch.cuda.is_available():
        major, minor = torch.cuda.get_device_capability()
        return major * 10 + minor
    return 0


cuda_version = get_cuda_version()
sm_version = get_device_sm()

37
cutlass = root / "3rdparty" / "cutlass"
38
flashinfer = root / "3rdparty" / "flashinfer"
39
40
41
include_dirs = [
    cutlass.resolve() / "include",
    cutlass.resolve() / "tools" / "util" / "include",
Ke Bao's avatar
Ke Bao committed
42
    root / "src" / "sgl-kernel" / "csrc",
43
    flashinfer.resolve() / "include",
44
    flashinfer.resolve() / "include" / "gemm",
45
    flashinfer.resolve() / "csrc",
46
]
Ke Bao's avatar
Ke Bao committed
47
nvcc_flags = [
48
    "-DNDEBUG",
Ke Bao's avatar
Ke Bao committed
49
50
51
52
53
54
55
    "-O3",
    "-Xcompiler",
    "-fPIC",
    "-gencode=arch=compute_75,code=sm_75",
    "-gencode=arch=compute_80,code=sm_80",
    "-gencode=arch=compute_89,code=sm_89",
    "-gencode=arch=compute_90,code=sm_90",
56
57
58
    "-std=c++17",
    "-use_fast_math",
    "-DFLASHINFER_ENABLE_F16",
Ke Bao's avatar
Ke Bao committed
59
]
60
61
62
63

if cuda_version >= (12, 0) and sm_version >= 90:
    nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a")

64
65
66
67
68
69
70
71
72
73
74
if sm_version >= 90:
    nvcc_flags.extend(
        [
            "-DFLASHINFER_ENABLE_FP8",
            "-DFLASHINFER_ENABLE_FP8_E4M3",
            "-DFLASHINFER_ENABLE_FP8_E5M2",
        ]
    )
if sm_version >= 80:
    nvcc_flags.append("-DFLASHINFER_ENABLE_BF16")

75
76
77
78
79
80
81
82
83
84
for flag in [
    "-D__CUDA_NO_HALF_OPERATORS__",
    "-D__CUDA_NO_HALF_CONVERSIONS__",
    "-D__CUDA_NO_BFLOAT16_CONVERSIONS__",
    "-D__CUDA_NO_HALF2_OPERATORS__",
]:
    try:
        torch.utils.cpp_extension.COMMON_NVCC_FLAGS.remove(flag)
    except ValueError:
        pass
Ke Bao's avatar
Ke Bao committed
85
cxx_flags = ["-O3"]
86
libraries = ["c10", "torch", "torch_python", "cuda"]
87
extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", "-L/usr/lib/x86_64-linux-gnu"]
Ke Bao's avatar
Ke Bao committed
88
89
90
91
92
93
94
ext_modules = [
    CUDAExtension(
        name="sgl_kernel.ops._kernels",
        sources=[
            "src/sgl-kernel/csrc/trt_reduce_internal.cu",
            "src/sgl-kernel/csrc/trt_reduce_kernel.cu",
            "src/sgl-kernel/csrc/moe_align_kernel.cu",
Ke Bao's avatar
Ke Bao committed
95
            "src/sgl-kernel/csrc/int8_gemm_kernel.cu",
96
            "src/sgl-kernel/csrc/sampling_scaling_penalties.cu",
97
            "src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu",
Ke Bao's avatar
Ke Bao committed
98
            "src/sgl-kernel/csrc/sgl_kernel_ops.cu",
99
            "src/sgl-kernel/csrc/rotary_embedding.cu",
100
101
102
103
            "3rdparty/flashinfer/csrc/activation.cu",
            "3rdparty/flashinfer/csrc/bmm_fp8.cu",
            "3rdparty/flashinfer/csrc/group_gemm.cu",
            "3rdparty/flashinfer/csrc/group_gemm_sm90.cu",
Yineng Zhang's avatar
Yineng Zhang committed
104
            "3rdparty/flashinfer/csrc/norm.cu",
105
            "3rdparty/flashinfer/csrc/sampling.cu",
Ke Bao's avatar
Ke Bao committed
106
        ],
107
        include_dirs=include_dirs,
Ke Bao's avatar
Ke Bao committed
108
109
110
111
112
113
114
115
116
        extra_compile_args={
            "nvcc": nvcc_flags,
            "cxx": cxx_flags,
        },
        libraries=libraries,
        extra_link_args=extra_link_args,
    ),
]

117
118
setup(
    name="sgl-kernel",
119
    version=__version__,
lukec's avatar
lukec committed
120
    packages=find_packages(),
121
    package_dir={"": "src"},
Ke Bao's avatar
Ke Bao committed
122
    ext_modules=ext_modules,
123
124
    cmdclass={"build_ext": BuildExtension},
)
Yineng Zhang's avatar
Yineng Zhang committed
125
126

update_wheel_platform_tag()