Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b8b6008f
Unverified
Commit
b8b6008f
authored
Apr 04, 2025
by
yinfan98
Committed by
GitHub
Apr 03, 2025
Browse files
[Fix] fix fa3 build at cu118 (#5036)
parent
8e10fec9
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
288 additions
and
142 deletions
+288
-142
sgl-kernel/CMakeLists.txt
sgl-kernel/CMakeLists.txt
+85
-50
sgl-kernel/cmake/utils.cmake
sgl-kernel/cmake/utils.cmake
+21
-0
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+1
-40
sgl-kernel/csrc/flash_extension.cc
sgl-kernel/csrc/flash_extension.cc
+62
-0
sgl-kernel/include/sgl_flash_kernel_ops.h
sgl-kernel/include/sgl_flash_kernel_ops.h
+85
-0
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+0
-47
sgl-kernel/python/sgl_kernel/flash_attn.py
sgl-kernel/python/sgl_kernel/flash_attn.py
+15
-4
sgl-kernel/tests/test_flash_attention.py
sgl-kernel/tests/test_flash_attention.py
+19
-1
No files found.
sgl-kernel/CMakeLists.txt
View file @
b8b6008f
cmake_minimum_required
(
VERSION 3.26 FATAL_ERROR
)
project
(
sgl-kernel LANGUAGES CXX CUDA
)
# we only want to download 3rd, but not build them.
# FetchContent_MakeAvailable will build it.
cmake_policy
(
SET CMP0169 OLD
)
include
(
${
CMAKE_CURRENT_LIST_DIR
}
/cmake/utils.cmake
)
set
(
BUILD_FA3, OFF
)
find_package
(
Python COMPONENTS Interpreter Development.Module
${
SKBUILD_SABI_COMPONENT
}
REQUIRED
)
enable_language
(
CUDA
)
...
...
@@ -22,6 +24,8 @@ elseif ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "11.8")
endif
()
find_package
(
Torch REQUIRED
)
# clean Torch Flag
clear_cuda_arches
(
CMAKE_FLAG
)
include
(
FetchContent
)
...
...
@@ -53,8 +57,8 @@ FetchContent_Populate(repo-flashinfer)
FetchContent_Declare
(
repo-flash-attention
GIT_REPOSITORY https://github.com/sgl-project/sgl-attn
GIT_TAG sgl-kernel
GIT_SHALLOW OFF
GIT_TAG
sgl-kernel
GIT_SHALLOW
OFF
)
FetchContent_Populate
(
repo-flash-attention
)
...
...
@@ -92,14 +96,13 @@ set(SGL_KERNEL_CUDA_FLAGS
"-gencode=arch=compute_90,code=sm_90"
"-std=c++17"
"-DFLASHINFER_ENABLE_F16"
"-DCUTE_USE_PACKED_TUPLE=1"
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1"
"-DCUTLASS_VERSIONS_GENERATED"
"-DCUTE_USE_PACKED_TUPLE=1"
"-DCUTLASS_TEST_LEVEL=0"
"-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1"
"-DCUTLASS_DEBUG_TRACE_LEVEL=0"
"--expt-relaxed-constexpr"
"--use_fast_math"
"-Xcompiler=-Wconversion"
"-Xcompiler=-fno-strict-aliasing"
)
...
...
@@ -122,6 +125,7 @@ else()
endif
()
if
(
"
${
CUDA_VERSION
}
"
VERSION_GREATER_EQUAL
"12.4"
OR SGL_KERNEL_ENABLE_SM90A
)
set
(
BUILD_FA3 ON
)
list
(
APPEND SGL_KERNEL_CUDA_FLAGS
"-gencode=arch=compute_90a,code=sm_90a"
)
...
...
@@ -152,30 +156,6 @@ string(REPLACE "-D__CUDA_NO_HALF_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE
string
(
REPLACE
"-D__CUDA_NO_BFLOAT16_CONVERSIONS__"
""
CMAKE_CUDA_FLAGS
"
${
CMAKE_CUDA_FLAGS
}
"
)
string
(
REPLACE
"-D__CUDA_NO_HALF2_OPERATORS__"
""
CMAKE_CUDA_FLAGS
"
${
CMAKE_CUDA_FLAGS
}
"
)
# set flash-attention sources file
# BF16 source files
file
(
GLOB FA3_BF16_GEN_SRCS
"
${
repo-flash-attention_SOURCE_DIR
}
/hopper/instantiations/flash_fwd_hdimall_bf16*_sm90.cu"
)
file
(
GLOB FA3_BF16_GEN_SRCS_
"
${
repo-flash-attention_SOURCE_DIR
}
/hopper/instantiations/flash_fwd_hdimdiff_bf16*_sm90.cu"
)
list
(
APPEND FA3_BF16_GEN_SRCS
${
FA3_BF16_GEN_SRCS_
}
)
# FP16 source files
file
(
GLOB FA3_FP16_GEN_SRCS
"
${
repo-flash-attention_SOURCE_DIR
}
/hopper/instantiations/flash_fwd_hdimall_fp16*_sm90.cu"
)
file
(
GLOB FA3_FP16_GEN_SRCS_
"
${
repo-flash-attention_SOURCE_DIR
}
/hopper/instantiations/flash_fwd_hdimdiff_fp16*_sm90.cu"
)
list
(
APPEND FA3_FP16_GEN_SRCS
${
FA3_FP16_GEN_SRCS_
}
)
# FP8 source files
file
(
GLOB FA3_FP8_GEN_SRCS
"
${
repo-flash-attention_SOURCE_DIR
}
/hopper/instantiations/flash_fwd_hdimall_e4m3*_sm90.cu"
)
file
(
GLOB FA3_FP8_GEN_SRCS_
"
${
repo-flash-attention_SOURCE_DIR
}
/hopper/instantiations/flash_fwd_hdimdiff_e4m3*_sm90.cu"
)
list
(
APPEND FA3_FP8_GEN_SRCS
${
FA3_FP8_GEN_SRCS_
}
)
set
(
FA3_GEN_SRCS
${
FA3_BF16_GEN_SRCS
}
${
FA3_FP16_GEN_SRCS
}
${
FA3_FP8_GEN_SRCS
}
)
set
(
SOURCES
"csrc/allreduce/trt_reduce_internal.cu"
"csrc/allreduce/trt_reduce_kernel.cu"
...
...
@@ -202,39 +182,94 @@ set(SOURCES
"csrc/speculative/eagle_utils.cu"
"csrc/speculative/speculative_sampling.cu"
"csrc/speculative/packbit.cu"
"csrc/
torch
_extension.cc"
"csrc/
common
_extension.cc"
"
${
repo-flashinfer_SOURCE_DIR
}
/csrc/norm.cu"
"
${
repo-flashinfer_SOURCE_DIR
}
/csrc/renorm.cu"
"
${
repo-flashinfer_SOURCE_DIR
}
/csrc/sampling.cu"
"
${
repo-flash-attention_SOURCE_DIR
}
/hopper/flash_prepare_scheduler.cu"
"
${
repo-flash-attention_SOURCE_DIR
}
/hopper/flash_api.cpp"
"
${
repo-flash-attention_SOURCE_DIR
}
/hopper/flash_fwd_combine.cu"
"
${
FA3_GEN_SRCS
}
"
)
# Support abi3 for build
# set flash-attention sources file
# BF16 source files
if
(
BUILD_FA3
)
set
(
SGL_FLASH_KERNEL_CUDA_FLAGS
"-DNDEBUG"
"-DOPERATOR_NAMESPACE=sgl-kernel"
"-O3"
"-Xcompiler"
"-fPIC"
"-gencode=arch=compute_90a,code=sm_90a"
"-std=c++17"
"-DCUTE_USE_PACKED_TUPLE=1"
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1"
"-DCUTLASS_VERSIONS_GENERATED"
"-DCUTLASS_TEST_LEVEL=0"
"-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1"
"-DCUTLASS_DEBUG_TRACE_LEVEL=0"
"--expt-relaxed-constexpr"
"--expt-extended-lambda"
"--use_fast_math"
"-Xcompiler=-Wconversion"
"-Xcompiler=-fno-strict-aliasing"
)
file
(
GLOB FA3_BF16_GEN_SRCS
"
${
repo-flash-attention_SOURCE_DIR
}
/hopper/instantiations/flash_fwd_hdimall_bf16*_sm90.cu"
)
file
(
GLOB FA3_BF16_GEN_SRCS_
"
${
repo-flash-attention_SOURCE_DIR
}
/hopper/instantiations/flash_fwd_hdimdiff_bf16*_sm90.cu"
)
list
(
APPEND FA3_BF16_GEN_SRCS
${
FA3_BF16_GEN_SRCS_
}
)
# FP16 source files
file
(
GLOB FA3_FP16_GEN_SRCS
"
${
repo-flash-attention_SOURCE_DIR
}
/hopper/instantiations/flash_fwd_hdimall_fp16*_sm90.cu"
)
file
(
GLOB FA3_FP16_GEN_SRCS_
"
${
repo-flash-attention_SOURCE_DIR
}
/hopper/instantiations/flash_fwd_hdimdiff_fp16*_sm90.cu"
)
list
(
APPEND FA3_FP16_GEN_SRCS
${
FA3_FP16_GEN_SRCS_
}
)
# FP8 source files
file
(
GLOB FA3_FP8_GEN_SRCS
"
${
repo-flash-attention_SOURCE_DIR
}
/hopper/instantiations/flash_fwd_hdimall_e4m3*_sm90.cu"
)
file
(
GLOB FA3_FP8_GEN_SRCS_
"
${
repo-flash-attention_SOURCE_DIR
}
/hopper/instantiations/flash_fwd_hdimdiff_e4m3*_sm90.cu"
)
list
(
APPEND FA3_FP8_GEN_SRCS
${
FA3_FP8_GEN_SRCS_
}
)
set
(
FA3_GEN_SRCS
${
FA3_BF16_GEN_SRCS
}
${
FA3_FP16_GEN_SRCS
}
${
FA3_FP8_GEN_SRCS
}
)
set
(
FLASH_SOURCES
"csrc/flash_extension.cc"
"
${
repo-flash-attention_SOURCE_DIR
}
/hopper/flash_prepare_scheduler.cu"
"
${
repo-flash-attention_SOURCE_DIR
}
/hopper/flash_api.cpp"
"
${
repo-flash-attention_SOURCE_DIR
}
/hopper/flash_fwd_combine.cu"
"
${
FA3_GEN_SRCS
}
"
)
Python_add_library
(
flash_ops MODULE USE_SABI
${
SKBUILD_SABI_VERSION
}
WITH_SOABI
${
FLASH_SOURCES
}
)
target_compile_options
(
flash_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:
${
SGL_FLASH_KERNEL_CUDA_FLAGS
}
>
)
target_include_directories
(
flash_ops PRIVATE
${
TORCH_INCLUDE_DIRS
}
)
target_link_libraries
(
flash_ops PRIVATE
${
TORCH_LIBRARIES
}
c10 cuda
)
install
(
TARGETS flash_ops LIBRARY DESTINATION
"sgl_kernel"
)
target_compile_definitions
(
flash_ops PRIVATE
FLASHATTENTION_DISABLE_SM8x
FLASHATTENTION_DISABLE_BACKWARD
FLASHATTENTION_DISABLE_DROPOUT
# FLASHATTENTION_DISABLE_ALIBI
# FLASHATTENTION_DISABLE_SOFTCAP
FLASHATTENTION_DISABLE_UNEVEN_K
# FLASHATTENTION_DISABLE_LOCAL
FLASHATTENTION_VARLEN_ONLY
)
endif
()
Python_add_library
(
common_ops MODULE USE_SABI
${
SKBUILD_SABI_VERSION
}
WITH_SOABI
${
SOURCES
}
)
target_compile_options
(
common_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:
${
SGL_KERNEL_CUDA_FLAGS
}
>
)
target_include_directories
(
common_ops PRIVATE
${
TORCH_INCLUDE_DIRS
}
)
target_link_libraries
(
common_ops PRIVATE
${
TORCH_LIBRARIES
}
c10 cuda cublas cublasLt
)
install
(
TARGETS common_ops LIBRARY DESTINATION
"sgl_kernel"
)
# Add some flash-attention custom flag for inference
target_compile_definitions
(
common_ops PRIVATE
FLASHATTENTION_DISABLE_SM8x
FLASHATTENTION_DISABLE_BACKWARD
FLASHATTENTION_DISABLE_DROPOUT
# FLASHATTENTION_DISABLE_ALIBI
# FLASHATTENTION_DISABLE_SOFTCAP
FLASHATTENTION_DISABLE_UNEVEN_K
# FLASHATTENTION_DISABLE_LOCAL
FLASHATTENTION_VARLEN_ONLY
)
# JIT Logic
# DeepGEMM
...
...
sgl-kernel/cmake/utils.cmake
0 → 100644
View file @
b8b6008f
# Adapt from: https://github.com/neuralmagic/vllm-flash-attention/blob/main/cmake/utils.cmake
#
# Clear all `-gencode` flags from `CMAKE_CUDA_FLAGS` and store them in
# `CUDA_ARCH_FLAGS`.
#
# Example:
# CMAKE_CUDA_FLAGS="-Wall -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75"
# clear_cuda_arches(CUDA_ARCH_FLAGS)
# CUDA_ARCH_FLAGS="-gencode arch=compute_70,code=sm_70;-gencode arch=compute_75,code=sm_75"
# CMAKE_CUDA_FLAGS="-Wall"
#
macro
(
clear_cuda_arches CUDA_ARCH_FLAGS
)
# Extract all `-gencode` flags from `CMAKE_CUDA_FLAGS`
string
(
REGEX MATCHALL
"-gencode arch=[^ ]+"
CUDA_ARCH_FLAGS
${
CMAKE_CUDA_FLAGS
}
)
# Remove all `-gencode` flags from `CMAKE_CUDA_FLAGS` since they will be modified
# and passed back via the `CUDA_ARCHITECTURES` property.
string
(
REGEX REPLACE
"-gencode arch=[^ ]+ *"
""
CMAKE_CUDA_FLAGS
${
CMAKE_CUDA_FLAGS
}
)
endmacro
()
sgl-kernel/csrc/
torch
_extension.cc
→
sgl-kernel/csrc/
common
_extension.cc
View file @
b8b6008f
...
...
@@ -18,7 +18,7 @@ limitations under the License.
#include "sgl_kernel_ops.h"
TORCH_LIBRARY_
EXPAND
(
sgl_kernel
,
m
)
{
TORCH_LIBRARY_
FRAGMENT
(
sgl_kernel
,
m
)
{
/*
* From csrc/allreduce
*/
...
...
@@ -202,45 +202,6 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
"top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
"maybe_top_p_arr, float top_p_val, bool deterministic, int cuda_stream) -> ()"
);
m
.
impl
(
"top_p_sampling_from_probs"
,
torch
::
kCUDA
,
&
top_p_sampling_from_probs
);
/*
* From flash-attention
*/
m
.
def
(
"fwd(Tensor! q,"
" Tensor k,"
" Tensor v,"
" Tensor? k_new,"
" Tensor? v_new,"
" Tensor? q_v,"
" Tensor!? out,"
" Tensor? cu_seqlens_q,"
" Tensor? cu_seqlens_k,"
" Tensor? cu_seqlens_k_new,"
" Tensor? seqused_q,"
" Tensor? seqused_k,"
" int? max_seqlen_q,"
" int? max_seqlen_k,"
" Tensor? page_table,"
" Tensor? kv_batch_idx,"
" Tensor? leftpad_k,"
" Tensor? rotary_cos,"
" Tensor? rotary_sin,"
" Tensor? seqlens_rotary,"
" Tensor? q_descale,"
" Tensor? k_descale,"
" Tensor? v_descale,"
" float softmax_scale,"
" bool is_causal,"
" int window_size_left,"
" int window_size_right,"
" float softcap,"
" bool is_rotary_interleaved,"
" Tensor? scheduler_metadata,"
" int num_splits,"
" bool? pack_gqa,"
" int sm_margin) -> Tensor[]"
);
m
.
impl
(
"fwd"
,
torch
::
kCUDA
,
make_pytorch_shim
(
&
mha_fwd
));
}
REGISTER_EXTENSION
(
common_ops
)
sgl-kernel/csrc/flash_extension.cc
0 → 100644
View file @
b8b6008f
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <ATen/core/dispatch/Dispatcher.h>
#include <torch/all.h>
#include <torch/library.h>
#include "sgl_flash_kernel_ops.h"
TORCH_LIBRARY_FRAGMENT
(
sgl_kernel
,
m
)
{
/*
* From flash-attention
*/
m
.
def
(
"fwd(Tensor! q,"
" Tensor k,"
" Tensor v,"
" Tensor? k_new,"
" Tensor? v_new,"
" Tensor? q_v,"
" Tensor!? out,"
" Tensor? cu_seqlens_q,"
" Tensor? cu_seqlens_k,"
" Tensor? cu_seqlens_k_new,"
" Tensor? seqused_q,"
" Tensor? seqused_k,"
" int? max_seqlen_q,"
" int? max_seqlen_k,"
" Tensor? page_table,"
" Tensor? kv_batch_idx,"
" Tensor? leftpad_k,"
" Tensor? rotary_cos,"
" Tensor? rotary_sin,"
" Tensor? seqlens_rotary,"
" Tensor? q_descale,"
" Tensor? k_descale,"
" Tensor? v_descale,"
" float softmax_scale,"
" bool is_causal,"
" int window_size_left,"
" int window_size_right,"
" float softcap,"
" bool is_rotary_interleaved,"
" Tensor? scheduler_metadata,"
" int num_splits,"
" bool? pack_gqa,"
" int sm_margin) -> Tensor[]"
);
m
.
impl
(
"fwd"
,
torch
::
kCUDA
,
make_pytorch_shim
(
&
mha_fwd
));
}
REGISTER_EXTENSION
(
flash_ops
)
sgl-kernel/include/sgl_flash_kernel_ops.h
0 → 100644
View file @
b8b6008f
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#pragma once
#include <ATen/ATen.h>
#include <ATen/Tensor.h>
#include <Python.h>
#include <torch/library.h>
#include <torch/torch.h>
#include <vector>
#include "sgl_kernel_torch_shim.h"
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
#define _CONCAT(A, B) A##B
#define CONCAT(A, B) _CONCAT(A, B)
#define _STRINGIFY(A) #A
#define STRINGIFY(A) _STRINGIFY(A)
#define REGISTER_EXTENSION(NAME) \
PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, STRINGIFY(NAME), nullptr, 0, nullptr}; \
return PyModule_Create(&module); \
}
/*
* From flash-attention
*/
std
::
vector
<
at
::
Tensor
>
mha_fwd
(
at
::
Tensor
&
q
,
// (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
const
at
::
Tensor
&
k
,
// (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size,
// h_k, d) if there is page_table.
const
at
::
Tensor
&
v
,
// (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages,
// page_size, h_k, dv) if there is page_table.
std
::
optional
<
const
at
::
Tensor
>&
k_new_
,
// (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new
std
::
optional
<
const
at
::
Tensor
>&
v_new_
,
// (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new
std
::
optional
<
const
at
::
Tensor
>&
q_v_
,
// (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q
std
::
optional
<
at
::
Tensor
>&
out_
,
// (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
std
::
optional
<
const
at
::
Tensor
>&
cu_seqlens_q_
,
// b+1
std
::
optional
<
const
at
::
Tensor
>&
cu_seqlens_k_
,
// b+1
std
::
optional
<
const
at
::
Tensor
>&
cu_seqlens_k_new_
,
// b+1
std
::
optional
<
const
at
::
Tensor
>&
seqused_q_
,
// b. If given, only this many elements of each batch element's queries and outputs are used.
std
::
optional
<
const
at
::
Tensor
>&
seqused_k_
,
// b. If given, only this many elements of each batch element's keys are used.
std
::
optional
<
int
>
max_seqlen_q_
,
// TODO: check if we need max_seqlen_k
std
::
optional
<
int
>
max_seqlen_k_
,
std
::
optional
<
const
at
::
Tensor
>&
page_table_
,
// (b_k, max_num_pages_per_seq)
std
::
optional
<
const
at
::
Tensor
>&
kv_batch_idx_
,
// b. indices to index into the KV cache
std
::
optional
<
const
at
::
Tensor
>&
leftpad_k_
,
// b
std
::
optional
<
const
at
::
Tensor
>&
rotary_cos_
,
// seqlen_ro x (rotary_dim / 2)
std
::
optional
<
const
at
::
Tensor
>&
rotary_sin_
,
// seqlen_ro x (rotary_dim / 2)
std
::
optional
<
const
at
::
Tensor
>&
seqlens_rotary_
,
// b
std
::
optional
<
at
::
Tensor
>&
q_descale_
,
// (b, h_k), not (b, h)
std
::
optional
<
at
::
Tensor
>&
k_descale_
,
// (b, h_k)
std
::
optional
<
at
::
Tensor
>&
v_descale_
,
// (b, h_k)
float
const
softmax_scale
,
bool
is_causal
,
int
window_size_left
,
int
window_size_right
,
float
const
softcap
,
bool
const
is_rotary_interleaved
,
// if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
std
::
optional
<
at
::
Tensor
>&
scheduler_metadata_
,
// (b + 1)
int
num_splits
,
std
::
optional
<
bool
>
pack_gqa_
,
int
const
sm_margin
);
sgl-kernel/include/sgl_kernel_ops.h
View file @
b8b6008f
...
...
@@ -23,8 +23,6 @@ limitations under the License.
#include <vector>
#include "sgl_kernel_torch_shim.h"
#define _CONCAT(A, B) A##B
#define CONCAT(A, B) _CONCAT(A, B)
...
...
@@ -293,48 +291,3 @@ void top_p_sampling_from_probs(
double
top_p_val
,
bool
deterministic
,
int64_t
cuda_stream
);
/*
* From flash-attention
*/
std
::
vector
<
at
::
Tensor
>
mha_fwd
(
at
::
Tensor
&
q
,
// (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
const
at
::
Tensor
&
k
,
// (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size,
// h_k, d) if there is page_table.
const
at
::
Tensor
&
v
,
// (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages,
// page_size, h_k, dv) if there is page_table.
std
::
optional
<
const
at
::
Tensor
>&
k_new_
,
// (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new
std
::
optional
<
const
at
::
Tensor
>&
v_new_
,
// (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new
std
::
optional
<
const
at
::
Tensor
>&
q_v_
,
// (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q
std
::
optional
<
at
::
Tensor
>&
out_
,
// (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
std
::
optional
<
const
at
::
Tensor
>&
cu_seqlens_q_
,
// b+1
std
::
optional
<
const
at
::
Tensor
>&
cu_seqlens_k_
,
// b+1
std
::
optional
<
const
at
::
Tensor
>&
cu_seqlens_k_new_
,
// b+1
std
::
optional
<
const
at
::
Tensor
>&
seqused_q_
,
// b. If given, only this many elements of each batch element's queries and outputs are used.
std
::
optional
<
const
at
::
Tensor
>&
seqused_k_
,
// b. If given, only this many elements of each batch element's keys are used.
std
::
optional
<
int
>
max_seqlen_q_
,
// TODO: check if we need max_seqlen_k
std
::
optional
<
int
>
max_seqlen_k_
,
std
::
optional
<
const
at
::
Tensor
>&
page_table_
,
// (b_k, max_num_pages_per_seq)
std
::
optional
<
const
at
::
Tensor
>&
kv_batch_idx_
,
// b. indices to index into the KV cache
std
::
optional
<
const
at
::
Tensor
>&
leftpad_k_
,
// b
std
::
optional
<
const
at
::
Tensor
>&
rotary_cos_
,
// seqlen_ro x (rotary_dim / 2)
std
::
optional
<
const
at
::
Tensor
>&
rotary_sin_
,
// seqlen_ro x (rotary_dim / 2)
std
::
optional
<
const
at
::
Tensor
>&
seqlens_rotary_
,
// b
std
::
optional
<
at
::
Tensor
>&
q_descale_
,
// (b, h_k), not (b, h)
std
::
optional
<
at
::
Tensor
>&
k_descale_
,
// (b, h_k)
std
::
optional
<
at
::
Tensor
>&
v_descale_
,
// (b, h_k)
float
const
softmax_scale
,
bool
is_causal
,
int
window_size_left
,
int
window_size_right
,
float
const
softcap
,
bool
const
is_rotary_interleaved
,
// if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
std
::
optional
<
at
::
Tensor
>&
scheduler_metadata_
,
// (b + 1)
int
num_splits
,
std
::
optional
<
bool
>
pack_gqa_
,
int
const
sm_margin
);
sgl-kernel/python/sgl_kernel/flash_attn.py
View file @
b8b6008f
...
...
@@ -3,15 +3,22 @@ from typing import List, Optional, Tuple, Union
import
torch
import
torch.nn
as
nn
try
:
from
sgl_kernel
import
flash_ops
except
:
raise
ImportError
(
"Can not import sgl_kernel. Please check your installation."
)
def
is_fa3_supported
(
device
=
None
)
->
bool
:
# FA3 can fail without a enough shared memory for a some shapes, currently
# only 8.0 and 8.7 have enough shared memory for all shapes
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x
return
FA3_AVAILABLE
and
(
torch
.
cuda
.
get_device_capability
(
device
)[
0
]
>=
9
or
torch
.
cuda
.
get_device_capability
(
device
)
==
(
8
,
0
)
or
torch
.
cuda
.
get_device_capability
(
device
)
==
(
8
,
7
)
# now sgl-kernel only build fa3 for sm90a && cuda >= 12.4
return
(
(
torch
.
cuda
.
get_device_capability
(
device
)[
0
]
>=
9
)
and
(
torch
.
version
.
cuda
>=
"12.4"
)
# or torch.cuda.get_device_capability(device) == (8, 0)
# or torch.cuda.get_device_capability(device) == (8, 7)
)
...
...
@@ -135,6 +142,10 @@ def flash_attn_with_kvcache(
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
"""
if
not
is_fa3_supported
():
raise
NotImplementedError
(
"flash_attn at sgl-kernel is only supported on sm90 and above"
)
assert
k_cache
.
stride
(
-
1
)
==
1
,
"k_cache must have contiguous last dimension"
assert
v_cache
.
stride
(
-
1
)
==
1
,
"v_cache must have contiguous last dimension"
if
softmax_scale
is
None
:
...
...
sgl-kernel/tests/test_flash_attention.py
View file @
b8b6008f
...
...
@@ -10,7 +10,19 @@ from einops import rearrange, repeat
apply_rotary_emb
=
None
from
sgl_kernel.flash_attn
import
flash_attn_with_kvcache
def
is_fa3_supported
(
device
=
None
)
->
bool
:
# FA3 can fail without a enough shared memory for a some shapes, currently
# only 8.0 and 8.7 have enough shared memory for all shapes
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x
# now sgl-kernel only build fa3 for sm90a && cuda >= 12.4
return
(
(
torch
.
cuda
.
get_device_capability
(
device
)[
0
]
>=
9
)
and
(
torch
.
version
.
cuda
>=
"12.4"
)
# or torch.cuda.get_device_capability(device) == (8, 0)
# or torch.cuda.get_device_capability(device) == (8, 7)
)
DISABLE_BACKWARD
=
True
# For CI test, we close them to True.
...
...
@@ -284,6 +296,10 @@ def attention_ref(
return
output
.
to
(
dtype
=
dtype_og
),
attention
.
to
(
dtype
=
dtype_og
)
@
pytest
.
mark
.
skipif
(
not
is_fa3_supported
(),
reason
=
"flash_attn at sgl-kernel is only supported on sm90 and above"
,
)
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
]
+
([
torch
.
float8_e4m3fn
]
if
not
DISABLE_FP8
else
[])
...
...
@@ -372,6 +388,8 @@ def test_flash_attn_kvcache(
mha_type
,
dtype
,
):
from
sgl_kernel.flash_attn
import
flash_attn_with_kvcache
if
page_size
is
not
None
and
seqlen_k
%
page_size
!=
0
:
pytest
.
skip
()
if
seqlen_q
>
seqlen_k
and
new_kv
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment