Unverified Commit 17944550 authored by Shengyu Liu's avatar Shengyu Liu Committed by GitHub
Browse files

Merge pull request #98 from deepseek-ai/open-source-h

Add Sparse Attention Kernels on Hopper
parents ebf30641 3969f20b
#pragma once
enum NamedBarriers : uint32_t {
sScale_and_sS_ready = 0,
sScale_and_sS_free = 1,
oBuf_free_and_sL_ready = 2,
epilogue_r2s_ready = 3,
batch_loop_sync = 4,
warpgroup0_sync = 5
};
This diff is collapsed.
#pragma once
#include "params.h"
namespace sm90 {
void run_flash_splitkv_mla_fp8_sparse_kernel(DecodingParams &params, cudaStream_t stream);
}
#pragma once
#include "params.h"
void run_get_mla_metadata_kernel(Mla_metadata_params &params, cudaStream_t stream);
This diff is collapsed.
#pragma once
#include "params.h"
namespace sm90 {
void run_fwd_kernel(const SparsePrefillParams& params);
}
#pragma once
#include <cutlass/bfloat16.h>
#include <cutlass/arch/barrier.h>
#include <cute/tensor.hpp>
namespace sm90 {
using bf16 = cutlass::bfloat16_t;
using transac_bar_t = cutlass::arch::ClusterTransactionBarrier;
using cutlass::arch::fence_view_async_shared;
using cutlass::arch::fence_barrier_init;
using cutlass::arch::NamedBarrier;
__forceinline__ __device__ void cp_async_cacheglobal_l2_prefetch_256B(const void* src, void* dst) {
uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst);
asm volatile("cp.async.cg.shared.global.L2::256B [%0], [%1], %2;\n"
:: "r"(dst_addr),
"l"(src),
"n"(16));
}
__forceinline__ __device__ void cp_async_cacheglobal_l2_prefetch_256B(const void* src, void* dst, bool pred, int64_t cache_policy) {
uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst);
asm volatile("cp.async.cg.shared.global.L2::cache_hint.L2::256B [%0], [%1], 16, %2, %3;\n"
:: "r"(dst_addr),
"l"(src),
"r"(pred?16:0),
"l"(cache_policy));
}
__forceinline__ __device__ int64_t createpolicy_evict_last() {
int64_t res;
asm volatile(
"createpolicy.fractional.L2::evict_last.b64 %0, 1.0; \n\t"
: "=l"(res)
:
);
return res;
}
__forceinline__ __device__ int64_t createpolicy_evict_first() {
int64_t res;
asm volatile(
"createpolicy.fractional.L2::evict_first.b64 %0, 1.0; \n\t"
: "=l"(res)
:
);
return res;
}
__forceinline__ __device__ int get_AorC_row_idx(int local_row_idx, int idx_in_warpgroup) {
// In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in two particular rows. This function converts the local_row_idx (0~2) to the actual row_idx
// You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a
int row_idx = (idx_in_warpgroup/32)*16 + local_row_idx*8 + (idx_in_warpgroup%32/4);
return row_idx;
}
__forceinline__ __device__ int get_AorC_col_idx(int local_elem_idx, int idx_in_warpgroup) {
int col_idx = 8*(local_elem_idx/4) + (idx_in_warpgroup%4)*2 + (local_elem_idx&1);
return col_idx;
}
// Adapted from https://github.com/Dao-AILab/flash-attention/blob/cdaf2de6e95cb05400959b5ab984f66e4c7df317/hopper/utils.h
// * Copyright (c) 2024, Tri Dao.
template <bool zero_init=false, int wg_wait=0, bool arrive=true, bool commit=true, typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>
__forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) {
using namespace cute;
constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;
// Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const
if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
warpgroup_fence_operand(tCrC);
if constexpr (arrive) {
warpgroup_arrive();
}
if constexpr (zero_init) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
} else {
// cute::gemm(tiled_mma, tCrA, tCrB, tCrC);
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
}
if constexpr (commit) {
warpgroup_commit_batch();
}
if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
warpgroup_fence_operand(tCrC);
if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
}
// A simpiler version of gemm
template <typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>
__forceinline__ __device__ void gemm_ss(bool clear_accum, TiledMma tiled_mma, Tensor0 const &sA, Tensor1 const &sB, Tensor2 &rC_frag, int idx_in_warpgroup) {
using namespace cute;
ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup);
Tensor sA_frag = thr_mma.partition_fragment_A(sA);
Tensor sB_frag = thr_mma.partition_fragment_B(sB);
static_assert(size<2>(sA_frag) == size<2>(sB_frag));
warpgroup_fence_operand(rC_frag);
warpgroup_arrive();
tiled_mma.accumulate_ = clear_accum ? GMMA::ScaleOut::Zero : GMMA::ScaleOut::One;
CUTLASS_PRAGMA_UNROLL
for (int k = 0; k < size<2>(sA_frag); ++k) {
cute::gemm(tiled_mma, sA_frag(_, _, k), sB_frag(_, _, k), rC_frag);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
warpgroup_fence_operand(rC_frag);
}
template <typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>
__forceinline__ __device__ void gemm_rs(bool clear_accum, TiledMma tiled_mma, Tensor0 rA_frag, Tensor1 const &sB, Tensor2 &rC_frag, int idx_in_warpgroup) {
using namespace cute;
ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup);
Tensor sB_frag = thr_mma.partition_fragment_B(sB);
static_assert(size<2>(rA_frag) == size<2>(sB_frag));
warpgroup_fence_operand(const_cast<Tensor0 &>(rA_frag));
warpgroup_fence_operand(rC_frag);
warpgroup_arrive();
tiled_mma.accumulate_ = clear_accum ? GMMA::ScaleOut::Zero : GMMA::ScaleOut::One;
CUTLASS_PRAGMA_UNROLL
for (int k = 0; k < size<2>(rA_frag); ++k) {
cute::gemm(tiled_mma, rA_frag(_, _, k), sB_frag(_, _, k), rC_frag);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
warpgroup_fence_operand(rC_frag);
warpgroup_fence_operand(const_cast<Tensor0 &>(rA_frag));
}
__forceinline__ __device__ uint32_t get_sm_id() {
uint32_t ret;
asm("mov.u32 %0, %smid;" : "=r"(ret));
return ret;
}
static constexpr int PEER_ADDR_MASK = 16777216; // peer_addr = my_addr ^ PEER_ADDR_MASK. 不确定是不是在所有显卡上都是这个数字
template<typename T>
CUTE_DEVICE
T* get_peer_addr(const T* p) {
return (T*)((int64_t)(p) ^ PEER_ADDR_MASK);
}
template<
typename TMA,
typename Tensor0,
typename Tensor1
>
CUTE_DEVICE
void launch_tma_copy(
const TMA &tma_copy,
Tensor0 src,
Tensor1 dst,
transac_bar_t &bar,
const cute::TMA::CacheHintSm90 &cache_hint = cute::TMA::CacheHintSm90::EVICT_NORMAL
) {
auto thr_tma = tma_copy.get_slice(cute::_0{});
cute::copy(
tma_copy.with(reinterpret_cast<typename transac_bar_t::ValueType&>(bar), 0, cache_hint),
thr_tma.partition_S(src),
thr_tma.partition_D(dst)
);
}
}
......@@ -6,7 +6,7 @@
#include "utils.h"
__global__ void __launch_bounds__(32, 1, 1)
get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) {
get_mla_metadata_kernel(__grid_constant__ const GetDecodingMetadataParams params) {
int *seqlens_k_ptr = params.seqlens_k_ptr;
int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr;
int *num_splits_ptr = params.num_splits_ptr;
......@@ -18,12 +18,26 @@ get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) {
extern __shared__ int shared_mem[];
int* num_blocks_shared = shared_mem; // [batch_size]
int* num_splits_shared = shared_mem + batch_size; // [batch_size+1]
int* seqlens_k_shared = shared_mem + batch_size*2+1; // [batch_size]
int* first_block_idx_shared = shared_mem + batch_size*3+1; // [batch_size]
int* last_block_idx_shared = shared_mem + batch_size*4+1; // [batch_size]
int total_num_blocks = 0;
for (int i = threadIdx.x; i < batch_size; i += 32) {
int num_blocks = cutlass::ceil_div(seqlens_k_ptr[i], block_size_n);
int cur_s_k = params.topk == -1 ? __ldg(seqlens_k_ptr + i) : params.topk;
seqlens_k_shared[i] = cur_s_k;
int first_token_idx = 0;
int last_token_idx = max(cur_s_k-1, 0);
int cur_first_block_idx = first_token_idx / block_size_n;
int cur_last_block_idx = last_token_idx / block_size_n;
// NOTE Should attend to tokens [first_token_idx, last_token_idx], i.e. blocks [cur_first_block_idx, cur_last_block_idx]
// NOTE Before clamping, first_token_idx <= last_token_idx always holds, so after clamping, first_token_idx <= last_token_idx still holds.
// NOTE if seqlens_k is 0, then first_token_idx == last_token_idx == cur_first_block_idx == cur_last_block_idx == 0. So the sequence will have 1 block. We will correct this later in this kernel.
int num_blocks = cur_last_block_idx - cur_first_block_idx + 1;
total_num_blocks += num_blocks + fixed_overhead_num_blocks;
num_blocks_shared[i] = num_blocks;
first_block_idx_shared[i] = cur_first_block_idx;
last_block_idx_shared[i] = cur_last_block_idx;
}
for (int offset = 16; offset >= 1; offset /= 2) {
total_num_blocks += __shfl_xor_sync(uint32_t(-1), total_num_blocks, offset);
......@@ -31,14 +45,14 @@ get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) {
__syncwarp();
if (threadIdx.x == 0) {
int payload = max(cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks, 2*fixed_overhead_num_blocks);
int payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks;
int now_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0;
num_splits_shared[0] = 0;
for (int i = 0; i < num_sm_parts; ++i) {
int tile_scheduler_metadata0[4], tile_scheduler_metadata1;
tile_scheduler_metadata0[0] = now_idx;
tile_scheduler_metadata0[1] = now_block * block_size_n;
tile_scheduler_metadata0[1] = now_block + first_block_idx_shared[now_idx];
tile_scheduler_metadata1 = now_n_split_idx;
int remain_payload = payload;
while (now_idx < batch_size) {
......@@ -61,7 +75,7 @@ get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) {
}
}
tile_scheduler_metadata0[2] = now_block > 0 ? now_idx : now_idx - 1;
tile_scheduler_metadata0[3] = now_block > 0 ? now_block * block_size_n : seqlens_k_ptr[now_idx - 1];
tile_scheduler_metadata0[3] = now_block > 0 ? now_block + first_block_idx_shared[now_idx] : (seqlens_k_shared[now_idx-1] == 0 ? 0 : last_block_idx_shared[now_idx-1] + 1);
*reinterpret_cast<int4 *>(tile_scheduler_metadata_ptr + i * TileSchedulerMetaDataSize) = *reinterpret_cast<int4 *>(tile_scheduler_metadata0);
tile_scheduler_metadata_ptr[i * TileSchedulerMetaDataSize + 4] = tile_scheduler_metadata1;
}
......@@ -74,8 +88,8 @@ get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) {
}
}
void run_get_mla_metadata_kernel(Mla_metadata_params &params, cudaStream_t stream) {
int smem_size = sizeof(int) * (params.batch_size*2+1);
void run_get_mla_metadata_kernel(GetDecodingMetadataParams &params, cudaStream_t stream) {
int smem_size = sizeof(int) * (params.batch_size*5+1);
CHECK_CUDA(cudaFuncSetAttribute(get_mla_metadata_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
get_mla_metadata_kernel<<<1, 32, smem_size, stream>>>(params);
CHECK_CUDA_KERNEL_LAUNCH();
......
#pragma once
#include "params.h"
void run_get_mla_metadata_kernel(GetDecodingMetadataParams &params, cudaStream_t stream);
......@@ -7,13 +7,12 @@
#include "params.h"
#include "utils.h"
#include "config.h" // for BLOCK_SIZE_M and HEAD_DIM_V
using namespace cute;
template<typename ElementT, int HEAD_DIM_V, int BLOCK_SIZE_M, int MAX_SPLITS, int NUM_THREADS>
__global__ void __launch_bounds__(NUM_THREADS)
flash_fwd_mla_combine_kernel(__grid_constant__ const Flash_fwd_mla_params params) {
flash_fwd_mla_combine_kernel(__grid_constant__ const DecodingParams params) {
// grid_shape: [batch_size, num_q_heads*s_q / BLOCK_SIZE_M]
// Each CTA gathers the activation of some heads from one batch, do scaling & accumulation, and save the result
static_assert(NUM_THREADS/32 == BLOCK_SIZE_M); // The number of warps == block_size_m
......@@ -176,12 +175,14 @@ flash_fwd_mla_combine_kernel(__grid_constant__ const Flash_fwd_mla_params params
template<typename ElementT>
void run_flash_mla_combine_kernel(Flash_fwd_mla_params &params, cudaStream_t stream) {
void run_flash_mla_combine_kernel(DecodingParams &params, cudaStream_t stream) {
static constexpr int HEAD_DIM_V = 512; // Since only this head dimension is supported by Flash MLA
FLASH_ASSERT(params.d_v == HEAD_DIM_V);
MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, NUM_SPLITS, [&] {
constexpr int BLOCK_SIZE_M = 8;
constexpr int NUM_THREADS = BLOCK_SIZE_M*32;
constexpr size_t smem_size = BLOCK_SIZE_M*(NUM_SPLITS+1)*sizeof(float);
auto combine_kernel = &flash_fwd_mla_combine_kernel<ElementT, Config::HEAD_DIM_V, BLOCK_SIZE_M, NUM_SPLITS, NUM_THREADS>;
auto combine_kernel = &flash_fwd_mla_combine_kernel<ElementT, HEAD_DIM_V, BLOCK_SIZE_M, NUM_SPLITS, NUM_THREADS>;
CHECK_CUDA(cudaFuncSetAttribute(combine_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
// Use cudaLaunchKernelEx to enable PDL (Programmatic Dependent Launch)
cudaLaunchAttribute attribute[1];
......@@ -200,8 +201,8 @@ void run_flash_mla_combine_kernel(Flash_fwd_mla_params &params, cudaStream_t str
CHECK_CUDA_KERNEL_LAUNCH();
}
template void run_flash_mla_combine_kernel<cutlass::bfloat16_t>(Flash_fwd_mla_params &params, cudaStream_t stream);
template void run_flash_mla_combine_kernel<cutlass::bfloat16_t>(DecodingParams &params, cudaStream_t stream);
#ifndef FLASH_MLA_DISABLE_FP16
template void run_flash_mla_combine_kernel<cutlass::half_t>(Flash_fwd_mla_params &params, cudaStream_t stream);
template void run_flash_mla_combine_kernel<cutlass::half_t>(DecodingParams &params, cudaStream_t stream);
#endif
\ No newline at end of file
......@@ -3,4 +3,4 @@
#include "params.h"
template<typename ElementT>
void run_flash_mla_combine_kernel(Flash_fwd_mla_params &params, cudaStream_t stream);
void run_flash_mla_combine_kernel(DecodingParams &params, cudaStream_t stream);
......@@ -30,3 +30,37 @@
} while(0)
#define println(fmt, ...) { print(fmt, ##__VA_ARGS__); print("\n"); }
template<typename T>
__inline__ __host__ __device__ T ceil_div(const T &a, const T &b) {
return (a + b - 1) / b;
}
#ifndef TRAP_ONLY_DEVICE_ASSERT
#define TRAP_ONLY_DEVICE_ASSERT(cond) \
do { \
if (not (cond)) \
asm("trap;"); \
} while (0)
#endif
// For development, we define both IS_SM100 and IS_SM90 when using CLion or VSCode IDEs so code highlighting will be correct.
#if defined(__CLION_IDE__) || defined(__VSCODE_IDE__)
#define IS_SM100 1
#define IS_SM90 1
#else
// We define the following macros to detect the CUDA architecture, so that we can enable/disable certains kernels that depends on specific architectures.
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 1000)
#define IS_SM100 1
#else
#define IS_SM100 0
#endif
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 900)
#define IS_SM90 1
#else
#define IS_SM90 0
#endif
#endif // defined(__CLION_IDE__) || defined(__VSCODE_IDE__)
\ No newline at end of file
......@@ -6,4 +6,5 @@ from flash_mla.flash_mla_interface import (
flash_attn_varlen_func,
flash_attn_varlen_qkvpacked_func,
flash_attn_varlen_kvpacked_func,
flash_mla_sparse_fwd
)
......@@ -2,30 +2,33 @@ from typing import Optional, Tuple
import torch
import flash_mla_sm90
import flash_mla_sm100
import flash_mla.cuda as flash_mla_cuda
def get_mla_metadata(
cache_seqlens: torch.Tensor,
num_heads_per_head_k: int,
num_q_tokens_per_head_k: int,
num_heads_k: int,
num_heads_q: Optional[int] = None,
is_fp8_kvcache: bool = False,
topk: Optional[int] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
cache_seqlens: (batch_size), dtype torch.int32.
num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
num_heads_k: num_heads_k.
num_q_tokens_per_head_k: Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k.
num_heads_k: The number of k heads.
num_heads_q: The number of q heads. This argument is optional when sparse attention is not enabled
is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format.
topk: If not None, sparse attention will be enabled, and only tokens in the `indices` array passed to `flash_mla_with_kvcache_sm90` will be attended to.
Returns:
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
num_splits: (batch_size + 1), dtype torch.int32.
"""
return flash_mla_sm90.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k)
return flash_mla_cuda.get_mla_decoding_metadata(cache_seqlens, num_q_tokens_per_head_k, num_heads_k, num_heads_q, is_fp8_kvcache, topk)
def flash_mla_with_kvcache_sm90(
def flash_mla_with_kvcache(
q: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
......@@ -35,6 +38,8 @@ def flash_mla_with_kvcache_sm90(
num_splits: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
is_fp8_kvcache: bool = False,
indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
......@@ -47,6 +52,8 @@ def flash_mla_with_kvcache_sm90(
num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
is_fp8_kvcache: bool. Whether the k_cache and v_cache are in fp8 format. For the format of FP8 KV cache, please refer to README.md
indices: (batch_size, seq_len_q, topk), torch.int32. If not None, sparse attention will be enabled, and only tokens in the `indices` array will be attended to. Invalid indices should be set to -1 or numbers >= total_seq_len_kv. For details about how to set up `indices`, please refer to README.md.
Returns:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
......@@ -54,7 +61,9 @@ def flash_mla_with_kvcache_sm90(
"""
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
out, softmax_lse = flash_mla_sm90.fwd_kvcache_mla(
if indices is not None:
assert causal == False, "causal must be `false` if sparse attention is enabled."
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
q,
k_cache,
head_dim_v,
......@@ -64,10 +73,42 @@ def flash_mla_with_kvcache_sm90(
causal,
tile_scheduler_metadata,
num_splits,
is_fp8_kvcache,
indices
)
return out, softmax_lse
def flash_mla_sparse_fwd(
q: torch.Tensor,
kv: torch.Tensor,
indices: torch.Tensor,
sm_scale: float,
d_v: int = 512,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Sparse attention prefill kernel
Args:
q: [s_q, h_q, d_qk], bfloat16
kv: [s_kv, h_kv, d_qk], bfloat16
indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1 or numbers >= s_kv
sm_scale: float
d_v: The dimension of value vectors. Can only be 512
Returns:
(output, max_logits, lse)
About the definition of output, max_logits and lse, please refer to README.md
- output: [s_q, h_q, d_v], bfloat16
- max_logits: [s_q, h_q], float
- lse: [s_q, h_q], float, 2-based log-sum-exp
"""
results = flash_mla_cuda.sparse_prefill_fwd(
q, kv, indices, sm_scale, d_v
)
return results
def _flash_attn_varlen_forward(
q: torch.Tensor,
k: torch.Tensor,
......@@ -96,7 +137,7 @@ def _flash_attn_varlen_forward(
lse = torch.empty(num_qo_heads, qo_total_len, device=q.device, dtype=torch.float32).T
workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.uint8, device=q.device)
flash_mla_sm100.fwd(
flash_mla_cuda.dense_prefill_fwd(
workspace_buffer,
q,
k,
......@@ -159,7 +200,7 @@ def _flash_attn_varlen_backward(
if num_qo_heads != num_kv_heads:
workspace_bytes += 2 * kv_total_len * num_qo_heads * (head_dim_qk + head_dim_vo) # dKV_acc
workspace_buffer = torch.empty(workspace_bytes, dtype=torch.uint8, device=q.device)
flash_mla_sm100.bwd(
flash_mla_cuda.dense_prefill_bwd(
workspace_buffer,
do,
q,
......@@ -195,7 +236,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
causal: bool = False,
softmax_scale: Optional[float] = None,
is_varlen: bool = True,
):
) -> Tuple[torch.Tensor, torch.Tensor]:
out, lse = _flash_attn_varlen_forward(
q, k, v,
cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv,
......@@ -290,40 +331,3 @@ def flash_attn_varlen_kvpacked_func(
cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv,
causal, softmax_scale, is_varlen,
)
def flash_mla_with_kvcache_sm100(
q: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
softmax_scale: Optional[float] = None,
causal: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
# TODO
pass
def flash_mla_with_kvcache(
q: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
tile_scheduler_metadata: Optional[torch.Tensor] = None,
num_splits: Optional[torch.Tensor] = None,
softmax_scale: Optional[float] = None,
causal: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
capability = torch.cuda.get_device_capability(q.device.index)
if capability == (9, 0):
return flash_mla_with_kvcache_sm90(
q, k_cache, block_table, cache_seqlens, head_dim_v,
tile_scheduler_metadata, num_splits,
softmax_scale, causal,
)
elif capability == (10, 0):
raise ValueError(f"Unsupported device capability: {capability}")
else:
raise ValueError(f"Unsupported device capability: {capability}")
......@@ -12,29 +12,31 @@ from torch.utils.cpp_extension import (
)
def append_nvcc_threads(nvcc_extra_args):
nvcc_threads = os.getenv("NVCC_THREADS") or "32"
return nvcc_extra_args + ["--threads", nvcc_threads]
def is_flag_set(flag: str) -> bool:
return os.getenv(flag, "FALSE").lower() in ["true", "1", "y", "yes"]
def get_features_args():
features_args = []
DISABLE_FP16 = os.getenv("FLASH_MLA_DISABLE_FP16", "FALSE") in ["TRUE", "1"]
if DISABLE_FP16:
if is_flag_set("FLASH_MLA_DISABLE_FP16"):
features_args.append("-DFLASH_MLA_DISABLE_FP16")
return features_args
def get_arch_flags():
DISABLE_SM100 = is_flag_set("FLASH_MLA_DISABLE_SM100")
DISABLE_SM90 = is_flag_set("FLASH_MLA_DISABLE_SM90")
arch_flags = []
if not DISABLE_SM100:
arch_flags.extend(["-gencode", "arch=compute_100a,code=sm_100a"])
if not DISABLE_SM90:
arch_flags.extend(["-gencode", "arch=compute_90a,code=sm_90a"])
return arch_flags
def get_nvcc_thread_args():
nvcc_threads = os.getenv("NVCC_THREADS") or "32"
return ["--threads", nvcc_threads]
subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"])
cc_flag_sm90 = []
cc_flag_sm90.append("-gencode")
cc_flag_sm90.append("arch=compute_90a,code=sm_90a")
cc_flag_sm100 = []
cc_flag_sm100.append("-gencode")
cc_flag_sm100.append("arch=compute_100a,code=sm_100a")
this_dir = os.path.dirname(os.path.abspath(__file__))
if IS_WINDOWS:
......@@ -45,17 +47,20 @@ else:
ext_modules = []
ext_modules.append(
CUDAExtension(
name="flash_mla_sm90",
name="flash_mla.cuda",
sources=[
"csrc/sm90/flash_api.cpp",
"csrc/sm90/kernels/get_mla_metadata.cu",
"csrc/sm90/kernels/mla_combine.cu",
"csrc/sm90/kernels/splitkv_mla.cu",
"csrc/pybind.cpp",
"csrc/smxx/get_mla_metadata.cu",
"csrc/smxx/mla_combine.cu",
"csrc/sm90/decode/dense/splitkv_mla.cu",
"csrc/sm90/decode/sparse_fp8/splitkv_mla.cu",
"csrc/sm90/prefill/sparse/fwd.cu",
"csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu",
"csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu",
],
extra_compile_args={
"cxx": cxx_args + get_features_args(),
"nvcc": append_nvcc_threads(
[
"nvcc": [
"-O3",
"-std=c++17",
"-DNDEBUG",
......@@ -69,55 +74,17 @@ ext_modules.append(
"--expt-extended-lambda",
"--use_fast_math",
"--ptxas-options=-v,--register-usage-level=10"
]
+ cc_flag_sm90
) + get_features_args(),
] + get_features_args() + get_arch_flags() + get_nvcc_thread_args(),
},
include_dirs=[
Path(this_dir) / "csrc",
Path(this_dir) / "csrc" / "sm90",
Path(this_dir) / "csrc" / "cutlass" / "include",
],
)
)
ext_modules.append(
CUDAExtension(
name="flash_mla_sm100",
sources=[
"csrc/sm100/pybind.cu",
"csrc/sm100/fmha_cutlass_fwd_sm100.cu",
"csrc/sm100/fmha_cutlass_bwd_sm100.cu",
],
extra_compile_args={
"cxx": ["-O3", "-std=c++17", "-DNDEBUG", "-Wno-deprecated-declarations"],
"nvcc": append_nvcc_threads(
[
"-O3",
"-std=c++17",
"-DNDEBUG",
"-Wno-deprecated-declarations",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
"-lineinfo",
"--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage",
]
+ cc_flag_sm100
),
},
include_dirs=[
Path(this_dir) / "csrc" / "sm100",
Path(this_dir) / "csrc" / "cutlass" / "include",
Path(this_dir) / "csrc" / "cutlass" / "tools" / "util" / "include",
],
)
)
try:
cmd = ['git', 'rev-parse', '--short', 'HEAD']
rev = '+' + subprocess.check_output(cmd).decode('ascii').rstrip()
......
from typing import List
import torch
def cdiv(x: int, y: int):
return (x+y-1) // y
def check_is_allclose(name: str, ans: torch.Tensor, ref: torch.Tensor, abs_tol: float = 1e-5, rel_tol: float = 1e-2, cos_diff_tol: float = 1e-7):
"""
Check if two tensors are close enough
"""
def get_cos_diff(x: torch.Tensor, y: torch.Tensor) -> float:
"""
Calculate the cosine diff between two tensors
"""
x, y = x.double(), y.double()
denominator = (x * x + y * y).sum().item()
if denominator == 0:
return 0
sim = 2 * (x * y).sum().item() / denominator
return 1 - sim
assert ans.shape == ref.shape, f"`{name}` Shape mismatch: {ans.shape} vs {ref.shape}"
ans = ans.clone().to(torch.float)
ref = ref.clone().to(torch.float)
# Deal with anomalies
def deal_with_anomalies(val: float):
ref_mask = (ref == val) if (val == val) else (ref != ref)
ans_mask = (ans == val) if (val == val) else (ans != ans)
ref[ref_mask] = 0.0
ans[ans_mask] = 0.0
if not torch.equal(ref_mask, ans_mask):
print(f"`{name}` Anomaly number `{val}` mismatch: {ans_mask.sum().item()} in ans but {ref_mask.sum().item()} in ref")
return False
return True
anomalies_check_passed = True
anomalies_check_passed &= deal_with_anomalies(float("inf"))
anomalies_check_passed &= deal_with_anomalies(float("-inf"))
anomalies_check_passed &= deal_with_anomalies(float("nan"))
if not anomalies_check_passed:
return False
cos_diff = get_cos_diff(ans, ref)
raw_abs_err = torch.abs(ans-ref)
raw_rel_err = raw_abs_err / (torch.abs(ref)+(1e-6))
rel_err = raw_rel_err.masked_fill(raw_abs_err<abs_tol, 0)
abs_err = raw_abs_err.masked_fill(raw_rel_err<rel_tol, 0)
pass_mask = (abs_err < abs_tol) | (rel_err < rel_tol)
if not pass_mask.all():
print(f"`{name}` mismatch")
max_abs_err_pos: int = torch.argmax(abs_err, keepdim=True).item() # type: ignore
max_rel_err_pos: int = torch.argmax(rel_err, keepdim=True).item() # type: ignore
def get_pos_in_tensor(t: torch.Tensor, pos: int) -> List[int]:
result = []
for size in t.shape[::-1]:
result.append(pos % size)
pos = pos // size
assert pos == 0
return result[::-1]
print(f"max abs err: {torch.max(abs_err).item()}: pos {get_pos_in_tensor(ans, max_abs_err_pos)}, {ans.reshape(-1)[max_abs_err_pos].item()} vs {ref.reshape(-1)[max_abs_err_pos].item()}")
print(f"max rel err: {torch.max(rel_err).item()}: pos {get_pos_in_tensor(ans, max_rel_err_pos)}, {ans.reshape(-1)[max_rel_err_pos].item()} vs {ref.reshape(-1)[max_rel_err_pos].item()}")
print(f"{pass_mask.sum()} out of {pass_mask.numel()} passed ({pass_mask.sum()/pass_mask.numel()*100.0:.2f}%)")
print(f"Cosine diff: {cos_diff} (threshold: {cos_diff_tol})")
return False
else:
if abs(cos_diff) > cos_diff_tol:
print(f"`{name}` mismatch: Cosine diff too large: {cos_diff} vs {cos_diff_tol})")
return False
return True
\ No newline at end of file
import enum
import torch
def quantize_k_cache(
input_k_cache: torch.Tensor, # (num_blocks, block_size, h_k, d)
dv: int,
tile_size: int = 128,
) -> torch.Tensor:
"""
Quantize the k-cache
Return a tensor with shape (num_blocks, block_size, h_k, dv + 4(dv/tile_size) + t(d-dv)) of dtype uint8_t, where t = input_k_cache.element_size()
For more detail about the layout of K/V, please refer to comments in flash_mla_interface.py or README.md
"""
assert dv % tile_size == 0
num_tiles = dv // tile_size
num_blocks, block_size, h_k, d = input_k_cache.shape
assert h_k == 1
input_k_cache = input_k_cache.squeeze(2) # [num_blocks, block_size, d]
input_elem_size = input_k_cache.element_size()
result = torch.empty((num_blocks, block_size, dv + num_tiles*4 + input_elem_size*(d-dv)), dtype=torch.float8_e4m3fn, device=input_k_cache.device)
result_k_nope_part = result[..., :dv]
result_k_scale_factor = result[..., dv:dv + num_tiles*4].view(torch.float32)
result_k_rope_part = result[..., dv + num_tiles*4:].view(input_k_cache.dtype)
result_k_rope_part[:] = input_k_cache[..., dv:]
for tile_idx in range(0, num_tiles):
cur_scale_factors_inv = torch.abs(input_k_cache[..., tile_idx*tile_size:(tile_idx+1)*tile_size]).max(dim=-1).values / 448.0 # [num_blocks, block_size]
result_k_scale_factor[:, :, tile_idx] = cur_scale_factors_inv
cur_scale_factors_inv.unsqueeze_(-1) # [num_blocks, block_size, 1]
cur_quantized_nope = (input_k_cache[..., tile_idx*tile_size:(tile_idx+1)*tile_size].float() / cur_scale_factors_inv.float()).to(torch.float8_e4m3fn)
result_k_nope_part[..., tile_idx*tile_size:(tile_idx+1)*tile_size] = cur_quantized_nope
result = result.view(num_blocks, block_size, 1, -1)
return result
def dequantize_k_cache(
quant_k_cache: torch.Tensor, # (num_blocks, block_size, 1, bytes_per_token)
dv: int = 512,
tile_size: int = 128,
d: int = 576
) -> torch.Tensor:
"""
De-quantize the k-cache
"""
assert dv % tile_size == 0
num_tiles = dv // tile_size
num_blocks, block_size, h_k, _ = quant_k_cache.shape
assert h_k == 1
result = torch.empty((num_blocks, block_size, d), dtype=torch.bfloat16, device=quant_k_cache.device)
quant_k_cache = quant_k_cache.view(num_blocks, block_size, -1)
input_nope = quant_k_cache[..., :dv]
input_scale = quant_k_cache[..., dv:dv + num_tiles*4].view(torch.float32)
input_rope = quant_k_cache[..., dv + num_tiles*4:].view(torch.bfloat16)
result[..., dv:] = input_rope
for tile_idx in range(0, num_tiles):
cur_nope = input_nope[..., tile_idx*tile_size:(tile_idx+1)*tile_size].to(torch.float32)
cur_scales = input_scale[..., tile_idx].unsqueeze(-1)
result[..., tile_idx*tile_size:(tile_idx+1)*tile_size] = cur_nope * cur_scales
result = result.view(num_blocks, block_size, 1, d)
return result
import argparse
import math
import random
import dataclasses
from typing import Optional, Tuple, List
import torch
import triton
import quant
import flash_mla
from lib import cdiv, check_is_allclose
@dataclasses.dataclass
class TestParam:
b: int # Batch size
s_q: int # Number of queries for one request
s_k: int # Seq len, or mean seq len if varlen == True
is_varlen: bool
is_causal: bool
is_fp8: bool
topk: Optional[int] = None
test_performance: bool = True
is_all_indices_invalid: bool = False
have_zero_seqlen_k: bool = False
block_size: int = 64
h_q: int = 128 # Number of q heads
h_kv: int = 1 # Number of kv heads
d: int = 576 # Q/K head dim (= dv + RoPE dim)
dv: int = 512 # V head dim
seed: int = 0
def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Generate test data from a given configuration
Return: [cache_seqlens, q, block_table, blocked_k]
Pay attention: This function changes the random seed
"""
random.seed(t.seed)
torch.manual_seed(t.seed)
torch.cuda.manual_seed(t.seed)
torch.backends.cudnn.deterministic = True
assert t.h_q % t.h_kv == 0
cache_seqlens_cpu = torch.full((t.b,), t.s_k, dtype=torch.int32, device='cpu')
if t.is_varlen:
for i in range(t.b):
cache_seqlens_cpu[i] = max(random.normalvariate(t.s_k, t.s_k / 2), t.s_q)
if t.have_zero_seqlen_k:
zeros_mask = torch.randn(t.b, dtype=torch.float32, device='cpu') > 0
cache_seqlens_cpu[zeros_mask] = 0
max_seqlen = cache_seqlens_cpu.max().item()
max_seqlen_pad = cdiv(max_seqlen, 256) * 256
cache_seqlens = cache_seqlens_cpu.cuda()
q = torch.randn(t.b, t.s_q, t.h_q, t.d)
q.clamp_(min=-1.0, max=1.0)
block_table = torch.arange(t.b * max_seqlen_pad // t.block_size, dtype=torch.int32).view(t.b, max_seqlen_pad // t.block_size)
block_table = block_table.view(-1)[torch.randperm(block_table.numel())].view(t.b, -1)
blocked_k = torch.randn(block_table.numel(), t.block_size, t.h_kv, t.d) / 10
blocked_k.clamp_(min=-1.0, max=1.0)
if t.topk is None:
for i in range(t.b):
cur_len = cache_seqlens_cpu[i].item()
cur_num_blocks = cdiv(cur_len, t.block_size)
blocked_k[block_table[i][cur_num_blocks:]] = float("nan")
if cur_len % t.block_size != 0:
blocked_k[block_table[i][cur_num_blocks-1]][cur_len % t.block_size:] = float("nan")
block_table[i][cur_num_blocks:] = 2147480000
return cache_seqlens, q, block_table, blocked_k, None, None
else:
block_table_cpu = block_table.cpu()
abs_indices = torch.empty(t.b, t.s_q, t.topk, dtype=torch.int32, device="cpu")
indices_in_kvcache = torch.empty(t.b, t.s_q, t.topk, dtype=torch.int32, device="cpu")
for i in range(t.b):
# Generate indices
for j in range(t.s_q):
cur_abs_indices = torch.randperm(int(cache_seqlens_cpu[i].item()), device="cpu")[:t.topk]
cur_blocked_indices = block_table_cpu[i, cur_abs_indices//t.block_size]*t.block_size + (cur_abs_indices%t.block_size)
if len(cur_abs_indices) < t.topk:
pad_len = t.topk - len(cur_abs_indices)
cur_abs_indices = torch.cat([cur_abs_indices, torch.full((pad_len,), -1, device='cpu')])
cur_blocked_indices = torch.cat([cur_blocked_indices, torch.full((pad_len,), -1, device='cpu')])
# Mask KV
perm = torch.randperm(t.topk, device='cpu')
cur_abs_indices = cur_abs_indices[perm]
cur_blocked_indices = cur_blocked_indices[perm]
# Fill it with invalid indices if needed
if t.is_all_indices_invalid:
cur_abs_indices.fill_(-1)
cur_blocked_indices.fill_(-1)
abs_indices[i, j, :] = cur_abs_indices
indices_in_kvcache[i, j, :] = cur_blocked_indices
# Mask nonused KV as NaN
all_indices = indices_in_kvcache.flatten().tolist()
all_indices = list(set(all_indices))
if -1 in all_indices:
all_indices.remove(-1)
all_indices = torch.tensor(all_indices, dtype=torch.int32, device='cpu')
blocked_k = blocked_k.view(-1, t.h_kv, t.d)
nonused_indices_mask = torch.ones(blocked_k.size(0)*blocked_k.size(1), dtype=torch.bool, device='cpu')
nonused_indices_mask[all_indices] = False
blocked_k[nonused_indices_mask, :, :] = float("nan")
blocked_k = blocked_k.view(-1, t.block_size, t.h_kv, t.d)
abs_indices = abs_indices.to(q.device)
indices_in_kvcache = indices_in_kvcache.to(q.device)
return cache_seqlens, q, block_table, blocked_k, abs_indices, indices_in_kvcache
def reference_torch(
cache_seqlens: torch.Tensor, # [batch_size]
block_table: torch.Tensor, # [batch_size, ?]
q: torch.Tensor, # [batch_size, s_q, h_q, d]
blocked_k: torch.Tensor, # [?, block_size, h_kv, d]
dv: int,
is_causal: bool,
indices: Optional[torch.Tensor] = None # [batch_size, s_q, topk]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
A reference implementation in PyTorch
"""
def get_topk_attn_mask(s_q: int, s_k: int, indices: torch.Tensor):
mask = torch.zeros(s_q, s_k, dtype=torch.bool)
for i in range(s_q):
cur_indices = indices[i]
valid_indices = cur_indices[cur_indices != -1]
mask[i, valid_indices] = True
return mask
def scaled_dot_product_attention(
batch_idx: int,
query: torch.Tensor, # [h_q, s_q, d]
kv: torch.Tensor, # [h_kv, s_k, d]
dv: int,
is_causal,
indices: Optional[torch.Tensor], # [s_q, topk]
) -> Tuple[torch.Tensor, torch.Tensor]:
h_q = query.size(0)
h_kv = kv.size(0)
s_q = query.shape[-2]
s_k = kv.shape[-2]
query = query.float()
kv = kv.float()
if h_kv != 1:
kv = kv.repeat_interleave(h_q // h_kv, dim=0)
kv[kv != kv] = 0.0
attn_weight = query @ kv.transpose(-2, -1) # [h_q, s_q, s_k]
if (is_causal and query.size(1) > 1) or indices is not None:
mask = torch.ones(s_q, s_k, dtype=torch.bool)
if is_causal:
assert indices is None
mask = mask.tril(diagonal=s_k - s_q)
if indices is not None:
mask &= get_topk_attn_mask(s_q, s_k, indices)
attn_bias = torch.zeros(s_q, s_k, dtype=torch.float)
attn_bias.masked_fill_(mask.logical_not(), float("-inf"))
attn_weight += attn_bias.to(q.dtype)
attn_weight /= math.sqrt(query.size(-1))
lse = attn_weight.logsumexp(dim=-1) # [h_q, s_q]
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
output = attn_weight @ kv[..., :dv] # [h_q, s_q, dv]
# Correct for q tokens which has no attendable k
lonely_q_mask = (lse == float("-inf"))
output[lonely_q_mask.unsqueeze(-1).broadcast_to(h_q, s_q, dv)] = 0.0
lse[lonely_q_mask] = float("+inf")
return output, lse
b, s_q, h_q, d = q.size()
block_size = blocked_k.size(1)
h_kv = blocked_k.size(2)
cache_seqlens_cpu = cache_seqlens.cpu()
out_ref = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
lse_ref = torch.empty(b, h_q, s_q, dtype=torch.float32)
for i in range(b):
cur_len = cache_seqlens_cpu[i].item()
cur_num_blocks = cdiv(cur_len, block_size)
cur_block_indices = block_table[i][0: cur_num_blocks]
cur_kv = blocked_k[cur_block_indices].view(-1, h_kv, d)[:cur_len, ...]
cur_out, cur_lse = scaled_dot_product_attention(
i,
q[i].transpose(0, 1),
cur_kv.transpose(0, 1),
dv,
is_causal,
indices[i] if indices is not None else None
)
out_ref[i] = cur_out.transpose(0, 1)
lse_ref[i] = cur_lse
out_ref = out_ref.to(torch.bfloat16)
return out_ref, lse_ref
@torch.inference_mode()
def test_flash_mla(t: TestParam):
print('-------------------------------')
print(f"Running on {t}...")
# Generating test data
torch.cuda.synchronize()
cache_seqlens, q, block_table, blocked_k, abs_indices, indices_in_kvcache = generate_test_data(t)
if t.is_fp8:
# The quantization error may be too large to be distinguished from wrong kernels
# So we quantize and de-quantize kv-cache here to mitigate quantization error
blocked_k_quantized = quant.quantize_k_cache(blocked_k, t.dv, 128)
blocked_k_dequantized = quant.dequantize_k_cache(blocked_k_quantized)
blocked_k = blocked_k_dequantized
# Get schedule metadata
torch.cuda.synchronize()
tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata(
cache_seqlens,
t.s_q * t.h_q // t.h_kv,
t.h_kv,
t.h_q,
t.is_fp8,
t.topk
)
torch.cuda.synchronize()
def run_flash_mla():
return flash_mla.flash_mla_with_kvcache(
q,
blocked_k if not t.is_fp8 else blocked_k_quantized, # type: ignore
block_table,
cache_seqlens,
t.dv,
tile_scheduler_metadata,
num_splits,
causal=t.is_causal,
is_fp8_kvcache=t.is_fp8,
indices=indices_in_kvcache
)
out_ans, lse_ans = run_flash_mla()
out_ref, lse_ref = reference_torch(cache_seqlens, block_table, q, blocked_k, t.dv, t.is_causal, abs_indices)
assert check_is_allclose("out", out_ans, out_ref, abs_tol=8e-4, rel_tol=2.01/128, cos_diff_tol=5e-6)
assert check_is_allclose("lse", lse_ans, lse_ref, abs_tol=1e-6, rel_tol=8.01/65536)
if t.test_performance:
time_usage: float = triton.testing.do_bench(run_flash_mla)/1000 # type: ignore
mean_attended_seqlens = cache_seqlens.float().mean().item() if t.topk is None else t.topk
compute_volume_flop = t.b*t.h_q*t.s_q*sum([
2*t.d*mean_attended_seqlens, # Q * K^T
2*mean_attended_seqlens*t.dv, # attention * V
])
q_elem_size = torch.bfloat16.itemsize
kv_token_size = 656 if t.is_fp8 else t.d*torch.bfloat16.itemsize
memory_volume_B = t.b*sum([
t.s_q*t.h_q*(t.d*q_elem_size), # Q
(t.s_q if t.topk is not None else 1) * mean_attended_seqlens*t.h_kv*kv_token_size, # K/V
t.s_q*t.h_q*(t.dv*q_elem_size), # Output
])
achieved_tflops = compute_volume_flop / time_usage / 1e12
achieved_gBps = memory_volume_B / time_usage / 1e9
print(f"{time_usage*1000:.3f} ms, {achieved_tflops:.0f} TFLOPS, {achieved_gBps:.0f} GB/s")
def main(torch_dtype):
device = torch.device("cuda:0")
torch.set_default_dtype(torch_dtype)
torch.set_default_device(device)
torch.cuda.set_device(device)
correctness_cases = [
TestParam(b, s_q, s_k, is_varlen, is_causal, is_fp8, topk, test_performance=False)
for b in [1, 2, 6, 64]
for s_q in [1, 2, 4]
for s_k in [20, 140, 4096]
for is_varlen in [False, True]
for is_causal in [False, True]
for (is_fp8, topk) in [
(False, None),
(True, 128),
(True, 2048)
]
if not (is_causal and topk is not None)
]
corner_cases = [
# Cases where all topk indices are invalid
TestParam(128, 2, 4096, is_varlen=True, is_causal=False, is_fp8=True, topk=topk, test_performance=False, is_all_indices_invalid=True)
for topk in [128, 2048, 4096]
] + [
# Cases where some kv cache have zero length
TestParam(128, 2, 4096, is_varlen=True, is_causal=is_causal, is_fp8=is_fp8, topk=topk, test_performance=False, have_zero_seqlen_k=True)
for (is_causal, is_fp8, topk) in [
(False, False, None),
(True, False, None),
(False, True, 128),
(False, True, 2048),
]
]
performance_cases = [
TestParam(128, s_q, s_k, is_varlen=True, is_causal=is_causal, is_fp8=is_fp8, topk=topk, test_performance=True)
for (is_causal, is_fp8, topk) in [
(False, False, None),
(True, False, None),
(False, True, 2048),
]
for s_q in [1, 2]
for s_k in [4096, 8192, 16384, 32768]
]
testcases = correctness_cases + corner_cases + performance_cases
for testcase in testcases:
test_flash_mla(testcase)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--dtype",
type=str,
choices=["bf16", "fp16"],
default="bf16",
help="Data type to use for testing (bf16 or fp16)",
)
args = parser.parse_args()
torch_dtype = torch.bfloat16
if args.dtype == "fp16":
torch_dtype = torch.float16
main(torch_dtype)
import math
import time
from typing import Tuple
import random
import dataclasses
import torch
import triton
from flash_mla import flash_mla_sparse_fwd
from lib import check_is_allclose
@dataclasses.dataclass
class TestParam:
b: int
s_q: int
s_kv: int
topk: int
h_q: int = 128
h_kv: int = 1
d_qk: int = 576
d_v: int = 512
seed: int = 0
check_correctness: bool = True
benchmark: bool = True
@dataclasses.dataclass
class Testcase:
t: TestParam
q: torch.Tensor
kv: torch.Tensor
indices: torch.Tensor
def generate_testcase(t: TestParam) -> Testcase:
torch.manual_seed(t.seed)
torch.cuda.manual_seed(t.seed)
random.seed(t.seed)
q = torch.randn((t.b, t.s_q, t.h_q, t.d_qk), dtype=torch.bfloat16)/10
kv = torch.randn((t.b, t.s_kv, t.h_kv, t.d_qk), dtype=torch.bfloat16)/10
q.clamp_(-10, 10)
kv.clamp_(-10, 10)
indices = torch.full((t.b, t.s_q, t.h_kv, t.topk), t.s_kv, dtype=torch.int32)
for b in range(t.b):
for s in range(t.s_q):
for h in range(t.h_kv):
# NOTE We use the following method to generate indices so that most indices lies within [s_kv-20000, s_kv), which is more realistic for sparse attention
near_mask = torch.randint(0, 32, (min(t.topk, t.s_kv),)) < 31
cur_indices = torch.randperm(t.s_kv)[:t.topk]
cur_indices[near_mask] = torch.randint(max(0, t.s_kv-20000), t.s_kv-1, (near_mask.sum().item(),))
if len(cur_indices) < t.topk:
cur_indices = torch.cat([cur_indices, torch.full((t.topk - len(cur_indices),), 2147480000)])
cur_indices = cur_indices[torch.randperm(t.topk)]
indices[b, s, h] = cur_indices
indices = indices.to(q.device)
return Testcase(
t=t,
q=q,
kv=kv,
indices=indices
)
def get_flop(p: TestParam) -> float:
flop = 2 * sum([
p.h_q * p.d_qk * p.topk,
p.h_q * p.d_v * p.topk
]) * p.b * p.s_q
return flop
def reference_torch(p: TestParam, t: Testcase, sm_scale: float) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def log2sumexp2(a: torch.Tensor, dim: int) -> torch.Tensor:
return torch.logsumexp(a * math.log(2), dim=dim) * math.log2(math.e)
assert p.b == 1
indices = t.indices[0, :, 0, :] # [s_q, topk]
invalid_indices_mask = (indices < 0) | (indices >= p.s_kv)
qs = t.q[0, :, :, :].float() # [s_q, h_q, d_qk]
kvs = t.kv[0, :, 0, :].float() # [s_kv, d_qk]
kvs = torch.index_select(kvs, 0, indices.masked_fill(invalid_indices_mask, 0).flatten()).view(p.s_q, p.topk, p.d_qk) # [s_q, topk, d_qk]
attn_score = qs @ kvs.transpose(1, 2) # [s_q, h_q, topk]
attn_score.masked_fill_(invalid_indices_mask.unsqueeze(1), float('-inf'))
attn_score *= sm_scale * math.log2(math.e)
max_logits = torch.max(attn_score, dim=-1)[0] # [s_q, h_q]
lse = log2sumexp2(attn_score, dim=-1) # [s_q, h_q]
attn_score = torch.exp2(attn_score - lse.unsqueeze(-1)) # [s_q, h_q, topk]
result = attn_score @ kvs[:, :, :p.d_v]
return (max_logits, lse, result)
@torch.inference_mode()
def run_test(p: TestParam) -> bool:
print("================")
print(f"Running on {p}")
torch.cuda.empty_cache()
assert p.b == 1
t = generate_testcase(p)
sm_scale = 1 / math.sqrt(p.d_qk)
torch.cuda.synchronize()
def run_ans():
return flash_mla_sparse_fwd(
t.q.squeeze(0), t.kv.squeeze(0), t.indices.squeeze(0), sm_scale=sm_scale
)
ans_out, ans_max_logits, ans_lse = run_ans()
torch.cuda.synchronize()
if p.benchmark:
flop = get_flop(p)
prefill_ans_time: float = triton.testing.do_bench(run_ans, warmup=10, rep=20)/1000 # type: ignore
prefill_flops = flop/prefill_ans_time/1e12
print(f"Prefill: {prefill_ans_time*1e6:4.0f} us, {prefill_flops:.3f} TFlops")
if p.check_correctness:
torch.cuda.synchronize()
ref_max_logits, ref_lse, ref_out = reference_torch(p, t, sm_scale)
torch.cuda.synchronize()
is_correct = True
is_correct &= check_is_allclose("out", ans_out, ref_out, abs_tol=8e-4, rel_tol=2.01/128, cos_diff_tol=7e-6)
is_correct &= check_is_allclose("max_logits", ans_max_logits, ref_max_logits, abs_tol=1e-6, rel_tol=2.01/65536)
is_correct &= check_is_allclose("lse", ans_lse, ref_lse, abs_tol=1e-6, rel_tol=2.01/65536)
return is_correct
else:
return True
if __name__ == '__main__':
device = torch.device("cuda:0")
torch.set_default_dtype(torch.bfloat16)
torch.set_default_device(device)
torch.cuda.set_device(device)
torch.set_float32_matmul_precision('high')
correctness_cases = [
# Regular shapes
TestParam(1, s_q, s_kv, topk, h_q=128, benchmark=False)
for s_kv, topk in [
# Regular shapes
(128, 128),
(256, 256),
(512, 512),
# Irregular shapes
(592, 128),
(1840, 256),
(1592, 384),
(1521, 512),
# Irregular shapes with OOB TopK
(95, 128),
(153, 256),
(114, 384),
]
for s_q in [
1, 62
]
]
corner_cases = [
# In these cases, some blocks may not have any valid topk indices
TestParam(1, s_q, s_kv, topk, h_q=128, benchmark=False)
for s_kv, topk in [
(32, 2048),
(64, 8192)
]
for s_q in [1, 1024]
]
performance_cases = [
TestParam(1, s_q, s_kv, topk, h_q=128)
for s_q in [4096]
for s_kv in [4096, 8192, 16384, 32768, 49152, 65536, 81920, 98304, 114688, 131072]
for topk in [2048]
]
testcases = correctness_cases + corner_cases + performance_cases
failed_cases = []
for test in testcases:
if test.benchmark:
time.sleep(0.2)
is_correct = run_test(test)
if not is_correct:
failed_cases.append(test)
if len(failed_cases) > 0:
print(f"\033[31m\033[1m{len(failed_cases)} / {len(testcases)} cases failed:\033[0m")
for case in failed_cases:
print(f" {case}")
else:
print(f"\033[32m\033[1mAll {len(testcases)} cases passed!\033[0m")
import argparse
import math
import random
import torch
import triton
from flash_mla import flash_mla_with_kvcache, get_mla_metadata
def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
query = query.float()
key = key.float()
value = value.float()
key = key.repeat_interleave(h_q // h_kv, dim=0)
value = value.repeat_interleave(h_q // h_kv, dim=0)
attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))
if is_causal:
s_q = query.shape[-2]
s_k = key.shape[-2]
attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype)
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
attn_weight += attn_bias
lse = attn_weight.logsumexp(dim=-1)
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
return attn_weight @ value, lse
def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
x, y = x.double(), y.double()
RMSE = ((x - y) * (x - y)).mean().sqrt().item()
cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12)
amax_diff = (x - y).abs().max().item()
# print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}")
assert cos_diff < 1e-5
@torch.inference_mode()
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen):
print(
f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}"
)
cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32)
if varlen:
for i in range(b):
cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), s_q)
total_seqlens = cache_seqlens.sum().item()
mean_seqlens = cache_seqlens.float().mean().int().item()
max_seqlen = cache_seqlens.max().item()
max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256
# print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}")
q = torch.randn(b, s_q, h_q, d)
block_size = 64
block_table = torch.arange(
b * max_seqlen_pad // block_size, dtype=torch.int32
).view(b, max_seqlen_pad // block_size)
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
for i in range(b):
blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = (
float("nan")
)
blocked_v = blocked_k[..., :dv]
tile_scheduler_metadata, num_splits = get_mla_metadata(
cache_seqlens, s_q * h_q // h_kv, h_kv
)
def flash_mla():
return flash_mla_with_kvcache(
q,
blocked_k,
block_table,
cache_seqlens,
dv,
tile_scheduler_metadata,
num_splits,
causal=causal,
)
def ref_mla():
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
for i in range(b):
begin = i * max_seqlen_pad
end = begin + cache_seqlens[i]
O, LSE = scaled_dot_product_attention(
q[i].transpose(0, 1),
blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1),
blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
h_q=h_q,
h_kv=h_kv,
is_causal=causal,
)
out[i] = O.transpose(0, 1)
lse[i] = LSE
return out, lse
out_flash, lse_flash = flash_mla()
out_torch, lse_torch = ref_mla()
cal_diff(out_flash, out_torch, "out")
cal_diff(lse_flash, lse_torch, "lse")
t = triton.testing.do_bench(flash_mla)
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (
torch.finfo(q.dtype).bits // 8
)
print(
f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s"
)
def main(torch_dtype):
device = torch.device("cuda:0")
torch.set_default_dtype(torch_dtype)
torch.set_default_device(device)
torch.cuda.set_device(device)
torch.manual_seed(0)
random.seed(0)
h_kv = 1
d, dv = 576, 512
causal = True
for b in [128]:
for s in [4096, 8192, 16384]:
for h_q in [16, 32, 64, 128]: # TP = 8, 4, 2, 1
for s_q in [1, 2]: # MTP = 1, 2
for varlen in [False, True]:
test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--dtype",
type=str,
choices=["bf16", "fp16"],
default="bf16",
help="Data type to use for testing (bf16 or fp16)",
)
args = parser.parse_args()
torch_dtype = torch.bfloat16
if args.dtype == "fp16":
torch_dtype = torch.float16
main(torch_dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment