// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/flash_api.cpp /****************************************************************************** * Copyright (c) 2024, Tri Dao. ******************************************************************************/ #pragma once // #include #include #include #include #include #include #include "flash_mla.h" #include "static_switch.h" #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") static const bool print_param = get_env_("FLASH_MLA_PRINT_PARAM"); std::string static execCommand(const char* cmd) { std::string result; FILE* pipe = popen(cmd, "r"); // 打开管道,只读方式 if (!pipe) { return "popen failed"; } char buffer[256]; while (fgets(buffer, sizeof(buffer), pipe) != nullptr) { result += buffer; } pclose(pipe); // 关闭管道并等待子进程结束 if (!result.empty() && result.back() == '\n') { result.pop_back(); } return result; } std::vector mha_fwd_kvcache_quantization_mla( at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size std::optional &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v const int head_size_v, const at::Tensor &seqlens_k, // batch_size const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq const float softmax_scale, bool is_causal, const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize const at::Tensor &num_splits, // batch_size + 1 const at::Tensor &k_scale, const std::string &kv_cache_dtype ) { auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_gfx936 = dprops->major == 9 && dprops->minor == 3; // TORCH_CHECK(is_sm90); at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache; auto q_dtype = q.dtype(); if (kv_cache_dtype == "fp8_e4m3" || kv_cache_dtype == "fp8_e5m2") { TORCH_CHECK(kcache.dtype() != q_dtype, "非量化情况下, query and key must have not the same dtype"); CHECK_DEVICE(k_scale); TORCH_CHECK(k_scale.dtype() == torch::kFloat32, "非量化情况下, query and key must have the same dtype"); TORCH_CHECK(is_gfx936, "fp8_e4m3 and fp8_e5m2 Attention Forward Kernel (mha_fwd_kvcache_quantization_mla) is only supported on gfx936 architectures"); } else { TORCH_CHECK(false, "Unsupported kv cache dtype"); } CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache); TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension"); CHECK_DEVICE(block_table); TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); const auto sizes = q.sizes(); const int batch_size = sizes[0]; const int seqlen_q_ori = sizes[1]; const int num_heads_ori = sizes[2]; const int head_size = sizes[3]; TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); TORCH_CHECK(head_size_v % 32 == 0, "head_size_v should be a multiple of 32"); const int max_num_blocks_per_seq = block_table.size(1); const int num_blocks = kcache.size(0); const int page_block_size = kcache.size(1); const int num_heads_k = kcache.size(2); TORCH_CHECK(batch_size > 0, "batch size must be postive"); TORCH_CHECK(num_heads_ori % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); if (seqlen_q_ori == 1) { is_causal = false; } const int ngroups = num_heads_ori / num_heads_k; const int seqlen_q = seqlen_q_ori * ngroups; const int num_heads = num_heads_k; q = q.view({batch_size, seqlen_q_ori, num_heads_k, ngroups, head_size}).transpose(2, 3) .reshape({batch_size, seqlen_q, num_heads, head_size}); int head_size_k = head_size; CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k); if (vcache_.has_value()) { CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_v); } CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32"); CHECK_DEVICE(seqlens_k); CHECK_CONTIGUOUS(seqlens_k); CHECK_SHAPE(seqlens_k, batch_size); at::cuda::CUDAGuard device_guard{(char)q.get_device()}; auto opts = q.options(); at::Tensor out = torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts); at::Tensor softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); Flash_fwd_mla_params params = {}; // Set the sizes. params.b = batch_size; params.seqlen_q = seqlen_q; params.cu_seqlens_k = seqlens_k.data_ptr(); params.h = num_heads; params.h_h_k_ratio = num_heads / num_heads_k; params.ngroups = ngroups; params.is_causal = is_causal; params.d = head_size; params.d_v = head_size_v; params.scale_softmax = softmax_scale; params.scale_softmax_log2 = float(softmax_scale * M_LOG2E); // Set the pointers and strides. params.q_ptr = q.data_ptr(); params.k_ptr = kcache.data_ptr(); params.v_ptr = vcache.data_ptr(); params.o_ptr = out.data_ptr(); params.softmax_lse_ptr = softmax_lse.data_ptr(); // All stride are in elements, not bytes. params.q_batch_stride = q.stride(0); params.k_batch_stride = kcache.stride(0); params.v_batch_stride = vcache.stride(0); params.o_batch_stride = out.stride(0); params.q_row_stride = q.stride(-3); params.k_row_stride = kcache.stride(-3); params.v_row_stride = vcache.stride(-3); params.o_row_stride = out.stride(-3); params.q_head_stride = q.stride(-2); params.k_head_stride = kcache.stride(-2); params.v_head_stride = vcache.stride(-2); params.o_head_stride = out.stride(-2); params.block_table = block_table.data_ptr(); params.block_table_batch_stride = block_table.stride(0); params.page_block_size = page_block_size; params.k_scale_ptr = k_scale.data_ptr(); TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32"); TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize); CHECK_DEVICE(tile_scheduler_metadata); CHECK_CONTIGUOUS(tile_scheduler_metadata); params.tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr(); params.num_sm_parts = tile_scheduler_metadata.size(0); TORCH_CHECK(num_splits.dtype() == torch::kInt32, "num_splits must have dtype int32"); CHECK_DEVICE(num_splits); CHECK_CONTIGUOUS(num_splits); params.num_splits_ptr = num_splits.data_ptr(); at::Tensor softmax_lse_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q}, opts.dtype(at::kFloat)); at::Tensor out_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q, head_size_v}, opts.dtype(at::kFloat)); params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr(); params.oaccum_ptr = out_accum.data_ptr(); auto stream = at::cuda::getCurrentCUDAStream().stream(); TORCH_CHECK(head_size == 576); if (print_param) { fprintf(stderr, "[flashmla] [mha_fwd_kvcache_quantization_mla] q_dtype = %s input size q (%d %d %d %d) block_table (%d %d) kcache (%d %d %d %d) is_causal = %d softmax_scale = %.4f kv_cache_dtype = %s\n", (q_dtype == torch::kBFloat16?"bf16":"fp16"), batch_size, seqlen_q_ori, num_heads_ori, head_size, batch_size, max_num_blocks_per_seq, num_blocks, page_block_size, num_heads_k, head_size_k, is_causal, softmax_scale, kv_cache_dtype.c_str()); } if (q_dtype == torch::kBFloat16) { run_mha_fwd_splitkv_mla(params, kv_cache_dtype, stream); } #ifndef FLASH_MLA_DISABLE_FP16 else if (q_dtype == torch::kHalf) { run_mha_fwd_splitkv_mla(params, kv_cache_dtype, stream); } #endif else { TORCH_CHECK(false, "Unsupported tensor dtype for query"); } out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3) .reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v}); softmax_lse = softmax_lse.view({batch_size, num_heads_k, seqlen_q_ori, ngroups}).transpose(2, 3) .reshape({batch_size, num_heads_ori, seqlen_q_ori}); return {out, softmax_lse}; } // static inline int int64_stride_to_int(int64_t orig_stride) { // if (orig_stride > std::numeric_limits::max()) { // TORCH_CHECK(false, "[Sparse TopK Attention] Stride exceeds int32 limit: ", orig_stride); // } // return static_cast(orig_stride); // } struct DecodingAttnImplMeta { int num_sm_parts; int fixed_overhead_num_blocks; int k_block_size; }; DecodingAttnImplMeta get_attn_impl_meta( int sm_count, int num_q_tokens_per_head_k, int h_k, std::optional h_q_, bool is_fp8_kvcache, bool is_sparse_attn ) { // if (arch.is_sm90()) { if (is_sparse_attn) { if (is_fp8_kvcache) { TORCH_CHECK(h_q_.has_value()); int h_q = h_q_.value(); TORCH_CHECK(h_q % h_k == 0); int s_q = num_q_tokens_per_head_k * h_k / h_q; // FP8 + Sparse MLA return { std::max((sm_count * 2) / h_k / (cutlass::ceil_div(h_q/h_k, 16) * s_q), 1), 5, 64 }; } else { // Sparse BF16 MLA TORCH_CHECK(false, "Sparse BF16 MLA is not supported on SM90"); } } else { if (is_fp8_kvcache) { // Dense FP8 MLA TORCH_CHECK(false, "Dense FP8 MLA is not supported on SM90"); } else { int h_q = h_q_.has_value() && h_q_.value() >= 64 ? 64 : 16; // Dense BF16 MLA return { std::max(sm_count / h_k / cutlass::ceil_div(num_q_tokens_per_head_k, h_q), 1), 5, 64 }; } } } } std::vector get_mla_decoding_metadata_dense_fp8( at::Tensor &seqlens_k, const int num_heads_per_head_k, const int num_heads_k) { // This should match the logic in the MLA kernel. int block_size_m = 16; static constexpr int block_size_n = 64; if (num_heads_per_head_k > 32) { block_size_m = 64; } else if (num_heads_per_head_k > 16) { block_size_m = 32; } static constexpr int fixed_overhead_num_blocks = 5; CHECK_DEVICE(seqlens_k); TORCH_CHECK(seqlens_k.is_contiguous()); TORCH_CHECK(seqlens_k.dtype() == torch::kInt32); int batch_size = seqlens_k.size(0); int *seqlens_k_ptr = seqlens_k.data_ptr(); auto options = seqlens_k.options(); auto dprops = at::cuda::getCurrentDeviceProperties(); int sm_count = dprops->multiProcessorCount*(block_size_m == 16 ? 2 : 1); int num_sm_parts = sm_count / num_heads_k / cutlass::ceil_div(num_heads_per_head_k, block_size_m); if (print_param) { fprintf(stderr, "[flashmla] [get_mla_decoding_metadata_dense_fp8] block_size_m=%d sm_count=%d num_sm_parts=%d\n", block_size_m, sm_count, num_sm_parts); } auto tile_scheduler_metadata = torch::empty({num_sm_parts, TileSchedulerMetaDataSize}, options); auto num_splits = torch::empty({batch_size + 1}, options); int *tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr(); int *num_splits_ptr = num_splits.data_ptr(); at::cuda::CUDAGuard device_guard{(char)seqlens_k.get_device()}; auto stream = at::cuda::getCurrentCUDAStream().stream(); Mla_metadata_params params = {}; params.seqlens_k_ptr = seqlens_k_ptr; params.tile_scheduler_metadata_ptr = tile_scheduler_metadata_ptr; params.num_splits_ptr = num_splits_ptr; params.batch_size = batch_size; params.block_size_n = block_size_n; params.fixed_overhead_num_blocks = fixed_overhead_num_blocks; params.num_sm_parts = num_sm_parts; get_mla_metadata_func(params, stream); return {tile_scheduler_metadata, num_splits}; } std::vector mha_fwd_kvcache_mla_nope_pe( at::Tensor &q_nope, // batch_size x seqlen_q x num_heads x head_size at::Tensor &q_pe, // batch_size x seqlen_q x num_heads x head_size const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size std::optional &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v const int head_size_v, const at::Tensor &seqlens_k, // batch_size const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq const float softmax_scale, bool is_causal, const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize const at::Tensor &num_splits // batch_size + 1 ) { // auto dprops = at::cuda::getCurrentDeviceProperties(); // bool is_sm90 = dprops->major == 9 && dprops->minor == 0; // TORCH_CHECK(is_sm90); at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache; auto q_dtype = q_nope.dtype(); TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype"); CHECK_DEVICE(q_nope); CHECK_DEVICE(q_pe); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache); TORCH_CHECK(q_nope.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(q_pe.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension"); CHECK_DEVICE(block_table); TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); const auto sizes = q_nope.sizes(); const int batch_size = sizes[0]; const int seqlen_q_ori = sizes[1]; const int num_heads_ori = sizes[2]; const int head_size_nope = sizes[3]; const int head_size_pe = q_pe.size(3); const int head_size = head_size_nope + head_size_pe; TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); TORCH_CHECK(head_size_v % 32 == 0, "head_size_v should be a multiple of 32"); const int max_num_blocks_per_seq = block_table.size(1); const int num_blocks = kcache.size(0); const int page_block_size = kcache.size(1); const int num_heads_k = kcache.size(2); TORCH_CHECK(batch_size > 0, "batch size must be postive"); TORCH_CHECK(num_heads_ori % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); if (seqlen_q_ori == 1) { is_causal = false; } TORCH_CHECK(seqlen_q_ori == 1, "mha_fwd_kvcache_mla_nope_pe only support seqlen_q_ori=1"); const int ngroups = num_heads_ori / num_heads_k; const int seqlen_q = seqlen_q_ori * ngroups; const int num_heads = num_heads_k; q_nope = q_nope.view({batch_size, seqlen_q_ori, num_heads_k, ngroups, head_size_nope}).transpose(2, 3) .reshape({batch_size, seqlen_q, num_heads, head_size_nope}); q_pe = q_pe.view({batch_size, seqlen_q_ori, num_heads_k, ngroups, head_size_pe}).transpose(2, 3) .reshape({batch_size, seqlen_q, num_heads, head_size_pe}); int head_size_k = head_size; CHECK_SHAPE(q_nope, batch_size, seqlen_q, num_heads, head_size_nope); CHECK_SHAPE(q_pe, batch_size, seqlen_q, num_heads, head_size_pe); CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k); if (vcache_.has_value()) { CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_v); } CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32"); CHECK_DEVICE(seqlens_k); CHECK_CONTIGUOUS(seqlens_k); CHECK_SHAPE(seqlens_k, batch_size); // at::cuda::CUDAGuard device_guard{(char)q.get_device()}; auto opts = q_nope.options(); at::Tensor out = torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts); at::Tensor softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); Flash_fwd_mla_params params = {}; // Set the sizes. params.b = batch_size; params.seqlen_q = seqlen_q; params.cu_seqlens_k = seqlens_k.data_ptr(); params.h = num_heads; params.h_h_k_ratio = num_heads / num_heads_k; params.ngroups = ngroups; params.is_causal = is_causal; params.d = head_size; params.d_v = head_size_v; params.scale_softmax = softmax_scale; params.scale_softmax_log2 = float(softmax_scale * M_LOG2E); // Set the pointers and strides. params.q_nope_ptr = q_nope.data_ptr(); params.q_pe_ptr = q_pe.data_ptr(); params.k_ptr = kcache.data_ptr(); params.v_ptr = vcache.data_ptr(); params.o_ptr = out.data_ptr(); params.softmax_lse_ptr = softmax_lse.data_ptr(); // All stride are in elements, not bytes. params.q_nope_batch_stride = q_nope.stride(0); params.q_pe_batch_stride = q_pe.stride(0); params.k_batch_stride = kcache.stride(0); params.v_batch_stride = vcache.stride(0); params.o_batch_stride = out.stride(0); params.q_nope_row_stride = q_nope.stride(-3); params.q_pe_row_stride = q_pe.stride(-3); params.k_row_stride = kcache.stride(-3); params.v_row_stride = vcache.stride(-3); params.o_row_stride = out.stride(-3); params.q_nope_head_stride = q_nope.stride(-2); params.q_pe_head_stride = q_pe.stride(-2); params.k_head_stride = kcache.stride(-2); params.v_head_stride = vcache.stride(-2); params.o_head_stride = out.stride(-2); params.block_table = block_table.data_ptr(); params.block_table_batch_stride = block_table.stride(0); params.page_block_size = page_block_size; TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32"); TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize); CHECK_DEVICE(tile_scheduler_metadata); CHECK_CONTIGUOUS(tile_scheduler_metadata); params.tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr(); params.num_sm_parts = tile_scheduler_metadata.size(0); TORCH_CHECK(num_splits.dtype() == torch::kInt32, "num_splits must have dtype int32"); CHECK_DEVICE(num_splits); CHECK_CONTIGUOUS(num_splits); params.num_splits_ptr = num_splits.data_ptr(); at::Tensor softmax_lse_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q}, opts.dtype(at::kFloat)); at::Tensor out_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q, head_size_v}, opts.dtype(at::kFloat)); params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr(); params.oaccum_ptr = out_accum.data_ptr(); auto stream = at::cuda::getCurrentCUDAStream().stream(); TORCH_CHECK(head_size == 576); if (print_param) { fprintf(stderr, "[flashmla] [mha_fwd_kvcache_mla_nope_pe] q_dtype = %s input size q (%d %d %d %d) block_table (%d %d) kcache (%d %d %d %d) is_causal = %d softmax_scale = %.4f\n", (q_dtype == torch::kBFloat16?"bf16":"fp16"), batch_size, seqlen_q_ori, num_heads_ori, head_size, batch_size, max_num_blocks_per_seq, num_blocks, page_block_size, num_heads_k, head_size_k, is_causal, softmax_scale); } if (q_dtype == torch::kBFloat16) { run_mha_fwd_splitkv_mla(params, "auto", stream, true); } #ifndef FLASH_MLA_DISABLE_FP16 else if (q_dtype == torch::kHalf) { run_mha_fwd_splitkv_mla(params, "auto", stream, true); } #endif else { TORCH_CHECK(false, "Unsupported tensor dtype for query"); } out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3) .reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v}); softmax_lse = softmax_lse.view({batch_size, num_heads_k, seqlen_q_ori, ngroups}).transpose(2, 3) .reshape({batch_size, num_heads_ori, seqlen_q_ori}); return {out, softmax_lse}; } std::vector mha_fwd_kvcache_quantization_q_nope_pe_mla( at::Tensor &q_nope, // batch_size x seqlen_q x num_heads x 512 at::Tensor &q_pe, // batch_size x seqlen_q x num_heads x 64 const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size std::optional &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v const int head_size_v, const at::Tensor &seqlens_k, // batch_size const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq const float softmax_scale, bool is_causal, const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize const at::Tensor &num_splits, // batch_size + 1 const at::Tensor &k_scale, const std::string &kv_cache_dtype ) { // auto dprops = at::cuda::getCurrentDeviceProperties(); // bool is_sm90 = dprops->major == 9 && dprops->minor == 0; // TORCH_CHECK(is_sm90); at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache; auto q_dtype = q_nope.dtype(); if (kv_cache_dtype == "fp8_e5m2") { TORCH_CHECK(kcache.dtype() != q_dtype, "非量化情况下, query and key must have not the same dtype"); CHECK_DEVICE(k_scale); TORCH_CHECK(k_scale.dtype() == torch::kFloat32, "非量化情况下, query and key must have the same dtype"); } else { TORCH_CHECK(false, "Unsupported kv cache dtype"); } CHECK_DEVICE(q_nope); CHECK_DEVICE(q_pe); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache); TORCH_CHECK(q_nope.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(q_pe.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension"); CHECK_DEVICE(block_table); TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); const auto sizes = q_nope.sizes(); const int batch_size = sizes[0]; const int seqlen_q_ori = sizes[1]; const int num_heads_ori = sizes[2]; const int head_size_nope = sizes[3]; const int head_size_pe = q_pe.size(3); const int head_size = head_size_nope + head_size_pe; TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); TORCH_CHECK(head_size_v % 32 == 0, "head_size_v should be a multiple of 32"); const int max_num_blocks_per_seq = block_table.size(1); const int num_blocks = kcache.size(0); const int page_block_size = kcache.size(1); const int num_heads_k = kcache.size(2); TORCH_CHECK(batch_size > 0, "batch size must be postive"); TORCH_CHECK(num_heads_ori % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); if (seqlen_q_ori == 1) { is_causal = false; } TORCH_CHECK(seqlen_q_ori == 1, "mha_fwd_kvcache_quantization_q_nope_pe_mla only support seqlen_q_ori=1"); const int ngroups = num_heads_ori / num_heads_k; const int seqlen_q = seqlen_q_ori * ngroups; const int num_heads = num_heads_k; q_nope = q_nope.view({batch_size, seqlen_q_ori, num_heads_k, ngroups, head_size_nope}).transpose(2, 3) .reshape({batch_size, seqlen_q, num_heads, head_size_nope}); q_pe = q_pe.view({batch_size, seqlen_q_ori, num_heads_k, ngroups, head_size_pe}).transpose(2, 3) .reshape({batch_size, seqlen_q, num_heads, head_size_pe}); int head_size_k = head_size; CHECK_SHAPE(q_nope, batch_size, seqlen_q, num_heads, head_size_nope); CHECK_SHAPE(q_pe, batch_size, seqlen_q, num_heads, head_size_pe); CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k); if (vcache_.has_value()) { CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_v); } CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32"); CHECK_DEVICE(seqlens_k); CHECK_CONTIGUOUS(seqlens_k); CHECK_SHAPE(seqlens_k, batch_size); // at::cuda::CUDAGuard device_guard{(char)q.get_device()}; auto opts = q_nope.options(); at::Tensor out = torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts); at::Tensor softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); Flash_fwd_mla_params params = {}; // Set the sizes. params.b = batch_size; params.seqlen_q = seqlen_q; params.cu_seqlens_k = seqlens_k.data_ptr(); params.h = num_heads; params.h_h_k_ratio = num_heads / num_heads_k; params.ngroups = ngroups; params.is_causal = is_causal; params.d = head_size; params.d_v = head_size_v; params.scale_softmax = softmax_scale; params.scale_softmax_log2 = float(softmax_scale * M_LOG2E); // Set the pointers and strides. params.q_nope_ptr = q_nope.data_ptr(); params.q_pe_ptr = q_pe.data_ptr(); params.k_ptr = kcache.data_ptr(); params.v_ptr = vcache.data_ptr(); params.o_ptr = out.data_ptr(); params.softmax_lse_ptr = softmax_lse.data_ptr(); // All stride are in elements, not bytes. params.q_nope_batch_stride = q_nope.stride(0); params.q_pe_batch_stride = q_pe.stride(0); params.k_batch_stride = kcache.stride(0); params.v_batch_stride = vcache.stride(0); params.o_batch_stride = out.stride(0); params.q_nope_row_stride = q_nope.stride(-3); params.q_pe_row_stride = q_pe.stride(-3); params.k_row_stride = kcache.stride(-3); params.v_row_stride = vcache.stride(-3); params.o_row_stride = out.stride(-3); params.q_nope_head_stride = q_nope.stride(-2); params.q_pe_head_stride = q_pe.stride(-2); params.k_head_stride = kcache.stride(-2); params.v_head_stride = vcache.stride(-2); params.o_head_stride = out.stride(-2); params.block_table = block_table.data_ptr(); params.block_table_batch_stride = block_table.stride(0); params.page_block_size = page_block_size; params.k_scale_ptr = k_scale.data_ptr(); TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32"); TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize); CHECK_DEVICE(tile_scheduler_metadata); CHECK_CONTIGUOUS(tile_scheduler_metadata); params.tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr(); params.num_sm_parts = tile_scheduler_metadata.size(0); TORCH_CHECK(num_splits.dtype() == torch::kInt32, "num_splits must have dtype int32"); CHECK_DEVICE(num_splits); CHECK_CONTIGUOUS(num_splits); params.num_splits_ptr = num_splits.data_ptr(); at::Tensor softmax_lse_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q}, opts.dtype(at::kFloat)); at::Tensor out_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q, head_size_v}, opts.dtype(at::kFloat)); params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr(); params.oaccum_ptr = out_accum.data_ptr(); auto stream = at::cuda::getCurrentCUDAStream().stream(); TORCH_CHECK(head_size == 576); if (print_param) { fprintf(stderr, "[flashmla] [mha_fwd_kvcache_quantization_q_nope_pe_mla] q_dtype = %s input size q (%d %d %d %d) block_table (%d %d) kcache (%d %d %d %d) is_causal = %d softmax_scale = %.4f kv_cache_dtype = %s\n", (q_dtype == torch::kBFloat16?"bf16":"fp16"), batch_size, seqlen_q_ori, num_heads_ori, head_size, batch_size, max_num_blocks_per_seq, num_blocks, page_block_size, num_heads_k, head_size_k, is_causal, softmax_scale, kv_cache_dtype.c_str()); } if (q_dtype == torch::kBFloat16) { run_mha_fwd_splitkv_mla(params, kv_cache_dtype, stream, true); } #ifndef FLASH_MLA_DISABLE_FP16 else if (q_dtype == torch::kHalf) { run_mha_fwd_splitkv_mla(params, kv_cache_dtype, stream, true); } #endif else { TORCH_CHECK(false, "Unsupported tensor dtype for query"); } out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3) .reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v}); softmax_lse = softmax_lse.view({batch_size, num_heads_k, seqlen_q_ori, ngroups}).transpose(2, 3) .reshape({batch_size, num_heads_ori, seqlen_q_ori}); return {out, softmax_lse}; } std::vector mha_fwd_kvcache_mla_fp8( at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size std::optional &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v const int head_size_v, const at::Tensor &seqlens_k, // batch_size const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq const float softmax_scale, bool is_causal, const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize const at::Tensor &num_splits, // batch_size + 1 const std::optional &descale_q, // None or batch_size const std::optional &descale_k // None or batch_size ) { // auto dprops = at::cuda::getCurrentDeviceProperties(); // bool is_sm90 = dprops->major == 9 && dprops->minor == 0; // TORCH_CHECK(is_sm90); // static std::string FLASH_MLA_ROOT_DIR = execCommand("python -c 'import site; print(site.getsitepackages()[0])'"); // setenv("FLASH_MLA_ROOT_DIR", (FLASH_MLA_ROOT_DIR + "/flash_mla/asm/").c_str(), 1); // std::cout << FLASH_MLA_ROOT_DIR << "\n"; // exit(-1); at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache; auto q_dtype = q.dtype(); TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype"); CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache); TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension"); CHECK_DEVICE(block_table); TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); if (descale_q.has_value()) CHECK_DEVICE(descale_q.value()); if (descale_k.has_value()) CHECK_DEVICE(descale_k.value()); const auto sizes = q.sizes(); const int batch_size = sizes[0]; const int seqlen_q_ori = sizes[1]; const int num_heads_ori = sizes[2]; const int head_size = sizes[3]; TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); TORCH_CHECK(head_size_v % 32 == 0, "head_size_v should be a multiple of 32"); const int max_num_blocks_per_seq = block_table.size(1); const int num_blocks = kcache.size(0); const int page_block_size = kcache.size(1); const int num_heads_k = kcache.size(2); TORCH_CHECK(batch_size > 0, "batch size must be postive"); TORCH_CHECK(num_heads_ori % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); TORCH_CHECK(descale_q.has_value() && descale_k.has_value(), "descale is required when input dtype is fp8"); auto descale_q_ = descale_q.value(); auto descale_k_ = descale_k.value(); CHECK_DEVICE(descale_q_); CHECK_DEVICE(descale_k_); TORCH_CHECK(descale_q_.stride(-1) == 1); TORCH_CHECK(descale_k_.stride(-1) == 1); TORCH_CHECK(descale_q_.dtype() == torch::kFloat); TORCH_CHECK(descale_k_.dtype() == torch::kFloat); CHECK_SHAPE(descale_q_, 1); CHECK_SHAPE(descale_k_, 1); if (seqlen_q_ori == 1) { is_causal = false; } const int ngroups = num_heads_ori / num_heads_k; const int seqlen_q = seqlen_q_ori * ngroups; const int num_heads = num_heads_k; q = q.view({batch_size, seqlen_q_ori, num_heads_k, ngroups, head_size}).transpose(2, 3) .reshape({batch_size, seqlen_q, num_heads, head_size}); int head_size_k = head_size; CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k); if (vcache_.has_value()) { CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_v); } CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32"); CHECK_DEVICE(seqlens_k); CHECK_CONTIGUOUS(seqlens_k); CHECK_SHAPE(seqlens_k, batch_size); at::cuda::CUDAGuard device_guard{(char)q.get_device()}; auto opts = q.options(); at::Tensor out = torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts.dtype(torch::kBFloat16));//1,16,1,512 at::Tensor softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));//1,1,16 Flash_fwd_mla_params params = {}; // Set the sizes. params.b = batch_size;//1 params.seqlen_q = seqlen_q;//16 params.cu_seqlens_k = seqlens_k.data_ptr(); params.h = num_heads;//1 params.h_h_k_ratio = num_heads / num_heads_k;//1 params.ngroups = ngroups;//16 params.is_causal = is_causal;//false params.d = head_size;//576 params.d_v = head_size_v;//512 params.scale_softmax = softmax_scale;//0.417 params.scale_softmax_log2 = float(softmax_scale * M_LOG2E); // Set the pointers and strides. params.q_ptr = q.data_ptr(); params.k_ptr = kcache.data_ptr(); params.v_ptr = vcache.data_ptr(); params.o_ptr = out.data_ptr(); params.softmax_lse_ptr = softmax_lse.data_ptr(); // All stride are in elements, not bytes. params.q_batch_stride = q.stride(0); params.k_batch_stride = kcache.stride(0); params.v_batch_stride = vcache.stride(0); params.o_batch_stride = out.stride(0); params.q_row_stride = q.stride(-3); params.k_row_stride = kcache.stride(-3); params.v_row_stride = vcache.stride(-3); params.o_row_stride = out.stride(-3); params.q_head_stride = q.stride(-2); params.k_head_stride = kcache.stride(-2); params.v_head_stride = vcache.stride(-2); params.o_head_stride = out.stride(-2); params.block_table = block_table.data_ptr(); params.block_table_batch_stride = block_table.stride(0); params.page_block_size = page_block_size;//64 params.descale_q_ptr = reinterpret_cast(descale_q.value().data_ptr()); params.descale_k_ptr = reinterpret_cast(descale_k.value().data_ptr()); TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32"); TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize); CHECK_DEVICE(tile_scheduler_metadata); CHECK_CONTIGUOUS(tile_scheduler_metadata); params.tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr(); params.num_sm_parts = tile_scheduler_metadata.size(0); TORCH_CHECK(num_splits.dtype() == torch::kInt32, "num_splits must have dtype int32"); CHECK_DEVICE(num_splits); CHECK_CONTIGUOUS(num_splits); params.num_splits_ptr = num_splits.data_ptr(); at::Tensor softmax_lse_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q}, opts.dtype(at::kFloat)); at::Tensor out_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q, head_size_v}, opts.dtype(at::kFloat)); params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr(); params.oaccum_ptr = out_accum.data_ptr(); auto stream = at::cuda::getCurrentCUDAStream().stream(); TORCH_CHECK(head_size == 576); if (print_param) { fprintf(stderr, "[flashmla] [mha_fwd_kvcache_mla_fp8] input size q (%d %d %d %d) block_table (%d %d) kcache (%d %d %d %d) is_causal = %d softmax_scale = %.4f \n", batch_size, seqlen_q_ori, num_heads_ori, head_size, batch_size, max_num_blocks_per_seq, num_blocks, page_block_size, num_heads_k, head_size_k, is_causal, softmax_scale); } if (q_dtype == torch::kFloat8_e4m3fn) { run_mha_fwd_splitkv_mla_fp8(params,stream,false); } else { TORCH_CHECK(false, "Unsupported tensor dtype for query"); } out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3) .reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v}); softmax_lse = softmax_lse.view({batch_size, num_heads_k, seqlen_q_ori, ngroups}).transpose(2, 3) .reshape({batch_size, num_heads_ori, seqlen_q_ori}); return {out, softmax_lse}; } std::vector mha_fwd_kvcache_mla_fp8_with_cat( at::Tensor &q_nope, // batch_size x seqlen_q x num_heads x 512 at::Tensor &q_pe, // batch_size x seqlen_q x num_heads x 64 const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size std::optional &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v const int head_size_v, const at::Tensor &seqlens_k, // batch_size const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq const float softmax_scale, bool is_causal, const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize const at::Tensor &num_splits, // batch_size + 1 const std::optional &descale_q, // None or batch_size const std::optional &descale_k // None or batch_size ) { // auto dprops = at::cuda::getCurrentDeviceProperties(); // bool is_sm90 = dprops->major == 9 && dprops->minor == 0; // TORCH_CHECK(is_sm90); at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache; auto q_dtype = q_nope.dtype(); // TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype"); CHECK_DEVICE(q_nope); CHECK_DEVICE(q_pe); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache); TORCH_CHECK(q_nope.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(q_pe.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension"); CHECK_DEVICE(block_table); TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); if (descale_q.has_value()) CHECK_DEVICE(descale_q.value()); if (descale_k.has_value()) CHECK_DEVICE(descale_k.value()); const auto sizes = q_nope.sizes(); const int batch_size = sizes[0]; const int seqlen_q_ori = sizes[1]; const int num_heads_ori = sizes[2]; const int head_size_nope = sizes[3]; const int head_size_pe = q_pe.size(3); const int head_size = head_size_nope + head_size_pe; TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); TORCH_CHECK(head_size_v % 32 == 0, "head_size_v should be a multiple of 32"); const int max_num_blocks_per_seq = block_table.size(1); const int num_blocks = kcache.size(0); const int page_block_size = kcache.size(1); const int num_heads_k = kcache.size(2); TORCH_CHECK(batch_size > 0, "batch size must be postive"); TORCH_CHECK(num_heads_ori % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); // TORCH_CHECK(num_heads_ori == 16, "only support q nheads = 16"); TORCH_CHECK(descale_q.has_value() && descale_k.has_value(), "descale is required when input dtype is fp8"); auto descale_q_ = descale_q.value(); auto descale_k_ = descale_k.value(); CHECK_DEVICE(descale_q_); CHECK_DEVICE(descale_k_); TORCH_CHECK(descale_q_.stride(-1) == 1); TORCH_CHECK(descale_k_.stride(-1) == 1); TORCH_CHECK(descale_q_.dtype() == torch::kFloat); TORCH_CHECK(descale_k_.dtype() == torch::kFloat); CHECK_SHAPE(descale_q_, 1); CHECK_SHAPE(descale_k_, 1); if (seqlen_q_ori == 1) { is_causal = false; } const int ngroups = num_heads_ori / num_heads_k; const int seqlen_q = seqlen_q_ori * ngroups; const int num_heads = num_heads_k; q_nope = q_nope.view({batch_size, seqlen_q_ori, num_heads_k, ngroups, head_size_nope}).transpose(2, 3) .reshape({batch_size, seqlen_q, num_heads, head_size_nope}); q_pe = q_pe.view({batch_size, seqlen_q_ori, num_heads_k, ngroups, head_size_pe}).transpose(2, 3) .reshape({batch_size, seqlen_q, num_heads, head_size_pe}); int head_size_k = head_size; CHECK_SHAPE(q_nope, batch_size, seqlen_q, num_heads, head_size_nope); CHECK_SHAPE(q_pe, batch_size, seqlen_q, num_heads, head_size_pe); CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k); if (vcache_.has_value()) { CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_v); } CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32"); CHECK_DEVICE(seqlens_k); CHECK_CONTIGUOUS(seqlens_k); CHECK_SHAPE(seqlens_k, batch_size); // at::cuda::CUDAGuard device_guard{(char)q.get_device()}; auto opts = q_nope.options(); at::Tensor out = torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts.dtype(torch::kBFloat16));//1,16,1,512 at::Tensor softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));//1,1,16 Flash_fwd_mla_params params = {}; // Set the sizes. params.b = batch_size;//1 params.seqlen_q = seqlen_q;//16 params.cu_seqlens_k = seqlens_k.data_ptr(); params.h = num_heads;//1 params.h_h_k_ratio = num_heads / num_heads_k;//1 params.ngroups = ngroups;//16 params.is_causal = is_causal;//false params.d = head_size;//576 params.d_v = head_size_v;//512 params.scale_softmax = softmax_scale;//0.417 params.scale_softmax_log2 = float(softmax_scale * M_LOG2E); // Set the pointers and strides. params.q_nope_ptr = q_nope.data_ptr(); params.q_pe_ptr = q_pe.data_ptr(); params.k_ptr = kcache.data_ptr(); params.v_ptr = vcache.data_ptr(); params.o_ptr = out.data_ptr(); params.softmax_lse_ptr = softmax_lse.data_ptr(); // All stride are in elements, not bytes. params.q_nope_batch_stride = q_nope.stride(0); params.q_pe_batch_stride = q_pe.stride(0); params.k_batch_stride = kcache.stride(0); params.v_batch_stride = vcache.stride(0); params.o_batch_stride = out.stride(0); params.q_nope_row_stride = q_nope.stride(-3); params.q_pe_row_stride = q_pe.stride(-3); params.k_row_stride = kcache.stride(-3); params.v_row_stride = vcache.stride(-3); params.o_row_stride = out.stride(-3); params.q_nope_head_stride = q_nope.stride(-2); params.q_pe_head_stride = q_pe.stride(-2); params.k_head_stride = kcache.stride(-2); params.v_head_stride = vcache.stride(-2); params.o_head_stride = out.stride(-2); params.block_table = block_table.data_ptr(); params.block_table_batch_stride = block_table.stride(0); params.page_block_size = page_block_size;//64 params.descale_q_ptr = reinterpret_cast(descale_q.value().data_ptr()); params.descale_k_ptr = reinterpret_cast(descale_k.value().data_ptr()); TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32"); TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize); CHECK_DEVICE(tile_scheduler_metadata); CHECK_CONTIGUOUS(tile_scheduler_metadata); params.tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr(); params.num_sm_parts = tile_scheduler_metadata.size(0); TORCH_CHECK(num_splits.dtype() == torch::kInt32, "num_splits must have dtype int32"); CHECK_DEVICE(num_splits); CHECK_CONTIGUOUS(num_splits); params.num_splits_ptr = num_splits.data_ptr(); at::Tensor softmax_lse_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q}, opts.dtype(at::kFloat)); at::Tensor out_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q, head_size_v}, opts.dtype(at::kFloat)); params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr(); params.oaccum_ptr = out_accum.data_ptr(); auto stream = at::cuda::getCurrentCUDAStream().stream(); TORCH_CHECK(head_size == 576); if (print_param) { fprintf(stderr, "[flashmla] [mha_fwd_kvcache_mla_fp8_with_cat] input size q (%d %d %d %d) block_table (%d %d) kcache (%d %d %d %d) is_causal = %d softmax_scale = %.4f \n", batch_size, seqlen_q_ori, num_heads_ori, head_size, batch_size, max_num_blocks_per_seq, num_blocks, page_block_size, num_heads_k, head_size_k, is_causal, softmax_scale); } if (q_dtype == torch::kFloat8_e4m3fn && kcache.dtype() == torch::kFloat8_e4m3fn) { run_mha_fwd_splitkv_mla_fp8(params,stream,true); } else if (q_dtype == torch::kBFloat16 && kcache.dtype() == torch::kFloat8_e4m3fn) { run_mha_fwd_splitkv_mla_fp8(params,stream,true); } else { TORCH_CHECK(false, "Unsupported tensor dtype for query"); } out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3) .reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v}); softmax_lse = softmax_lse.view({batch_size, num_heads_k, seqlen_q_ori, ngroups}).transpose(2, 3) .reshape({batch_size, num_heads_ori, seqlen_q_ori}); return {out, softmax_lse}; } // PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // m.doc() = "FlashMLA"; // m.def("get_mla_metadata", &get_mla_metadata); // m.def("get_mla_decoding_metadata_dense_fp8", &get_mla_decoding_metadata_dense_fp8); // m.def("fwd_kvcache_mla", &mha_fwd_kvcache_mla); // m.def("fwd_kvcache_quantization_mla", &mha_fwd_kvcache_quantization_mla); // m.def("sparse_prefill_fwd", &sparse_prefill_fwd); // m.def("fwd_kvcache_quantization_q_nope_pe_mla", &mha_fwd_kvcache_quantization_q_nope_pe_mla); // m.def("fwd_kvcache_mla_nope_pe", &mha_fwd_kvcache_mla_nope_pe); // m.def("fwd_kvcache_mla_fp8", &mha_fwd_kvcache_mla_fp8); // m.def("fwd_kvcache_mla_fp8_with_cat", &mha_fwd_kvcache_mla_fp8_with_cat); // }