setup.py 9.57 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright 2025 SGLang Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

16
import os
17
import shutil
18
import sys
19
20
from pathlib import Path

21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# Setup flash_mla at the top level for tests to find
# This makes the module importable without installation
root_dir = Path(__file__).parent.resolve()
module_src = root_dir / "3rdparty" / "flashmla" / "flash_mla"
module_dest = root_dir / "flash_mla"

if module_src.exists() and not module_dest.exists():
    try:
        os.symlink(module_src, module_dest, target_is_directory=True)
        print(f"Created symbolic link from {module_src} to {module_dest}")
    except (OSError, NotImplementedError):
        if module_src.exists():
            shutil.copytree(module_src, module_dest)
            print(f"Copied directory from {module_src} to {module_dest}")

36
import torch
lukec's avatar
lukec committed
37
from setuptools import find_packages, setup
38
from setuptools.command.build_py import build_py
39
40
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

41
42
43
root = Path(__file__).parent.resolve()


44
45
if "bdist_wheel" in sys.argv and "--plat-name" not in sys.argv:
    sys.argv.extend(["--plat-name", "manylinux2014_x86_64"])
Yineng Zhang's avatar
Yineng Zhang committed
46
47


48
def _get_cuda_version():
49
50
51
52
53
    if torch.version.cuda:
        return tuple(map(int, torch.version.cuda.split(".")))
    return (0, 0)


54
def _get_device_sm():
55
56
57
58
59
60
    if torch.cuda.is_available():
        major, minor = torch.cuda.get_device_capability()
        return major * 10 + minor
    return 0


61
62
63
64
65
66
def _get_version():
    with open(root / "pyproject.toml") as f:
        for line in f:
            if line.startswith("version"):
                return line.split("=")[1].strip().strip('"')

67

68
operator_namespace = "sgl_kernel"
69
70
cutlass_default = root / "3rdparty" / "cutlass"
cutlass = Path(os.environ.get("CUSTOM_CUTLASS_SRC_DIR", default=cutlass_default))
71
flashinfer = root / "3rdparty" / "flashinfer"
72
deepgemm = root / "3rdparty" / "deepgemm"
73
flashmla = root / "3rdparty" / "flashmla"
74
include_dirs = [
75
76
    root / "include",
    root / "csrc",
77
78
    cutlass.resolve() / "include",
    cutlass.resolve() / "tools" / "util" / "include",
79
    flashinfer.resolve() / "include",
80
    flashinfer.resolve() / "include" / "gemm",
81
    flashinfer.resolve() / "csrc",
82
    flashmla.resolve() / "csrc",
83
    "cublas",
84
]
85

86
87
88
89

class CustomBuildPy(build_py):
    def run(self):
        self.copy_deepgemm_to_build_lib()
90
        self.copy_flashmla_to_build_lib()
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
        self.make_jit_include_symlinks()
        build_py.run(self)

    def make_jit_include_symlinks(self):
        # Make symbolic links of third-party include directories
        build_include_dir = os.path.join(self.build_lib, "deep_gemm/include")
        os.makedirs(build_include_dir, exist_ok=True)

        third_party_include_dirs = [
            cutlass.resolve() / "include" / "cute",
            cutlass.resolve() / "include" / "cutlass",
        ]

        for d in third_party_include_dirs:
            dirname = str(d).split("/")[-1]
            src_dir = d
            dst_dir = f"{build_include_dir}/{dirname}"
            assert os.path.exists(src_dir)
            if os.path.exists(dst_dir):
                assert os.path.islink(dst_dir)
                os.unlink(dst_dir)
            os.symlink(src_dir, dst_dir, target_is_directory=True)

114
115
116
117
118
119
120
121
122
123
124
        # Create symbolic links for FlashMLA
        flash_mla_include_dir = os.path.join(self.build_lib, "flash_mla/include")
        os.makedirs(flash_mla_include_dir, exist_ok=True)

        # Create empty directories for FlashMLA's include paths
        # This is safer than creating symlinks as the targets might not exist in CI
        for dirname in ["cute", "cutlass"]:
            dst_dir = f"{flash_mla_include_dir}/{dirname}"
            if not os.path.exists(dst_dir):
                os.makedirs(dst_dir, exist_ok=True)

125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
    def copy_deepgemm_to_build_lib(self):
        """
        This function copies DeepGemm to python's site-packages
        """
        dst_dir = os.path.join(self.build_lib, "deep_gemm")
        os.makedirs(dst_dir, exist_ok=True)

        # Copy deepgemm/deep_gemm to the build directory
        src_dir = os.path.join(str(deepgemm.resolve()), "deep_gemm")

        # Remove existing directory if it exists
        if os.path.exists(dst_dir):
            shutil.rmtree(dst_dir)

        # Copy the directory
        shutil.copytree(src_dir, dst_dir)

142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
    def copy_flashmla_to_build_lib(self):
        """
        This function copies FlashMLA to python's site-packages
        """
        dst_dir = os.path.join(self.build_lib, "flash_mla")
        os.makedirs(dst_dir, exist_ok=True)

        src_dir = os.path.join(str(flashmla.resolve()), "flash_mla")

        if not os.path.exists(src_dir):
            print(
                f"Warning: Source directory {src_dir} does not exist, possibly the submodule is not properly initialized"
            )
            return

        if os.path.exists(dst_dir):
            shutil.rmtree(dst_dir)

        shutil.copytree(src_dir, dst_dir)

162

Ke Bao's avatar
Ke Bao committed
163
nvcc_flags = [
164
    "-DNDEBUG",
165
    f"-DOPERATOR_NAMESPACE={operator_namespace}",
Ke Bao's avatar
Ke Bao committed
166
167
168
169
170
171
172
    "-O3",
    "-Xcompiler",
    "-fPIC",
    "-gencode=arch=compute_75,code=sm_75",
    "-gencode=arch=compute_80,code=sm_80",
    "-gencode=arch=compute_89,code=sm_89",
    "-gencode=arch=compute_90,code=sm_90",
173
174
    "-std=c++17",
    "-DFLASHINFER_ENABLE_F16",
175
    "-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1",
176
177
178
179
180
181
    "-DCUTLASS_VERSIONS_GENERATED",
    "-DCUTE_USE_PACKED_TUPLE=1",
    "-DCUTLASS_TEST_LEVEL=0",
    "-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1",
    "-DCUTLASS_DEBUG_TRACE_LEVEL=0",
    "--ptxas-options=-v",
182
    "--expt-relaxed-constexpr",
183
184
    "-Xcompiler=-Wconversion",
    "-Xcompiler=-fno-strict-aliasing",
Ke Bao's avatar
Ke Bao committed
185
]
186
187
188
189
190
nvcc_flags_fp8 = [
    "-DFLASHINFER_ENABLE_FP8",
    "-DFLASHINFER_ENABLE_FP8_E4M3",
    "-DFLASHINFER_ENABLE_FP8_E5M2",
]
191

192
sources = [
193
194
195
    "csrc/allreduce/trt_reduce_internal.cu",
    "csrc/allreduce/trt_reduce_kernel.cu",
    "csrc/attention/lightning_attention_decode_kernel.cu",
196
    "csrc/elementwise/activation.cu",
197
    "csrc/elementwise/fused_add_rms_norm_kernel.cu",
198
199
    "csrc/elementwise/rope.cu",
    "csrc/gemm/bmm_fp8.cu",
200
    "csrc/gemm/cublas_grouped_gemm.cu",
201
    "csrc/gemm/awq_kernel.cu",
202
203
204
205
206
207
208
    "csrc/gemm/fp8_gemm_kernel.cu",
    "csrc/gemm/fp8_blockwise_gemm_kernel.cu",
    "csrc/gemm/int8_gemm_kernel.cu",
    "csrc/gemm/per_token_group_quant_fp8.cu",
    "csrc/gemm/per_token_quant_fp8.cu",
    "csrc/gemm/per_tensor_quant_fp8.cu",
    "csrc/moe/moe_align_kernel.cu",
209
    "csrc/moe/moe_topk_softmax_kernels.cu",
210
211
    "csrc/speculative/eagle_utils.cu",
    "csrc/speculative/speculative_sampling.cu",
212
    "csrc/speculative/packbit.cu",
213
    "csrc/torch_extension.cc",
214
215
    "3rdparty/flashinfer/csrc/norm.cu",
    "3rdparty/flashinfer/csrc/renorm.cu",
216
    "3rdparty/flashinfer/csrc/sampling.cu",
217
218
]

219
220
221
enable_bf16 = os.getenv("SGL_KERNEL_ENABLE_BF16", "0") == "1"
enable_fp8 = os.getenv("SGL_KERNEL_ENABLE_FP8", "0") == "1"
enable_sm90a = os.getenv("SGL_KERNEL_ENABLE_SM90A", "0") == "1"
222
enable_sm100a = os.getenv("SGL_KERNEL_ENABLE_SM100A", "0") == "1"
223
224
225
226
227
228
cuda_version = _get_cuda_version()
sm_version = _get_device_sm()

if torch.cuda.is_available():
    if cuda_version >= (12, 0) and sm_version >= 90:
        nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a")
229
230
231
    if cuda_version >= (12, 8) and sm_version >= 100:
        nvcc_flags.append("-gencode=arch=compute_100,code=sm_100")
        nvcc_flags.append("-gencode=arch=compute_100a,code=sm_100a")
Yineng Zhang's avatar
Yineng Zhang committed
232
233
    else:
        nvcc_flags.append("-use_fast_math")
234
    if sm_version >= 90:
235
        nvcc_flags.extend(nvcc_flags_fp8)
236
237
238
239
240
241
    if sm_version >= 80:
        nvcc_flags.append("-DFLASHINFER_ENABLE_BF16")
else:
    # compilation environment without GPU
    if enable_sm90a:
        nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a")
242
243
    if enable_sm100a:
        nvcc_flags.append("-gencode=arch=compute_100a,code=sm_100a")
Yineng Zhang's avatar
Yineng Zhang committed
244
245
    else:
        nvcc_flags.append("-use_fast_math")
246
    if enable_fp8:
247
        nvcc_flags.extend(nvcc_flags_fp8)
248
249
    if enable_bf16:
        nvcc_flags.append("-DFLASHINFER_ENABLE_BF16")
250

251
252
253
254
255
256
257
258
259
260
for flag in [
    "-D__CUDA_NO_HALF_OPERATORS__",
    "-D__CUDA_NO_HALF_CONVERSIONS__",
    "-D__CUDA_NO_BFLOAT16_CONVERSIONS__",
    "-D__CUDA_NO_HALF2_OPERATORS__",
]:
    try:
        torch.utils.cpp_extension.COMMON_NVCC_FLAGS.remove(flag)
    except ValueError:
        pass
261

Ke Bao's avatar
Ke Bao committed
262
cxx_flags = ["-O3"]
Yineng Zhang's avatar
Yineng Zhang committed
263
libraries = ["c10", "torch", "torch_python", "cuda", "cublas"]
264
extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", "-L/usr/lib/x86_64-linux-gnu"]
265

Ke Bao's avatar
Ke Bao committed
266
267
ext_modules = [
    CUDAExtension(
268
        name="sgl_kernel.common_ops",
269
        sources=sources,
270
        include_dirs=include_dirs,
Ke Bao's avatar
Ke Bao committed
271
272
273
274
275
276
        extra_compile_args={
            "nvcc": nvcc_flags,
            "cxx": cxx_flags,
        },
        libraries=libraries,
        extra_link_args=extra_link_args,
277
        py_limited_api=True,
Ke Bao's avatar
Ke Bao committed
278
279
280
    ),
]

281
282
setup(
    name="sgl-kernel",
283
    version=_get_version(),
284
285
    packages=find_packages(where="python"),
    package_dir={"": "python"},
Ke Bao's avatar
Ke Bao committed
286
    ext_modules=ext_modules,
287
288
289
290
    cmdclass={
        "build_ext": BuildExtension.with_options(use_ninja=True),
        "build_py": CustomBuildPy,
    },
291
    options={"bdist_wheel": {"py_limited_api": "cp39"}},
292
)