Unverified Commit 0b9557fc authored by Qiaolin Yu's avatar Qiaolin Yu Committed by GitHub
Browse files

Disable compiling arch below sm_90 in aarch64 by default (#6380)

parent 87068b5c
......@@ -83,6 +83,15 @@ if(CCACHE_FOUND AND ENABLE_CCACHE AND DEFINED ENV{CCACHE_DIR})
set_property(GLOBAL PROPERTY RULE_LAUNCH_LINK "ccache")
endif()
# Enable gencode below SM90
option(ENABLE_BELOW_SM90 "Enable below SM90" ON)
if (CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64")
set(ENABLE_BELOW_SM90 OFF)
message(STATUS "For aarch64, disable gencode below SM90 by default")
endif()
include_directories(
${PROJECT_SOURCE_DIR}/include
${PROJECT_SOURCE_DIR}/csrc
......@@ -98,9 +107,6 @@ set(SGL_KERNEL_CUDA_FLAGS
"-O3"
"-Xcompiler"
"-fPIC"
"-gencode=arch=compute_75,code=sm_75"
"-gencode=arch=compute_80,code=sm_80"
"-gencode=arch=compute_89,code=sm_89"
"-gencode=arch=compute_90,code=sm_90"
"-std=c++17"
"-DFLASHINFER_ENABLE_F16"
......@@ -130,6 +136,14 @@ option(SGL_KERNEL_ENABLE_FP8 "Enable FP8" ON)
option(SGL_KERNEL_ENABLE_FP4 "Enable FP4" OFF)
option(SGL_KERNEL_ENABLE_FA3 "Enable FA3" OFF)
if (ENABLE_BELOW_SM90)
list(APPEND SGL_KERNEL_CUDA_FLAGS
"-gencode=arch=compute_75,code=sm_75"
"-gencode=arch=compute_80,code=sm_80"
"-gencode=arch=compute_89,code=sm_89"
)
endif()
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "13.0" OR SGL_KERNEL_ENABLE_SM100A)
list(APPEND SGL_KERNEL_CUDA_FLAGS
"-gencode=arch=compute_100,code=sm_110"
......@@ -253,8 +267,6 @@ if (SGL_KERNEL_ENABLE_FA3)
"-O3"
"-Xcompiler"
"-fPIC"
"-gencode=arch=compute_80,code=sm_80"
"-gencode=arch=compute_86,code=sm_86"
"-gencode=arch=compute_90a,code=sm_90a"
"-std=c++17"
"-DCUTE_USE_PACKED_TUPLE=1"
......@@ -270,9 +282,15 @@ if (SGL_KERNEL_ENABLE_FA3)
"-Xcompiler=-fno-strict-aliasing"
)
# SM8X Logic
file(GLOB FA3_SM8X_GEN_SRCS
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdim*_sm80.cu")
if (ENABLE_BELOW_SM90)
list(APPEND SGL_FLASH_KERNEL_CUDA_FLAGS
"-gencode=arch=compute_80,code=sm_80"
"-gencode=arch=compute_86,code=sm_86"
)
# SM8X Logic
file(GLOB FA3_SM8X_GEN_SRCS
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdim*_sm80.cu")
endif()
file(GLOB FA3_BF16_GEN_SRCS
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_bf16*_sm90.cu")
......@@ -313,14 +331,17 @@ if (SGL_KERNEL_ENABLE_FA3)
target_link_libraries(flash_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda)
install(TARGETS flash_ops LIBRARY DESTINATION "sgl_kernel")
target_compile_definitions(flash_ops PRIVATE
# FLASHATTENTION_DISABLE_SM8x
set(FLASH_OPS_COMPILE_DEFS
FLASHATTENTION_DISABLE_BACKWARD
FLASHATTENTION_DISABLE_DROPOUT
FLASHATTENTION_DISABLE_UNEVEN_K
FLASHATTENTION_VARLEN_ONLY
)
if(NOT ENABLE_BELOW_SM90)
list(APPEND FLASH_OPS_COMPILE_DEFS FLASHATTENTION_DISABLE_SM8x)
endif()
target_compile_definitions(flash_ops PRIVATE ${FLASH_OPS_COMPILE_DEFS})
endif()
# JIT Logic
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment