Commit 892f7274 authored by zhanghj2's avatar zhanghj2
Browse files

支持kv 软fp8 e5m2

parent 11e445c3
...@@ -4,12 +4,13 @@ ...@@ -4,12 +4,13 @@
#include "sparse_decode.h" #include "sparse_decode.h"
#include "dense_decode.h" #include "dense_decode.h"
#include "dense_decode_qkvfp8.h" #include "dense_decode_qkvfp8.h"
#include "dense_decode_kvfp8.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "FlashMLA"; m.doc() = "FlashMLA";
m.def("sparse_decode_fwd", &sparse_attn_decode_interface); m.def("sparse_decode_fwd", &sparse_attn_decode_interface);
m.def("dense_decode_fwd", &dense_attn_decode_interface); m.def("dense_decode_fwd", &dense_attn_decode_interface);
m.def("dense_decode_fwd_qkvfp8", &dense_attn_decode_qkvfp8_interface); m.def("dense_decode_fwd_qkvfp8", &dense_attn_decode_qkvfp8_interface);
m.def("dense_decode_fwd_kvfp8", &dense_attn_decode_kvfp8_interface);
m.def("sparse_prefill_fwd", &sparse_attn_prefill_interface); m.def("sparse_prefill_fwd", &sparse_attn_prefill_interface);
} }
#pragma once
#include <cutlass/half.h>
#include <cutlass/fast_math.h>
#include "common.h"
#include "params.h"
#include "sm90/decode/dense_kvfp8/splitkv_mla.h"
#include "smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.h"
#include "smxx/decode/combine/combine.h"
static std::tuple<at::Tensor, at::Tensor, std::optional<at::Tensor>, std::optional<at::Tensor>>
dense_attn_decode_kvfp8_interface(
at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size (when is_fp8 is False) or num_blocks x num_heads_k x (page_block_size*656) (when is_fp8 is True)
const int head_size_v,
const at::Tensor &seqlens_k, // batch_size
const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq
const float softmax_scale,
bool is_causal,
std::optional<at::Tensor> &tile_scheduler_metadata, // num_sm_parts x (DecodingSchedMetaSize/4)
std::optional<at::Tensor> &num_splits, // batch_size + 1
std::optional<const at::Tensor> &descale_q,
std::optional<const at::Tensor> &descale_k
) {
// Check arch
Arch arch = Arch();
if (!arch.is_sm90a()) {
TORCH_CHECK(false, "Dense decode MLA is only supported on SM90a architecture");
}
// Check data types
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kBFloat16);
TORCH_CHECK(descale_q.has_value() && descale_k.has_value(), "descale is required when input dtype is fp8");
auto descale_q_ = descale_q.value();
auto descale_k_ = descale_k.value();
TORCH_CHECK(kcache.dtype() == torch::kFloat8_e5m2, "key must have the same dtype torch::kFloat8_e5m2");
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
// Check device
KU_CHECK_DEVICE(q);
KU_CHECK_DEVICE(kcache);
KU_CHECK_DEVICE(seqlens_k);
KU_CHECK_DEVICE(block_table);
KU_CHECK_DEVICE(tile_scheduler_metadata);
KU_CHECK_DEVICE(num_splits);
KU_CHECK_DEVICE(descale_q_);
KU_CHECK_DEVICE(descale_k_);
TORCH_CHECK(descale_q_.stride(-1) == 1);
TORCH_CHECK(descale_k_.stride(-1) == 1);
TORCH_CHECK(descale_q_.dtype() == torch::kFloat);
TORCH_CHECK(descale_k_.dtype() == torch::kFloat);
KU_CHECK_SHAPE(descale_q_, 1);
KU_CHECK_SHAPE(descale_k_, 1);
// Check layout
TORCH_CHECK(q.stride(-1) == 1, "q must have contiguous last dimension");
TORCH_CHECK(kcache.stride(-1) == 1, "kcache must have contiguous last dimension");
KU_CHECK_CONTIGUOUS(seqlens_k);
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
KU_CHECK_CONTIGUOUS(tile_scheduler_metadata);
KU_CHECK_CONTIGUOUS(num_splits);
const auto sizes = q.sizes();
const int batch_size = sizes[0];
const int seqlen_q_ori = sizes[1];
const int num_heads_q = sizes[2];
const int head_size_k = sizes[3];
TORCH_CHECK(head_size_k == 576 || head_size_k == 512, "Only head_size_k == 576 or 512 is supported");
TORCH_CHECK(head_size_v == 512, "Only head_size_v == 576 is supported");
const int max_num_blocks_per_seq = block_table.size(1);
const int num_blocks = kcache.size(0);
const int page_block_size = kcache.size(1);
const int num_heads_k = kcache.size(2);
TORCH_CHECK(page_block_size == 64, "Currently page_block_size must be 64");
TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(num_heads_q % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (seqlen_q_ori == 1) { is_causal = false; }
const int num_q_heads_per_hk = num_heads_q / num_heads_k;
const int q_seq_per_hk = seqlen_q_ori * num_q_heads_per_hk;
const int num_heads = num_heads_k;
q = q.view({batch_size, seqlen_q_ori, num_heads_k, num_q_heads_per_hk, head_size_k}).transpose(2, 3)
.reshape({batch_size, q_seq_per_hk, num_heads, head_size_k});
int num_sm_parts = std::max(arch.num_sms / num_heads_k / cutlass::ceil_div(seqlen_q_ori*num_heads_q/num_heads_k, 16) * 2, 1);
KU_CHECK_SHAPE(q, batch_size, q_seq_per_hk, num_heads, head_size_k);
KU_CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k);
KU_CHECK_SHAPE(seqlens_k, batch_size);
KU_CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
KU_CHECK_SHAPE(tile_scheduler_metadata, num_sm_parts, DecodingSchedMetaSize/sizeof(int));
KU_CHECK_SHAPE(num_splits, batch_size+1);
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q.options();
at::Tensor out = torch::empty({batch_size, num_heads, q_seq_per_hk, head_size_v}, opts.dtype(torch::kBFloat16));
at::Tensor lse = torch::empty({batch_size, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat));
KU_CHECK_CONTIGUOUS(out);
KU_CHECK_CONTIGUOUS(lse);
if (!tile_scheduler_metadata.has_value()) {
tile_scheduler_metadata = torch::empty({num_sm_parts, sizeof(DecodingSchedMeta)/4}, opts.dtype(torch::kInt32));
num_splits = torch::empty({batch_size+1}, opts.dtype(torch::kInt32));
KU_CHECK_CONTIGUOUS(tile_scheduler_metadata);
KU_CHECK_CONTIGUOUS(num_splits);
GetDecodeSchedMetaParams get_sched_meta_params = {
batch_size, seqlen_q_ori,
64,
5,
-1, -1,
nullptr, nullptr,
seqlens_k.data_ptr<int>(),
(DecodingSchedMeta*)tile_scheduler_metadata->data_ptr(),
num_splits->data_ptr<int>(),
num_sm_parts,
at::cuda::getCurrentCUDAStream().stream()
};
smxx::decode::run_get_decoding_sched_meta_kernel(get_sched_meta_params);
} else {
KU_CHECK_DTYPE(tile_scheduler_metadata, torch::kInt32);
KU_CHECK_DTYPE(num_splits, torch::kInt32);
KU_CHECK_DEVICE(tile_scheduler_metadata);
KU_CHECK_DEVICE(num_splits);
KU_CHECK_CONTIGUOUS(tile_scheduler_metadata);
KU_CHECK_CONTIGUOUS(num_splits);
KU_CHECK_SHAPE(tile_scheduler_metadata, num_sm_parts, sizeof(DecodingSchedMeta)/sizeof(int));
KU_CHECK_SHAPE(num_splits, batch_size+1);
}
// Set the sizes
DenseAttnDecodeParams_fp8 params;
params.b = batch_size;
params.s_q = seqlen_q_ori;
params.q_seq_per_hk = q_seq_per_hk;
params.seqlens_k_ptr = seqlens_k.data_ptr<int>();
params.h_q = num_heads_q;
params.h_k = num_heads_k;
params.num_blocks = num_blocks;
params.q_head_per_hk = num_q_heads_per_hk;
params.is_causal = is_causal;
params.d = head_size_k;
params.d_v = head_size_v;
params.scale_softmax = softmax_scale;
params.scale_softmax_log2 = float(softmax_scale * M_LOG2E);
// Set the pointers and strides.
params.q_ptr = q.data_ptr();
params.k_ptr = kcache.data_ptr();
params.o_ptr = out.data_ptr();
params.softmax_lse_ptr = lse.data_ptr<float>();
// All stride are in elements, not bytes.
params.q_batch_stride = q.stride(0);
params.k_batch_stride = kcache.stride(0);
params.o_batch_stride = out.stride(0);
params.q_row_stride = q.stride(1);
params.k_row_stride = kcache.stride(1);
params.o_row_stride = out.stride(2);
params.q_head_stride = q.stride(2);
params.k_head_stride = kcache.stride(2);
params.o_head_stride = out.stride(1);
params.block_table = block_table.data_ptr<int>();
params.block_table_batch_stride = block_table.stride(0);
params.page_block_size = page_block_size;
params.tile_scheduler_metadata_ptr = (DecodingSchedMeta*)tile_scheduler_metadata->data_ptr();
params.num_sm_parts = num_sm_parts;
params.num_splits_ptr = num_splits->data_ptr<int>();
params.descale_q_ptr = descale_q_.data_ptr<float>();
params.descale_k_ptr = descale_k_.data_ptr<float>();;
const int total_num_splits = batch_size + params.num_sm_parts;
at::Tensor lse_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat));
at::Tensor out_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk, head_size_v}, opts.dtype(at::kFloat));
KU_CHECK_CONTIGUOUS(lse_accum);
KU_CHECK_CONTIGUOUS(out_accum);
params.total_num_splits = total_num_splits;
params.softmax_lseaccum_ptr = lse_accum.data_ptr<float>();
params.oaccum_ptr = out_accum.data_ptr<float>();
params.stream = at::cuda::getCurrentCUDAStream().stream();
if (q_dtype == torch::kBFloat16) {
sm90::run_flash_splitkv_mla_kvfp8_kernel<cutlass::bfloat16_t>(params);
} else {
TORCH_CHECK(false, "Unsupported dtype for dense MLA on SM90");
}
CombineParams combine_params = {
batch_size, seqlen_q_ori,
num_heads_q, head_size_v,
params.softmax_lse_ptr,
params.o_ptr,
num_heads*q_seq_per_hk, num_heads_q,
num_heads_q*seqlen_q_ori*head_size_v, num_heads_q*head_size_v, head_size_v,
params.softmax_lseaccum_ptr,
params.oaccum_ptr,
num_heads*q_seq_per_hk, num_heads_q,
num_heads_q*seqlen_q_ori*head_size_v, num_heads_q*head_size_v, head_size_v,
params.tile_scheduler_metadata_ptr,
params.num_splits_ptr,
params.num_sm_parts,
nullptr,
at::cuda::getCurrentCUDAStream().stream()
};
smxx::decode::run_flash_mla_combine_kernel<cutlass::bfloat16_t>(combine_params);
out = out.view({batch_size, num_heads_k, seqlen_q_ori, num_q_heads_per_hk, head_size_v}).transpose(1, 2)
.reshape({batch_size, seqlen_q_ori, num_heads_q, head_size_v});
lse = lse.view({batch_size, num_heads_k, seqlen_q_ori, num_q_heads_per_hk}).transpose(2, 3)
.reshape({batch_size, num_heads_q, seqlen_q_ori});
return {out, lse, tile_scheduler_metadata, num_splits};
}
...@@ -183,3 +183,10 @@ inline constexpr bool is_decode_v = std::bool_constant<FWD_MODE == SparseAttnFwd ...@@ -183,3 +183,10 @@ inline constexpr bool is_decode_v = std::bool_constant<FWD_MODE == SparseAttnFwd
template<SparseAttnFwdMode FWD_MODE> template<SparseAttnFwdMode FWD_MODE>
using SparseFwdArgT = std::conditional_t<is_decode_v<FWD_MODE>, SparseAttnDecodeParams, SparseAttnFwdParams>; using SparseFwdArgT = std::conditional_t<is_decode_v<FWD_MODE>, SparseAttnDecodeParams, SparseAttnFwdParams>;
enum class Fp8KVCacheDataType {
kAuto = 0,
kFp8E4M3 = 1,
kFp8E5M2 = 2,
kInt8 = 3,
};
#pragma once
namespace Config {
static constexpr int BLOCK_SIZE_M = 16;
static constexpr int PAGE_BLOCK_SIZE = 64;
static constexpr int HEAD_DIM_K = 576;
static constexpr int HEAD_DIM_V = 512;
}
#include "../splitkv_mla.cuh"
#include "../splitkv_mla.h"
namespace sm90 {
template void run_flash_splitkv_mla_kvfp8_kernel<cutlass::bfloat16_t>(DenseAttnDecodeParams_fp8 &params);
}
This diff is collapsed.
#pragma once
#include "params.h"
namespace sm90 {
template<typename InputT>
void run_flash_splitkv_mla_kvfp8_kernel(DenseAttnDecodeParams_fp8 &params);
}
#pragma once
#include <cute/tensor.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_types.h>
#include <cutlass/barrier.h>
#include "config.h"
using namespace cute;
template<typename InputT_, bool Is_causal_>
struct Traits {
using InputT = InputT_;
static constexpr bool Is_causal = Is_causal_;
static constexpr int BLOCK_SIZE_M = Config::BLOCK_SIZE_M;
static constexpr int PAGE_BLOCK_SIZE = Config::PAGE_BLOCK_SIZE;
static constexpr int HEAD_DIM_K = Config::HEAD_DIM_K;
static constexpr int HEAD_DIM_V = Config::HEAD_DIM_V;
static constexpr int NUM_THREADS = 256;
static_assert(std::is_same_v<InputT, cutlass::bfloat16_t>);
static constexpr int kBlockM = BLOCK_SIZE_M;
static constexpr int kBlockN = PAGE_BLOCK_SIZE;
static constexpr int kHeadDim = HEAD_DIM_K;
static constexpr int kHeadDimV = HEAD_DIM_V;
static constexpr int kNWarps = 4;
static constexpr int kSwizzle = 3;
using Element = InputT;
using elem_type = Element;
using ElementAccum = float;
using ValLayoutMNK = Layout<Shape<_1, _1, _1>>;
using SmemLayoutRow = Layout<Shape<_128>, Stride<_1>>;
using SmemLayoutAtomK = decltype(composition(
Swizzle<3, 3, 3>{},
Layout<Shape<Int<8>, Int<32>>, Stride<Int<32>, _1>>{}));
using SmemLayoutK = decltype(tile_to_shape(
SmemLayoutAtomK{},
Shape<Int<kBlockN>, Int<8 * 32>>{}));
using SmemLayoutAtomV = SmemLayoutAtomK;
using SmemLayoutV = decltype(tile_to_shape(
SmemLayoutAtomV{},
Shape<Int<kBlockN>, Int<8 * 32>>{}));
using SmemLayoutAtomP = Layout<Shape<Int<4*16*16>>, Stride<Int<1>>>;
using SmemLayoutP = decltype(tile_to_shape(
SmemLayoutAtomP{},
Shape<Int<4*16*16>>{}));
using SmemLayoutVtransposed = decltype(
composition(SmemLayoutV{}, make_layout(Shape<Int<8 * 32>, Int<kBlockN>>{}, GenRowMajor{})));
using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{}));
using SmemLayoutAtomV_fp8 = Layout<Shape<Int<kBlockN>, Int<512>>, Stride<_512, _1>>;
using SmemLayoutV_fp8 = decltype(tile_to_shape(
SmemLayoutAtomV_fp8{},
Shape<Int<kBlockN>, Int<512>>{}));
using SmemLayoutVtransposed_fp8 = decltype(
composition(SmemLayoutV_fp8{}, make_layout(Shape<Int<512>, Int<kBlockN>>{}, GenRowMajor{})));
using SmemLayoutVtransposedNoSwizzle_fp8 = decltype(get_nonswizzle_portion(SmemLayoutVtransposed_fp8{}));
using SmemLayoutAtomQ = decltype(composition(
Swizzle<kSwizzle, 3, 3>{},
Layout<Shape<Int<8>, Int<64>>, Stride<Int<64>, _1>>{}));
using SmemLayoutQ = decltype(tile_to_shape(
SmemLayoutAtomQ{},
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
using MMA_Atom_Arch_16_16_32 = std::conditional_t<
std::is_same_v<elem_type, cutlass::half_t>,
MMA_Atom<GFX928_16x16x32_F32F16F16F32_NT>,
MMA_Atom<GFX928_16x16x32_F32BF16BF16F32_NT>
>;
using TiledMma_16_16_32 = TiledMMA<
MMA_Atom_Arch_16_16_32,
Layout<Shape<_1, Int<4>, _1>>, // 1x4x1 or 1x8x1 thread group
ValLayoutMNK>;
using MMA_Atom_Arch = std::conditional_t<
std::is_same_v<elem_type, cutlass::half_t>,
MMA_Atom<GFX928_16x16x64_F32F16F16F32_NT>,
MMA_Atom<GFX928_16x16x64_F32BF16BF16F32_NT>
>;
using MMA_Atom_Arch_16_32_16 = std::conditional_t<
std::is_same_v<elem_type, cutlass::half_t>,
MMA_Atom<GFX928_16x32x16_F32F16F16F32_NT>,
MMA_Atom<GFX928_16x32x16_F32BF16BF16F32_NT>
>;
using TiledMma_O_16_32_16 = TiledMMA<
MMA_Atom_Arch_16_32_16,
Layout<Shape<_1, Int<kNWarps>, _1>>, // 1x4x1 or 1x8x1 thread group
ValLayoutMNK>;
using TiledMma = TiledMMA<
MMA_Atom_Arch,
Layout<Shape<_1, Int<kNWarps>, _1>>, // 1x4x1 or 1x8x1 thread group
ValLayoutMNK>;
using MMA_Atom_Arch_int8 = std::conditional_t<
std::is_same_v<elem_type, cutlass::half_t>,
MMA_Atom<GFX928_16x16x64_F32F16uint8F32_NT>,
MMA_Atom<GFX928_16x16x64_F32BF16int8F32_NT>
>;
using MMA_Atom_Arch_16x64 = std::conditional_t<
std::is_same_v<elem_type, cutlass::half_t>,
MMA_Atom<GFX928_16x64x16_FP8_F32F16F16F32_NT>,
MMA_Atom<GFX928_16x64x16_FP8_F32BF16BF16F32_NT>
>;
using TiledMma_O = TiledMMA<
MMA_Atom_Arch_16x64,
Layout<Shape<_1, Int<kNWarps>, _1>>, // 1x4x1 or 1x8x1 thread group
ValLayoutMNK>;
using TiledMma_int8 = TiledMMA<
MMA_Atom_Arch_int8,
Layout<Shape<_1, Int<kNWarps>, _1>>, // 1x4x1 or 1x8x1 thread group
ValLayoutMNK>;
using GmemLayoutAtomQ = Layout<Shape <_32, _8>,
Stride< _8, _1>>;
using GmemTiledCopyQ = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
GmemLayoutAtomQ{},
Layout<Shape<_1, _8>>{}));
struct SharedMemoryPlan {
union {
struct {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v; // Double buffer
cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutRow>> smem_row_sum;
cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutRow>> smem_row_max;
};
struct {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
};
};
};
};
This diff is collapsed.
...@@ -4,12 +4,14 @@ from flash_mla.flash_mla_interface import ( ...@@ -4,12 +4,14 @@ from flash_mla.flash_mla_interface import (
get_mla_metadata, get_mla_metadata,
flash_mla_with_kvcache, flash_mla_with_kvcache,
flash_mla_sparse_fwd, flash_mla_sparse_fwd,
flash_mla_with_kvcache_qkvfp8 flash_mla_with_kvcache_qkvfp8,
flash_mla_with_kvcache_kvfp8
) )
__all__ = [ __all__ = [
"get_mla_metadata", "get_mla_metadata",
"flash_mla_with_kvcache", "flash_mla_with_kvcache",
"flash_mla_sparse_fwd", "flash_mla_sparse_fwd",
"flash_mla_with_kvcache_qkvfp8" "flash_mla_with_kvcache_qkvfp8",
"flash_mla_with_kvcache_kvfp8"
] ]
...@@ -228,8 +228,6 @@ def flash_mla_with_kvcache_qkvfp8( ...@@ -228,8 +228,6 @@ def flash_mla_with_kvcache_qkvfp8(
Arguments: Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim). q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
Different modes (including fp8/bf16, and sparsity) has different KV cache layouts. See comments below for details.
The KV cache must be contiguously valid for sparse attention on sm100. Here "contiguously valid" means that every byte, from the very beginning of the KV cache, till the last byte in the KV cache, is valid memory address to visit (i.e. won't IMA). In other words, the KV cache could be a slice of a larger array, but cannot be a list of disjoint memory blocks.
block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used. block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used.
cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used. cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used.
head_dim_v: Head_dim of v. Must be 512 head_dim_v: Head_dim of v. Must be 512
...@@ -289,4 +287,82 @@ def flash_mla_with_kvcache_qkvfp8( ...@@ -289,4 +287,82 @@ def flash_mla_with_kvcache_qkvfp8(
) )
sched_meta.tile_scheduler_metadata = new_tile_scheduler_metadata sched_meta.tile_scheduler_metadata = new_tile_scheduler_metadata
sched_meta.num_splits = new_num_splits sched_meta.num_splits = new_num_splits
return (out, lse)
def flash_mla_with_kvcache_kvfp8(
q: torch.Tensor,
k_cache: torch.Tensor,
block_table: Optional[torch.Tensor],
cache_seqlens: Optional[torch.Tensor],
head_dim_v: int,
tile_scheduler_metadata: FlashMLASchedMeta,
num_splits: None = None,
softmax_scale: Optional[float] = None,
causal: bool = False,
descale_q: Optional[torch.Tensor] = None,
descale_k: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used.
cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used.
head_dim_v: Head_dim of v. Must be 512
sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same.
num_splits_placeholder: must be "None" (to be compatible with the old interface).
softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k).
causal: bool. Whether to apply causal attention mask. Only valid for dense attention
descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
Return:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v), only support bf16 output
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
sched_meta = tile_scheduler_metadata
assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta"
assert num_splits is None, "num_splits must be None"
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
if not sched_meta.have_initialized:
# Initialize the tile scheduler metadata during the first invocation.
sched_meta.have_initialized = True
sched_meta.config = FlashMLASchedMeta.Config(
q.shape[0],
q.shape[1],
q.shape[2],
k_cache.shape[1],
k_cache.shape[2],
causal,
False,
0,
0,
0
)
else:
# Check whether the input arguments are consistent with sched_meta
helper_msg = " Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta."
assert sched_meta.config is not None
assert sched_meta.config.b == q.shape[0], "sched_meta.config.b must be equal to batch_size." + helper_msg
assert sched_meta.config.s_q == q.shape[1], "sched_meta.config.s_q must be equal to seq_len_q." + helper_msg
assert sched_meta.config.h_q == q.shape[2], "sched_meta.config.h_q must be equal to num_heads_q." + helper_msg
assert sched_meta.config.page_block_size == k_cache.shape[1], "sched_meta.config.page_block_size must be equal to page_block_size." + helper_msg
assert sched_meta.config.h_k == k_cache.shape[2], "sched_meta.config.h_k must be equal to num_heads_k." + helper_msg
assert sched_meta.config.causal == causal, "sched_meta.config.causal must be equal to causal." + helper_msg
assert sched_meta.config.is_fp8_kvcache == False, "sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache." + helper_msg
# Dense attention
assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used."
out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.dense_decode_fwd_kvfp8(
q, k_cache, head_dim_v,
cache_seqlens, block_table,
softmax_scale, causal,
sched_meta.tile_scheduler_metadata, sched_meta.num_splits,
descale_q, descale_k
)
sched_meta.tile_scheduler_metadata = new_tile_scheduler_metadata
sched_meta.num_splits = new_num_splits
return (out, lse) return (out, lse)
\ No newline at end of file
...@@ -61,6 +61,9 @@ ext_modules.append( ...@@ -61,6 +61,9 @@ ext_modules.append(
## sm90 dense qkvfp8 decode ## sm90 dense qkvfp8 decode
"csrc/sm90/decode/dense_qkvfp8/instantiations/fp8e4m3.cu", "csrc/sm90/decode/dense_qkvfp8/instantiations/fp8e4m3.cu",
## sm90 dense kvfp8 decode
"csrc/sm90/decode/dense_kvfp8/instantiations/kvfp8.cu",
# # sm90 sparse decode # # sm90 sparse decode
"csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h16.cu", "csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h16.cu",
...@@ -97,7 +100,7 @@ ext_modules.append( ...@@ -97,7 +100,7 @@ ext_modules.append(
Path(this_dir) / "csrc", Path(this_dir) / "csrc",
Path(this_dir) / "csrc" / "kerutils" / "include", # TODO Remove me Path(this_dir) / "csrc" / "kerutils" / "include", # TODO Remove me
Path(this_dir) / "csrc" / "sm90", Path(this_dir) / "csrc" / "sm90",
Path(this_dir) / "csrc" / "cutlass" / "include" Path(this_dir) / "csrc" / "cutlass" / "include",
], ],
) )
) )
......
import argparse
import math
import random
import torch
import triton
from flash_mla import flash_mla_with_kvcache_kvfp8, get_mla_metadata
torch.set_printoptions(precision=4, profile="default", sci_mode=False)
def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False, k_scale=1.0):
query = query.float()
key = key.float() * k_scale
value = value.float() * k_scale
key = key.repeat_interleave(h_q // h_kv, dim=0)
value = value.repeat_interleave(h_q // h_kv, dim=0)
# tmp = query @ key.transpose(-2, -1)
# print("tmp ", tmp.shape, tmp[0, 0, :16])
# print("tmp ", tmp.shape, tmp[0, 0, 16:32])
attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))
if is_causal:
s_q = query.shape[-2]
s_k = key.shape[-2]
attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype)
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
attn_weight += attn_bias
lse = attn_weight.logsumexp(dim=-1)
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
return attn_weight @ value, lse
def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
torch_dtype = x.dtype
x, y = x.double(), y.double()
RMSE = ((x - y) * (x - y)).mean().sqrt().item()
cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12)
amax_diff = (x - y).abs().max().item()
print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}")
assert cos_diff < (1e-4 if torch_dtype == torch.bfloat16 else 1e-5)
@torch.inference_mode()
def test_flash_mla_fp8_e5m2(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, is_prof=False):
print(
f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}"
)
cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32)
if varlen:
for i in range(b):
cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), s_q)
total_seqlens = cache_seqlens.sum().item()
mean_seqlens = cache_seqlens.float().mean().int().item()
max_seqlen = cache_seqlens.max().item()
max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256
print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}, {max_seqlen_pad=}")
q = torch.randn(b, s_q, h_q, d)
# q = torch.ones(b, s_q, h_q, d)
block_size = 64
block_table = torch.arange(
b * max_seqlen_pad // block_size, dtype=torch.int32
).view(b, max_seqlen_pad // block_size)
# blocked_k = torch.randint(low=0, high=4, size = (block_table.numel(), block_size, h_kv, d), dtype = torch.int8)
# blocked_k = torch.ones(size = (block_table.numel(), block_size, h_kv, d), dtype = torch.int8)
blocked_k = (torch.randn(block_table.numel(), block_size, h_kv, d))
# blocked_k = (torch.zeros(block_table.numel(), block_size, h_kv, d))
# blocked_k[:, 1:, :, :] = 0
# blocked_k[:, :, :, 1:] = 0
# blocked_k[0, 0:16, 0, 0] = 0
# blocked_k[0, 32:, 0, 0] = 0
# blocked_k[0, 0, 0, 1] = 2
# blocked_k[0, 0, 0, 2] = 3
# blocked_k[0, 0, 0, 3] = 4
# print(" blocked_k ", blocked_k[0, 0, 0, :])
blocked_k = blocked_k.to(torch.float8_e5m2)
# blocked_k = (torch.ones(block_table.numel(), block_size, h_kv, d)).to(torch.float8_e5m2)
# blocked_k[0, 0, 0, 56] = 1
# blocked_k[0, 1, 0, 8] = 2
# blocked_k[0, 2, 0, 8] = 5
# blocked_k[0, 3, 0, 8] = 4
# for i in range(64):
# for j in range(64):
# blocked_k[0, i, 0, j] = j
# blocked_k[0, i, 0, j] = (i * 50 + j) % 128
# print("blocked_k ", blocked_k[0, 0, 0, 0:10])
# for i in range(b):
# blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = (
# -128
# )
blocked_v = blocked_k[..., :dv]
tile_scheduler_metadata, num_splits = get_mla_metadata()
# print("q:", q.shape, q.dtype, q)
# print("cache_seqlens:", cache_seqlens.shape, cache_seqlens)
# print("block_table:", block_table.shape, block_table)
# print("blocked_k:", blocked_k.shape, blocked_k[0])
# print("blocked_v:", blocked_v.shape)
# torch.set_printoptions(precision=4, profile="full", sci_mode=False)
# print("tile_scheduler_metadata:", tile_scheduler_metadata.shape, tile_scheduler_metadata)
# torch.set_printoptions(precision=4, profile="default", sci_mode=False)
# print("num_splits:", num_splits.shape, num_splits)
# k_scale = torch.tensor(1.0).to(torch.float32).to("cuda:0")
# k_scale = torch.tensor(2.1).to(torch.float32).to("cuda:0")
k_scale = torch.tensor(1.0).to(torch.float32).to("cuda:0")
descale_q = torch.ones((1), dtype=torch.float32)
descale_k = torch.ones((1), dtype=torch.float32)
def flash_mla():
return flash_mla_with_kvcache_kvfp8(
q,
blocked_k,
block_table,
cache_seqlens,
dv,
tile_scheduler_metadata,
num_splits,
causal=causal,
descale_q = descale_q,
descale_k = descale_k,
)
def ref_mla():
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
for i in range(b):
begin = i * max_seqlen_pad
end = begin + cache_seqlens[i]
O, LSE = scaled_dot_product_attention(
q[i].transpose(0, 1),
blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1),
blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
h_q=h_q,
h_kv=h_kv,
is_causal=causal,
k_scale = k_scale
)
out[i] = O.transpose(0, 1)
lse[i] = LSE
return out, lse
out_flash, lse_flash = flash_mla()
out_torch, lse_torch = ref_mla()
# print("out_flash ", out_flash[0, 0, 0, 0:14])
# print("out_torch ", out_torch[0, 0, 0, 0:14])
# print("lse_flash ", lse_flash[0, 0, 0:10])
# print("lse_torch ", lse_torch[0, 0, 0:10])
print("out max_diff ", (out_flash - out_torch).abs().max())
print("lse max_diff ", (lse_flash - lse_torch).abs().max())
# print(" out ", torch.nonzero((out_flash - out_torch).abs() > 0.1))
# print(" out_torch", out_torch)
cal_diff(lse_flash, lse_torch, "lse")
cal_diff(out_flash, out_torch, "out")
t = triton.testing.do_bench(flash_mla)
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = ( b * s_q * h_q * d + b * s_q * h_q * dv) * (
torch.finfo(q.dtype).bits // 8
) + total_seqlens * h_kv * d
print(
f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s"
)
def main(torch_dtype, is_prof=False):
device = torch.device("cuda:0")
torch.set_default_dtype(torch_dtype)
torch.set_default_device(device)
torch.cuda.set_device(device)
torch.manual_seed(0)
random.seed(0)
'''
h_kv = 1
d, dv = 576, 512
causal = True
for b in [128]:
for s in [4096, 8192]:
for h_q in [16, 32, 64, 128]: # TP = 8, 4, 2, 1
for s_q in [1, 2]: # MTP = 1, 2
for varlen in [False, True]:
test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen)
# b, s_q, s, h_q, h_kv, d, dv, causal, varlen'''
# test_flash_mla( 1, 1, 64, 16, 1, 576, 512, True, False, is_prof=is_prof)
# test_flash_mla_fp8( 1, 1, 1000, 1, 1, 576, 512, True, False, is_prof=is_prof)
# test_flash_mla_fp8( 1, 1, 4096, 8, 1, 576, 512, True, False, is_prof=is_prof)
# test_flash_mla_fp8(32, 1, 4096, 16, 1, 576, 512, False, False, is_prof=is_prof)
# '''
h_kv = 1
d, dv = 576, 512
causal = True
# for b in [1, 32]:
# for s in [200, 1002, 2002, 1024, 2000, 4000, 32768, 65536]:
# for h_q in [4, 16, 32, 64]:
# for s_q in [1, 2]: # MTP = 1, 2
# for varlen in [True]:
# test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen)
# for b in [32]:
# for s in [16384, 32768, 65536*2]:
# for h_q in [16]:
# for s_q in [1]: # MTP = 1, 2
# for varlen in [False]:
# # for varlen in [True]:
# test_flash_mla_fp8_e5m2(b, s_q, s, h_q, h_kv, d, dv, causal, varlen)
# test_flash_mla_fp8_e4m3(b, s_q, s, h_q, h_kv, d, dv, causal, varlen)
# '''
for b in [3, 6, 9, 12, 15, 18, 21, 24]:
for s in [111, 112, 123, 1234, 432, 4325, 4000, 8192, 11111]:
for h_q in [16]:
for s_q in [1, 2, 3]: # MTP = 1, 2
for varlen in [False, True]:
test_flash_mla_fp8_e5m2(b, s_q, s, h_q, h_kv, d, dv, causal, varlen,True)
for b in [3, 6, 9, 12, 15, 18, 21, 24, 32, 64, 128, 256]:
for s in [4000]:
for h_q in [16]:
for s_q in [1]: # MTP = 1, 2
for varlen in [False]:
test_flash_mla_fp8_e5m2(b, s_q, s, h_q, h_kv, d, dv, causal, varlen)
# for b in [1]:
# for s in [64]:
# for h_q in [16]:
# for s_q in [1]: # MTP = 1, 2
# for varlen in [False]:
# test_flash_mla_fp8_e5m2(b, s_q, s, h_q, h_kv, d, dv, causal, varlen)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--dtype",
type=str,
choices=["bf16", "fp16"],
default="bf16",
help="Data type to use for testing (bf16 or fp16)",
)
parser.add_argument('--prof', default=False, action='store_true', help='prof or not')
args = parser.parse_args()
torch_dtype = torch.bfloat16
if args.dtype == "fp16":
torch_dtype = torch.float16
main(torch_dtype, args.prof)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment