Commit 26d2ab19 authored by zhanghj2's avatar zhanghj2
Browse files

支持nmz qkvfp8

parent 3eb7071c
#pragma once
#include <cutlass/half.h>
#include <cutlass/fast_math.h>
#include "common.h"
#include "params.h"
// #include "sm90/decode/dense/splitkv_mla.h"
#include "smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.h"
#include "smxx/decode/combine/combine.h"
static std::tuple<at::Tensor, at::Tensor, std::optional<at::Tensor>, std::optional<at::Tensor>>
dense_attn_decode_kvfp8_interface(
at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size (when is_fp8 is False) or num_blocks x num_heads_k x (page_block_size*656) (when is_fp8 is True)
const int head_size_v,
const at::Tensor &seqlens_k, // batch_size
const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq
const float softmax_scale,
bool is_causal,
std::optional<at::Tensor> &tile_scheduler_metadata, // num_sm_parts x (DecodingSchedMetaSize/4)
std::optional<at::Tensor> &num_splits // batch_size + 1
) {
// Check arch
Arch arch = Arch();
if (!arch.is_sm90a()) {
TORCH_CHECK(false, "Dense decode MLA is only supported on SM90a architecture");
}
// Check data types
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kHalf);
TORCH_CHECK(kcache.dtype() == torch::kFloat8_e5m2fn, "key must have torch::kFloat8_e5m2fn");
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
// Check device
KU_CHECK_DEVICE(q);
KU_CHECK_DEVICE(kcache);
KU_CHECK_DEVICE(seqlens_k);
KU_CHECK_DEVICE(block_table);
KU_CHECK_DEVICE(tile_scheduler_metadata);
KU_CHECK_DEVICE(num_splits);
// Check layout
TORCH_CHECK(q.stride(-1) == 1, "q must have contiguous last dimension");
TORCH_CHECK(kcache.stride(-1) == 1, "kcache must have contiguous last dimension");
KU_CHECK_CONTIGUOUS(seqlens_k);
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
KU_CHECK_CONTIGUOUS(tile_scheduler_metadata);
KU_CHECK_CONTIGUOUS(num_splits);
const auto sizes = q.sizes();
const int batch_size = sizes[0];
const int seqlen_q_ori = sizes[1];
const int num_heads_q = sizes[2];
const int head_size_k = sizes[3];
TORCH_CHECK(head_size_k == 576, "Only head_size_k == 576 or 512 is supported");
TORCH_CHECK(head_size_v == 512, "Only head_size_v == 576 is supported");
const int max_num_blocks_per_seq = block_table.size(1);
const int num_blocks = kcache.size(0);
const int page_block_size = kcache.size(1);
const int num_heads_k = kcache.size(2);
TORCH_CHECK(page_block_size == 64, "Currently page_block_size must be 64");
TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(num_heads_q % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (seqlen_q_ori == 1) { is_causal = false; }
const int num_q_heads_per_hk = num_heads_q / num_heads_k;
const int q_seq_per_hk = seqlen_q_ori * num_q_heads_per_hk;
const int num_heads = num_heads_k;
q = q.view({batch_size, seqlen_q_ori, num_heads_k, num_q_heads_per_hk, head_size_k}).transpose(2, 3)
.reshape({batch_size, q_seq_per_hk, num_heads, head_size_k});
int num_sm_parts = std::max(arch.num_sms / num_heads_k / cutlass::ceil_div(seqlen_q_ori*num_heads_q/num_heads_k, 16), 1);
KU_CHECK_SHAPE(q, batch_size, q_seq_per_hk, num_heads, head_size_k);
KU_CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k);
KU_CHECK_SHAPE(seqlens_k, batch_size);
KU_CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
KU_CHECK_SHAPE(tile_scheduler_metadata, num_sm_parts, DecodingSchedMetaSize/sizeof(int));
KU_CHECK_SHAPE(num_splits, batch_size+1);
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q.options();
at::Tensor out = torch::empty({batch_size, num_heads, q_seq_per_hk, head_size_v}, opts);
at::Tensor lse = torch::empty({batch_size, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat));
KU_CHECK_CONTIGUOUS(out);
KU_CHECK_CONTIGUOUS(lse);
if (!tile_scheduler_metadata.has_value()) {
tile_scheduler_metadata = torch::empty({num_sm_parts, sizeof(DecodingSchedMeta)/4}, opts.dtype(torch::kInt32));
num_splits = torch::empty({batch_size+1}, opts.dtype(torch::kInt32));
KU_CHECK_CONTIGUOUS(tile_scheduler_metadata);
KU_CHECK_CONTIGUOUS(num_splits);
GetDecodeSchedMetaParams get_sched_meta_params = {
batch_size, seqlen_q_ori,
64,
5,
-1, -1,
nullptr, nullptr,
seqlens_k.data_ptr<int>(),
(DecodingSchedMeta*)tile_scheduler_metadata->data_ptr(),
num_splits->data_ptr<int>(),
num_sm_parts,
at::cuda::getCurrentCUDAStream().stream()
};
smxx::decode::run_get_decoding_sched_meta_kernel(get_sched_meta_params);
} else {
KU_CHECK_DTYPE(tile_scheduler_metadata, torch::kInt32);
KU_CHECK_DTYPE(num_splits, torch::kInt32);
KU_CHECK_DEVICE(tile_scheduler_metadata);
KU_CHECK_DEVICE(num_splits);
KU_CHECK_CONTIGUOUS(tile_scheduler_metadata);
KU_CHECK_CONTIGUOUS(num_splits);
KU_CHECK_SHAPE(tile_scheduler_metadata, num_sm_parts, sizeof(DecodingSchedMeta)/sizeof(int));
KU_CHECK_SHAPE(num_splits, batch_size+1);
}
// Set the sizes
DenseAttnDecodeParams params;
params.b = batch_size;
params.s_q = seqlen_q_ori;
params.q_seq_per_hk = q_seq_per_hk;
params.seqlens_k_ptr = seqlens_k.data_ptr<int>();
params.h_q = num_heads_q;
params.h_k = num_heads_k;
params.num_blocks = num_blocks;
params.q_head_per_hk = num_q_heads_per_hk;
params.is_causal = is_causal;
params.d = head_size_k;
params.d_v = head_size_v;
params.scale_softmax = softmax_scale;
params.scale_softmax_log2 = float(softmax_scale * M_LOG2E);
// Set the pointers and strides.
params.q_ptr = q.data_ptr();
params.k_ptr = kcache.data_ptr();
params.o_ptr = out.data_ptr();
params.softmax_lse_ptr = lse.data_ptr<float>();
// All stride are in elements, not bytes.
params.q_batch_stride = q.stride(0);
params.k_batch_stride = kcache.stride(0);
params.o_batch_stride = out.stride(0);
params.q_row_stride = q.stride(1);
params.k_row_stride = kcache.stride(1);
params.o_row_stride = out.stride(2);
params.q_head_stride = q.stride(2);
params.k_head_stride = kcache.stride(2);
params.o_head_stride = out.stride(1);
params.block_table = block_table.data_ptr<int>();
params.block_table_batch_stride = block_table.stride(0);
params.page_block_size = page_block_size;
params.tile_scheduler_metadata_ptr = (DecodingSchedMeta*)tile_scheduler_metadata->data_ptr();
params.num_sm_parts = num_sm_parts;
params.num_splits_ptr = num_splits->data_ptr<int>();
const int total_num_splits = batch_size + params.num_sm_parts;
at::Tensor lse_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat));
at::Tensor out_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk, head_size_v}, opts.dtype(at::kFloat));
KU_CHECK_CONTIGUOUS(lse_accum);
KU_CHECK_CONTIGUOUS(out_accum);
params.total_num_splits = total_num_splits;
params.softmax_lseaccum_ptr = lse_accum.data_ptr<float>();
params.oaccum_ptr = out_accum.data_ptr<float>();
params.stream = at::cuda::getCurrentCUDAStream().stream();
if (q_dtype == torch::kBFloat16) {
sm90::run_flash_splitkv_mla_kernel<cutlass::bfloat16_t>(params);
} else {
TORCH_CHECK(false, "Unsupported dtype for dense MLA on SM90");
}
CombineParams combine_params = {
batch_size, seqlen_q_ori,
num_heads_q, head_size_v,
params.softmax_lse_ptr,
params.o_ptr,
num_heads*q_seq_per_hk, num_heads_q,
num_heads_q*seqlen_q_ori*head_size_v, num_heads_q*head_size_v, head_size_v,
params.softmax_lseaccum_ptr,
params.oaccum_ptr,
num_heads*q_seq_per_hk, num_heads_q,
num_heads_q*seqlen_q_ori*head_size_v, num_heads_q*head_size_v, head_size_v,
params.tile_scheduler_metadata_ptr,
params.num_splits_ptr,
params.num_sm_parts,
nullptr,
at::cuda::getCurrentCUDAStream().stream()
};
if (q_dtype == torch::kBFloat16) {
smxx::decode::run_flash_mla_combine_kernel<cutlass::bfloat16_t>(combine_params);
} else if (q_dtype == torch::kHalf) {
#ifndef FLASH_MLA_DISABLE_FP16
smxx::decode::run_flash_mla_combine_kernel<cutlass::half_t>(combine_params);
#endif
} else {
TORCH_CHECK(false, "Unsupported tensor dtype for query");
}
out = out.view({batch_size, num_heads_k, seqlen_q_ori, num_q_heads_per_hk, head_size_v}).transpose(1, 2)
.reshape({batch_size, seqlen_q_ori, num_heads_q, head_size_v});
lse = lse.view({batch_size, num_heads_k, seqlen_q_ori, num_q_heads_per_hk}).transpose(2, 3)
.reshape({batch_size, num_heads_q, seqlen_q_ori});
return {out, lse, tile_scheduler_metadata, num_splits};
}
......@@ -20,7 +20,9 @@ dense_attn_decode_qkvfp8_interface(
const float softmax_scale,
bool is_causal,
std::optional<at::Tensor> &tile_scheduler_metadata, // num_sm_parts x (DecodingSchedMetaSize/4)
std::optional<at::Tensor> &num_splits // batch_size + 1
std::optional<at::Tensor> &num_splits, // batch_size + 1
std::optional<const at::Tensor> &descale_q,
std::optional<const at::Tensor> &descale_k
) {
// Check arch
Arch arch = Arch();
......@@ -30,7 +32,10 @@ dense_attn_decode_qkvfp8_interface(
// Check data types
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kHalf);
TORCH_CHECK(q_dtype == torch::kFloat8_e4m3fn);
TORCH_CHECK(descale_q.has_value() && descale_k.has_value(), "descale is required when input dtype is fp8");
auto descale_q_ = descale_q.value();
auto descale_k_ = descale_k.value();
TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
......@@ -43,6 +48,14 @@ dense_attn_decode_qkvfp8_interface(
KU_CHECK_DEVICE(block_table);
KU_CHECK_DEVICE(tile_scheduler_metadata);
KU_CHECK_DEVICE(num_splits);
KU_CHECK_DEVICE(descale_q_);
KU_CHECK_DEVICE(descale_k_);
TORCH_CHECK(descale_q_.stride(-1) == 1);
TORCH_CHECK(descale_k_.stride(-1) == 1);
TORCH_CHECK(descale_q_.dtype() == torch::kFloat);
TORCH_CHECK(descale_k_.dtype() == torch::kFloat);
KU_CHECK_SHAPE(descale_q_, 1);
KU_CHECK_SHAPE(descale_k_, 1);
// Check layout
TORCH_CHECK(q.stride(-1) == 1, "q must have contiguous last dimension");
......@@ -75,7 +88,7 @@ dense_attn_decode_qkvfp8_interface(
const int num_heads = num_heads_k;
q = q.view({batch_size, seqlen_q_ori, num_heads_k, num_q_heads_per_hk, head_size_k}).transpose(2, 3)
.reshape({batch_size, q_seq_per_hk, num_heads, head_size_k});
int num_sm_parts = std::max(arch.num_sms / num_heads_k / cutlass::ceil_div(seqlen_q_ori*num_heads_q/num_heads_k, 16), 1);
int num_sm_parts = std::max(arch.num_sms / num_heads_k / cutlass::ceil_div(seqlen_q_ori*num_heads_q/num_heads_k, 16) * 2, 1);
KU_CHECK_SHAPE(q, batch_size, q_seq_per_hk, num_heads, head_size_k);
KU_CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k);
......@@ -87,7 +100,7 @@ dense_attn_decode_qkvfp8_interface(
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q.options();
at::Tensor out = torch::empty({batch_size, num_heads, q_seq_per_hk, head_size_v}, opts);
at::Tensor out = torch::empty({batch_size, num_heads, q_seq_per_hk, head_size_v}, opts.dtype(torch::kBFloat16));
at::Tensor lse = torch::empty({batch_size, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat));
KU_CHECK_CONTIGUOUS(out);
KU_CHECK_CONTIGUOUS(lse);
......@@ -123,7 +136,7 @@ dense_attn_decode_qkvfp8_interface(
}
// Set the sizes
DenseAttnDecodeParams params;
DenseAttnDecodeParams_fp8 params;
params.b = batch_size;
params.s_q = seqlen_q_ori;
params.q_seq_per_hk = q_seq_per_hk;
......@@ -161,6 +174,8 @@ dense_attn_decode_qkvfp8_interface(
params.num_sm_parts = num_sm_parts;
params.num_splits_ptr = num_splits->data_ptr<int>();
params.descale_q_ptr = descale_q_.data_ptr<float>();
params.descale_k_ptr = descale_k_.data_ptr<float>();;
const int total_num_splits = batch_size + params.num_sm_parts;
at::Tensor lse_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat));
at::Tensor out_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk, head_size_v}, opts.dtype(at::kFloat));
......@@ -200,11 +215,8 @@ dense_attn_decode_qkvfp8_interface(
at::cuda::getCurrentCUDAStream().stream()
};
if (q_dtype == torch::kBFloat16) {
smxx::decode::run_flash_mla_combine_kernel<cutlass::bfloat16_t>(combine_params);
} else {
TORCH_CHECK(false, "Unsupported tensor dtype for query");
}
out = out.view({batch_size, num_heads_k, seqlen_q_ori, num_q_heads_per_hk, head_size_v}).transpose(1, 2)
.reshape({batch_size, seqlen_q_ori, num_heads_q, head_size_v});
......
......@@ -60,6 +60,11 @@ struct DenseAttnDecodeParams { // TODO Change name to DenseAttnDecodeParams
cudaStream_t stream;
};
struct DenseAttnDecodeParams_fp8 : public DenseAttnDecodeParams {
float* __restrict__ descale_q_ptr = nullptr;
float* __restrict__ descale_k_ptr = nullptr;
};
struct SparseAttnDecodeParams {
int b, s_q;
int h_q, h_kv;
......
......@@ -12,7 +12,7 @@ namespace sm90 {
template<typename T>
__device__ void
compute_attn_1rowblock_splitkv_mla_gfx936(const DenseAttnDecodeParams params,
compute_attn_1rowblock_splitkv_mla_gfx936(const DenseAttnDecodeParams& params,
const int bidb, const int bidh, const int m_block,
const int n_split_idx, const int seqlen_k,
const int n_block_min, const int n_block_max, const bool NoSplit)
......
......@@ -3,6 +3,6 @@
namespace sm90 {
template void run_flash_splitkv_mla_qkvfp8_kernel<cutlass::float_e4m3_t>(DenseAttnDecodeParams &params);
template void run_flash_splitkv_mla_qkvfp8_kernel<cutlass::float_e4m3_t>(DenseAttnDecodeParams_fp8 &params);
}
......@@ -5,6 +5,6 @@
namespace sm90 {
template<typename InputT>
void run_flash_splitkv_mla_qkvfp8_kernel(DenseAttnDecodeParams &params);
void run_flash_splitkv_mla_qkvfp8_kernel(DenseAttnDecodeParams_fp8 &params);
}
......@@ -31,6 +31,13 @@ struct Traits {
using SmemLayoutRow = Layout<Shape<_128>, Stride<_1>>;
static constexpr int kSwizzle = 3;
using SmemLayoutAtomQ =
Layout<Shape<Int<16>, Int<64>>, Stride<Int<64>, _1>>;
using SmemLayoutQ = decltype(tile_to_shape(
SmemLayoutAtomQ{},
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
using SmemLayoutAtomK = decltype(composition(
Swizzle<kSwizzle, 4, 3>{},
Layout<Shape<Int<8>, Int<64>>, Stride<Int<64>, _1>>{}));
......@@ -39,8 +46,57 @@ struct Traits {
SmemLayoutAtomK{},
Shape<Int<kBlockN>, Int<8 * 64>>{}));
using SmemLayoutAtomV = SmemLayoutAtomK;
using SmemLayoutV = decltype(tile_to_shape(
SmemLayoutAtomV{},
Shape<Int<kBlockN>, Int<kHeadDimV>>{}));
using SmemLayoutAtomP = Layout<Shape<Int<4*16*16>>, Stride<Int<1>>>;
using SmemLayoutP = decltype(tile_to_shape(
SmemLayoutAtomP{},
Shape<Int<4*16*16>>{}));
using SmemLayoutVtransposed = decltype(
composition(SmemLayoutV{}, make_layout(Shape<Int<kHeadDimV>, Int<kBlockN>>{}, GenRowMajor{})));
using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{}));
using SmemLayoutAtomK_place_holder = Layout<Shape<Int<kBlockN>, Int<64>>, Stride<_64, _1>>;
using SmemLayoutK_place_holder = decltype(tile_to_shape(
SmemLayoutAtomK_place_holder{},
Shape<Int<kBlockN>, Int<7*64>>{}));
using MMA_Atom_Arch = MMA_Atom<GFX938_16x16x64_F32F8F8F32E4M3E4M3_NN_LIT>;
using MMA_Atom_Arch_16x32 = MMA_Atom<GFX938_16x32x32_F32F8F8F32E4M3E4M3_NT_LIT>;
using ValLayoutMNK = Layout<Shape<_1, _1, _1>>;
using TiledMma = TiledMMA<
MMA_Atom_Arch,
Layout<Shape<_1, Int<kNWarps>, _1>>, // 1x4x1 or 1x8x1 thread group
ValLayoutMNK>;//
using TiledMma_O = TiledMMA<
MMA_Atom_Arch_16x32,
Layout<Shape<_1, Int<kNWarps>, _1>>, // 1x4x1 or 1x8x1 thread group
ValLayoutMNK>;
struct SharedMemoryPlan {
union {
struct {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v; // Double buffer
};
struct {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutK_place_holder>> smem_temp; // Double buffer
cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutRow>> smem_row_sum;
cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutRow>> smem_row_max;
};
struct {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
};
};
};
......
......@@ -735,6 +735,213 @@ __forceinline__ __device__ auto convert_layout_acc_Aregs_dense(const TiledMma& t
return tSrACC;
}
template <
bool Is_even_MN=true,
bool Is_even_K=true,
bool Is_load_Q=false,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
lds_direct_copy_qkvfp8(
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst,
int k_idx_, const int row_stride,
const int max_MN=0)
{
if constexpr (Is_load_Q) {
constexpr int warp_size = 64;
int tidx = threadIdx.x;
int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size);
int lane = tidx % warp_size;
constexpr int element_size = 1;
int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);
const int offset_s = 0;
struct PtrWrapper {
uint32_t former;
uint32_t latter;
};
PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
uint32x4_t global_addr = {0};
global_addr[0] = (glob_ptr.former);
global_addr[1] = (glob_ptr.latter);
global_addr[2] = 0x80000000;
global_addr[3] = 0x00020000;
constexpr int elements_per_thread = 16;
constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;
int mma_k = 16*256;
int row = lane % 16;
int col = lane / 16;
int row_offset = row ;
int col_offset = (col + warp_id * 4) * elements_per_thread + k_idx * 256;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
if (!Is_even_K && col_offset >= 576) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_idx * mma_k * element_size;
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
} else {
constexpr int warp_size = 64;
int tidx = threadIdx.x;//0-256
int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size);
int lane = tidx % warp_size;//0-63
constexpr int element_size = 1;
int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);//576
const int offset_s = 0;
// global addr
// uint32x4_t global_addr = {0};
// *(uint64_t*)&global_addr = reinterpret_cast<uint64_t>(src.data().get());
// global_addr[1] += 0x41000000; // 62 bit: cache swizzle; 48~61: Stride
// global_addr[2] = 0xfffffffe;
// global_addr[3] = 0x00020000;
struct PtrWrapper {
uint32_t former;
uint32_t latter;
};
PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
uint32x4_t global_addr = {0};
global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former);
global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter);
global_addr[2] = 0x80000000;
global_addr[3] = 0x00020000;
constexpr int elements_per_thread = 16;
constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;//64*16*1
int mma_k = 64*64;
// int row = lane / 4;
// int col = lane % 4;
// int swizzle_col = ((row / 2) ^ (col )) * 4 + (col % 4);
// 此处待优化,后8行,行号需要交换
int virtual_row = lane / 8;//0
int virtual_col = lane % 8;//0
int swizzle_col = virtual_row ^ virtual_col;
int row = lane / 4;//0
// 8->9 9->8
row = (row >= 8 ) ^ row;
// row = row >= 8 ? (swizzle_col / 4) > 0 ? row + 1 : row - 1 : row;
int col = swizzle_col % 4;
int row_offset = row + (warp_id * 16) ;
int col_offset = col * elements_per_thread + k_idx * 64;
int offset_v = row_offset * row_stride + (col_offset) * element_size; // bytes
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
//int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + (k_idx % 2) * mma_k * element_size;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + (k_idx) * mma_k * element_size;
#if defined(__gfx938__)
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
#endif
}
}
template <
bool Is_even_MN=true,
bool Is_even_K=true,
bool mma_layout = false,
bool use_asm = false,
class SrcEngine, class SrcLayout
>
CUTE_HOST_DEVICE
void
buffer_load_copy_qkvfp8(
Tensor<SrcEngine, SrcLayout> const& src,
uint128_t & dst,
int k_idx_, const int row_stride,
int offset_k,
const int max_MN=0)
{
constexpr int warp_size = 64;
int tidx = threadIdx.x;
int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size);
int lane = tidx % warp_size;
constexpr int element_size = 1;
int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);
constexpr int elements_per_thread = 16;
if constexpr (mma_layout)
{
struct PtrWrapper {
uint32_t former;
uint32_t latter;
};
PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
uint32x4_t global_addr = {0};
global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former);
global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter);
global_addr[2] = 0x80000000;
global_addr[3] = 0x00020000;
int mma_k = 32*64;
int row = tidx % 16;
int col = lane / 16;
int row_offset = row + (warp_id * 16) ;
int col_offset = col * elements_per_thread + k_idx * 64;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
if constexpr(use_asm) {
asm volatile(
"buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0 \n"
" \n\t" :"=v"(dst),
"+v"(offset_v), "+s"(global_addr)
);
}
else {
auto res = __builtin_amdgcn_buffer_load_dwordx4(global_addr, 0, offset_v, false, false);
dst = *reinterpret_cast<uint128_t*>(&res);
}
}
}
template<int row, int col, int r_row, typename Tensor0>
__forceinline__ __device__ void __ds_read_m32x32_row_col_rrow(Tensor0& src, intx4_t& dst)
{
auto lds = reinterpret_cast<int *>(src.data().get());
auto layout = src.layout();
constexpr short offset = layout(0, row, col) * 1;
auto d = __builtin_amdgcn_ds_read_m32x32u8((__attribute__((address_space(3))) int*)(lds), offset);
dst = d;
}
}
\ No newline at end of file
......@@ -3,11 +3,13 @@ __version__ = "1.0.0"
from flash_mla.flash_mla_interface import (
get_mla_metadata,
flash_mla_with_kvcache,
flash_mla_sparse_fwd
flash_mla_sparse_fwd,
flash_mla_with_kvcache_qkvfp8
)
__all__ = [
"get_mla_metadata",
"flash_mla_with_kvcache",
"flash_mla_sparse_fwd"
"flash_mla_sparse_fwd",
"flash_mla_with_kvcache_qkvfp8"
]
......@@ -211,3 +211,82 @@ def flash_mla_sparse_fwd(
return results
def flash_mla_with_kvcache_qkvfp8(
q: torch.Tensor,
k_cache: torch.Tensor,
block_table: Optional[torch.Tensor],
cache_seqlens: Optional[torch.Tensor],
head_dim_v: int,
tile_scheduler_metadata: FlashMLASchedMeta,
num_splits: None = None,
softmax_scale: Optional[float] = None,
causal: bool = False,
descale_q: Optional[torch.Tensor] = None,
descale_k: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
Different modes (including fp8/bf16, and sparsity) has different KV cache layouts. See comments below for details.
The KV cache must be contiguously valid for sparse attention on sm100. Here "contiguously valid" means that every byte, from the very beginning of the KV cache, till the last byte in the KV cache, is valid memory address to visit (i.e. won't IMA). In other words, the KV cache could be a slice of a larger array, but cannot be a list of disjoint memory blocks.
block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used.
cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used.
head_dim_v: Head_dim of v. Must be 512
sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same.
num_splits_placeholder: must be "None" (to be compatible with the old interface).
softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k).
causal: bool. Whether to apply causal attention mask. Only valid for dense attention
descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
Return:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v), only support bf16 output
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
sched_meta = tile_scheduler_metadata
assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta"
assert num_splits is None, "num_splits must be None"
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
if not sched_meta.have_initialized:
# Initialize the tile scheduler metadata during the first invocation.
sched_meta.have_initialized = True
sched_meta.config = FlashMLASchedMeta.Config(
q.shape[0],
q.shape[1],
q.shape[2],
k_cache.shape[1],
k_cache.shape[2],
causal,
False,
0,
0,
0
)
else:
# Check whether the input arguments are consistent with sched_meta
helper_msg = " Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta."
assert sched_meta.config is not None
assert sched_meta.config.b == q.shape[0], "sched_meta.config.b must be equal to batch_size." + helper_msg
assert sched_meta.config.s_q == q.shape[1], "sched_meta.config.s_q must be equal to seq_len_q." + helper_msg
assert sched_meta.config.h_q == q.shape[2], "sched_meta.config.h_q must be equal to num_heads_q." + helper_msg
assert sched_meta.config.page_block_size == k_cache.shape[1], "sched_meta.config.page_block_size must be equal to page_block_size." + helper_msg
assert sched_meta.config.h_k == k_cache.shape[2], "sched_meta.config.h_k must be equal to num_heads_k." + helper_msg
assert sched_meta.config.causal == causal, "sched_meta.config.causal must be equal to causal." + helper_msg
assert sched_meta.config.is_fp8_kvcache == False, "sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache." + helper_msg
# Dense attention
assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used."
out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.dense_decode_fwd_qkvfp8(
q, k_cache, head_dim_v,
cache_seqlens, block_table,
softmax_scale, causal,
sched_meta.tile_scheduler_metadata, sched_meta.num_splits,
descale_q, descale_k
)
sched_meta.tile_scheduler_metadata = new_tile_scheduler_metadata
sched_meta.num_splits = new_num_splits
return (out, lse)
\ No newline at end of file
import argparse
import math
import random
import torch
import triton
from flash_mla import flash_mla_with_kvcache_qkvfp8, get_mla_metadata
torch.set_printoptions(precision=4, profile="default", sci_mode=False)
def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False, k_scale=1.0):
query = query.float()
key = key.float() * k_scale
value = value.float() * k_scale
key = key.repeat_interleave(h_q // h_kv, dim=0)
value = value.repeat_interleave(h_q // h_kv, dim=0)
tmp = query @ key.transpose(-2, -1)
# print("tmp s ", tmp[0, :4, :10])
attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))
if is_causal:
s_q = query.shape[-2]
s_k = key.shape[-2]
attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype)
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
attn_weight += attn_bias
lse = attn_weight.logsumexp(dim=-1)
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
return attn_weight @ value, lse
def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool=False) -> None:
torch_dtype = x.dtype
x, y = x.double(), y.double()
RMSE = ((x - y) * (x - y)).mean().sqrt().item()
cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12)
amax_diff = (x - y).abs().max().item()
print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}")
if use_fp8:
assert cos_diff < 1e-3
else:
assert cos_diff < (1e-4 if torch_dtype == torch.bfloat16 else 1e-5)
@torch.inference_mode()
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, is_prof=False,torch_dtype=torch.float16):
print(
f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}, {torch_dtype=}"
)
use_fp8 = torch_dtype == torch.float8_e4m3fn
cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32)
if varlen:
for i in range(b):
cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), s_q)
total_seqlens = cache_seqlens.sum().item()
mean_seqlens = cache_seqlens.float().mean().int().item()
max_seqlen = cache_seqlens.max().item()
max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256
print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}, {max_seqlen_pad=}")
q = torch.ones(b, s_q, h_q, d)
q = torch.randn(b, s_q, h_q, d)
# for i in range(576):
# q[:, :, :, i] = i
# q[:, :, 1:, :] = 0
# q = torch.ones(b, s_q, h_q, d)
# print("q ", q[0, 0, 0:3, :10])
block_size = 64
block_table = torch.arange(
b * max_seqlen_pad // block_size, dtype=torch.int32
).view(b, max_seqlen_pad // block_size)
blocked_k = torch.ones(block_table.numel(), block_size, h_kv, d)
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
# blocked_k[:, :, :, 32:] = 0.0
# blocked_k[:, 32:, :, :] = 0
# blocked_k[:, :, :, 4:] = 0
# blocked_k[:, :32, :, :] = 0
# blocked_k[:, 16:, :, :] = 0
for i in range(b):
blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = (
float("nan")
)
blocked_v = blocked_k[..., :dv]
tile_scheduler_metadata, num_splits = get_mla_metadata()
init_dtype = q.dtype
def prepare_fp8_input():
q_fp8, blocked_k_fp8, blocked_v_fp8, descale_q, descale_k = None, None, None, None, None
if use_fp8:
nonlocal q, blocked_k, blocked_v
fp8_dtype = torch.float8_e4m3fn
descale_q = torch.ones((1), dtype=torch.float32)
descale_k = torch.ones((1), dtype=torch.float32)
q_fp8 = q.to(fp8_dtype)
blocked_k_fp8 = blocked_k.to(fp8_dtype)
blocked_v_fp8 = blocked_k_fp8[..., :dv]
return q_fp8, blocked_k_fp8, blocked_v_fp8, descale_q, descale_k
q_fp8, blocked_k_fp8, blocked_v_fp8, descale_q, descale_k = prepare_fp8_input()
# print(blocked_v_fp8[0, 32:36, 0, :4])
if use_fp8:
q = q_fp8
blocked_k = blocked_k_fp8
blocked_v = blocked_v_fp8
# print(" descale_q ", descale_q.shape, descale_q.stride())
# print(" blocked_k ", blocked_k.shape)
def flash_mla():
return flash_mla_with_kvcache_qkvfp8(
q,
blocked_k,
block_table,
cache_seqlens,
dv,
tile_scheduler_metadata,
num_splits,
causal=causal,
descale_q=descale_q,
descale_k=descale_k,
)
def ref_mla():
q_ = (q.to(torch.float) * descale_q).to(init_dtype) if use_fp8 else q
blocked_k_ = (blocked_k.to(torch.float) * descale_k).to(init_dtype) if use_fp8 else blocked_k
blocked_v_ = (blocked_v.to(torch.float) * descale_k).to(init_dtype) if use_fp8 else blocked_v
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
for i in range(b):
begin = i * max_seqlen_pad
end = begin + cache_seqlens[i]
O, LSE = scaled_dot_product_attention(
q_[i].transpose(0, 1),
blocked_k_.view(-1, h_kv, d)[begin:end].transpose(0, 1),
blocked_v_.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
h_q=h_q,
h_kv=h_kv,
is_causal=causal,
)
out[i] = O.transpose(0, 1)
lse[i] = LSE
return out, lse
out_flash, lse_flash = flash_mla()
out_torch, lse_torch = ref_mla()
# print(" ", out_flash.shape, lse_flash.shape, q.shape)
# print("out max_diff ", (out_flash - out_torch).abs().max())
# print("lse max_diff ", (lse_flash - lse_torch).abs().max())
# print(" diff ", torch.nonzero((lse_flash - lse_torch).abs() > 0.1))
# print(" diff ", torch.nonzero((out_flash - out_torch).abs() > 0.1))
# print(" nan ", torch.nonzero(torch.isnan(out_flash)))
cal_diff(out_flash, out_torch, "out", use_fp8)
cal_diff(lse_flash, lse_torch, "lse")
if is_prof: return
t = triton.testing.do_bench(flash_mla)
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d) * (torch.finfo(torch_dtype).bits // 8) + (b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8)
print(
f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s"
)
def main(torch_dtype, is_prof=False):
device = torch.device("cuda:0")
init_dtype = torch.bfloat16 if torch_dtype == torch.float8_e4m3fn else torch_dtype
torch.set_default_dtype(init_dtype)
torch.set_default_device(device)
torch.cuda.set_device(device)
torch.manual_seed(0)
random.seed(0)
'''
h_kv = 1
d, dv = 576, 512
causal = True
for b in [128]:
for s in [4096, 8192]:
for h_q in [16, 32, 64, 128]: # TP = 8, 4, 2, 1
for s_q in [1, 2]: # MTP = 1, 2
for varlen in [False, True]:
test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen)
# b, s_q, s, h_q, h_kv, d, dv, causal, varlen'''
# test_flash_mla( 1, 1, 64, 16, 1, 576, 512, True, False, is_prof=is_prof)
# test_flash_mla_fp8( 1, 1, 1000, 1, 1, 576, 512, True, False, is_prof=is_prof)
# test_flash_mla_fp8( 1, 1, 4096, 8, 1, 576, 512, True, False, is_prof=is_prof)
# test_flash_mla_fp8(32, 1, 4096, 16, 1, 576, 512, False, False, is_prof=is_prof)
# '''
h_kv = 1
d, dv = 576, 512
causal = True
# for b in [40, 80]:
# for s in [3500, 4000, 8192, 16384]:
# for h_q in [16]:
# for s_q in [1]: # MTP = 1, 2
# for varlen in [False]:
# test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen,False,torch_dtype)
# 压测
for b in [3, 6, 9, 12, 15, 18, 21, 24, 40, 41, 79, 80]:
for s in [111, 112, 123, 1234, 432, 4325, 4000, 8192, 12345, 45321]:
for h_q in [16]:
for s_q in [1, 2, 3]: # MTP = 1, 2
for varlen in [False, True]:
test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen,True,torch_dtype)
for b in [3, 6, 9, 12, 15, 18, 21, 24, 32, 64, 128, 256]:
for s in [4000]:
for h_q in [16]:
for s_q in [1]: # MTP = 1, 2
for varlen in [False]:
test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen,False,torch_dtype)
# for b in [1]:
# for s in [128]:
# for h_q in [128]:
# for s_q in [2]: # MTP = 1, 2
# for varlen in [False]:
# test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen,False,torch_dtype)
# for b in [1, 32]:
# for s in [200, 1002, 2002, 1024, 2000, 4000, 32768, 65536]:
# for h_q in [4, 16, 32, 64]:
# for s_q in [1, 2]: # MTP = 1, 2
# for varlen in [False]:
# test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen,False,torch_dtype)
# '''
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--dtype",
type=str,
choices=["bf16", "fp16","e4m3"],
default="bf16",
help="Data type to use for testing (bf16/fp16/e4m3)",
)
parser.add_argument('--prof', default=False, action='store_true', help='prof or not')
args = parser.parse_args()
torch_dtype = torch.float8_e4m3fn
if args.dtype == "fp16":
torch_dtype = torch.float16
elif args.dtype == "e4m3":
torch_dtype = torch.float8_e4m3fn
main(torch_dtype, args.prof)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment