setup.py 4.77 KB
Newer Older
1
import multiprocessing
2
import os
3
4
from pathlib import Path

5
import torch
lukec's avatar
lukec committed
6
from setuptools import find_packages, setup
7
8
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

9
10
11
root = Path(__file__).parent.resolve()


12
def _update_wheel_platform_tag():
Yineng Zhang's avatar
Yineng Zhang committed
13
    wheel_dir = Path("dist")
14
15
16
17
18
19
    if wheel_dir.exists() and wheel_dir.is_dir():
        old_wheel = next(wheel_dir.glob("*.whl"))
        new_wheel = wheel_dir / old_wheel.name.replace(
            "linux_x86_64", "manylinux2014_x86_64"
        )
        old_wheel.rename(new_wheel)
Yineng Zhang's avatar
Yineng Zhang committed
20
21


22
def _get_cuda_version():
23
24
25
26
27
    if torch.version.cuda:
        return tuple(map(int, torch.version.cuda.split(".")))
    return (0, 0)


28
def _get_device_sm():
29
30
31
32
33
34
    if torch.cuda.is_available():
        major, minor = torch.cuda.get_device_capability()
        return major * 10 + minor
    return 0


35
36
37
38
39
40
def _get_version():
    with open(root / "pyproject.toml") as f:
        for line in f:
            if line.startswith("version"):
                return line.split("=")[1].strip().strip('"')

41

42
operator_namespace = "sgl_kernels"
43
44
cutlass_default = root / "3rdparty" / "cutlass"
cutlass = Path(os.environ.get("CUSTOM_CUTLASS_SRC_DIR", default=cutlass_default))
45
flashinfer = root / "3rdparty" / "flashinfer"
46
turbomind = root / "3rdparty" / "turbomind"
47
48
49
include_dirs = [
    cutlass.resolve() / "include",
    cutlass.resolve() / "tools" / "util" / "include",
50
    root / "src" / "sgl-kernel" / "include",
Ke Bao's avatar
Ke Bao committed
51
    root / "src" / "sgl-kernel" / "csrc",
52
    flashinfer.resolve() / "include",
53
    flashinfer.resolve() / "include" / "gemm",
54
    flashinfer.resolve() / "csrc",
55
56
    "cublas",
    "cublasLt",
57
58
    turbomind.resolve(),
    turbomind.resolve() / "src",
59
]
60

Ke Bao's avatar
Ke Bao committed
61
nvcc_flags = [
62
    "-DNDEBUG",
63
    f"-DOPERATOR_NAMESPACE={operator_namespace}",
Ke Bao's avatar
Ke Bao committed
64
65
66
67
68
69
70
    "-O3",
    "-Xcompiler",
    "-fPIC",
    "-gencode=arch=compute_75,code=sm_75",
    "-gencode=arch=compute_80,code=sm_80",
    "-gencode=arch=compute_89,code=sm_89",
    "-gencode=arch=compute_90,code=sm_90",
71
72
73
    "-std=c++17",
    "-use_fast_math",
    "-DFLASHINFER_ENABLE_F16",
74
75
    "-Xcompiler",
    "-w",
Ke Bao's avatar
Ke Bao committed
76
]
77
78
79
80
81
nvcc_flags_fp8 = [
    "-DFLASHINFER_ENABLE_FP8",
    "-DFLASHINFER_ENABLE_FP8_E4M3",
    "-DFLASHINFER_ENABLE_FP8_E5M2",
]
82

83
sources = [
84
    "src/sgl-kernel/torch_extension.cc",
85
86
87
88
    "src/sgl-kernel/csrc/trt_reduce_internal.cu",
    "src/sgl-kernel/csrc/trt_reduce_kernel.cu",
    "src/sgl-kernel/csrc/moe_align_kernel.cu",
    "src/sgl-kernel/csrc/int8_gemm_kernel.cu",
89
    "src/sgl-kernel/csrc/fp8_gemm_kernel.cu",
90
91
92
93
94
95
96
    "src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu",
    "src/sgl-kernel/csrc/rotary_embedding.cu",
    "3rdparty/flashinfer/csrc/activation.cu",
    "3rdparty/flashinfer/csrc/bmm_fp8.cu",
    "3rdparty/flashinfer/csrc/norm.cu",
    "3rdparty/flashinfer/csrc/sampling.cu",
    "3rdparty/flashinfer/csrc/renorm.cu",
97
    "3rdparty/flashinfer/csrc/rope.cu",
98
99
]

100
101
102
103
104
105
106
107
108
109
enable_bf16 = os.getenv("SGL_KERNEL_ENABLE_BF16", "0") == "1"
enable_fp8 = os.getenv("SGL_KERNEL_ENABLE_FP8", "0") == "1"
enable_sm90a = os.getenv("SGL_KERNEL_ENABLE_SM90A", "0") == "1"
cuda_version = _get_cuda_version()
sm_version = _get_device_sm()

if torch.cuda.is_available():
    if cuda_version >= (12, 0) and sm_version >= 90:
        nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a")
    if sm_version >= 90:
110
        nvcc_flags.extend(nvcc_flags_fp8)
111
112
113
114
115
116
117
    if sm_version >= 80:
        nvcc_flags.append("-DFLASHINFER_ENABLE_BF16")
else:
    # compilation environment without GPU
    if enable_sm90a:
        nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a")
    if enable_fp8:
118
        nvcc_flags.extend(nvcc_flags_fp8)
119
120
    if enable_bf16:
        nvcc_flags.append("-DFLASHINFER_ENABLE_BF16")
121

122
123
124
125
126
127
128
129
130
131
for flag in [
    "-D__CUDA_NO_HALF_OPERATORS__",
    "-D__CUDA_NO_HALF_CONVERSIONS__",
    "-D__CUDA_NO_BFLOAT16_CONVERSIONS__",
    "-D__CUDA_NO_HALF2_OPERATORS__",
]:
    try:
        torch.utils.cpp_extension.COMMON_NVCC_FLAGS.remove(flag)
    except ValueError:
        pass
132

Ke Bao's avatar
Ke Bao committed
133
cxx_flags = ["-O3"]
134
libraries = ["c10", "torch", "torch_python", "cuda", "cublas", "cublasLt"]
135
extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", "-L/usr/lib/x86_64-linux-gnu"]
136

Ke Bao's avatar
Ke Bao committed
137
138
139
ext_modules = [
    CUDAExtension(
        name="sgl_kernel.ops._kernels",
140
        sources=sources,
141
        include_dirs=include_dirs,
Ke Bao's avatar
Ke Bao committed
142
143
144
145
146
147
        extra_compile_args={
            "nvcc": nvcc_flags,
            "cxx": cxx_flags,
        },
        libraries=libraries,
        extra_link_args=extra_link_args,
148
        py_limited_api=True,
Ke Bao's avatar
Ke Bao committed
149
150
151
    ),
]

152
153
setup(
    name="sgl-kernel",
154
    version=_get_version(),
lukec's avatar
lukec committed
155
    packages=find_packages(),
156
    package_dir={"": "src"},
Ke Bao's avatar
Ke Bao committed
157
    ext_modules=ext_modules,
158
159
160
161
162
    cmdclass={
        "build_ext": BuildExtension.with_options(
            use_ninja=True, max_jobs=multiprocessing.cpu_count()
        )
    },
163
    options={"bdist_wheel": {"py_limited_api": "cp39"}},
164
)
Yineng Zhang's avatar
Yineng Zhang committed
165

166
_update_wheel_platform_tag()