Commit 4976cbaa authored by zhanghj2's avatar zhanghj2
Browse files

搭建支持qkvfp8的框架

parent 06612d65
...@@ -3,11 +3,13 @@ ...@@ -3,11 +3,13 @@
#include "sparse_fwd.h" #include "sparse_fwd.h"
#include "sparse_decode.h" #include "sparse_decode.h"
#include "dense_decode.h" #include "dense_decode.h"
#include "dense_decode_qkvfp8.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "FlashMLA"; m.doc() = "FlashMLA";
m.def("sparse_decode_fwd", &sparse_attn_decode_interface); m.def("sparse_decode_fwd", &sparse_attn_decode_interface);
m.def("dense_decode_fwd", &dense_attn_decode_interface); m.def("dense_decode_fwd", &dense_attn_decode_interface);
m.def("dense_decode_fwd_qkvfp8", &dense_attn_decode_qkvfp8_interface);
m.def("sparse_prefill_fwd", &sparse_attn_prefill_interface); m.def("sparse_prefill_fwd", &sparse_attn_prefill_interface);
} }
#pragma once
#include <cutlass/half.h>
#include <cutlass/fast_math.h>
#include "common.h"
#include "params.h"
#include "sm90/decode/dense_qkvfp8/splitkv_mla.h"
#include "smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.h"
#include "smxx/decode/combine/combine.h"
static std::tuple<at::Tensor, at::Tensor, std::optional<at::Tensor>, std::optional<at::Tensor>>
dense_attn_decode_qkvfp8_interface(
at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size (when is_fp8 is False) or num_blocks x num_heads_k x (page_block_size*656) (when is_fp8 is True)
const int head_size_v,
const at::Tensor &seqlens_k, // batch_size
const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq
const float softmax_scale,
bool is_causal,
std::optional<at::Tensor> &tile_scheduler_metadata, // num_sm_parts x (DecodingSchedMetaSize/4)
std::optional<at::Tensor> &num_splits // batch_size + 1
) {
// Check arch
Arch arch = Arch();
if (!arch.is_sm90a()) {
TORCH_CHECK(false, "Dense decode MLA is only supported on SM90a architecture");
}
// Check data types
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kHalf);
TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
// Check device
KU_CHECK_DEVICE(q);
KU_CHECK_DEVICE(kcache);
KU_CHECK_DEVICE(seqlens_k);
KU_CHECK_DEVICE(block_table);
KU_CHECK_DEVICE(tile_scheduler_metadata);
KU_CHECK_DEVICE(num_splits);
// Check layout
TORCH_CHECK(q.stride(-1) == 1, "q must have contiguous last dimension");
TORCH_CHECK(kcache.stride(-1) == 1, "kcache must have contiguous last dimension");
KU_CHECK_CONTIGUOUS(seqlens_k);
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
KU_CHECK_CONTIGUOUS(tile_scheduler_metadata);
KU_CHECK_CONTIGUOUS(num_splits);
const auto sizes = q.sizes();
const int batch_size = sizes[0];
const int seqlen_q_ori = sizes[1];
const int num_heads_q = sizes[2];
const int head_size_k = sizes[3];
TORCH_CHECK(head_size_k == 576 || head_size_k == 512, "Only head_size_k == 576 or 512 is supported");
TORCH_CHECK(head_size_v == 512, "Only head_size_v == 576 is supported");
const int max_num_blocks_per_seq = block_table.size(1);
const int num_blocks = kcache.size(0);
const int page_block_size = kcache.size(1);
const int num_heads_k = kcache.size(2);
TORCH_CHECK(page_block_size == 64, "Currently page_block_size must be 64");
TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(num_heads_q % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (seqlen_q_ori == 1) { is_causal = false; }
const int num_q_heads_per_hk = num_heads_q / num_heads_k;
const int q_seq_per_hk = seqlen_q_ori * num_q_heads_per_hk;
const int num_heads = num_heads_k;
q = q.view({batch_size, seqlen_q_ori, num_heads_k, num_q_heads_per_hk, head_size_k}).transpose(2, 3)
.reshape({batch_size, q_seq_per_hk, num_heads, head_size_k});
int num_sm_parts = std::max(arch.num_sms / num_heads_k / cutlass::ceil_div(seqlen_q_ori*num_heads_q/num_heads_k, 16), 1);
KU_CHECK_SHAPE(q, batch_size, q_seq_per_hk, num_heads, head_size_k);
KU_CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k);
KU_CHECK_SHAPE(seqlens_k, batch_size);
KU_CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
KU_CHECK_SHAPE(tile_scheduler_metadata, num_sm_parts, DecodingSchedMetaSize/sizeof(int));
KU_CHECK_SHAPE(num_splits, batch_size+1);
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q.options();
at::Tensor out = torch::empty({batch_size, num_heads, q_seq_per_hk, head_size_v}, opts);
at::Tensor lse = torch::empty({batch_size, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat));
KU_CHECK_CONTIGUOUS(out);
KU_CHECK_CONTIGUOUS(lse);
if (!tile_scheduler_metadata.has_value()) {
tile_scheduler_metadata = torch::empty({num_sm_parts, sizeof(DecodingSchedMeta)/4}, opts.dtype(torch::kInt32));
num_splits = torch::empty({batch_size+1}, opts.dtype(torch::kInt32));
KU_CHECK_CONTIGUOUS(tile_scheduler_metadata);
KU_CHECK_CONTIGUOUS(num_splits);
GetDecodeSchedMetaParams get_sched_meta_params = {
batch_size, seqlen_q_ori,
64,
5,
-1, -1,
nullptr, nullptr,
seqlens_k.data_ptr<int>(),
(DecodingSchedMeta*)tile_scheduler_metadata->data_ptr(),
num_splits->data_ptr<int>(),
num_sm_parts,
at::cuda::getCurrentCUDAStream().stream()
};
smxx::decode::run_get_decoding_sched_meta_kernel(get_sched_meta_params);
} else {
KU_CHECK_DTYPE(tile_scheduler_metadata, torch::kInt32);
KU_CHECK_DTYPE(num_splits, torch::kInt32);
KU_CHECK_DEVICE(tile_scheduler_metadata);
KU_CHECK_DEVICE(num_splits);
KU_CHECK_CONTIGUOUS(tile_scheduler_metadata);
KU_CHECK_CONTIGUOUS(num_splits);
KU_CHECK_SHAPE(tile_scheduler_metadata, num_sm_parts, sizeof(DecodingSchedMeta)/sizeof(int));
KU_CHECK_SHAPE(num_splits, batch_size+1);
}
// Set the sizes
DenseAttnDecodeParams params;
params.b = batch_size;
params.s_q = seqlen_q_ori;
params.q_seq_per_hk = q_seq_per_hk;
params.seqlens_k_ptr = seqlens_k.data_ptr<int>();
params.h_q = num_heads_q;
params.h_k = num_heads_k;
params.num_blocks = num_blocks;
params.q_head_per_hk = num_q_heads_per_hk;
params.is_causal = is_causal;
params.d = head_size_k;
params.d_v = head_size_v;
params.scale_softmax = softmax_scale;
params.scale_softmax_log2 = float(softmax_scale * M_LOG2E);
// Set the pointers and strides.
params.q_ptr = q.data_ptr();
params.k_ptr = kcache.data_ptr();
params.o_ptr = out.data_ptr();
params.softmax_lse_ptr = lse.data_ptr<float>();
// All stride are in elements, not bytes.
params.q_batch_stride = q.stride(0);
params.k_batch_stride = kcache.stride(0);
params.o_batch_stride = out.stride(0);
params.q_row_stride = q.stride(1);
params.k_row_stride = kcache.stride(1);
params.o_row_stride = out.stride(2);
params.q_head_stride = q.stride(2);
params.k_head_stride = kcache.stride(2);
params.o_head_stride = out.stride(1);
params.block_table = block_table.data_ptr<int>();
params.block_table_batch_stride = block_table.stride(0);
params.page_block_size = page_block_size;
params.tile_scheduler_metadata_ptr = (DecodingSchedMeta*)tile_scheduler_metadata->data_ptr();
params.num_sm_parts = num_sm_parts;
params.num_splits_ptr = num_splits->data_ptr<int>();
const int total_num_splits = batch_size + params.num_sm_parts;
at::Tensor lse_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat));
at::Tensor out_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk, head_size_v}, opts.dtype(at::kFloat));
KU_CHECK_CONTIGUOUS(lse_accum);
KU_CHECK_CONTIGUOUS(out_accum);
params.total_num_splits = total_num_splits;
params.softmax_lseaccum_ptr = lse_accum.data_ptr<float>();
params.oaccum_ptr = out_accum.data_ptr<float>();
params.stream = at::cuda::getCurrentCUDAStream().stream();
if (q_dtype == torch::kFloat8_e4m3fn) {
sm90::run_flash_splitkv_mla_qkvfp8_kernel<cutlass::float_e4m3_t>(params);
} else {
TORCH_CHECK(false, "Unsupported dtype for dense MLA on SM90");
}
CombineParams combine_params = {
batch_size, seqlen_q_ori,
num_heads_q, head_size_v,
params.softmax_lse_ptr,
params.o_ptr,
num_heads*q_seq_per_hk, num_heads_q,
num_heads_q*seqlen_q_ori*head_size_v, num_heads_q*head_size_v, head_size_v,
params.softmax_lseaccum_ptr,
params.oaccum_ptr,
num_heads*q_seq_per_hk, num_heads_q,
num_heads_q*seqlen_q_ori*head_size_v, num_heads_q*head_size_v, head_size_v,
params.tile_scheduler_metadata_ptr,
params.num_splits_ptr,
params.num_sm_parts,
nullptr,
at::cuda::getCurrentCUDAStream().stream()
};
if (q_dtype == torch::kBFloat16) {
smxx::decode::run_flash_mla_combine_kernel<cutlass::bfloat16_t>(combine_params);
} else {
TORCH_CHECK(false, "Unsupported tensor dtype for query");
}
out = out.view({batch_size, num_heads_k, seqlen_q_ori, num_q_heads_per_hk, head_size_v}).transpose(1, 2)
.reshape({batch_size, seqlen_q_ori, num_heads_q, head_size_v});
lse = lse.view({batch_size, num_heads_k, seqlen_q_ori, num_q_heads_per_hk}).transpose(2, 3)
.reshape({batch_size, num_heads_q, seqlen_q_ori});
return {out, lse, tile_scheduler_metadata, num_splits};
}
#pragma once
namespace Config {
static constexpr int BLOCK_SIZE_M = 16;
static constexpr int PAGE_BLOCK_SIZE = 64;
static constexpr int HEAD_DIM_K = 576;
static constexpr int HEAD_DIM_V = 512;
}
#include "../splitkv_mla.cuh"
#include "../splitkv_mla.h"
namespace sm90 {
template void run_flash_splitkv_mla_qkvfp8_kernel<cutlass::float_e4m3_t>(DenseAttnDecodeParams &params);
}
#include <cutlass/cutlass.h>
#include "utils.h"
#include "params.h"
#include "config.h"
#include "traits.h"
#include "softmax.h"
using namespace cute;
namespace sm90 {
template<typename T>
__device__ void
compute_attn_1rowblock_splitkv_mla_qkvfp8_gfx938(const DenseAttnDecodeParams params,
const int bidb, const int bidh, const int m_block,
const int n_split_idx, const int seqlen_k,
const int n_block_min, const int n_block_max, const bool NoSplit)
{
constexpr int kBlockM = T::kBlockM;
constexpr int kBlockN = T::kBlockN;
constexpr int kHeadDim = T::kHeadDim;
constexpr int kHeadDimV = T::kHeadDimV;
const int tidx = threadIdx.x;
}
template<typename T>
__global__ void __launch_bounds__(T::NUM_THREADS, 1)
flash_fwd_splitkv_mla_qkvfp8_kernel(const DenseAttnDecodeParams params) {
const int m_block = blockIdx.x;
const int bidh = blockIdx.y;
const int partition_idx = blockIdx.z;
DecodingSchedMeta sched_meta = params.tile_scheduler_metadata_ptr[partition_idx];
if (sched_meta.begin_req_idx >= params.b) return;
for (int batch_idx = sched_meta.begin_req_idx; batch_idx <= sched_meta.end_req_idx; ++batch_idx) {
constexpr int kBlockN = T::PAGE_BLOCK_SIZE;
const int n_split_idx = batch_idx == sched_meta.begin_req_idx ? sched_meta.begin_split_idx : 0;
int seqlen_k = __ldg(params.seqlens_k_ptr + batch_idx);
const int start_block_idx = batch_idx == sched_meta.begin_req_idx ? sched_meta.begin_block_idx : 0;
int end_block_idx = batch_idx == sched_meta.end_req_idx ? sched_meta.end_block_idx : cute::ceil_div(seqlen_k, kBlockN);
const bool is_no_split = batch_idx == sched_meta.begin_req_idx ? !sched_meta.is_first_req_splitted : (batch_idx == sched_meta.end_req_idx ? !sched_meta.is_last_req_splitted : true);
if (batch_idx > sched_meta.begin_req_idx) {
__syncthreads();
}
compute_attn_1rowblock_splitkv_mla_qkvfp8_gfx938<T>(params, batch_idx, bidh, m_block, n_split_idx,
seqlen_k, start_block_idx, end_block_idx, is_no_split
);
}
}
template<typename InputT>
void run_flash_splitkv_mla_qkvfp8_kernel(DenseAttnDecodeParams &params) {
FLASH_ASSERT(params.d == Config::HEAD_DIM_K);
FLASH_ASSERT(params.d_v == Config::HEAD_DIM_V);
constexpr size_t smem_size = 65536;
// Use cudaLaunchKernelEx to enable PDL (Programmatic Dependent Launch)
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
using T = Traits<InputT, Is_causal>;
const int num_m_block = cute::ceil_div(params.q_seq_per_hk, T::BLOCK_SIZE_M);
auto mla_kernel = &flash_fwd_splitkv_mla_qkvfp8_kernel<T>;
mla_kernel<<<dim3(num_m_block, params.h_k, params.num_sm_parts), T::NUM_THREADS, smem_size, params.stream>>>(params);
});
// cudaLaunchConfig_t mla_kernel_config = {
// dim3(num_m_block, params.h_k, params.num_sm_parts),
// dim3(T::NUM_THREADS, 1, 1),
// smem_size,
// params.stream,
// mla_kernel_attributes,
// 1
// };
// cudaLaunchKernelEx(&mla_kernel_config, mla_kernel, params, tma_params);
CHECK_CUDA_KERNEL_LAUNCH();
}
}
#pragma once
#include "params.h"
namespace sm90 {
template<typename InputT>
void run_flash_splitkv_mla_qkvfp8_kernel(DenseAttnDecodeParams &params);
}
#pragma once
#include <cute/tensor.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_types.h>
#include <cutlass/barrier.h>
#include "config.h"
using namespace cute;
template<typename InputT_, bool Is_causal_>
struct Traits {
using InputT = InputT_;
static constexpr bool Is_causal = Is_causal_;
static constexpr int BLOCK_SIZE_M = Config::BLOCK_SIZE_M;
static constexpr int PAGE_BLOCK_SIZE = Config::PAGE_BLOCK_SIZE;
static constexpr int HEAD_DIM_K = Config::HEAD_DIM_K;
static constexpr int HEAD_DIM_V = Config::HEAD_DIM_V;
static constexpr int NUM_THREADS = 256;
static_assert(std::is_same_v<InputT, cutlass::float_e4m3_t>);
static constexpr int kBlockM = BLOCK_SIZE_M;
static constexpr int kBlockN = PAGE_BLOCK_SIZE;
static constexpr int kHeadDim = HEAD_DIM_K;
static constexpr int kHeadDimV = HEAD_DIM_V;
static constexpr int kNWarps = 4;
using Element = InputT;
using elem_type = Element;
using ElementAccum = float;
using SmemLayoutRow = Layout<Shape<_128>, Stride<_1>>;
static constexpr int kSwizzle = 3;
using SmemLayoutAtomK = decltype(composition(
Swizzle<kSwizzle, 4, 3>{},
Layout<Shape<Int<8>, Int<64>>, Stride<Int<64>, _1>>{}));
using SmemLayoutK = decltype(tile_to_shape(
SmemLayoutAtomK{},
Shape<Int<kBlockN>, Int<8 * 64>>{}));
struct SharedMemoryPlan {
};
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment