Unverified Commit 0b9557fc authored by Qiaolin Yu's avatar Qiaolin Yu Committed by GitHub
Browse files

Disable compiling arch below sm_90 in aarch64 by default (#6380)

parent 87068b5c
...@@ -83,6 +83,15 @@ if(CCACHE_FOUND AND ENABLE_CCACHE AND DEFINED ENV{CCACHE_DIR}) ...@@ -83,6 +83,15 @@ if(CCACHE_FOUND AND ENABLE_CCACHE AND DEFINED ENV{CCACHE_DIR})
set_property(GLOBAL PROPERTY RULE_LAUNCH_LINK "ccache") set_property(GLOBAL PROPERTY RULE_LAUNCH_LINK "ccache")
endif() endif()
# Enable gencode below SM90
option(ENABLE_BELOW_SM90 "Enable below SM90" ON)
if (CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64")
set(ENABLE_BELOW_SM90 OFF)
message(STATUS "For aarch64, disable gencode below SM90 by default")
endif()
include_directories( include_directories(
${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/include
${PROJECT_SOURCE_DIR}/csrc ${PROJECT_SOURCE_DIR}/csrc
...@@ -98,9 +107,6 @@ set(SGL_KERNEL_CUDA_FLAGS ...@@ -98,9 +107,6 @@ set(SGL_KERNEL_CUDA_FLAGS
"-O3" "-O3"
"-Xcompiler" "-Xcompiler"
"-fPIC" "-fPIC"
"-gencode=arch=compute_75,code=sm_75"
"-gencode=arch=compute_80,code=sm_80"
"-gencode=arch=compute_89,code=sm_89"
"-gencode=arch=compute_90,code=sm_90" "-gencode=arch=compute_90,code=sm_90"
"-std=c++17" "-std=c++17"
"-DFLASHINFER_ENABLE_F16" "-DFLASHINFER_ENABLE_F16"
...@@ -130,6 +136,14 @@ option(SGL_KERNEL_ENABLE_FP8 "Enable FP8" ON) ...@@ -130,6 +136,14 @@ option(SGL_KERNEL_ENABLE_FP8 "Enable FP8" ON)
option(SGL_KERNEL_ENABLE_FP4 "Enable FP4" OFF) option(SGL_KERNEL_ENABLE_FP4 "Enable FP4" OFF)
option(SGL_KERNEL_ENABLE_FA3 "Enable FA3" OFF) option(SGL_KERNEL_ENABLE_FA3 "Enable FA3" OFF)
if (ENABLE_BELOW_SM90)
list(APPEND SGL_KERNEL_CUDA_FLAGS
"-gencode=arch=compute_75,code=sm_75"
"-gencode=arch=compute_80,code=sm_80"
"-gencode=arch=compute_89,code=sm_89"
)
endif()
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "13.0" OR SGL_KERNEL_ENABLE_SM100A) if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "13.0" OR SGL_KERNEL_ENABLE_SM100A)
list(APPEND SGL_KERNEL_CUDA_FLAGS list(APPEND SGL_KERNEL_CUDA_FLAGS
"-gencode=arch=compute_100,code=sm_110" "-gencode=arch=compute_100,code=sm_110"
...@@ -253,8 +267,6 @@ if (SGL_KERNEL_ENABLE_FA3) ...@@ -253,8 +267,6 @@ if (SGL_KERNEL_ENABLE_FA3)
"-O3" "-O3"
"-Xcompiler" "-Xcompiler"
"-fPIC" "-fPIC"
"-gencode=arch=compute_80,code=sm_80"
"-gencode=arch=compute_86,code=sm_86"
"-gencode=arch=compute_90a,code=sm_90a" "-gencode=arch=compute_90a,code=sm_90a"
"-std=c++17" "-std=c++17"
"-DCUTE_USE_PACKED_TUPLE=1" "-DCUTE_USE_PACKED_TUPLE=1"
...@@ -270,9 +282,15 @@ if (SGL_KERNEL_ENABLE_FA3) ...@@ -270,9 +282,15 @@ if (SGL_KERNEL_ENABLE_FA3)
"-Xcompiler=-fno-strict-aliasing" "-Xcompiler=-fno-strict-aliasing"
) )
# SM8X Logic if (ENABLE_BELOW_SM90)
file(GLOB FA3_SM8X_GEN_SRCS list(APPEND SGL_FLASH_KERNEL_CUDA_FLAGS
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdim*_sm80.cu") "-gencode=arch=compute_80,code=sm_80"
"-gencode=arch=compute_86,code=sm_86"
)
# SM8X Logic
file(GLOB FA3_SM8X_GEN_SRCS
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdim*_sm80.cu")
endif()
file(GLOB FA3_BF16_GEN_SRCS file(GLOB FA3_BF16_GEN_SRCS
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_bf16*_sm90.cu") "${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_bf16*_sm90.cu")
...@@ -313,14 +331,17 @@ if (SGL_KERNEL_ENABLE_FA3) ...@@ -313,14 +331,17 @@ if (SGL_KERNEL_ENABLE_FA3)
target_link_libraries(flash_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda) target_link_libraries(flash_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda)
install(TARGETS flash_ops LIBRARY DESTINATION "sgl_kernel") install(TARGETS flash_ops LIBRARY DESTINATION "sgl_kernel")
set(FLASH_OPS_COMPILE_DEFS
target_compile_definitions(flash_ops PRIVATE
# FLASHATTENTION_DISABLE_SM8x
FLASHATTENTION_DISABLE_BACKWARD FLASHATTENTION_DISABLE_BACKWARD
FLASHATTENTION_DISABLE_DROPOUT FLASHATTENTION_DISABLE_DROPOUT
FLASHATTENTION_DISABLE_UNEVEN_K FLASHATTENTION_DISABLE_UNEVEN_K
FLASHATTENTION_VARLEN_ONLY FLASHATTENTION_VARLEN_ONLY
) )
if(NOT ENABLE_BELOW_SM90)
list(APPEND FLASH_OPS_COMPILE_DEFS FLASHATTENTION_DISABLE_SM8x)
endif()
target_compile_definitions(flash_ops PRIVATE ${FLASH_OPS_COMPILE_DEFS})
endif() endif()
# JIT Logic # JIT Logic
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment