Unverified Commit 37c66ec8 authored by yinfan98's avatar yinfan98 Committed by GitHub
Browse files

[feat] add fa3 in sgl-kernel (#4902)


Co-authored-by: default avatarSleepcoo <Sleepcoo@gmail.com>
parent 9adf178c
...@@ -25,6 +25,7 @@ find_package(Torch REQUIRED) ...@@ -25,6 +25,7 @@ find_package(Torch REQUIRED)
include(FetchContent) include(FetchContent)
# cutlass
FetchContent_Declare( FetchContent_Declare(
repo-cutlass repo-cutlass
GIT_REPOSITORY https://github.com/NVIDIA/cutlass GIT_REPOSITORY https://github.com/NVIDIA/cutlass
...@@ -32,6 +33,7 @@ FetchContent_Declare( ...@@ -32,6 +33,7 @@ FetchContent_Declare(
GIT_SHALLOW ON GIT_SHALLOW ON
) )
FetchContent_Populate(repo-cutlass) FetchContent_Populate(repo-cutlass)
# DeepGEMM
FetchContent_Declare( FetchContent_Declare(
repo-deepgemm repo-deepgemm
GIT_REPOSITORY https://github.com/deepseek-ai/DeepGEMM GIT_REPOSITORY https://github.com/deepseek-ai/DeepGEMM
...@@ -39,6 +41,7 @@ FetchContent_Declare( ...@@ -39,6 +41,7 @@ FetchContent_Declare(
GIT_SHALLOW ON GIT_SHALLOW ON
) )
FetchContent_Populate(repo-deepgemm) FetchContent_Populate(repo-deepgemm)
# flashinfer
FetchContent_Declare( FetchContent_Declare(
repo-flashinfer repo-flashinfer
GIT_REPOSITORY https://github.com/sgl-project/flashinfer GIT_REPOSITORY https://github.com/sgl-project/flashinfer
...@@ -46,6 +49,15 @@ FetchContent_Declare( ...@@ -46,6 +49,15 @@ FetchContent_Declare(
GIT_SHALLOW OFF GIT_SHALLOW OFF
) )
FetchContent_Populate(repo-flashinfer) FetchContent_Populate(repo-flashinfer)
# flash-attention
FetchContent_Declare(
repo-flash-attention
GIT_REPOSITORY https://github.com/sgl-project/sgl-attn
GIT_TAG sgl-kernel
GIT_SHALLOW OFF
)
FetchContent_Populate(repo-flash-attention)
include_directories( include_directories(
${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/include
...@@ -54,6 +66,7 @@ include_directories( ...@@ -54,6 +66,7 @@ include_directories(
${repo-cutlass_SOURCE_DIR}/tools/util/include ${repo-cutlass_SOURCE_DIR}/tools/util/include
${repo-flashinfer_SOURCE_DIR}/include ${repo-flashinfer_SOURCE_DIR}/include
${repo-flashinfer_SOURCE_DIR}/csrc ${repo-flashinfer_SOURCE_DIR}/csrc
${repo-flash-attention_SOURCE_DIR}/hopper
) )
set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD 17)
...@@ -78,6 +91,7 @@ set(SGL_KERNEL_CUDA_FLAGS ...@@ -78,6 +91,7 @@ set(SGL_KERNEL_CUDA_FLAGS
"-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1" "-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1"
"-DCUTLASS_DEBUG_TRACE_LEVEL=0" "-DCUTLASS_DEBUG_TRACE_LEVEL=0"
"--expt-relaxed-constexpr" "--expt-relaxed-constexpr"
"--use_fast_math"
"-Xcompiler=-Wconversion" "-Xcompiler=-Wconversion"
"-Xcompiler=-fno-strict-aliasing" "-Xcompiler=-fno-strict-aliasing"
) )
...@@ -130,6 +144,30 @@ string(REPLACE "-D__CUDA_NO_HALF_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE ...@@ -130,6 +144,30 @@ string(REPLACE "-D__CUDA_NO_HALF_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE
string(REPLACE "-D__CUDA_NO_BFLOAT16_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}") string(REPLACE "-D__CUDA_NO_BFLOAT16_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
string(REPLACE "-D__CUDA_NO_HALF2_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}") string(REPLACE "-D__CUDA_NO_HALF2_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
# set flash-attention sources file
# BF16 source files
file(GLOB FA3_BF16_GEN_SRCS
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_bf16*_sm90.cu")
file(GLOB FA3_BF16_GEN_SRCS_
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_bf16*_sm90.cu")
list(APPEND FA3_BF16_GEN_SRCS ${FA3_BF16_GEN_SRCS_})
# FP16 source files
file(GLOB FA3_FP16_GEN_SRCS
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_fp16*_sm90.cu")
file(GLOB FA3_FP16_GEN_SRCS_
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_fp16*_sm90.cu")
list(APPEND FA3_FP16_GEN_SRCS ${FA3_FP16_GEN_SRCS_})
# FP8 source files
file(GLOB FA3_FP8_GEN_SRCS
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_e4m3*_sm90.cu")
file(GLOB FA3_FP8_GEN_SRCS_
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_e4m3*_sm90.cu")
list(APPEND FA3_FP8_GEN_SRCS ${FA3_FP8_GEN_SRCS_})
set(FA3_GEN_SRCS ${FA3_BF16_GEN_SRCS} ${FA3_FP16_GEN_SRCS} ${FA3_FP8_GEN_SRCS})
set(SOURCES set(SOURCES
"csrc/allreduce/trt_reduce_internal.cu" "csrc/allreduce/trt_reduce_internal.cu"
"csrc/allreduce/trt_reduce_kernel.cu" "csrc/allreduce/trt_reduce_kernel.cu"
...@@ -160,6 +198,10 @@ set(SOURCES ...@@ -160,6 +198,10 @@ set(SOURCES
"${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu" "${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu"
"${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu" "${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu"
"${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu" "${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu"
"${repo-flash-attention_SOURCE_DIR}/hopper/flash_prepare_scheduler.cu"
"${repo-flash-attention_SOURCE_DIR}/hopper/flash_api.cpp"
"${repo-flash-attention_SOURCE_DIR}/hopper/flash_fwd_combine.cu"
"${FA3_GEN_SRCS}"
) )
# Support abi3 for build # Support abi3 for build
...@@ -173,6 +215,18 @@ target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cubl ...@@ -173,6 +215,18 @@ target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cubl
install(TARGETS common_ops LIBRARY DESTINATION "sgl_kernel") install(TARGETS common_ops LIBRARY DESTINATION "sgl_kernel")
# Add some flash-attention custom flag for inference
target_compile_definitions(common_ops PRIVATE
FLASHATTENTION_DISABLE_SM8x
FLASHATTENTION_DISABLE_BACKWARD
FLASHATTENTION_DISABLE_DROPOUT
# FLASHATTENTION_DISABLE_ALIBI
# FLASHATTENTION_DISABLE_SOFTCAP
FLASHATTENTION_DISABLE_UNEVEN_K
# FLASHATTENTION_DISABLE_LOCAL
FLASHATTENTION_VARLEN_ONLY
)
# JIT Logic # JIT Logic
# DeepGEMM # DeepGEMM
......
...@@ -92,6 +92,36 @@ Steps to add a new kernel: ...@@ -92,6 +92,36 @@ Steps to add a new kernel:
) )
``` ```
### Integrating Third-Party Libraries with Data Type Conversion
When integrating new third-party libraries like flash-attention, you may encounter data type compatibility issues between the C++ interface and PyTorch bindings. For example, the third-party code might use `float` or `int` types, while PyTorch requires `double` and `int64_t`.
To address this issue, we provide the `make_pytorch_shim` function in [sgl_kernel_torch_shim](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/include/sgl_kernel_torch_shim.h) that handles data type conversions automatically.
When you need to support new data type conversions, you can easily add conversion functions like this:
```cpp
// Map `int` -> `int64_t`
template <>
struct pytorch_library_compatible_type<int> {
using type = int64_t;
static int convert_from_type(int64_t arg) {
TORCH_CHECK(arg <= std::numeric_limits<int>::max(), "int64_t value is too large to be converted to int");
TORCH_CHECK(arg >= std::numeric_limits<int>::min(), "int64_t value is too small to be converted to int");
return arg;
}
};
```
To use this with your library functions, simply wrap them with make_pytorch_shim:
```cpp
/*
* From flash-attention
*/
m.def("fwd", make_pytorch_shim(mha_fwd));
```
### Build & Install ### Build & Install
Development build: Development build:
......
...@@ -91,6 +91,11 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { ...@@ -91,6 +91,11 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
m.def("top_p_renorm_probs", top_p_renorm_probs); m.def("top_p_renorm_probs", top_p_renorm_probs);
m.def("top_k_top_p_sampling_from_probs", top_k_top_p_sampling_from_probs); m.def("top_k_top_p_sampling_from_probs", top_k_top_p_sampling_from_probs);
m.def("top_p_sampling_from_probs", top_p_sampling_from_probs); m.def("top_p_sampling_from_probs", top_p_sampling_from_probs);
/*
* From flash-attention
*/
m.def("fwd", make_pytorch_shim(mha_fwd));
} }
REGISTER_EXTENSION(common_ops) REGISTER_EXTENSION(common_ops)
...@@ -23,6 +23,8 @@ limitations under the License. ...@@ -23,6 +23,8 @@ limitations under the License.
#include <vector> #include <vector>
#include "sgl_kernel_torch_shim.h"
#define _CONCAT(A, B) A##B #define _CONCAT(A, B) A##B
#define CONCAT(A, B) _CONCAT(A, B) #define CONCAT(A, B) _CONCAT(A, B)
...@@ -291,3 +293,48 @@ void top_p_sampling_from_probs( ...@@ -291,3 +293,48 @@ void top_p_sampling_from_probs(
double top_p_val, double top_p_val,
bool deterministic, bool deterministic,
int64_t cuda_stream); int64_t cuda_stream);
/*
* From flash-attention
*/
std::vector<at::Tensor> mha_fwd(
at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
const at::Tensor& k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size,
// h_k, d) if there is page_table.
const at::Tensor& v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages,
// page_size, h_k, dv) if there is page_table.
std::optional<const at::Tensor>&
k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new
std::optional<const at::Tensor>&
v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new
std::optional<const at::Tensor>& q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q
std::optional<at::Tensor>& out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
std::optional<const at::Tensor>& cu_seqlens_q_, // b+1
std::optional<const at::Tensor>& cu_seqlens_k_, // b+1
std::optional<const at::Tensor>& cu_seqlens_k_new_, // b+1
std::optional<const at::Tensor>&
seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
std::optional<const at::Tensor>&
seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
std::optional<int> max_seqlen_q_,
// TODO: check if we need max_seqlen_k
std::optional<int> max_seqlen_k_,
std::optional<const at::Tensor>& page_table_, // (b_k, max_num_pages_per_seq)
std::optional<const at::Tensor>& kv_batch_idx_, // b. indices to index into the KV cache
std::optional<const at::Tensor>& leftpad_k_, // b
std::optional<const at::Tensor>& rotary_cos_, // seqlen_ro x (rotary_dim / 2)
std::optional<const at::Tensor>& rotary_sin_, // seqlen_ro x (rotary_dim / 2)
std::optional<const at::Tensor>& seqlens_rotary_, // b
std::optional<at::Tensor>& q_descale_, // (b, h_k), not (b, h)
std::optional<at::Tensor>& k_descale_, // (b, h_k)
std::optional<at::Tensor>& v_descale_, // (b, h_k)
float const softmax_scale,
bool is_causal,
int window_size_left,
int window_size_right,
float const softcap,
bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
std::optional<at::Tensor>& scheduler_metadata_, // (b + 1)
int num_splits,
std::optional<bool> pack_gqa_,
int const sm_margin);
/*Adapt from:
https://github.com/neuralmagic/vllm-flash-attention/blob/90eacc1af2a7c3de62ea249e929ed5faccf38954/csrc/common/pytorch_shim.h
Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#pragma once
#include <torch/library.h>
/**
* Unforunately, the type signatures of the flash_attn ops are not compatible
* with the PyTorch library bindings. To get around that we use
* `make_pytorch_shim` which creates a lambda that exponses the API using
* PyTorch compatible types to the types, then converts them to the types
* expected by the flash_attn ops. This shims allows us to make minimal changes
* to `flash_api.cpp` making it easier to synchronize with upstream changes.
*
* The `pytorch_library_compatible_type` struct is used to map from the
* flash_attn ops types to a PyTorch library compatible one. The main issues is
* that the following types are not support by PyTorch libary bindings:
* - `int`
* - `float`
* - `std::optional<T> &`
* - `std::optional<const at::Tensor> &`
* So we convert them to (respectively):
* - `int64_t`
* - `double`
* - `const std::optional<T>&`
* - `const std::optional<at::Tensor>&`
*/
template <typename T>
struct pytorch_library_compatible_type {
using type = T;
static T convert_from_type(T arg) {
return arg;
}
};
template <typename T>
using pytorch_library_compatible_type_t = typename pytorch_library_compatible_type<T>::type;
template <typename T>
T convert_from_pytorch_compatible_type(pytorch_library_compatible_type_t<T> arg) {
return pytorch_library_compatible_type<T>::convert_from_type(arg);
}
// Map `c10::optional<T> &` -> `const c10::optional<T>&`
// (NOTE: this is bit unsafe but non of the ops in flash_attn mutate
// the optional container)
template <typename T>
struct pytorch_library_compatible_type<c10::optional<T>&> {
using type = const c10::optional<T>&;
static c10::optional<T>& convert_from_type(const c10::optional<T>& arg) {
return const_cast<c10::optional<T>&>(arg);
}
};
// Map `c10::optional<T>` ->
// `c10::optional<pytorch_library_compatible_type_t<T>>`
// (NOTE: tested for `c10::optional<int>` -> `c10::optional<int64_t>`)
template <typename T>
struct pytorch_library_compatible_type<c10::optional<T>> {
using type = c10::optional<pytorch_library_compatible_type_t<T>>;
static c10::optional<pytorch_library_compatible_type_t<T>> convert_from_type(c10::optional<T> arg) {
return arg;
}
};
// Map `c10::optional<const at::Tensor>&` -> `const c10::optional<at::Tensor>&`
template <>
struct pytorch_library_compatible_type<c10::optional<const at::Tensor>&> {
using type = const c10::optional<at::Tensor>&;
static c10::optional<const at::Tensor>& convert_from_type(const c10::optional<at::Tensor>& arg) {
return const_cast<c10::optional<const at::Tensor>&>(reinterpret_cast<const c10::optional<const at::Tensor>&>(arg));
}
};
// Map `int` -> `int64_t`
template <>
struct pytorch_library_compatible_type<int> {
using type = int64_t;
static int convert_from_type(int64_t arg) {
TORCH_CHECK(arg <= std::numeric_limits<int>::max(), "int64_t value is too large to be converted to int");
TORCH_CHECK(arg >= std::numeric_limits<int>::min(), "int64_t value is too small to be converted to int");
return arg;
}
};
// Map `float` -> `double`
template <>
struct pytorch_library_compatible_type<float> {
using type = double;
static float convert_from_type(double arg) {
TORCH_CHECK(
std::abs(arg) <= std::numeric_limits<float>::max(), "double value is too large to be converted to float");
return arg;
}
};
//
// Shim Utils
//
template <typename Ret, typename... Args>
auto make_pytorch_shim(Ret (*fun)(Args... args)) {
return [fun](pytorch_library_compatible_type_t<Args>... args) {
return fun(convert_from_pytorch_compatible_type<Args>(args)...);
};
}
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
def is_fa3_supported(device=None) -> bool:
# FA3 can fail without a enough shared memory for a some shapes, currently
# only 8.0 and 8.7 have enough shared memory for all shapes
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x
return FA3_AVAILABLE and (
torch.cuda.get_device_capability(device)[0] >= 9
or torch.cuda.get_device_capability(device) == (8, 0)
or torch.cuda.get_device_capability(device) == (8, 7)
)
def maybe_contiguous(x):
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
def flash_attn_with_kvcache(
q,
k_cache,
v_cache,
k=None,
v=None,
qv=None,
rotary_cos=None,
rotary_sin=None,
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
cache_batch_idx: Optional[torch.Tensor] = None,
cache_leftpad: Optional[torch.Tensor] = None,
page_table: Optional[torch.Tensor] = None,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_k_new: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
rotary_seqlens: Optional[torch.Tensor] = None,
q_descale: Optional[torch.Tensor] = None,
k_descale: Optional[torch.Tensor] = None,
v_descale: Optional[torch.Tensor] = None,
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
softcap=0.0, # 0.0 means deactivated
rotary_interleaved=True,
scheduler_metadata=None,
num_splits=0, # Can be tuned for speed
pack_gqa=None, # Can be tuned for speed
sm_margin=0, # Can be tuned if some SMs are used for communication
return_softmax_lse=False,
):
"""
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
the previous step, and update them with the new keys/values from the current step, and do
attention with the updated cache, all in 1 kernel.
If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
For example, the KV cache could be pre-allocated with the max sequence length, and you can use
cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
Note: Does not support backward pass.
Arguments:
q: (batch_size, seqlen, nheads, headdim)
k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
page_block_size must be a multiple of 256.
v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
k with k_cache, starting at the indices specified by cache_seqlens.
v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
qv [optional]: (batch_size, seqlen, nheads, headdim_v)
rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
KV cache.
cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
If the indices are not distinct, and k and v are provided, the values updated in the cache
might come from any of the duplicate indices.
cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
(i.e. GPT-NeoX style).
num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
to automatically determine the number of splits.
Don't change this unless you know what you are doing.
return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
Return:
out: (batch_size, seqlen, nheads, headdim).
softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
"""
assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
if softmax_scale is None:
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (
-0.5
)
if cache_seqlens is not None and isinstance(cache_seqlens, int):
cache_seqlens = torch.full(
(k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
)
cache_seqlens = maybe_contiguous(cache_seqlens)
q, k_cache, k, v = [maybe_contiguous(x) for x in (q, k_cache, k, v)]
v_cache = (
v_cache.contiguous()
if v_cache.stride(-1) != 1 and v_cache.stride(-3) != 1
else v_cache
)
cu_seqlens_q, cu_seqlens_k_new = [
maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k_new)
]
page_table, cache_batch_idx, cache_leftpad = [
maybe_contiguous(x) for x in (page_table, cache_batch_idx, cache_leftpad)
]
rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
rotary_seqlens = maybe_contiguous(rotary_seqlens)
out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default(
q,
k_cache,
v_cache,
k,
v,
qv,
None, # out
cu_seqlens_q,
None, # cu_seqlens_k
cu_seqlens_k_new,
None, # seqused_q
cache_seqlens,
max_seqlen_q,
None, # max_seqlen_k
page_table,
cache_batch_idx,
cache_leftpad,
rotary_cos,
rotary_sin,
rotary_seqlens,
q_descale,
k_descale,
v_descale,
softmax_scale,
causal,
window_size[0],
window_size[1],
softcap,
rotary_interleaved,
scheduler_metadata,
num_splits,
pack_gqa,
sm_margin,
)
# return (out, softmax_lse) if return_softmax_lse else out
return (out, softmax_lse, *rest) if return_softmax_lse else out
# Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/hopper/test_flash_attn.py
import itertools
import math
import os
import pytest
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
apply_rotary_emb = None
from sgl_kernel.flash_attn import flash_attn_with_kvcache
DISABLE_BACKWARD = True
# For CI test, we close them to True.
# DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "FALSE") == "TRUE"
# DISABLE_PAGEDKV = os.getenv("FLASH_ATTENTION_DISABLE_PAGEDKV", "FALSE") == "TRUE"
# DISABLE_APPENDKV = os.getenv("FLASH_ATTENTION_DISABLE_APPENDKV", "FALSE") == "TRUE"
# DISABLE_LOCAL = os.getenv("FLASH_ATTENTION_DISABLE_LOCAL", "FALSE") == "TRUE"
# DISABLE_SOFTCAP = os.getenv("FLASH_ATTENTION_DISABLE_SOFTCAP", "FALSE") == "TRUE"
# DISABLE_PACKGQA = os.getenv("FLASH_ATTENTION_DISABLE_PACKGQA", "FALSE") == "TRUE"
# DISABLE_FP16 = os.getenv("FLASH_ATTENTION_DISABLE_FP16", "FALSE") == "TRUE"
# DISABLE_FP8 = (
# os.getenv("FLASH_ATTENTION_DISABLE_FP8", "FALSE") == "TRUE"
# or torch.cuda.get_device_capability("cuda")[0] < 9
# )
DISABLE_SPLIT = True
DISABLE_PAGEDKV = True
DISABLE_APPENDKV = True
DISABLE_LOCAL = True
DISABLE_SOFTCAP = True
DISABLE_PACKGQA = True
DISABLE_FP16 = True
DISABLE_FP8 = True
# Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/hopper/padding.py
def unpad_input(hidden_states, attention_mask, unused_mask=None):
"""
Arguments:
hidden_states: (batch, seqlen, ...)
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
Return:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
indices: (total_nnz), the indices of masked tokens from the flattened input sequence.
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
max_seqlen_in_batch: int
seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask.
"""
all_masks = (
(attention_mask + unused_mask) if unused_mask is not None else attention_mask
)
seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
# index with integer indices.
return (
rearrange(hidden_states, "b s ... -> (b s) ...")[indices],
indices,
cu_seqlens,
max_seqlen_in_batch,
used_seqlens_in_batch,
)
def generate_random_padding_mask(
max_seqlen, batch_size, device, mode="random", zero_lengths=False
):
assert mode in ["full", "random", "third"]
if mode == "full":
lengths = torch.full(
(batch_size, 1), max_seqlen, device=device, dtype=torch.int32
)
elif mode == "random":
lengths = torch.randint(
max(0 if zero_lengths else 1, max_seqlen - 20),
max_seqlen + 1,
(batch_size, 1),
device=device,
)
elif mode == "third":
lengths = torch.randint(
max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device
)
if zero_lengths:
# Generate zero-lengths every 5 batches and the last batch.
for i in range(batch_size):
if i % 5 == 0:
lengths[i] = 0
lengths[-1] = 0
padding_mask = (
repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size)
< lengths
)
return padding_mask
def pad_input(hidden_states, indices, batch, seqlen):
"""
Arguments:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
batch: int, batch size for the padded sequence.
seqlen: int, maximum sequence length for the padded sequence.
Return:
hidden_states: (batch, seqlen, ...)
"""
dim = hidden_states.shape[1:]
output = torch.zeros(
(batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype
)
output[indices] = hidden_states
return rearrange(output, "(b s) ... -> b s ...", b=batch)
def construct_local_mask(
seqlen_q,
seqlen_k,
window_size=(-1, -1), # -1 means infinite window size
sink_token_length=0,
query_padding_mask=None,
key_padding_mask=None,
key_leftpad=None,
device=None,
):
row_idx = rearrange(
torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1"
)
col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
if key_leftpad is not None:
key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1")
col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0])
col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)
sk = (
seqlen_k
if key_padding_mask is None
else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
)
sq = (
seqlen_q
if query_padding_mask is None
else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
)
if window_size[0] < 0:
return col_idx > row_idx + sk - sq + window_size[1]
else:
sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk
return torch.logical_or(
col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk),
torch.logical_and(
col_idx < row_idx + sk - sq - window_size[0],
col_idx >= sink_token_length,
),
)
def attention_ref(
q,
k,
v,
query_padding_mask=None,
key_padding_mask=None,
key_leftpad=None,
attn_bias=None,
dropout_p=0.0,
dropout_mask=None,
causal=False,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=(-1, -1), # -1 means infinite window size
sink_token_length=0,
softcap=0.0,
upcast=True,
reorder_ops=False,
intermediate_dtype=None,
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, head_dim)
k: (batch_size, seqlen_k, nheads, head_dim)
v: (batch_size, seqlen_k, nheads, head_dim_v)
qv: (batch_size, seqlen_q, nheads, head_dim_v)
query_padding_mask: (batch_size, seqlen_q)
key_padding_mask: (batch_size, seqlen_k)
attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
dropout_p: float
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
causal: whether to apply causal masking
upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
output back to fp16/bf16.
reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.)
without changing the math. This is to estimate the numerical error from operation
reordering.
Output:
output: (batch_size, seqlen_q, nheads, head_dim_v)
attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
"""
if causal:
window_size = (window_size[0], 0)
dtype_og = q.dtype
if upcast:
q, k, v = q.float(), k.float(), v.float()
qv = qv.float() if qv is not None else None
if q_descale is not None:
q_descale = repeat(q_descale, "b h -> b 1 (h g) 1", g=q.shape[2] // k.shape[2])
q = (q.float() * q_descale).to(q.dtype)
qv = (qv.float() * q_descale).to(qv.dtype) if qv is not None else None
if k_descale is not None:
k = (k.float() * rearrange(k_descale, "b h -> b 1 h 1")).to(dtype=k.dtype)
if v_descale is not None:
v = (v.float() * rearrange(v_descale, "b h -> b 1 h 1")).to(dtype=v.dtype)
seqlen_q, seqlen_k = q.shape[1], k.shape[1]
k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
d = q.shape[-1]
dv = v.shape[-1]
softmax_scale = 1.0 / math.sqrt(d if qv is None else d + dv)
if not reorder_ops:
scores = torch.einsum("bthd,bshd->bhts", q * softmax_scale, k)
else:
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
if qv is not None:
scores = scores + torch.einsum("bthd,bshd->bhts", qv * softmax_scale, v)
if softcap > 0:
scores = torch.tanh(scores / softcap) * softcap
if key_padding_mask is not None:
scores.masked_fill_(
rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")
)
if window_size[0] >= 0 or window_size[1] >= 0:
local_mask = construct_local_mask(
seqlen_q,
seqlen_k,
window_size,
sink_token_length,
query_padding_mask,
key_padding_mask,
key_leftpad=key_leftpad,
device=q.device,
)
scores.masked_fill_(local_mask, float("-inf"))
if attn_bias is not None:
scores = scores + attn_bias
attention = torch.softmax(scores, dim=-1).to(v.dtype)
# We want to mask here so that the attention matrix doesn't have any NaNs
# Otherwise we'll get NaN in dV
if query_padding_mask is not None:
attention = attention.masked_fill(
rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0
)
# Without this we might get NaN in dv
if key_padding_mask is not None:
attention = attention.masked_fill(
rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0
)
# Some rows might be completely masked out so we fill them with zero instead of NaN
if window_size[0] >= 0 or window_size[1] >= 0:
attention = attention.masked_fill(
torch.all(local_mask, dim=-1, keepdim=True), 0.0
)
dropout_scaling = 1.0 / (1 - dropout_p)
# attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
# output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
if dropout_mask is not None:
attention_drop = attention.masked_fill(~dropout_mask, 0.0)
else:
attention_drop = attention
if intermediate_dtype is not None:
attention_drop = attention_drop.to(intermediate_dtype).to(attention_drop.dtype)
output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling)
if query_padding_mask is not None:
output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])
@pytest.mark.parametrize(
"dtype", [torch.bfloat16] + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])
)
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
@pytest.mark.parametrize("mha_type", ["mha"])
@pytest.mark.parametrize("new_kv", [False] + ([True] if not DISABLE_APPENDKV else []))
# @pytest.mark.parametrize("new_kv", [True])
# @pytest.mark.parametrize(
# "causal,local",
# [(False, False), (True, False)] + ([(False, True)] if not DISABLE_LOCAL else []),
# )
# @pytest.mark.parametrize("causal,local", [(False, False), (True, False)])
@pytest.mark.parametrize("causal,local", [(False, False)])
@pytest.mark.parametrize(
"seqlen_new_eq_seqlen_q", [True, False] if not DISABLE_APPENDKV else [True]
)
# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True])
# @pytest.mark.parametrize("has_rotary_seqlens", [False, True])
@pytest.mark.parametrize("has_rotary_seqlens", [False])
@pytest.mark.parametrize(
"rotary_interleaved", [False, True] if not DISABLE_APPENDKV else [False]
)
# @pytest.mark.parametrize("rotary_interleaved", [True])
@pytest.mark.parametrize(
"rotary_fraction",
(
[0.0, 0.5, 1.0]
if (not DISABLE_APPENDKV) and (apply_rotary_emb is not None)
else [0.0]
),
)
# @pytest.mark.parametrize("rotary_fraction", [0.0])
@pytest.mark.parametrize(
"page_size", [None] + ([1, 4, 128] if not DISABLE_PAGEDKV else [])
)
# @pytest.mark.parametrize("page_size", [None])
# @pytest.mark.parametrize("has_leftpad", [False, True])
@pytest.mark.parametrize("has_leftpad", [False])
# @pytest.mark.parametrize("has_batch_idx", [False, True])
@pytest.mark.parametrize("has_batch_idx", [False])
# @pytest.mark.parametrize("varlen_q", [False, True])
@pytest.mark.parametrize("varlen_q", [False])
# @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
@pytest.mark.parametrize("d", [64])
# @pytest.mark.parametrize("d", [192])
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
(1, 128),
(1, 339),
(3, 1024),
(64, 800),
(64, 256),
(3, 799),
(64, 2048),
(16, 20000),
# (1, 128 * 1024),
# (16, 128 * 1024),
(128, 128),
(256, 512), # To test appending KV with more than 1 block
(2048, 3577), # Enough tile to test persistent scheduler
],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def test_flash_attn_kvcache(
seqlen_q,
seqlen_k,
d,
varlen_q,
has_batch_idx,
has_leftpad,
page_size,
rotary_fraction,
rotary_interleaved,
has_rotary_seqlens,
seqlen_new_eq_seqlen_q,
causal,
local,
new_kv,
mha_type,
dtype,
):
if page_size is not None and seqlen_k % page_size != 0:
pytest.skip()
if seqlen_q > seqlen_k and new_kv:
pytest.skip()
if not new_kv and rotary_fraction > 0.0:
pytest.skip()
if rotary_fraction == 0.0 and has_rotary_seqlens:
pytest.skip()
device = "cuda"
# set seed
torch.random.manual_seed(0)
batch_size = 5
# batch_size = 1
batch_size_cache = batch_size if not has_batch_idx else batch_size * 2
nheads = 6
# nheads = 1
# rotary_dim must be a multiple of 16, and must be <= d
rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
assert nheads % nheads_k == 0
dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype
dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d])
if dtype == torch.float8_e4m3fn:
dv_vals = [d]
for dv in dv_vals:
has_qv = d == 64 and dv >= 256
q = (
torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref)
.to(dtype)
.to(dtype_ref)
)
if has_qv:
qv = (
torch.randn(
batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref
)
.to(dtype)
.to(dtype_ref)
)
else:
qv = None
if varlen_q:
query_padding_mask = generate_random_padding_mask(
seqlen_q, batch_size, device, mode="random"
)
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, *rest = unpad_input(
q, query_padding_mask
)
output_pad_fn = lambda output_unpad: pad_input(
output_unpad, indices_q, batch_size, seqlen_q
)
qv_unpad = (
rearrange(qv, "b s ... -> (b s) ...")[indices_q] if has_qv else None
)
else:
query_padding_mask = None
q_unpad = q
qv_unpad = qv
cu_seqlens_q, max_seqlen_q = None, None
# Put window_size after QKV randn so that window_size changes from test to test
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
seqlen_new = (
seqlen_q
if seqlen_new_eq_seqlen_q
else torch.randint(1, seqlen_q + 1, (1,)).item()
)
cu_seqlens_k_new = None
key_new_padding_mask = None
if new_kv:
k = (
torch.randn(
batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype_ref
)
.to(dtype)
.to(dtype_ref)
)
v = (
torch.randn(
batch_size, seqlen_new, nheads_k, dv, device=device, dtype=dtype_ref
)
.to(dtype)
.to(dtype_ref)
)
if varlen_q: # k & v are also varlen
key_new_padding_mask = generate_random_padding_mask(
seqlen_new, batch_size, device, mode="random"
)
k_unpad, indices_k, cu_seqlens_k_new, *rest = unpad_input(
k, key_new_padding_mask
)
v_unpad, *rest = unpad_input(v, key_new_padding_mask)
else:
k_unpad, v_unpad = k, v
else:
k, v, k_unpad, v_unpad = None, None, None, None
if page_size is None:
k_cache = (
torch.randn(
batch_size_cache,
seqlen_k,
nheads_k,
d,
device=device,
dtype=dtype_ref,
)
.to(dtype)
.to(dtype_ref)
)
v_cache = (
torch.randn(
batch_size_cache,
seqlen_k,
nheads_k,
dv,
device=device,
dtype=dtype_ref,
)
.to(dtype)
.to(dtype_ref)
)
page_table = None
else:
(
k_cache,
v_cache,
page_table,
k_cache_paged,
v_cache_paged,
num_blocks,
) = _generate_block_kvcache(
seqlen_k,
page_size,
batch_size_cache,
nheads_k,
d,
dv,
device,
dtype,
dtype_ref,
)
cache_seqlens = torch.randint(
0 if new_kv else 1,
# If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough
(
(
seqlen_k
- (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new)
+ 1
)
if new_kv
else (seqlen_k + 1)
),
(batch_size,),
dtype=torch.int32,
device=device,
)
if has_leftpad:
cache_leftpad = torch.cat(
[
(
torch.randint(
0,
cache_seqlens[i].item(),
(1,),
dtype=torch.int32,
device=device,
)
if cache_seqlens[i].item() > 0
else torch.zeros(1, dtype=torch.int32, device=device)
)
for i in range(batch_size)
]
)
else:
cache_leftpad = None
if has_batch_idx:
cache_batch_idx = torch.randperm(
batch_size_cache, dtype=torch.int32, device=device
)[:batch_size]
else:
cache_batch_idx = None
arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s")
cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1")
if not new_kv:
key_padding_mask = arange < cache_seqlens_expanded
else:
k_new_seqlens = (
key_new_padding_mask.sum(-1, keepdims=True) if varlen_q else seqlen_new
)
key_padding_mask = arange < cache_seqlens_expanded + k_new_seqlens
if has_leftpad:
key_padding_mask = torch.logical_and(
key_padding_mask,
arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k),
)
# cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)
rotary_seqlens = cache_seqlens if not has_rotary_seqlens else cache_seqlens // 2
if rotary_dim > 0:
angle = (
torch.rand(
seqlen_k if page_size is None else num_blocks * page_size,
rotary_dim // 2,
device=device,
)
* 2
* math.pi
)
cos = torch.cos(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref)
sin = torch.sin(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref)
if causal or local:
q_ro = apply_rotary_emb(
q,
cos,
sin,
seqlen_offsets=rotary_seqlens,
interleaved=rotary_interleaved,
)
else:
q_ro = rearrange(
apply_rotary_emb(
rearrange(q, "b s h d -> b 1 (s h) d"),
cos,
sin,
seqlen_offsets=rotary_seqlens,
interleaved=rotary_interleaved,
),
"b 1 (s h) d -> b s h d",
s=seqlen_q,
)
# q_ro = q
k_ro = apply_rotary_emb(
k,
cos,
sin,
seqlen_offsets=rotary_seqlens,
interleaved=rotary_interleaved,
)
else:
cos, sin = None, None
q_ro, k_ro = q, k
# k_cache[:, 64:] = -1
k_cache_ref = (
k_cache if not has_batch_idx else k_cache[cache_batch_idx]
).clone()
v_cache_ref = (
v_cache if not has_batch_idx else v_cache[cache_batch_idx]
).clone()
if new_kv:
update_mask = torch.logical_and(
cache_seqlens_expanded <= arange,
arange < cache_seqlens_expanded + k_new_seqlens,
)
k_to_update = rearrange(k_ro, "b s ... -> (b s) ...")
v_to_update = rearrange(v, "b s ... -> (b s) ...")
if varlen_q:
k_to_update = k_to_update[indices_k]
v_to_update = v_to_update[indices_k]
k_cache_ref[update_mask] = k_to_update
v_cache_ref[update_mask] = v_to_update
k_cache_rep = repeat(
k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k
)
v_cache_rep = repeat(
v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k
)
out_ref, _ = attention_ref(
q_ro,
k_cache_rep,
v_cache_rep,
query_padding_mask,
key_padding_mask,
causal=causal,
qv=qv,
window_size=window_size,
key_leftpad=cache_leftpad,
)
out_pt, _ = attention_ref(
q_ro,
k_cache_rep,
v_cache_rep,
query_padding_mask,
key_padding_mask,
causal=causal,
qv=qv,
window_size=window_size,
upcast=False,
reorder_ops=True,
key_leftpad=cache_leftpad,
intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None,
)
q = q.to(dtype)
q_unpad = q_unpad.to(dtype) if varlen_q else None
k_cache = k_cache.to(dtype)
v_cache = v_cache.to(dtype)
k_cache_paged = k_cache_paged.to(dtype) if page_size is not None else None
v_cache_paged = v_cache_paged.to(dtype) if page_size is not None else None
k = k.to(dtype) if k is not None else None
v = v.to(dtype) if v is not None else None
k_unpad = k_unpad.to(dtype) if k_unpad is not None else None
v_unpad = v_unpad.to(dtype) if v_unpad is not None else None
qv = qv.to(dtype) if qv is not None else None
qv_unpad = qv_unpad.to(dtype) if (varlen_q and qv is not None) else None
cos = cos.to(dtype) if cos is not None else None
sin = sin.to(dtype) if sin is not None else None
k_cache_saved = k_cache.clone() if page_size is None else k_cache_paged.clone()
v_cache_saved = v_cache.clone() if page_size is None else v_cache_paged.clone()
num_splits_vals = [1, 0] if not DISABLE_SPLIT else [1]
precompute_metadata_vals = [False]
for num_splits, precompute_metadata in itertools.product(
num_splits_vals, precompute_metadata_vals
):
scheduler_metadata = None
# Repeat to test metadata reuse
for _ in range(1 if not precompute_metadata else 2):
if page_size is None:
k_cache.copy_(k_cache_saved)
v_cache.copy_(v_cache_saved)
else:
k_cache_paged.copy_(k_cache_saved)
v_cache_paged.copy_(v_cache_saved)
out, lse, *rest = flash_attn_with_kvcache(
q if not varlen_q else q_unpad,
k_cache if page_size is None else k_cache_paged,
v_cache if page_size is None else v_cache_paged,
k if not new_kv or not varlen_q else k_unpad,
v if not new_kv or not varlen_q else v_unpad,
qv=qv if not varlen_q else qv_unpad,
rotary_cos=cos,
rotary_sin=sin,
cache_seqlens=cache_seqlens,
cache_batch_idx=cache_batch_idx,
cache_leftpad=cache_leftpad,
page_table=page_table,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k_new=cu_seqlens_k_new,
max_seqlen_q=max_seqlen_q,
rotary_seqlens=rotary_seqlens,
causal=causal,
window_size=window_size,
rotary_interleaved=rotary_interleaved,
scheduler_metadata=scheduler_metadata,
num_splits=num_splits,
return_softmax_lse=True,
)
if varlen_q:
out = output_pad_fn(out)
# out = flash_attn_with_kvcache(
# q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size
# )
# out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size)
# qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref)
# m = qk.amax(-1, keepdim=True)
# s_tmp = torch.exp((qk - m) / math.sqrt(d))
# o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref)
# lse_ref = torch.logsumexp(qk / math.sqrt(d), -1)
# probs = torch.softmax(qk, dim=-1)
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
# breakpoint()
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
if new_kv:
if page_size is None:
k_cache_select = (
k_cache.to(dtype_ref)
if not has_batch_idx
else k_cache.to(dtype_ref)[cache_batch_idx]
)
v_cache_select = (
v_cache.to(dtype_ref)
if not has_batch_idx
else v_cache.to(dtype_ref)[cache_batch_idx]
)
else:
k_cache_select = rearrange(
k_cache_paged.to(dtype_ref)[
(
page_table
if not has_batch_idx
else page_table[cache_batch_idx]
).flatten()
],
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
b=batch_size,
)[:, :seqlen_k].to(dtype_ref)
v_cache_select = rearrange(
v_cache_paged.to(dtype_ref)[
(
page_table
if not has_batch_idx
else page_table[cache_batch_idx]
).flatten()
],
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
b=batch_size,
)[:, :seqlen_k].to(dtype_ref)
k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref)
v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref)
if dtype is not torch.float8_e4m3fn:
assert torch.equal(v_cache_select, v_cache_ref)
else:
assert torch.allclose(
v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3
)
# breakpoint()
# if rotary_dim == 0 and dtype is not torch.float8_e4m3fn:
if rotary_dim == 0:
assert torch.equal(k_cache_select, k_cache_ref)
else:
# if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3):
# breakpoint()
if dtype is not torch.float8_e4m3fn:
assert torch.allclose(
k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3
)
else:
assert torch.allclose(
k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1
)
mult = 4 if dtype == torch.float8_e4m3fn else 2
assert (out - out_ref).abs().max().item() <= mult * (
out_pt - out_ref
).abs().max().item() + 1e-5
mult_mean = 3 if dtype == torch.float8_e4m3fn else 1.5
assert (out - out_ref).abs().mean().item() <= mult_mean * (
out_pt - out_ref
).abs().mean().item()
def _generate_block_kvcache(
seqlen_k, page_size, batch_size, nheads_k, d, dv, device, dtype, dtype_ref
):
num_blocks = math.ceil(seqlen_k / page_size) * batch_size * 3
k_cache_paged = (
torch.randn(num_blocks, page_size, nheads_k, d, device=device, dtype=dtype_ref)
.to(dtype)
.to(dtype_ref)
)
v_cache_paged = (
torch.randn(num_blocks, page_size, nheads_k, dv, device=device, dtype=dtype_ref)
.to(dtype)
.to(dtype_ref)
)
page_table = rearrange(
torch.randperm(num_blocks, dtype=torch.int32, device=device),
"(b nblocks) -> b nblocks",
b=batch_size,
)
k_cache = rearrange(
k_cache_paged[page_table.flatten()],
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
b=batch_size,
)[:, :seqlen_k]
v_cache = rearrange(
v_cache_paged[page_table.flatten()],
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
b=batch_size,
)[:, :seqlen_k]
return k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks
if __name__ == "__main__":
pytest.main([__file__])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment