Unverified Commit 37c66ec8 authored by yinfan98's avatar yinfan98 Committed by GitHub
Browse files

[feat] add fa3 in sgl-kernel (#4902)


Co-authored-by: default avatarSleepcoo <Sleepcoo@gmail.com>
parent 9adf178c
......@@ -25,6 +25,7 @@ find_package(Torch REQUIRED)
include(FetchContent)
# cutlass
FetchContent_Declare(
repo-cutlass
GIT_REPOSITORY https://github.com/NVIDIA/cutlass
......@@ -32,6 +33,7 @@ FetchContent_Declare(
GIT_SHALLOW ON
)
FetchContent_Populate(repo-cutlass)
# DeepGEMM
FetchContent_Declare(
repo-deepgemm
GIT_REPOSITORY https://github.com/deepseek-ai/DeepGEMM
......@@ -39,6 +41,7 @@ FetchContent_Declare(
GIT_SHALLOW ON
)
FetchContent_Populate(repo-deepgemm)
# flashinfer
FetchContent_Declare(
repo-flashinfer
GIT_REPOSITORY https://github.com/sgl-project/flashinfer
......@@ -46,6 +49,15 @@ FetchContent_Declare(
GIT_SHALLOW OFF
)
FetchContent_Populate(repo-flashinfer)
# flash-attention
FetchContent_Declare(
repo-flash-attention
GIT_REPOSITORY https://github.com/sgl-project/sgl-attn
GIT_TAG sgl-kernel
GIT_SHALLOW OFF
)
FetchContent_Populate(repo-flash-attention)
include_directories(
${PROJECT_SOURCE_DIR}/include
......@@ -54,6 +66,7 @@ include_directories(
${repo-cutlass_SOURCE_DIR}/tools/util/include
${repo-flashinfer_SOURCE_DIR}/include
${repo-flashinfer_SOURCE_DIR}/csrc
${repo-flash-attention_SOURCE_DIR}/hopper
)
set(CMAKE_CXX_STANDARD 17)
......@@ -78,6 +91,7 @@ set(SGL_KERNEL_CUDA_FLAGS
"-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1"
"-DCUTLASS_DEBUG_TRACE_LEVEL=0"
"--expt-relaxed-constexpr"
"--use_fast_math"
"-Xcompiler=-Wconversion"
"-Xcompiler=-fno-strict-aliasing"
)
......@@ -130,6 +144,30 @@ string(REPLACE "-D__CUDA_NO_HALF_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE
string(REPLACE "-D__CUDA_NO_BFLOAT16_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
string(REPLACE "-D__CUDA_NO_HALF2_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
# set flash-attention sources file
# BF16 source files
file(GLOB FA3_BF16_GEN_SRCS
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_bf16*_sm90.cu")
file(GLOB FA3_BF16_GEN_SRCS_
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_bf16*_sm90.cu")
list(APPEND FA3_BF16_GEN_SRCS ${FA3_BF16_GEN_SRCS_})
# FP16 source files
file(GLOB FA3_FP16_GEN_SRCS
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_fp16*_sm90.cu")
file(GLOB FA3_FP16_GEN_SRCS_
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_fp16*_sm90.cu")
list(APPEND FA3_FP16_GEN_SRCS ${FA3_FP16_GEN_SRCS_})
# FP8 source files
file(GLOB FA3_FP8_GEN_SRCS
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_e4m3*_sm90.cu")
file(GLOB FA3_FP8_GEN_SRCS_
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_e4m3*_sm90.cu")
list(APPEND FA3_FP8_GEN_SRCS ${FA3_FP8_GEN_SRCS_})
set(FA3_GEN_SRCS ${FA3_BF16_GEN_SRCS} ${FA3_FP16_GEN_SRCS} ${FA3_FP8_GEN_SRCS})
set(SOURCES
"csrc/allreduce/trt_reduce_internal.cu"
"csrc/allreduce/trt_reduce_kernel.cu"
......@@ -160,6 +198,10 @@ set(SOURCES
"${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu"
"${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu"
"${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu"
"${repo-flash-attention_SOURCE_DIR}/hopper/flash_prepare_scheduler.cu"
"${repo-flash-attention_SOURCE_DIR}/hopper/flash_api.cpp"
"${repo-flash-attention_SOURCE_DIR}/hopper/flash_fwd_combine.cu"
"${FA3_GEN_SRCS}"
)
# Support abi3 for build
......@@ -173,6 +215,18 @@ target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cubl
install(TARGETS common_ops LIBRARY DESTINATION "sgl_kernel")
# Add some flash-attention custom flag for inference
target_compile_definitions(common_ops PRIVATE
FLASHATTENTION_DISABLE_SM8x
FLASHATTENTION_DISABLE_BACKWARD
FLASHATTENTION_DISABLE_DROPOUT
# FLASHATTENTION_DISABLE_ALIBI
# FLASHATTENTION_DISABLE_SOFTCAP
FLASHATTENTION_DISABLE_UNEVEN_K
# FLASHATTENTION_DISABLE_LOCAL
FLASHATTENTION_VARLEN_ONLY
)
# JIT Logic
# DeepGEMM
......
......@@ -92,6 +92,36 @@ Steps to add a new kernel:
)
```
### Integrating Third-Party Libraries with Data Type Conversion
When integrating new third-party libraries like flash-attention, you may encounter data type compatibility issues between the C++ interface and PyTorch bindings. For example, the third-party code might use `float` or `int` types, while PyTorch requires `double` and `int64_t`.
To address this issue, we provide the `make_pytorch_shim` function in [sgl_kernel_torch_shim](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/include/sgl_kernel_torch_shim.h) that handles data type conversions automatically.
When you need to support new data type conversions, you can easily add conversion functions like this:
```cpp
// Map `int` -> `int64_t`
template <>
struct pytorch_library_compatible_type<int> {
using type = int64_t;
static int convert_from_type(int64_t arg) {
TORCH_CHECK(arg <= std::numeric_limits<int>::max(), "int64_t value is too large to be converted to int");
TORCH_CHECK(arg >= std::numeric_limits<int>::min(), "int64_t value is too small to be converted to int");
return arg;
}
};
```
To use this with your library functions, simply wrap them with make_pytorch_shim:
```cpp
/*
* From flash-attention
*/
m.def("fwd", make_pytorch_shim(mha_fwd));
```
### Build & Install
Development build:
......
......@@ -91,6 +91,11 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
m.def("top_p_renorm_probs", top_p_renorm_probs);
m.def("top_k_top_p_sampling_from_probs", top_k_top_p_sampling_from_probs);
m.def("top_p_sampling_from_probs", top_p_sampling_from_probs);
/*
* From flash-attention
*/
m.def("fwd", make_pytorch_shim(mha_fwd));
}
REGISTER_EXTENSION(common_ops)
......@@ -23,6 +23,8 @@ limitations under the License.
#include <vector>
#include "sgl_kernel_torch_shim.h"
#define _CONCAT(A, B) A##B
#define CONCAT(A, B) _CONCAT(A, B)
......@@ -291,3 +293,48 @@ void top_p_sampling_from_probs(
double top_p_val,
bool deterministic,
int64_t cuda_stream);
/*
* From flash-attention
*/
std::vector<at::Tensor> mha_fwd(
at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
const at::Tensor& k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size,
// h_k, d) if there is page_table.
const at::Tensor& v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages,
// page_size, h_k, dv) if there is page_table.
std::optional<const at::Tensor>&
k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new
std::optional<const at::Tensor>&
v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new
std::optional<const at::Tensor>& q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q
std::optional<at::Tensor>& out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
std::optional<const at::Tensor>& cu_seqlens_q_, // b+1
std::optional<const at::Tensor>& cu_seqlens_k_, // b+1
std::optional<const at::Tensor>& cu_seqlens_k_new_, // b+1
std::optional<const at::Tensor>&
seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
std::optional<const at::Tensor>&
seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
std::optional<int> max_seqlen_q_,
// TODO: check if we need max_seqlen_k
std::optional<int> max_seqlen_k_,
std::optional<const at::Tensor>& page_table_, // (b_k, max_num_pages_per_seq)
std::optional<const at::Tensor>& kv_batch_idx_, // b. indices to index into the KV cache
std::optional<const at::Tensor>& leftpad_k_, // b
std::optional<const at::Tensor>& rotary_cos_, // seqlen_ro x (rotary_dim / 2)
std::optional<const at::Tensor>& rotary_sin_, // seqlen_ro x (rotary_dim / 2)
std::optional<const at::Tensor>& seqlens_rotary_, // b
std::optional<at::Tensor>& q_descale_, // (b, h_k), not (b, h)
std::optional<at::Tensor>& k_descale_, // (b, h_k)
std::optional<at::Tensor>& v_descale_, // (b, h_k)
float const softmax_scale,
bool is_causal,
int window_size_left,
int window_size_right,
float const softcap,
bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
std::optional<at::Tensor>& scheduler_metadata_, // (b + 1)
int num_splits,
std::optional<bool> pack_gqa_,
int const sm_margin);
/*Adapt from:
https://github.com/neuralmagic/vllm-flash-attention/blob/90eacc1af2a7c3de62ea249e929ed5faccf38954/csrc/common/pytorch_shim.h
Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#pragma once
#include <torch/library.h>
/**
* Unforunately, the type signatures of the flash_attn ops are not compatible
* with the PyTorch library bindings. To get around that we use
* `make_pytorch_shim` which creates a lambda that exponses the API using
* PyTorch compatible types to the types, then converts them to the types
* expected by the flash_attn ops. This shims allows us to make minimal changes
* to `flash_api.cpp` making it easier to synchronize with upstream changes.
*
* The `pytorch_library_compatible_type` struct is used to map from the
* flash_attn ops types to a PyTorch library compatible one. The main issues is
* that the following types are not support by PyTorch libary bindings:
* - `int`
* - `float`
* - `std::optional<T> &`
* - `std::optional<const at::Tensor> &`
* So we convert them to (respectively):
* - `int64_t`
* - `double`
* - `const std::optional<T>&`
* - `const std::optional<at::Tensor>&`
*/
template <typename T>
struct pytorch_library_compatible_type {
using type = T;
static T convert_from_type(T arg) {
return arg;
}
};
template <typename T>
using pytorch_library_compatible_type_t = typename pytorch_library_compatible_type<T>::type;
template <typename T>
T convert_from_pytorch_compatible_type(pytorch_library_compatible_type_t<T> arg) {
return pytorch_library_compatible_type<T>::convert_from_type(arg);
}
// Map `c10::optional<T> &` -> `const c10::optional<T>&`
// (NOTE: this is bit unsafe but non of the ops in flash_attn mutate
// the optional container)
template <typename T>
struct pytorch_library_compatible_type<c10::optional<T>&> {
using type = const c10::optional<T>&;
static c10::optional<T>& convert_from_type(const c10::optional<T>& arg) {
return const_cast<c10::optional<T>&>(arg);
}
};
// Map `c10::optional<T>` ->
// `c10::optional<pytorch_library_compatible_type_t<T>>`
// (NOTE: tested for `c10::optional<int>` -> `c10::optional<int64_t>`)
template <typename T>
struct pytorch_library_compatible_type<c10::optional<T>> {
using type = c10::optional<pytorch_library_compatible_type_t<T>>;
static c10::optional<pytorch_library_compatible_type_t<T>> convert_from_type(c10::optional<T> arg) {
return arg;
}
};
// Map `c10::optional<const at::Tensor>&` -> `const c10::optional<at::Tensor>&`
template <>
struct pytorch_library_compatible_type<c10::optional<const at::Tensor>&> {
using type = const c10::optional<at::Tensor>&;
static c10::optional<const at::Tensor>& convert_from_type(const c10::optional<at::Tensor>& arg) {
return const_cast<c10::optional<const at::Tensor>&>(reinterpret_cast<const c10::optional<const at::Tensor>&>(arg));
}
};
// Map `int` -> `int64_t`
template <>
struct pytorch_library_compatible_type<int> {
using type = int64_t;
static int convert_from_type(int64_t arg) {
TORCH_CHECK(arg <= std::numeric_limits<int>::max(), "int64_t value is too large to be converted to int");
TORCH_CHECK(arg >= std::numeric_limits<int>::min(), "int64_t value is too small to be converted to int");
return arg;
}
};
// Map `float` -> `double`
template <>
struct pytorch_library_compatible_type<float> {
using type = double;
static float convert_from_type(double arg) {
TORCH_CHECK(
std::abs(arg) <= std::numeric_limits<float>::max(), "double value is too large to be converted to float");
return arg;
}
};
//
// Shim Utils
//
template <typename Ret, typename... Args>
auto make_pytorch_shim(Ret (*fun)(Args... args)) {
return [fun](pytorch_library_compatible_type_t<Args>... args) {
return fun(convert_from_pytorch_compatible_type<Args>(args)...);
};
}
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
def is_fa3_supported(device=None) -> bool:
# FA3 can fail without a enough shared memory for a some shapes, currently
# only 8.0 and 8.7 have enough shared memory for all shapes
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x
return FA3_AVAILABLE and (
torch.cuda.get_device_capability(device)[0] >= 9
or torch.cuda.get_device_capability(device) == (8, 0)
or torch.cuda.get_device_capability(device) == (8, 7)
)
def maybe_contiguous(x):
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
def flash_attn_with_kvcache(
q,
k_cache,
v_cache,
k=None,
v=None,
qv=None,
rotary_cos=None,
rotary_sin=None,
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
cache_batch_idx: Optional[torch.Tensor] = None,
cache_leftpad: Optional[torch.Tensor] = None,
page_table: Optional[torch.Tensor] = None,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_k_new: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
rotary_seqlens: Optional[torch.Tensor] = None,
q_descale: Optional[torch.Tensor] = None,
k_descale: Optional[torch.Tensor] = None,
v_descale: Optional[torch.Tensor] = None,
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
softcap=0.0, # 0.0 means deactivated
rotary_interleaved=True,
scheduler_metadata=None,
num_splits=0, # Can be tuned for speed
pack_gqa=None, # Can be tuned for speed
sm_margin=0, # Can be tuned if some SMs are used for communication
return_softmax_lse=False,
):
"""
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
the previous step, and update them with the new keys/values from the current step, and do
attention with the updated cache, all in 1 kernel.
If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
For example, the KV cache could be pre-allocated with the max sequence length, and you can use
cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
Note: Does not support backward pass.
Arguments:
q: (batch_size, seqlen, nheads, headdim)
k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
page_block_size must be a multiple of 256.
v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
k with k_cache, starting at the indices specified by cache_seqlens.
v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
qv [optional]: (batch_size, seqlen, nheads, headdim_v)
rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
KV cache.
cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
If the indices are not distinct, and k and v are provided, the values updated in the cache
might come from any of the duplicate indices.
cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
(i.e. GPT-NeoX style).
num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
to automatically determine the number of splits.
Don't change this unless you know what you are doing.
return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
Return:
out: (batch_size, seqlen, nheads, headdim).
softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
"""
assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
if softmax_scale is None:
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (
-0.5
)
if cache_seqlens is not None and isinstance(cache_seqlens, int):
cache_seqlens = torch.full(
(k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
)
cache_seqlens = maybe_contiguous(cache_seqlens)
q, k_cache, k, v = [maybe_contiguous(x) for x in (q, k_cache, k, v)]
v_cache = (
v_cache.contiguous()
if v_cache.stride(-1) != 1 and v_cache.stride(-3) != 1
else v_cache
)
cu_seqlens_q, cu_seqlens_k_new = [
maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k_new)
]
page_table, cache_batch_idx, cache_leftpad = [
maybe_contiguous(x) for x in (page_table, cache_batch_idx, cache_leftpad)
]
rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
rotary_seqlens = maybe_contiguous(rotary_seqlens)
out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default(
q,
k_cache,
v_cache,
k,
v,
qv,
None, # out
cu_seqlens_q,
None, # cu_seqlens_k
cu_seqlens_k_new,
None, # seqused_q
cache_seqlens,
max_seqlen_q,
None, # max_seqlen_k
page_table,
cache_batch_idx,
cache_leftpad,
rotary_cos,
rotary_sin,
rotary_seqlens,
q_descale,
k_descale,
v_descale,
softmax_scale,
causal,
window_size[0],
window_size[1],
softcap,
rotary_interleaved,
scheduler_metadata,
num_splits,
pack_gqa,
sm_margin,
)
# return (out, softmax_lse) if return_softmax_lse else out
return (out, softmax_lse, *rest) if return_softmax_lse else out
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment