CMakeLists.txt 16.7 KB
Newer Older
1
2
3
cmake_minimum_required(VERSION 3.26 FATAL_ERROR)
project(sgl-kernel LANGUAGES CXX CUDA)

4
# CMake
5
cmake_policy(SET CMP0169 OLD)
6
include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
7
8
9
10
set(CMAKE_COLOR_DIAGNOSTICS ON)
set(CMAKE_VERBOSE_MAKEFILE ON CACHE BOOL "ON")
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_SHARED_LIBRARY_PREFIX "")
11

12
# Python
13
find_package(Python COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT} REQUIRED)
14

15
16
17
18
# CXX
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3")

Yineng Zhang's avatar
Yineng Zhang committed
19
# CUDA
20
21
enable_language(CUDA)
find_package(CUDAToolkit REQUIRED)
22
set_property(GLOBAL PROPERTY CUDA_SEPARABLE_COMPILATION ON)
23
24

message(STATUS "Detected CUDA_VERSION=${CUDA_VERSION}")
Johnny's avatar
Johnny committed
25
26
27
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "13.0")
    message("CUDA_VERSION ${CUDA_VERSION} >= 13.0")
elseif ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8")
28
29
30
31
32
33
34
35
36
    message("CUDA_VERSION ${CUDA_VERSION} >= 12.8")
elseif ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.4")
    message("CUDA_VERSION ${CUDA_VERSION} >= 12.4")
elseif ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.1")
    message("CUDA_VERSION ${CUDA_VERSION} >= 12.1")
elseif ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "11.8")
    message("CUDA_VERSION ${CUDA_VERSION} >= 11.8")
endif()

37
# Torch
38
find_package(Torch REQUIRED)
39
40
# clean Torch Flag
clear_cuda_arches(CMAKE_FLAG)
41
42
43

include(FetchContent)

44
# cutlass
45
46
47
FetchContent_Declare(
    repo-cutlass
    GIT_REPOSITORY https://github.com/NVIDIA/cutlass
48
    GIT_TAG        a49a78ffefc86a87160dfe0ccc3a3a2d1622c918
Zhiqiang Xie's avatar
Zhiqiang Xie committed
49
    GIT_SHALLOW    OFF
50
51
)
FetchContent_Populate(repo-cutlass)
52

53
54
55
56
57
58
59
# DeepGEMM
if("${CUDA_VERSION}" VERSION_EQUAL "12.8")
  set(DeepGEMM_REPO "https://github.com/sgl-project/DeepGEMM")
  set(DeepGEMM_TAG "blackwell")
elseif("${CUDA_VERSION}" VERSION_EQUAL "12.9")
  set(DeepGEMM_REPO "https://github.com/sgl-project/DeepGEMM")
  set(DeepGEMM_TAG "blackwell")
60
61
62
elseif("${CUDA_VERSION}" VERSION_EQUAL "13.0")
  set(DeepGEMM_REPO "https://github.com/sgl-project/DeepGEMM")
  set(DeepGEMM_TAG "blackwell")
63
64
65
66
67
else()
  set(DeepGEMM_REPO "https://github.com/deepseek-ai/DeepGEMM")
  set(DeepGEMM_TAG "391755ada0ffefa9a6a52b6f14dcaf22d1a463e0")
endif()

68
69
FetchContent_Declare(
    repo-deepgemm
70
71
    GIT_REPOSITORY ${DeepGEMM_REPO}
    GIT_TAG        ${DeepGEMM_TAG}
Zhiqiang Xie's avatar
Zhiqiang Xie committed
72
    GIT_SHALLOW    OFF
73
74
)
FetchContent_Populate(repo-deepgemm)
75
76
77
78
79
80
81
82
83
84

# Triton
FetchContent_Declare(
    repo-triton
    GIT_REPOSITORY "https://github.com/triton-lang/triton"
    GIT_TAG        8f9f695ea8fde23a0c7c88e4ab256634ca27789f
    GIT_SHALLOW    OFF
)
FetchContent_Populate(repo-triton)

85
# flashinfer
86
87
FetchContent_Declare(
    repo-flashinfer
88
    GIT_REPOSITORY https://github.com/flashinfer-ai/flashinfer.git
89
    GIT_TAG        018b551825c8e5579206e6eb9d3229fa679202b3
90
91
92
    GIT_SHALLOW    OFF
)
FetchContent_Populate(repo-flashinfer)
93

94
95
96
97
# flash-attention
FetchContent_Declare(
    repo-flash-attention
    GIT_REPOSITORY https://github.com/sgl-project/sgl-attn
98
99
    GIT_TAG        sgl-kernel
    GIT_SHALLOW    OFF
100
101
)
FetchContent_Populate(repo-flash-attention)
102

103
104
105
106
107
108
109
110
# mscclpp
FetchContent_Declare(
    repo-mscclpp
    GIT_REPOSITORY https://github.com/microsoft/mscclpp.git
    GIT_TAG        51eca89d20f0cfb3764ccd764338d7b22cd486a6
    GIT_SHALLOW    OFF
)
FetchContent_Populate(repo-mscclpp)
111

112
113
114
115
116
117
118
119
# ccache option
option(ENABLE_CCACHE "Whether to use ccache" ON)
find_program(CCACHE_FOUND ccache)
if(CCACHE_FOUND AND ENABLE_CCACHE AND DEFINED ENV{CCACHE_DIR})
    message(STATUS "Building with CCACHE enabled")
    set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "ccache")
    set_property(GLOBAL PROPERTY RULE_LAUNCH_LINK "ccache")
endif()
120

121
122
123
124
125
126
127
128
# Enable gencode below SM90
option(ENABLE_BELOW_SM90 "Enable below SM90" ON)

if (CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64")
    set(ENABLE_BELOW_SM90 OFF)
    message(STATUS "For aarch64, disable gencode below SM90 by default")
endif()

129
130
131
132
133
134
135
include_directories(
    ${PROJECT_SOURCE_DIR}/include
    ${PROJECT_SOURCE_DIR}/csrc
    ${repo-cutlass_SOURCE_DIR}/include
    ${repo-cutlass_SOURCE_DIR}/tools/util/include
    ${repo-flashinfer_SOURCE_DIR}/include
    ${repo-flashinfer_SOURCE_DIR}/csrc
136
    ${repo-mscclpp_SOURCE_DIR}/include
137
138
139
140
141
142
143
144
145
146
147
)

set(SGL_KERNEL_CUDA_FLAGS
    "-DNDEBUG"
    "-DOPERATOR_NAMESPACE=sgl-kernel"
    "-O3"
    "-Xcompiler"
    "-fPIC"
    "-gencode=arch=compute_90,code=sm_90"
    "-std=c++17"
    "-DFLASHINFER_ENABLE_F16"
148
    "-DCUTE_USE_PACKED_TUPLE=1"
149
150
151
152
153
154
    "-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1"
    "-DCUTLASS_VERSIONS_GENERATED"
    "-DCUTLASS_TEST_LEVEL=0"
    "-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1"
    "-DCUTLASS_DEBUG_TRACE_LEVEL=0"
    "--expt-relaxed-constexpr"
155
156
    "--expt-extended-lambda"
    "--threads=32"
157

158
    # Suppress warnings
159
160
    "-Xcompiler=-Wconversion"
    "-Xcompiler=-fno-strict-aliasing"
161
162
163
164

    # uncomment to debug
    # "--ptxas-options=-v"
    # "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage"
165
166
)

167
168
169
170
171
172
option(SGL_KERNEL_ENABLE_SM100A           "Enable SM100A"           OFF)
option(SGL_KERNEL_ENABLE_SM90A            "Enable SM90A"            OFF)
option(SGL_KERNEL_ENABLE_BF16             "Enable BF16"             ON)
option(SGL_KERNEL_ENABLE_FP8              "Enable FP8"              ON)
option(SGL_KERNEL_ENABLE_FP4              "Enable FP4"              OFF)
option(SGL_KERNEL_ENABLE_FA3              "Enable FA3"              OFF)
173

174
175
176
177
178
179
180
if (ENABLE_BELOW_SM90)
    list(APPEND SGL_KERNEL_CUDA_FLAGS
        "-gencode=arch=compute_80,code=sm_80"
        "-gencode=arch=compute_89,code=sm_89"
    )
endif()

181
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_SM100A)
182
183
184
    list(APPEND SGL_KERNEL_CUDA_FLAGS
        "-gencode=arch=compute_100,code=sm_100"
        "-gencode=arch=compute_100a,code=sm_100a"
185
186
        "-gencode=arch=compute_103,code=sm_103"
        "-gencode=arch=compute_103a,code=sm_103a"
zhjunqin's avatar
zhjunqin committed
187
        "-gencode=arch=compute_120,code=sm_120"
188
        "-gencode=arch=compute_120a,code=sm_120a"
189
    )
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206

    # refer sm_121, sm_110 and sm_101 description  https://github.com/pytorch/pytorch/pull/156176
    if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "13.0")
        list(APPEND SGL_KERNEL_CUDA_FLAGS
            "-gencode=arch=compute_110,code=sm_110"
            "-gencode=arch=compute_110a,code=sm_110a"
            "-gencode=arch=compute_121,code=sm_121"
            "-gencode=arch=compute_121a,code=sm_121a"
            "--compress-mode=size"
        )
    else()
        list(APPEND SGL_KERNEL_CUDA_FLAGS
            "-gencode=arch=compute_101,code=sm_101"
            "-gencode=arch=compute_101a,code=sm_101a"
        )
    endif()

207
208
209
210
211
212
213
else()
    list(APPEND SGL_KERNEL_CUDA_FLAGS
        "-use_fast_math"
    )
endif()

if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.4" OR SGL_KERNEL_ENABLE_SM90A)
214
    set(SGL_KERNEL_ENABLE_FA3 ON)
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
    list(APPEND SGL_KERNEL_CUDA_FLAGS
        "-gencode=arch=compute_90a,code=sm_90a"
    )
endif()

if (SGL_KERNEL_ENABLE_BF16)
    list(APPEND SGL_KERNEL_CUDA_FLAGS
        "-DFLASHINFER_ENABLE_BF16"
    )
endif()

if (SGL_KERNEL_ENABLE_FP8)
    list(APPEND SGL_KERNEL_CUDA_FLAGS
        "-DFLASHINFER_ENABLE_FP8"
        "-DFLASHINFER_ENABLE_FP8_E4M3"
        "-DFLASHINFER_ENABLE_FP8_E5M2"
    )
endif()

234
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_FP4)
235
236
237
238
239
    list(APPEND SGL_KERNEL_CUDA_FLAGS
        "-DENABLE_NVFP4=1"
    )
endif()

240
241
242
243
244
245
string(REPLACE "-D__CUDA_NO_HALF_OPERATORS__"       "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
string(REPLACE "-D__CUDA_NO_HALF_CONVERSIONS__"     "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
string(REPLACE "-D__CUDA_NO_BFLOAT16_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
string(REPLACE "-D__CUDA_NO_HALF2_OPERATORS__"      "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")

set(SOURCES
246
    "csrc/allreduce/custom_all_reduce.cu"
247
    "csrc/allreduce/mscclpp_allreduce.cu"
Yineng Zhang's avatar
Yineng Zhang committed
248
    "csrc/attention/cascade.cu"
249
    "csrc/attention/cutlass_mla_kernel.cu"
250
    "csrc/attention/lightning_attention_decode_kernel.cu"
251
252
    "csrc/attention/merge_attn_states.cu"
    "csrc/attention/vertical_slash_index.cu"
253
    "csrc/elementwise/activation.cu"
254
    "csrc/elementwise/cast.cu"
255
256
    "csrc/elementwise/fused_add_rms_norm_kernel.cu"
    "csrc/elementwise/rope.cu"
257
    "csrc/common_extension.cc"
258

259
260
    "csrc/gemm/awq_kernel.cu"
    "csrc/gemm/bmm_fp8.cu"
261
    "csrc/gemm/dsv3_fused_a_gemm.cu"
262
263
264
    "csrc/gemm/dsv3_router_gemm_bf16_out.cu"
    "csrc/gemm/dsv3_router_gemm_entry.cu"
    "csrc/gemm/dsv3_router_gemm_float_out.cu"
265
266
267
    "csrc/gemm/fp8_blockwise_gemm_kernel.cu"
    "csrc/gemm/fp8_gemm_kernel.cu"
    "csrc/gemm/int8_gemm_kernel.cu"
268
    "csrc/gemm/nvfp4_expert_quant.cu"
269
270
271
272
273
274
275
    "csrc/gemm/nvfp4_quant_entry.cu"
    "csrc/gemm/nvfp4_quant_kernels.cu"
    "csrc/gemm/nvfp4_scaled_mm_entry.cu"
    "csrc/gemm/nvfp4_scaled_mm_kernels.cu"
    "csrc/gemm/per_tensor_quant_fp8.cu"
    "csrc/gemm/per_token_group_quant_8bit.cu"
    "csrc/gemm/per_token_quant_fp8.cu"
HandH1998's avatar
HandH1998 committed
276
277
    "csrc/gemm/qserve_w4a8_per_chn_gemm.cu"
    "csrc/gemm/qserve_w4a8_per_group_gemm.cu"
278
279
280
281
    "csrc/gemm/marlin/gptq_marlin.cu"
    "csrc/gemm/marlin/gptq_marlin_repack.cu"
    "csrc/gemm/marlin/awq_marlin_repack.cu"
    "csrc/gemm/gptq/gptq_kernel.cu"
282

283
    "csrc/grammar/apply_token_bitmask_inplace_cuda.cu"
284

285
286
287
    "csrc/moe/cutlass_moe/w4a8/scaled_mm_entry.cu"
    "csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu"
    "csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu"
288
    "csrc/moe/marlin_moe_wna16/ops.cu"
289
290
291
292
293
294
295
296
    "csrc/moe/moe_align_kernel.cu"
    "csrc/moe/moe_fused_gate.cu"
    "csrc/moe/moe_topk_softmax_kernels.cu"
    "csrc/moe/nvfp4_blockwise_moe.cu"
    "csrc/moe/fp8_blockwise_moe_kernel.cu"
    "csrc/moe/prepare_moe_input.cu"
    "csrc/moe/ep_moe_reorder_kernel.cu"
    "csrc/moe/ep_moe_silu_and_mul_kernel.cu"
297
298

    "csrc/memory/store.cu"
299
    "csrc/kvcacheio/transfer.cu"
300

301
302
303
    "csrc/speculative/eagle_utils.cu"
    "csrc/speculative/packbit.cu"
    "csrc/speculative/speculative_sampling.cu"
304

305
306
307
    "${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu"
    "${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu"
    "${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu"
308

309
310
311
312
313
    "${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_causal_sm80.cu"
    "${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_sm80.cu"
    "${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_fp16_causal_sm80.cu"
    "${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_fp16_sm80.cu"
    "${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/flash_sparse_api.cpp"
314
315
)

316
317
318
319
Python_add_library(common_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES})

target_compile_options(common_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERNEL_CUDA_FLAGS}>)
target_include_directories(common_ops PRIVATE
320
321
322
323
    ${repo-cutlass_SOURCE_DIR}/examples/77_blackwell_fmha
    ${repo-cutlass_SOURCE_DIR}/examples/common
    ${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src
)
324
325
set_source_files_properties("csrc/gemm/per_token_group_quant_8bit" PROPERTIES COMPILE_OPTIONS "--use_fast_math")

326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341

find_package(Python3 COMPONENTS Interpreter REQUIRED)
execute_process(
    COMMAND ${Python3_EXECUTABLE} -c "import torch; print(int(torch._C._GLIBCXX_USE_CXX11_ABI))"
    OUTPUT_VARIABLE TORCH_CXX11_ABI
    OUTPUT_STRIP_TRAILING_WHITESPACE
)
if(TORCH_CXX11_ABI STREQUAL "0")
    message(STATUS "Using old C++ ABI (-D_GLIBCXX_USE_CXX11_ABI=0)")
    set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0")
    set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0")
else()
    message(STATUS "Using new C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI=1)")
    set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=1")
    set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=1")
endif()
342
343

# mscclpp
344
345
346
set(MSCCLPP_USE_CUDA ON)
set(MSCCLPP_BYPASS_GPU_CHECK ON)
set(MSCCLPP_BUILD_TESTS OFF)
347
348
349
350
add_subdirectory(
    ${repo-mscclpp_SOURCE_DIR}
    ${CMAKE_CURRENT_BINARY_DIR}/mscclpp-build
)
351
target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt mscclpp_static)
352

353
# flash attention
354
target_compile_definitions(common_ops PRIVATE
355
356
357
358
    FLASHATTENTION_DISABLE_BACKWARD
    FLASHATTENTION_DISABLE_DROPOUT
    FLASHATTENTION_DISABLE_UNEVEN_K
)
359

360
install(TARGETS common_ops LIBRARY DESTINATION sgl_kernel)
361
362

# ============================ Optional Install ============================= #
363
# set flash-attention sources file
364
# Now FA3 support sm80/sm86/sm90
365
if (SGL_KERNEL_ENABLE_FA3)
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
    set(SGL_FLASH_KERNEL_CUDA_FLAGS
        "-DNDEBUG"
        "-DOPERATOR_NAMESPACE=sgl-kernel"
        "-O3"
        "-Xcompiler"
        "-fPIC"
        "-gencode=arch=compute_90a,code=sm_90a"
        "-std=c++17"
        "-DCUTE_USE_PACKED_TUPLE=1"
        "-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1"
        "-DCUTLASS_VERSIONS_GENERATED"
        "-DCUTLASS_TEST_LEVEL=0"
        "-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1"
        "-DCUTLASS_DEBUG_TRACE_LEVEL=0"
        "--expt-relaxed-constexpr"
        "--expt-extended-lambda"
        "--use_fast_math"
        "-Xcompiler=-Wconversion"
        "-Xcompiler=-fno-strict-aliasing"
    )

387
388
389
390
391
392
393
394
395
    if (ENABLE_BELOW_SM90)
        list(APPEND SGL_FLASH_KERNEL_CUDA_FLAGS
            "-gencode=arch=compute_80,code=sm_80"
            "-gencode=arch=compute_86,code=sm_86"
        )
        # SM8X Logic
        file(GLOB FA3_SM8X_GEN_SRCS
            "${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdim*_sm80.cu")
    endif()
396

397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
    file(GLOB FA3_BF16_GEN_SRCS
        "${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_bf16*_sm90.cu")
    file(GLOB FA3_BF16_GEN_SRCS_
        "${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_bf16*_sm90.cu")
    list(APPEND FA3_BF16_GEN_SRCS ${FA3_BF16_GEN_SRCS_})

    # FP16 source files
    file(GLOB FA3_FP16_GEN_SRCS
        "${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_fp16*_sm90.cu")
    file(GLOB FA3_FP16_GEN_SRCS_
        "${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_fp16*_sm90.cu")
    list(APPEND FA3_FP16_GEN_SRCS ${FA3_FP16_GEN_SRCS_})

    # FP8 source files
    file(GLOB FA3_FP8_GEN_SRCS
        "${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_e4m3*_sm90.cu")
    file(GLOB FA3_FP8_GEN_SRCS_
        "${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_e4m3*_sm90.cu")
    list(APPEND FA3_FP8_GEN_SRCS ${FA3_FP8_GEN_SRCS_})

417
    set(FA3_GEN_SRCS ${FA3_BF16_GEN_SRCS} ${FA3_FP16_GEN_SRCS} ${FA3_FP8_GEN_SRCS} ${FA3_SM8X_GEN_SRCS})
418
419
420
421
422
423
424
425
426
427
428
429

    set(FLASH_SOURCES
        "csrc/flash_extension.cc"
        "${repo-flash-attention_SOURCE_DIR}/hopper/flash_prepare_scheduler.cu"
        "${repo-flash-attention_SOURCE_DIR}/hopper/flash_api.cpp"
        "${repo-flash-attention_SOURCE_DIR}/hopper/flash_fwd_combine.cu"
        "${FA3_GEN_SRCS}"
    )

    Python_add_library(flash_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${FLASH_SOURCES})

    target_compile_options(flash_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_FLASH_KERNEL_CUDA_FLAGS}>)
430
    target_include_directories(flash_ops PRIVATE
431
432
        ${repo-flash-attention_SOURCE_DIR}/hopper
    )
433
434
435
    target_link_libraries(flash_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda)

    install(TARGETS flash_ops LIBRARY DESTINATION "sgl_kernel")
436
    set(FLASH_OPS_COMPILE_DEFS
437
438
439
440
441
        FLASHATTENTION_DISABLE_BACKWARD
        FLASHATTENTION_DISABLE_DROPOUT
        FLASHATTENTION_DISABLE_UNEVEN_K
        FLASHATTENTION_VARLEN_ONLY
    )
442
443
444
445
446

    if(NOT ENABLE_BELOW_SM90)
        list(APPEND FLASH_OPS_COMPILE_DEFS FLASHATTENTION_DISABLE_SM8x)
    endif()
    target_compile_definitions(flash_ops PRIVATE ${FLASH_OPS_COMPILE_DEFS})
447
448
endif()

449
450
451
452
453
454
455
456
457
458
459
460
# Build spatial_ops as a separate, optional extension for green contexts
set(SPATIAL_SOURCES
    "csrc/spatial/greenctx_stream.cu"
    "csrc/spatial_extension.cc"
)

Python_add_library(spatial_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SPATIAL_SOURCES})
target_compile_options(spatial_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERNEL_CUDA_FLAGS}>)
target_link_libraries(spatial_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda)
install(TARGETS spatial_ops LIBRARY DESTINATION sgl_kernel)


461
462
463
464
465
466
# ============================ DeepGEMM (JIT) ============================= #
# Create a separate library for DeepGEMM's Python API.
# This keeps its compilation isolated from the main common_ops.
set(DEEPGEMM_SOURCES
    "${repo-deepgemm_SOURCE_DIR}/csrc/python_api.cpp"
)
467
468
# JIT Logic
# DeepGEMM
469

470
471
472
473
install(DIRECTORY "${repo-deepgemm_SOURCE_DIR}/deep_gemm/"
        DESTINATION "deep_gemm"
        PATTERN ".git*" EXCLUDE
        PATTERN "__pycache__" EXCLUDE)
474
475
476
477
478
479

install(DIRECTORY "${repo-cutlass_SOURCE_DIR}/include/cute/"
        DESTINATION "deep_gemm/include/cute")

install(DIRECTORY "${repo-cutlass_SOURCE_DIR}/include/cutlass/"
        DESTINATION "deep_gemm/include/cutlass")
480
481
482
483
484
485

# triton_kernels
install(DIRECTORY "${repo-triton_SOURCE_DIR}/python/triton_kernels/triton_kernels/"
        DESTINATION "triton_kernels"
        PATTERN ".git*" EXCLUDE
        PATTERN "__pycache__" EXCLUDE)