CMakeLists.txt 12.6 KB
Newer Older
1
2
3
cmake_minimum_required(VERSION 3.26 FATAL_ERROR)
project(sgl-kernel LANGUAGES CXX CUDA)

4
# CMake
5
cmake_policy(SET CMP0169 OLD)
6
include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
7
8
9
10
set(CMAKE_COLOR_DIAGNOSTICS ON)
set(CMAKE_VERBOSE_MAKEFILE ON CACHE BOOL "ON")
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_SHARED_LIBRARY_PREFIX "")
11

12
# Python
13
find_package(Python COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT} REQUIRED)
14

15
16
17
18
# CXX
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3")

Yineng Zhang's avatar
Yineng Zhang committed
19
# CUDA
20
21
enable_language(CUDA)
find_package(CUDAToolkit REQUIRED)
22
set_property(GLOBAL PROPERTY CUDA_SEPARABLE_COMPILATION ON)
23
24

message(STATUS "Detected CUDA_VERSION=${CUDA_VERSION}")
Johnny's avatar
Johnny committed
25
26
27
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "13.0")
    message("CUDA_VERSION ${CUDA_VERSION} >= 13.0")
elseif ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8")
28
29
30
31
32
33
34
35
36
    message("CUDA_VERSION ${CUDA_VERSION} >= 12.8")
elseif ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.4")
    message("CUDA_VERSION ${CUDA_VERSION} >= 12.4")
elseif ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.1")
    message("CUDA_VERSION ${CUDA_VERSION} >= 12.1")
elseif ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "11.8")
    message("CUDA_VERSION ${CUDA_VERSION} >= 11.8")
endif()

37
# Torch
38
find_package(Torch REQUIRED)
39
40
# clean Torch Flag
clear_cuda_arches(CMAKE_FLAG)
41
42
43

include(FetchContent)

44
# cutlass
45
46
47
FetchContent_Declare(
    repo-cutlass
    GIT_REPOSITORY https://github.com/NVIDIA/cutlass
Elfie Guo's avatar
Elfie Guo committed
48
    GIT_TAG        f115c3f85467d5d9619119d1dbeb9c03c3d73864
Zhiqiang Xie's avatar
Zhiqiang Xie committed
49
    GIT_SHALLOW    OFF
50
51
)
FetchContent_Populate(repo-cutlass)
52
# DeepGEMM
53
54
55
FetchContent_Declare(
    repo-deepgemm
    GIT_REPOSITORY https://github.com/deepseek-ai/DeepGEMM
Yineng Zhang's avatar
Yineng Zhang committed
56
    GIT_TAG        d75b218b7b8f4a5dd5406ac87905039ead3ae42f
Zhiqiang Xie's avatar
Zhiqiang Xie committed
57
    GIT_SHALLOW    OFF
58
59
)
FetchContent_Populate(repo-deepgemm)
60
# flashinfer
61
62
FetchContent_Declare(
    repo-flashinfer
63
64
    GIT_REPOSITORY https://github.com/flashinfer-ai/flashinfer.git
    GIT_TAG        9220fb3443b5a5d274f00ca5552f798e225239b7
65
66
67
    GIT_SHALLOW    OFF
)
FetchContent_Populate(repo-flashinfer)
68
69
70
71
# flash-attention
FetchContent_Declare(
    repo-flash-attention
    GIT_REPOSITORY https://github.com/sgl-project/sgl-attn
72
73
    GIT_TAG        sgl-kernel
    GIT_SHALLOW    OFF
74
75
76
)
FetchContent_Populate(repo-flash-attention)

77
78
79
80
81
82
83
84
# ccache option
option(ENABLE_CCACHE "Whether to use ccache" ON)
find_program(CCACHE_FOUND ccache)
if(CCACHE_FOUND AND ENABLE_CCACHE AND DEFINED ENV{CCACHE_DIR})
    message(STATUS "Building with CCACHE enabled")
    set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "ccache")
    set_property(GLOBAL PROPERTY RULE_LAUNCH_LINK "ccache")
endif()
85

86
87
88
89
90
91
92
93
94
# Enable gencode below SM90
option(ENABLE_BELOW_SM90 "Enable below SM90" ON)

if (CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64")
    set(ENABLE_BELOW_SM90 OFF)
    message(STATUS "For aarch64, disable gencode below SM90 by default")
endif()


95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
include_directories(
    ${PROJECT_SOURCE_DIR}/include
    ${PROJECT_SOURCE_DIR}/csrc
    ${repo-cutlass_SOURCE_DIR}/include
    ${repo-cutlass_SOURCE_DIR}/tools/util/include
    ${repo-flashinfer_SOURCE_DIR}/include
    ${repo-flashinfer_SOURCE_DIR}/csrc
)

set(SGL_KERNEL_CUDA_FLAGS
    "-DNDEBUG"
    "-DOPERATOR_NAMESPACE=sgl-kernel"
    "-O3"
    "-Xcompiler"
    "-fPIC"
    "-gencode=arch=compute_90,code=sm_90"
    "-std=c++17"
    "-DFLASHINFER_ENABLE_F16"
113
    "-DCUTE_USE_PACKED_TUPLE=1"
114
115
116
117
118
119
    "-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1"
    "-DCUTLASS_VERSIONS_GENERATED"
    "-DCUTLASS_TEST_LEVEL=0"
    "-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1"
    "-DCUTLASS_DEBUG_TRACE_LEVEL=0"
    "--expt-relaxed-constexpr"
120
121
    "--expt-extended-lambda"
    "--threads=32"
122

123
    # Suppress warnings
124
125
    "-Xcompiler=-Wconversion"
    "-Xcompiler=-fno-strict-aliasing"
126
127
128
129

    # uncomment to debug
    # "--ptxas-options=-v"
    # "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage"
130
131
)

132
133
134
135
136
137
option(SGL_KERNEL_ENABLE_SM100A           "Enable SM100A"           OFF)
option(SGL_KERNEL_ENABLE_SM90A            "Enable SM90A"            OFF)
option(SGL_KERNEL_ENABLE_BF16             "Enable BF16"             ON)
option(SGL_KERNEL_ENABLE_FP8              "Enable FP8"              ON)
option(SGL_KERNEL_ENABLE_FP4              "Enable FP4"              OFF)
option(SGL_KERNEL_ENABLE_FA3              "Enable FA3"              OFF)
138

139
140
141
142
143
144
145
146
if (ENABLE_BELOW_SM90)
    list(APPEND SGL_KERNEL_CUDA_FLAGS
        "-gencode=arch=compute_75,code=sm_75"
        "-gencode=arch=compute_80,code=sm_80"
        "-gencode=arch=compute_89,code=sm_89"
    )
endif()

Johnny's avatar
Johnny committed
147
148
149
150
151
152
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "13.0" OR SGL_KERNEL_ENABLE_SM100A)
    list(APPEND SGL_KERNEL_CUDA_FLAGS
        "-gencode=arch=compute_100,code=sm_110"
        "-gencode=arch=compute_100a,code=sm_110a"
    )
elseif ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_SM100A)
153
154
    list(APPEND SGL_KERNEL_CUDA_FLAGS
        "-gencode=arch=compute_100,code=sm_100"
Johnny's avatar
Johnny committed
155
156
        "-gencode=arch=compute_100,code=sm_101"
        "-gencode=arch=compute_100,code=sm_101a"
157
        "-gencode=arch=compute_100a,code=sm_100a"
zhjunqin's avatar
zhjunqin committed
158
        "-gencode=arch=compute_120,code=sm_120"
159
160
161
162
163
164
165
166
    )
else()
    list(APPEND SGL_KERNEL_CUDA_FLAGS
        "-use_fast_math"
    )
endif()

if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.4" OR SGL_KERNEL_ENABLE_SM90A)
167
    set(SGL_KERNEL_ENABLE_FA3 ON)
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
    list(APPEND SGL_KERNEL_CUDA_FLAGS
        "-gencode=arch=compute_90a,code=sm_90a"
    )
endif()

if (SGL_KERNEL_ENABLE_BF16)
    list(APPEND SGL_KERNEL_CUDA_FLAGS
        "-DFLASHINFER_ENABLE_BF16"
    )
endif()

if (SGL_KERNEL_ENABLE_FP8)
    list(APPEND SGL_KERNEL_CUDA_FLAGS
        "-DFLASHINFER_ENABLE_FP8"
        "-DFLASHINFER_ENABLE_FP8_E4M3"
        "-DFLASHINFER_ENABLE_FP8_E5M2"
    )
endif()

187
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_FP4)
188
189
190
191
192
    list(APPEND SGL_KERNEL_CUDA_FLAGS
        "-DENABLE_NVFP4=1"
    )
endif()

193
194
195
196
197
198
string(REPLACE "-D__CUDA_NO_HALF_OPERATORS__"       "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
string(REPLACE "-D__CUDA_NO_HALF_CONVERSIONS__"     "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
string(REPLACE "-D__CUDA_NO_BFLOAT16_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
string(REPLACE "-D__CUDA_NO_HALF2_OPERATORS__"      "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")

set(SOURCES
199
    "csrc/allreduce/custom_all_reduce.cu"
Yineng Zhang's avatar
Yineng Zhang committed
200
    "csrc/attention/cascade.cu"
201
    "csrc/attention/merge_attn_states.cu"
202
    "csrc/attention/cutlass_mla_kernel.cu"
203
    "csrc/attention/vertical_slash_index.cu"
204
205
206
207
208
209
210
211
212
    "csrc/attention/lightning_attention_decode_kernel.cu"
    "csrc/elementwise/activation.cu"
    "csrc/elementwise/fused_add_rms_norm_kernel.cu"
    "csrc/elementwise/rope.cu"
    "csrc/gemm/awq_kernel.cu"
    "csrc/gemm/bmm_fp8.cu"
    "csrc/gemm/fp8_blockwise_gemm_kernel.cu"
    "csrc/gemm/fp8_gemm_kernel.cu"
    "csrc/gemm/int8_gemm_kernel.cu"
213
    "csrc/gemm/nvfp4_expert_quant.cu"
214
215
216
217
218
219
220
    "csrc/gemm/nvfp4_quant_entry.cu"
    "csrc/gemm/nvfp4_quant_kernels.cu"
    "csrc/gemm/nvfp4_scaled_mm_entry.cu"
    "csrc/gemm/nvfp4_scaled_mm_kernels.cu"
    "csrc/gemm/per_tensor_quant_fp8.cu"
    "csrc/gemm/per_token_group_quant_8bit.cu"
    "csrc/gemm/per_token_quant_fp8.cu"
HandH1998's avatar
HandH1998 committed
221
222
    "csrc/gemm/qserve_w4a8_per_chn_gemm.cu"
    "csrc/gemm/qserve_w4a8_per_group_gemm.cu"
223
    "csrc/moe/moe_align_kernel.cu"
224
    "csrc/moe/moe_fused_gate.cu"
225
    "csrc/moe/moe_topk_softmax_kernels.cu"
226
    "csrc/moe/nvfp4_blockwise_moe.cu"
227
    "csrc/moe/fp8_blockwise_moe_kernel.cu"
228
    "csrc/moe/prepare_moe_input.cu"
229
    "csrc/moe/ep_moe_reorder_kernel.cu"
230
231
232
    "csrc/speculative/eagle_utils.cu"
    "csrc/speculative/speculative_sampling.cu"
    "csrc/speculative/packbit.cu"
233
    "csrc/grammar/apply_token_bitmask_inplace_cuda.cu"
234
    "csrc/common_extension.cc"
235
236
237
    "${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu"
    "${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu"
    "${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu"
238
239
240
241
242
    "${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_causal_sm80.cu"
    "${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_sm80.cu"
    "${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_fp16_causal_sm80.cu"
    "${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_fp16_sm80.cu"
    "${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/flash_sparse_api.cpp"
243
244
)

245
246
247
248
Python_add_library(common_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES})

target_compile_options(common_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERNEL_CUDA_FLAGS}>)
target_include_directories(common_ops PRIVATE
249
250
251
252
    ${repo-cutlass_SOURCE_DIR}/examples/77_blackwell_fmha
    ${repo-cutlass_SOURCE_DIR}/examples/common
    ${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src
)
253
254
255
target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt)

target_compile_definitions(common_ops PRIVATE
256
257
258
259
    FLASHATTENTION_DISABLE_BACKWARD
    FLASHATTENTION_DISABLE_DROPOUT
    FLASHATTENTION_DISABLE_UNEVEN_K
)
260

261
install(TARGETS common_ops LIBRARY DESTINATION sgl_kernel)
262
263

# ============================ Optional Install ============================= #
264
# set flash-attention sources file
265
# Now FA3 support sm80/sm86/sm90
266
if (SGL_KERNEL_ENABLE_FA3)
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
    set(SGL_FLASH_KERNEL_CUDA_FLAGS
        "-DNDEBUG"
        "-DOPERATOR_NAMESPACE=sgl-kernel"
        "-O3"
        "-Xcompiler"
        "-fPIC"
        "-gencode=arch=compute_90a,code=sm_90a"
        "-std=c++17"
        "-DCUTE_USE_PACKED_TUPLE=1"
        "-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1"
        "-DCUTLASS_VERSIONS_GENERATED"
        "-DCUTLASS_TEST_LEVEL=0"
        "-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1"
        "-DCUTLASS_DEBUG_TRACE_LEVEL=0"
        "--expt-relaxed-constexpr"
        "--expt-extended-lambda"
        "--use_fast_math"
        "-Xcompiler=-Wconversion"
        "-Xcompiler=-fno-strict-aliasing"
    )

288
289
290
291
292
293
294
295
296
    if (ENABLE_BELOW_SM90)
        list(APPEND SGL_FLASH_KERNEL_CUDA_FLAGS
            "-gencode=arch=compute_80,code=sm_80"
            "-gencode=arch=compute_86,code=sm_86"
        )
        # SM8X Logic
        file(GLOB FA3_SM8X_GEN_SRCS
            "${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdim*_sm80.cu")
    endif()
297

298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
    file(GLOB FA3_BF16_GEN_SRCS
        "${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_bf16*_sm90.cu")
    file(GLOB FA3_BF16_GEN_SRCS_
        "${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_bf16*_sm90.cu")
    list(APPEND FA3_BF16_GEN_SRCS ${FA3_BF16_GEN_SRCS_})

    # FP16 source files
    file(GLOB FA3_FP16_GEN_SRCS
        "${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_fp16*_sm90.cu")
    file(GLOB FA3_FP16_GEN_SRCS_
        "${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_fp16*_sm90.cu")
    list(APPEND FA3_FP16_GEN_SRCS ${FA3_FP16_GEN_SRCS_})

    # FP8 source files
    file(GLOB FA3_FP8_GEN_SRCS
        "${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_e4m3*_sm90.cu")
    file(GLOB FA3_FP8_GEN_SRCS_
        "${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_e4m3*_sm90.cu")
    list(APPEND FA3_FP8_GEN_SRCS ${FA3_FP8_GEN_SRCS_})

318
    set(FA3_GEN_SRCS ${FA3_BF16_GEN_SRCS} ${FA3_FP16_GEN_SRCS} ${FA3_FP8_GEN_SRCS} ${FA3_SM8X_GEN_SRCS})
319
320
321
322
323
324
325
326
327
328
329
330

    set(FLASH_SOURCES
        "csrc/flash_extension.cc"
        "${repo-flash-attention_SOURCE_DIR}/hopper/flash_prepare_scheduler.cu"
        "${repo-flash-attention_SOURCE_DIR}/hopper/flash_api.cpp"
        "${repo-flash-attention_SOURCE_DIR}/hopper/flash_fwd_combine.cu"
        "${FA3_GEN_SRCS}"
    )

    Python_add_library(flash_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${FLASH_SOURCES})

    target_compile_options(flash_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_FLASH_KERNEL_CUDA_FLAGS}>)
331
    target_include_directories(flash_ops PRIVATE
332
333
        ${repo-flash-attention_SOURCE_DIR}/hopper
    )
334
335
336
    target_link_libraries(flash_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda)

    install(TARGETS flash_ops LIBRARY DESTINATION "sgl_kernel")
337
    set(FLASH_OPS_COMPILE_DEFS
338
339
340
341
342
        FLASHATTENTION_DISABLE_BACKWARD
        FLASHATTENTION_DISABLE_DROPOUT
        FLASHATTENTION_DISABLE_UNEVEN_K
        FLASHATTENTION_VARLEN_ONLY
    )
343
344
345
346
347

    if(NOT ENABLE_BELOW_SM90)
        list(APPEND FLASH_OPS_COMPILE_DEFS FLASHATTENTION_DISABLE_SM8x)
    endif()
    target_compile_definitions(flash_ops PRIVATE ${FLASH_OPS_COMPILE_DEFS})
348
349
endif()

350
351
352
353
354
355
356
357
358
359
360
361
362
# JIT Logic
# DeepGEMM

install(DIRECTORY "${repo-deepgemm_SOURCE_DIR}/deep_gemm/"
        DESTINATION "deep_gemm"
        PATTERN ".git*" EXCLUDE
        PATTERN "__pycache__" EXCLUDE)

install(DIRECTORY "${repo-cutlass_SOURCE_DIR}/include/cute/"
        DESTINATION "deep_gemm/include/cute")

install(DIRECTORY "${repo-cutlass_SOURCE_DIR}/include/cutlass/"
        DESTINATION "deep_gemm/include/cutlass")