Commit 2c35de66 authored by shenzhe's avatar shenzhe Committed by zhanghj2
Browse files

Add DSA BF16 sparse decode support

parent a1eef562
......@@ -6,6 +6,7 @@
#include "params.h"
#include "gfx93/decode/sparse_fp8/splitkv_mla.h"
#include "gfx93/decode/sparse_bf16_dsa/fwd.h"
#include "gfx9/decode/get_decoding_sched_meta/get_decoding_sched_meta.h"
#include "gfx9/decode/combine/combine.h"
......@@ -123,6 +124,14 @@ sparse_attn_decode_interface(
bool have_extra_topk_length = extra_topk_length.has_value();
bool have_attn_sink = attn_sink.has_value();
if (kv.dtype() == torch::kBFloat16) {
return gfx93::decode::sparse_bf16_dsa::run(
q, kv, indices, topk_length, attn_sink,
tile_scheduler_metadata, num_splits,
extra_kv, extra_indices, extra_topk_length,
d_v, sm_scale);
}
int extra_num_blocks = 0, extra_page_block_size = 0, extra_topk = 0;
if (have_extra_kcache) {
extra_num_blocks = extra_kv->size(0);
......
#include "fwd.h"
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/extension.h>
#include <algorithm>
#include <cstring>
#include <limits>
#include <optional>
#include <tuple>
#include "kerutils/supplemental/torch_tensors.h"
#include "gfx93/prefill/sparse/dsa_mls/dispatch.h"
namespace gfx93::decode::sparse_bf16_dsa {
static constexpr float LOG_2_E = 1.44269504f;
struct LocalArch {
int num_sms;
std::string arch_name;
LocalArch() {
auto* props = at::cuda::getCurrentDeviceProperties();
num_sms = props->multiProcessorCount;
arch_name = props->gcnArchName;
}
bool is_gfx93x() const {
const auto base = arch_name.substr(0, arch_name.find(':'));
return base == "gfx936" || base == "gfx938";
}
};
static int int64_stride_to_int(int64_t stride) {
TORCH_CHECK(stride <= std::numeric_limits<int>::max(), "DSA BF16 sparse decode stride exceeds int32 limit: ", stride);
return static_cast<int>(stride);
}
static int default_num_splits(int topk, int extra_topk) {
if (extra_topk > 0) {
return 2;
}
if (topk == 1024) return 16;
if (topk == 512) return 8;
return 1;
}
static void check_optional_extra(
const std::optional<at::Tensor>& extra_kv,
const std::optional<at::Tensor>& extra_indices,
const std::optional<at::Tensor>& extra_topk_length) {
if (extra_kv.has_value()) {
TORCH_CHECK(extra_indices.has_value(), "extra_indices_in_kvcache must be provided when extra_k_cache is provided");
} else {
TORCH_CHECK(!extra_indices.has_value(), "extra_indices_in_kvcache must not be provided when extra_k_cache is not provided");
TORCH_CHECK(!extra_topk_length.has_value(), "extra_topk_length must not be provided when extra_k_cache is not provided");
}
}
std::tuple<at::Tensor, at::Tensor, std::optional<at::Tensor>, std::optional<at::Tensor>>
run(
const at::Tensor& q,
const at::Tensor& kv,
const at::Tensor& indices,
const std::optional<at::Tensor>& topk_length,
const std::optional<at::Tensor>& attn_sink,
std::optional<at::Tensor>& tile_scheduler_metadata,
std::optional<at::Tensor>& num_splits,
const std::optional<at::Tensor>& extra_kv,
const std::optional<at::Tensor>& extra_indices,
const std::optional<at::Tensor>& extra_topk_length,
int d_v,
float sm_scale) {
LocalArch arch;
TORCH_CHECK(arch.is_gfx93x(), "DSA BF16 sparse decode is only supported on gfx936/gfx938");
KU_CHECK_NDIM(q, 4);
KU_CHECK_NDIM(kv, 4);
KU_CHECK_NDIM(indices, 3);
if (extra_kv.has_value()) KU_CHECK_NDIM(extra_kv, 4);
if (extra_indices.has_value()) KU_CHECK_NDIM(extra_indices, 3);
const int b = q.size(0);
const int s_q = q.size(1);
const int h_q = q.size(2);
const int d_qk = q.size(3);
const int page_block_size = kv.size(1);
const int h_kv = kv.size(2);
const int topk = indices.size(2);
const bool has_extra = extra_kv.has_value() && extra_indices.has_value() &&
extra_kv->numel() > 0 && extra_indices->numel() > 0 &&
extra_indices->size(2) > 0;
const int extra_topk = has_extra ? extra_indices->size(2) : 0;
TORCH_CHECK(b > 0 && s_q > 0 && h_q > 0, "Invalid q shape for DSA BF16 sparse decode");
TORCH_CHECK(h_kv == 1, "DSA BF16 sparse decode only supports h_kv == 1");
TORCH_CHECK(h_q == 64 || h_q == 128, "DSA BF16 sparse decode only supports h_q == 64 or 128");
TORCH_CHECK(d_qk == 512 || d_qk == 576, "DSA BF16 sparse decode only supports d_qk == 512 or 576");
TORCH_CHECK(d_v == 512, "DSA BF16 sparse decode only supports d_v == 512");
TORCH_CHECK(topk > 0, "topk must be positive");
if (has_extra) {
TORCH_CHECK(topk <= 256, "DSA BF16 sparse decode with extra_kv supports topk <= 256");
TORCH_CHECK(extra_topk <= 1024, "DSA BF16 sparse decode supports extra_topk <= 1024");
TORCH_CHECK(extra_kv->size(1) > 0, "extra page_block_size must be positive");
TORCH_CHECK(extra_kv->size(2) == h_kv, "extra_kv h_kv must match kv h_kv");
TORCH_CHECK(extra_kv->size(3) == d_qk, "extra_kv d_qk must match q d_qk");
} else {
TORCH_CHECK(topk <= 1024, "DSA BF16 sparse decode supports topk <= 1024");
}
check_optional_extra(extra_kv, extra_indices, extra_topk_length);
KU_CHECK_DEVICE(q);
KU_CHECK_DEVICE(kv);
KU_CHECK_DEVICE(indices);
KU_CHECK_DEVICE(topk_length);
KU_CHECK_DEVICE(attn_sink);
KU_CHECK_DEVICE(tile_scheduler_metadata);
KU_CHECK_DEVICE(num_splits);
KU_CHECK_DEVICE(extra_kv);
KU_CHECK_DEVICE(extra_indices);
KU_CHECK_DEVICE(extra_topk_length);
KU_CHECK_DTYPE(q, torch::kBFloat16);
KU_CHECK_DTYPE(kv, torch::kBFloat16);
KU_CHECK_DTYPE(indices, torch::kInt32);
KU_CHECK_DTYPE(topk_length, torch::kInt32);
KU_CHECK_DTYPE(attn_sink, torch::kFloat32);
KU_CHECK_DTYPE(tile_scheduler_metadata, torch::kInt32);
KU_CHECK_DTYPE(num_splits, torch::kInt32);
KU_CHECK_DTYPE(extra_kv, torch::kBFloat16);
KU_CHECK_DTYPE(extra_indices, torch::kInt32);
KU_CHECK_DTYPE(extra_topk_length, torch::kInt32);
KU_CHECK_LAST_DIM_CONTIGUOUS(q);
KU_CHECK_LAST_DIM_CONTIGUOUS(kv);
KU_CHECK_LAST_DIM_CONTIGUOUS(indices);
KU_CHECK_CONTIGUOUS(topk_length);
KU_CHECK_CONTIGUOUS(attn_sink);
KU_CHECK_LAST_DIM_CONTIGUOUS(extra_kv);
KU_CHECK_LAST_DIM_CONTIGUOUS(extra_indices);
KU_CHECK_CONTIGUOUS(extra_topk_length);
KU_CHECK_SHAPE(q, b, s_q, h_q, d_qk);
KU_CHECK_SHAPE(kv, kv.size(0), page_block_size, h_kv, d_qk);
KU_CHECK_SHAPE(indices, b, s_q, topk);
KU_CHECK_SHAPE(topk_length, b);
KU_CHECK_SHAPE(attn_sink, h_q);
if (has_extra) {
KU_CHECK_SHAPE(extra_indices, b, s_q, extra_topk);
KU_CHECK_SHAPE(extra_topk_length, b);
}
at::Tensor indices_for_dsa = indices.unsqueeze(2);
at::Tensor extra_indices_for_dsa;
if (has_extra) {
extra_indices_for_dsa = extra_indices->unsqueeze(2);
}
c10::cuda::CUDAGuard device_guard{q.device()};
auto opts = q.options();
at::Tensor out = torch::empty({b, s_q, h_q, d_v}, opts);
at::Tensor lse = torch::empty({b, h_q, s_q}, opts.dtype(at::kFloat));
at::Tensor scores_memory = torch::empty({2, b, h_kv, s_q * h_q}, opts.dtype(at::kFloat));
at::Tensor scores_max = scores_memory.select(0, 0);
at::Tensor scores_sum = scores_memory.select(0, 1);
if (!num_splits.has_value()) {
const int split = default_num_splits(topk, extra_topk);
num_splits = torch::empty({1}, opts.dtype(torch::kInt32));
num_splits->fill_(split);
}
KU_CHECK_DTYPE(num_splits, torch::kInt32);
KU_CHECK_DEVICE(num_splits);
KU_CHECK_CONTIGUOUS(num_splits);
TORCH_CHECK(num_splits->numel() == 1, "DSA BF16 sparse decode expects num_splits to be a scalar tensor");
const int requested_num_splits = num_splits->item<int>();
TORCH_CHECK(requested_num_splits >= 1 && requested_num_splits <= 64, "DSA BF16 sparse decode requires 1 <= num_splits <= 64");
Flash_fwd_mla_params_dsa params;
std::memset(&params, 0, sizeof(params));
params.layout = 1;
params.b = b;
params.h = h_kv;
params.h_k = h_kv;
params.h_h_k_ratio = 1;
params.mtp = 1;
params.ngroups = h_q / h_kv;
params.topk = topk;
params.extra_topk = has_extra ? extra_topk : 0;
params.d = d_qk;
params.d_v = d_v;
params.scale_softmax = sm_scale;
params.scale_softmax_log2 = sm_scale * LOG_2_E;
params.topk_length = ku::get_optional_tensor_ptr<int>(topk_length);
params.extra_topk_length = ku::get_optional_tensor_ptr<int>(extra_topk_length);
params.attn_sink = ku::get_optional_tensor_ptr<float>(attn_sink);
params.q_ptr = q.data_ptr();
params.k_ptr = kv.data_ptr();
params.v_ptr = kv.data_ptr();
params.extra_k_ptr = has_extra ? extra_kv->data_ptr() : nullptr;
params.extra_v_ptr = has_extra ? extra_kv->data_ptr() : nullptr;
params.o_ptr = out.data_ptr();
params.sparse_indices = reinterpret_cast<int*>(indices_for_dsa.data_ptr());
params.extra_sparse_indices = has_extra ? reinterpret_cast<int*>(extra_indices_for_dsa.data_ptr()) : nullptr;
params.softmax_lse_ptr = lse.data_ptr<float>();
params.scores_max_ptr = scores_max.data_ptr<float>();
params.scores_sum_ptr = scores_sum.data_ptr<float>();
params.page_block_size = page_block_size;
params.extra_page_block_size = has_extra ? extra_kv->size(1) : 0;
params.is_causal = false;
params.q_batch_stride = int64_stride_to_int(q.stride(0));
params.q_token_stride = int64_stride_to_int(q.stride(1));
params.q_row_stride = int64_stride_to_int(q.stride(2));
params.q_head_stride = int64_stride_to_int(q.stride(2));
params.k_batch_stride = int64_stride_to_int(kv.stride(0));
params.k_row_stride = int64_stride_to_int(kv.stride(1));
params.k_head_stride = int64_stride_to_int(kv.stride(2));
params.v_batch_stride = params.k_batch_stride;
params.v_row_stride = params.k_row_stride;
params.v_head_stride = params.k_head_stride;
params.extra_k_batch_stride = has_extra ? int64_stride_to_int(extra_kv->stride(0)) : 0;
params.extra_k_row_stride = has_extra ? int64_stride_to_int(extra_kv->stride(1)) : 0;
params.extra_v_batch_stride = params.extra_k_batch_stride;
params.extra_v_row_stride = params.extra_k_row_stride;
params.sparse_indices_batch_stride = int64_stride_to_int(indices_for_dsa.stride(0));
params.sparse_indices_row_stride = int64_stride_to_int(indices_for_dsa.stride(1));
params.sparse_indices_head_stride = int64_stride_to_int(indices_for_dsa.stride(2));
params.sparse_indices_topk_stride = int64_stride_to_int(indices_for_dsa.stride(3));
params.extra_sparse_indices_batch_stride = has_extra ? int64_stride_to_int(extra_indices_for_dsa.stride(0)) : 0;
params.extra_sparse_indices_row_stride = has_extra ? int64_stride_to_int(extra_indices_for_dsa.stride(1)) : 0;
params.extra_sparse_indices_head_stride = has_extra ? int64_stride_to_int(extra_indices_for_dsa.stride(2)) : 0;
params.extra_sparse_indices_topk_stride = has_extra ? int64_stride_to_int(extra_indices_for_dsa.stride(3)) : 0;
params.o_batch_stride = int64_stride_to_int(out.stride(0));
params.o_row_stride = int64_stride_to_int(out.stride(1));
params.o_head_stride = int64_stride_to_int(out.stride(2));
params.seqlen_q = s_q * params.ngroups;
params.seqlen_k = kv.size(0) * kv.size(1);
params.max_seqlen = s_q;
params.is_bf16 = true;
params.is_e4m3 = false;
params.is_int8 = false;
params.cu_count = arch.num_sms;
params.seqlenq_ngroups_swapped = true;
params.is_seqlens_k_cumulative = false;
params.splitkv_use_fp32_as_accum = false;
params.num_splits = requested_num_splits;
params.partition_size = topk + params.extra_topk;
if (params.num_splits > 1) {
params.partition_size = std::max(64, (params.partition_size + params.num_splits - 1) / params.num_splits);
params.partition_size = ((params.partition_size + 63) / 64) * 64;
}
at::Tensor out_accum;
at::Tensor lse_accum;
if (params.num_splits > 1) {
lse_accum = torch::empty({params.num_splits, b, h_kv, params.seqlen_q}, opts.dtype(at::kFloat));
out_accum = torch::empty({params.num_splits, b, s_q, h_q, d_v}, opts);
params.softmax_lse_ptr = lse_accum.data_ptr<float>();
params.oaccum_ptr = out_accum.data_ptr();
}
hipStream_t stream = reinterpret_cast<hipStream_t>(at::cuda::getCurrentCUDAStream().stream());
if (d_qk == 512) {
gfx93::fwd::dsa_mls::run_dsa_prefill_nopage_64_dispatch<BFloat16, 512, 512>(params, stream);
} else {
gfx93::fwd::dsa_mls::run_dsa_prefill_nopage_64_dispatch<BFloat16, 576, 512>(params, stream);
}
return {out, lse, tile_scheduler_metadata, num_splits};
}
} // namespace gfx93::decode::sparse_bf16_dsa
#pragma once
#include <ATen/core/Tensor.h>
#include <optional>
#include <tuple>
namespace gfx93::decode::sparse_bf16_dsa {
std::tuple<at::Tensor, at::Tensor, std::optional<at::Tensor>, std::optional<at::Tensor>>
run(
const at::Tensor& q,
const at::Tensor& kv,
const at::Tensor& indices,
const std::optional<at::Tensor>& topk_length,
const std::optional<at::Tensor>& attn_sink,
std::optional<at::Tensor>& tile_scheduler_metadata,
std::optional<at::Tensor>& num_splits,
const std::optional<at::Tensor>& extra_kv,
const std::optional<at::Tensor>& extra_indices,
const std::optional<at::Tensor>& extra_topk_length,
int d_v,
float sm_scale);
} // namespace gfx93::decode::sparse_bf16_dsa
......@@ -9,8 +9,68 @@
#include "legacy/include/static_switch.h"
#include "legacy/src/flash_fwd_b16_mla.h"
#include "legacy/src/flash_fwd_reduce.h"
namespace gfx93::fwd::dsa_mls {
template<typename Kernel_traits, const bool Tail, typename Params>
void run_dsa_mla_splitkv_reduce(Params& params, hipStream_t stream) {
static_assert(Kernel_traits::kHeadDimV == 512,
"run_dsa_mla_splitkv_reduce only supports hdimv == 512");
using Element = typename Kernel_traits::Element;
using SplitkvAccumType = typename Kernel_traits::SplitkvAccumType;
Flash_fwd_mla_reduce_params reduce_params;
reduce_params.softmax_lse_ptr = params.softmax_lse_ptr;
reduce_params.oaccum_ptr = params.oaccum_ptr;
reduce_params.o_ptr = params.o_ptr;
reduce_params.cu_seqlens_k = params.cu_seqlens_k;
reduce_params.num_splits = params.num_splits;
reduce_params.partition_size = params.partition_size;
reduce_params.h = params.h;
reduce_params.ngroups = params.ngroups;
reduce_params.seqlen_q = params.seqlen_q;
reduce_params.layout = params.layout;
reduce_params.topk_length = params.topk_length;
reduce_params.attn_sink = params.attn_sink;
reduce_params.extra_topk_length = params.extra_topk_length;
reduce_params.topk = params.topk;
reduce_params.extra_topk = params.extra_topk;
if (params.num_splits > 1) {
dim3 block(256);
dim3 grid(params.b * params.h * params.seqlen_q, 4);
constexpr int MAX_NUM_SPLITS = 64;
if (params.num_splits > MAX_NUM_SPLITS) {
printf("\x1b[31mnum_splits %d is larger than limit %d, and thus won't execute the kernel\033[0m\n",
params.num_splits, MAX_NUM_SPLITS);
return;
}
if (params.num_splits == 2) {
::flash_mla_splitkv_reduce_kernel<SplitkvAccumType, Element, 2, true, Tail, Kernel_traits::kHeadDimV>
<<<grid, block, 0, stream>>>(reduce_params);
} else if (params.num_splits == 4) {
::flash_mla_splitkv_reduce_kernel<SplitkvAccumType, Element, 4, true, Tail, Kernel_traits::kHeadDimV>
<<<grid, block, 0, stream>>>(reduce_params);
} else if (params.num_splits == 8) {
::flash_mla_splitkv_reduce_kernel<SplitkvAccumType, Element, 8, true, Tail, Kernel_traits::kHeadDimV>
<<<grid, block, 0, stream>>>(reduce_params);
} else if (params.num_splits == 16) {
::flash_mla_splitkv_reduce_kernel<SplitkvAccumType, Element, 16, true, Tail, Kernel_traits::kHeadDimV>
<<<grid, block, 0, stream>>>(reduce_params);
} else if (params.num_splits == 32) {
::flash_mla_splitkv_reduce_kernel<SplitkvAccumType, Element, 32, true, Tail, Kernel_traits::kHeadDimV>
<<<grid, block, 0, stream>>>(reduce_params);
} else if (params.num_splits == 64) {
::flash_mla_splitkv_reduce_kernel<SplitkvAccumType, Element, 64, true, Tail, Kernel_traits::kHeadDimV>
<<<grid, block, 0, stream>>>(reduce_params);
} else {
printf("\x1b[31mnum_splits %d is not supported yet, and thus won't execute the kernel\033[0m\n",
params.num_splits);
}
}
}
template<typename T, int Headdim, int HeaddimV>
void run_dsa_prefill_nopage_64_dispatch(Flash_fwd_mla_params_dsa& params, hipStream_t stream) {
constexpr int kBlockM = 64;
......@@ -34,21 +94,49 @@ void run_dsa_prefill_nopage_64_dispatch(Flash_fwd_mla_params_dsa& params, hipStr
constexpr bool Is_dropout = false;
constexpr bool IsEvenMNConst = false;
BOOL_SWITCH(params.mtp > 1, Is_MTP, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if (params.topk == 2048) {
flash::flash_fwd_mla_decode_kernel_gfx938_dsa_prefill_nopage_64<
Kernel_traits, true, Is_dropout, false, Is_causal,
IsEvenMNConst, true, false, Is_MTP, 0, Flash_fwd_mla_params_dsa>
<<<dimGrid, dimBlock, 21 * 1024, stream>>>(params);
} else {
flash::flash_fwd_mla_decode_kernel_gfx938_dsa_prefill_nopage_64_topk1024<
Kernel_traits, true, Is_dropout, false, Is_causal,
IsEvenMNConst, true, false, Is_MTP, 0, Flash_fwd_mla_params_dsa>
<<<dimGrid, dimBlock, 21 * 1024, stream>>>(params);
}
constexpr int REUSE_KV = 1;
const bool has_extra = params.extra_sparse_indices != nullptr && params.extra_topk > 0;
if (params.num_splits == 1) {
BOOL_SWITCH(params.mtp > 1, Is_MTP, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
BOOL_SWITCH(has_extra, Has_extra, [&] {
flash::flash_fwd_mla_decode_kernel_gfx938_dsa_nopage_64<
Kernel_traits, true, Is_dropout, false, Is_causal,
IsEvenMNConst, true, false, Is_MTP, 0, Has_extra, Flash_fwd_mla_params_dsa>
<<<dimGrid, dimBlock, 21 * 1024, stream>>>(params);
});
});
});
} else if (params.num_splits != 0) {
dimGrid.y = params.num_splits;
BOOL_SWITCH(params.mtp > 1, Is_MTP, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
BOOL_SWITCH(has_extra, Has_extra, [&] {
flash::flash_fwd_mla_decode_kernel_gfx938_dsa_nopage_64_splitkv<
Kernel_traits, true, Is_dropout, false, Is_causal,
IsEvenMNConst, true, false, Is_MTP, 0, Has_extra, Flash_fwd_mla_params_dsa>
<<<dimGrid, dimBlock, 21 * 1024, stream>>>(params);
});
});
});
run_dsa_mla_splitkv_reduce<Kernel_traits, false>(params, stream);
} else {
BOOL_SWITCH(params.mtp > 1, Is_MTP, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if (params.topk == 2048) {
flash::flash_fwd_mla_decode_kernel_gfx938_dsa_prefill_nopage_64<
Kernel_traits, true, Is_dropout, false, Is_causal,
IsEvenMNConst, true, false, Is_MTP, 0, Flash_fwd_mla_params_dsa>
<<<dimGrid, dimBlock, 21 * 1024, stream>>>(params);
} else {
flash::flash_fwd_mla_decode_kernel_gfx938_dsa_prefill_nopage_64_topk1024<
Kernel_traits, true, Is_dropout, false, Is_causal,
IsEvenMNConst, true, false, Is_MTP, 0, Flash_fwd_mla_params_dsa>
<<<dimGrid, dimBlock, 21 * 1024, stream>>>(params);
}
});
});
});
}
}
} // namespace gfx93::fwd::dsa_mls
......@@ -450,6 +450,7 @@ struct Flash_fwd_mla_reduce_params {
int num_splits;
int partition_size;
int h;
int ngroups;
int seqlen_q;
int layout;
float* attn_sink;
......
......@@ -2503,25 +2503,9 @@ __forceinline__ __device__ void flash_fwd_mla_decode_kernel_gfx938_dsa_prefill_n
/**************************************************************************************************************************************/
constexpr bool Is_Interleave = true;
int lane_id = threadIdx.x & 63;
#pragma unroll
for (int mi = 0; mi < WARP_M / 16; ++mi) {
if (real_topk == 0) {
lse[mi].f32[0] = __builtin_inff();
}
}
if (params.softmax_lse_ptr != nullptr) {
prefill_mla_epilogue_store_lse<WARP_M, Is_even_MN, false/*SplitD*/, Is_Interleave, ElementAccum, vec_Accum<ElementAccum>, 1/* M_MMAC_COUNT */>(lse, params.softmax_lse_ptr, row_offset_lse, warp_id, lane_id, 0, actual_seqlen_q - m_block * kBlockM);
}
if (params.scores_max_ptr != nullptr) {
float* scores_max_ptr = params.scores_max_ptr + row_offset_lse;
const bool write_scores_max = (lane_id >> 4) == 0;
if (write_scores_max) {
const int row = warp_id * WARP_M + (lane_id & 15);
if (row < actual_seqlen_q - m_block * kBlockM) {
scores_max_ptr[row] = scores_max[0].f32[0] * params.scale_softmax;
}
}
}
/**************************************************************************************************************************************/
Element* o_ptr = reinterpret_cast<Element *>(params.o_ptr) + row_offset_o;
prefill_mla_epilogue_store_output<kHeadDimV, kBlockM, kBlockK, WARP_M, Is_even_MN, Is_Interleave, false/*TcpSwizzle*/, Element, ElementAccum, 1/* M_MMAC_COUNT */>(o_ptr, acc_o, m_block, warp_id, lane_id, seqlen_o_stride, actual_seqlen_q);
......@@ -2539,7 +2523,7 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa
flash_fwd_mla_decode_kernel_gfx938_dsa_prefill_nopage_64_impl<Kernel_traits, Is_training, Is_dropout, Is_prefix, Is_causal, Is_even_MN, Is_even_K, Return_softmax, Is_MTP, Layout, false, Params>(params);
}
template<typename Kernel_traits, bool Is_training, bool Is_dropout, bool Is_prefix, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Return_softmax, bool Is_MTP, int Layout, typename Params>
template<typename Kernel_traits, bool Is_training, bool Is_dropout, bool Is_prefix, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Return_softmax, bool Is_MTP, int Layout, bool Has_extra, typename Params>
__global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa_nopage_64(const Params params) {
#if defined(__gfx938__)
using Element = typename Kernel_traits::Element;
......@@ -2576,7 +2560,7 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa
const int warp_offset_in_seq_q = m_block * kBlockM + warp_id * WARP_M;
const int warp_seqq_limit = Is_even_MN ? 0: actual_seqlen_q - m_block * kBlockM;
const bool has_extra = params.extra_sparse_indices != nullptr && params.extra_topk > 0;
constexpr bool has_extra = Has_extra;
// 分配 lds Q/P same place, K/V same place;
// extern __shared__ Element smem[];
// int* index_lds = (int *)&(smem);
......@@ -2588,7 +2572,7 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa
Element* k_lds = q_lds; // 16KB
Element* v_lds = q_lds;
int* index_lds = (int *)(q_lds + 8 * 1024);
int* extra_index_lds = (int *)(index_lds + 1024);
int* extra_index_lds = (int *)(index_lds + 256);
// int* sIndices = (int *)(q_lds + 8192);
// 计算当前任务沿着 seqlenKV 方向的 block 起始位置和终止位置
......@@ -2601,7 +2585,7 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa
int extra_topk_length = params.extra_topk_length ? params.extra_topk_length[bidb] : params.extra_topk;
int main_num_blocks = ceil_div(main_topk_length, kBlockN);
int extra_num_blocks = has_extra ? ceil_div(extra_topk_length, kBlockN) : 0;
int extra_num_blocks = Has_extra ? ceil_div(extra_topk_length, kBlockN) : 0;
// 计算数据跨度
int seqlen_q_stride = params.q_head_stride;
......@@ -2619,12 +2603,13 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa
+ bidb * params.sparse_indices_batch_stride
+ query_idx * params.sparse_indices_row_stride
+ (q_head_start / params.ngroups) * params.sparse_indices_head_stride;
int* extra_index_ptr = has_extra
? params.extra_sparse_indices
int* extra_index_ptr = nullptr;
if constexpr (Has_extra) {
extra_index_ptr = params.extra_sparse_indices
+ bidb * params.extra_sparse_indices_batch_stride
+ query_idx * params.extra_sparse_indices_row_stride
+ (q_head_start / params.ngroups) * params.extra_sparse_indices_head_stride
: nullptr;
+ (q_head_start / params.ngroups) * params.extra_sparse_indices_head_stride;
}
// const int block_table_idx = 0;
// const int block_table_offset = 0;
......@@ -2645,43 +2630,50 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa
qv_ptr = prepare_for_matrix_load<kHeadDim>(reinterpret_cast<Element*>(params.qv_ptr) + row_offset_qv);
auto k_ptr = prepare_for_matrix_load<kHeadDim>(reinterpret_cast<Element*>(params.k_ptr) + row_offset_k);
auto k_ptr_buffer = tcp_cache_swizzle_func<64, Element>(reinterpret_cast<Element*>(params.k_ptr) + row_offset_k);
auto extra_k_ptr_buffer = has_extra ? tcp_cache_swizzle_func<64, Element>(reinterpret_cast<Element*>(params.extra_k_ptr)) : k_ptr_buffer;
auto extra_k_ptr_buffer = tcp_cache_swizzle_func<64, Element>(reinterpret_cast<Element*>(Has_extra ? params.extra_k_ptr : params.k_ptr));
auto v_ptr = prepare_for_matrix_load<kHeadDimV>(reinterpret_cast<Element*>(params.v_ptr) + row_offset_v);
auto v_ptr_buffer = tcp_cache_swizzle_func<64, Element>(reinterpret_cast<Element*>(params.v_ptr) + row_offset_v);
auto extra_v_ptr_buffer = has_extra ? tcp_cache_swizzle_func<64, Element>(reinterpret_cast<Element*>(params.extra_v_ptr)) : v_ptr_buffer;
auto extra_v_ptr_buffer = tcp_cache_swizzle_func<64, Element>(reinterpret_cast<Element*>(Has_extra ? params.extra_v_ptr : params.v_ptr));
auto index_ptr_buffer = tcp_cache_swizzle_func<0, int>(reinterpret_cast<int*>(index_ptr));
int tid = threadIdx.x % 64;
if (has_extra) {
if constexpr (Has_extra) {
auto extra_index_ptr_buffer = tcp_cache_swizzle_func<0, int>(extra_index_ptr);
int lds_offset = __builtin_amdgcn_readfirstlane(warp_id * 64);
int g_offset_v = tid;
int g_offset_s = warp_id * 64;
inline_buffer_load_dword_lds(extra_index_lds, extra_index_ptr_buffer, lds_offset, g_offset_s, g_offset_v);
int lds_offset = __builtin_amdgcn_readfirstlane(warp_id * 4 * 64);
int g_offset_v = tid * 4;
int g_offset_s = warp_id * 256;
inline_buffer_load_dwordx4_lds(extra_index_lds, extra_index_ptr_buffer, lds_offset, g_offset_s, g_offset_v);
}
auto k_faker = reinterpret_cast<Element*>(params.k_ptr) + row_offset_k;
auto v_faker = reinterpret_cast<Element*>(params.v_ptr) + row_offset_v;
// apply causal mask 的步骤和 no causal mask 的步骤分开算
constexpr int n_masking_steps = 1;
int lds_offset = __builtin_amdgcn_readfirstlane(warp_id * 4 * 64);
int g_offset_v = tid * 4;
int g_offset_s = warp_id * 256;
inline_buffer_load_dwordx4_lds(index_lds, index_ptr_buffer, lds_offset, g_offset_s, g_offset_v);
int lds_offset = __builtin_amdgcn_readfirstlane(Has_extra ? warp_id * 64 : warp_id * 4 * 64);
int g_offset_v = Has_extra ? tid : tid * 4;
int g_offset_s = Has_extra ? warp_id * 64 : warp_id * 256;
if constexpr (Has_extra) {
inline_buffer_load_dword_lds(index_lds, index_ptr_buffer, lds_offset, g_offset_s, g_offset_v);
} else {
inline_buffer_load_dwordx4_lds(index_lds, index_ptr_buffer, lds_offset, g_offset_s, g_offset_v);
}
flash::wait_buffer_data_arrived<true>(0);
flash::wait_all_warp_arrived();
#pragma unroll
for (int i = 0; i < 4; ++i) {
const int local_index = warp_id * 256 + tid * 4 + i;
const int local_index = Has_extra ? warp_id * 64 + tid : warp_id * 256 + tid * 4 + i;
if (local_index >= main_topk_length) {
index_lds[local_index] = main_topk_length > 0 ? index_lds[main_topk_length - 1] : -1;
}
}
if (has_extra) {
const int local_index = warp_id * 64 + tid;
if (local_index >= extra_topk_length) {
extra_index_lds[local_index] = extra_topk_length > 0 ? extra_index_lds[extra_topk_length - 1] : -1;
if constexpr (Has_extra) {
#pragma unroll
for (int i = 0; i < 4; ++i) {
const int local_index = warp_id * 256 + tid * 4 + i;
if (local_index >= extra_topk_length) {
extra_index_lds[local_index] = extra_topk_length > 0 ? extra_index_lds[extra_topk_length - 1] : -1;
}
}
}
flash::wait_all_warp_arrived();
......@@ -2862,7 +2854,7 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa
#endif
}
template<typename Kernel_traits, bool Is_training, bool Is_dropout, bool Is_prefix, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Return_softmax, bool Is_MTP, int Layout, typename Params>
template<typename Kernel_traits, bool Is_training, bool Is_dropout, bool Is_prefix, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Return_softmax, bool Is_MTP, int Layout, bool Has_extra, typename Params>
__global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa_nopage_64_splitkv(const Params params) {
#if defined(__gfx938__)
using Element = typename Kernel_traits::Element;
......@@ -2899,7 +2891,7 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa
const int warp_offset_in_seq_q = m_block * kBlockM + warp_id * WARP_M;
const int warp_seqq_limit = Is_even_MN ? 0: actual_seqlen_q - m_block * kBlockM;
const bool has_extra = params.extra_sparse_indices != nullptr && params.extra_topk > 0;
constexpr bool has_extra = Has_extra;
// 分配 lds Q/P same place, K/V same place;
// extern __shared__ Element smem[];
// int* index_lds = (int *)&(smem);
......@@ -2911,7 +2903,7 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa
Element* k_lds = q_lds; // 16KB
Element* v_lds = q_lds;
int* index_lds = (int *)(q_lds + 8 * 1024);
int* extra_index_lds = (int *)(index_lds + 1024);
int* extra_index_lds = (int *)(index_lds + 256);
// int* sIndices = (int *)(q_lds + 8192);
// 计算当前任务沿着 seqlenKV 方向的 block 起始位置和终止位置
int split_id = blockIdx.y;
......@@ -2925,7 +2917,7 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa
int extra_topk_length = params.extra_topk_length ? params.extra_topk_length[bidb] : params.extra_topk;
int main_num_blocks = ceil_div(main_topk_length, kBlockN);
int extra_num_blocks = params.extra_sparse_indices ? ceil_div(extra_topk_length, kBlockN) : 0;
int extra_num_blocks = Has_extra ? ceil_div(extra_topk_length, kBlockN) : 0;
int total_num_blocks = main_num_blocks + extra_num_blocks;
int blocks_per_split = ceil_div(params.partition_size, kBlockN);
......@@ -2949,12 +2941,13 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa
+ bidb * params.sparse_indices_batch_stride
+ query_idx * params.sparse_indices_row_stride
+ (q_head_start / params.ngroups) * params.sparse_indices_head_stride;
int* extra_index_ptr = has_extra
? params.extra_sparse_indices
int* extra_index_ptr = nullptr;
if constexpr (Has_extra) {
extra_index_ptr = params.extra_sparse_indices
+ bidb * params.extra_sparse_indices_batch_stride
+ query_idx * params.extra_sparse_indices_row_stride
+ (q_head_start / params.ngroups) * params.extra_sparse_indices_head_stride
: nullptr;
+ (q_head_start / params.ngroups) * params.extra_sparse_indices_head_stride;
}
// const int block_table_idx = 0;
// const int block_table_offset = 0;
......@@ -2963,6 +2956,11 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa
row_offset_v = 0;
row_offset_o = bidb * int64_t(params.o_batch_stride) + query_idx * int64_t(params.o_row_stride) + q_head_start * int64_t(params.o_head_stride);
row_offset_lse = bidb * params.seqlen_q + m_block * kBlockM;
const int64_t split_block_base = int64_t(bidb) * params.h * params.seqlen_q + m_block * kBlockM;
const int64_t split_accum_row_offset = int64_t(split_id) * params.b * params.h * params.seqlen_q * kHeadDimV
+ split_block_base * kHeadDimV;
const int64_t split_lse_row_offset = int64_t(split_id) * params.b * params.h * params.seqlen_q
+ split_block_base;
// row_offset_k = bidb * int64_t(params.k_batch_stride) + block_table_offset * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
// row_offset_v = bidb * int64_t(params.v_batch_stride) + block_table_offset * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
......@@ -2975,42 +2973,49 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa
qv_ptr = prepare_for_matrix_load<kHeadDim>(reinterpret_cast<Element*>(params.qv_ptr) + row_offset_qv);
auto k_ptr = prepare_for_matrix_load<kHeadDim>(reinterpret_cast<Element*>(params.k_ptr) + row_offset_k);
auto k_ptr_buffer = tcp_cache_swizzle_func<64, Element>(reinterpret_cast<Element*>(params.k_ptr) + row_offset_k);
auto extra_k_ptr_buffer = has_extra ? tcp_cache_swizzle_func<64, Element>(reinterpret_cast<Element*>(params.extra_k_ptr)) : k_ptr_buffer;
auto extra_k_ptr_buffer = tcp_cache_swizzle_func<64, Element>(reinterpret_cast<Element*>(Has_extra ? params.extra_k_ptr : params.k_ptr));
auto v_ptr = prepare_for_matrix_load<kHeadDimV>(reinterpret_cast<Element*>(params.v_ptr) + row_offset_v);
auto v_ptr_buffer = tcp_cache_swizzle_func<64, Element>(reinterpret_cast<Element*>(params.v_ptr) + row_offset_v);
auto extra_v_ptr_buffer = has_extra ? tcp_cache_swizzle_func<64, Element>(reinterpret_cast<Element*>(params.extra_v_ptr)) : v_ptr_buffer;
auto extra_v_ptr_buffer = tcp_cache_swizzle_func<64, Element>(reinterpret_cast<Element*>(Has_extra ? params.extra_v_ptr : params.v_ptr));
auto index_ptr_buffer = tcp_cache_swizzle_func<0, int>(reinterpret_cast<int*>(index_ptr));
int tid = threadIdx.x % 64;
if (has_extra) {
if constexpr (Has_extra) {
auto extra_index_ptr_buffer = tcp_cache_swizzle_func<0, int>(extra_index_ptr);
int lds_offset = __builtin_amdgcn_readfirstlane(warp_id * 64);
int g_offset_v = tid;
int g_offset_s = warp_id * 64;
inline_buffer_load_dword_lds(extra_index_lds, extra_index_ptr_buffer, lds_offset, g_offset_s, g_offset_v);
int lds_offset = __builtin_amdgcn_readfirstlane(warp_id * 4 * 64);
int g_offset_v = tid * 4;
int g_offset_s = warp_id * 256;
inline_buffer_load_dwordx4_lds(extra_index_lds, extra_index_ptr_buffer, lds_offset, g_offset_s, g_offset_v);
}
auto k_faker = reinterpret_cast<Element*>(params.k_ptr) + row_offset_k;
auto v_faker = reinterpret_cast<Element*>(params.v_ptr) + row_offset_v;
// apply causal mask 的步骤和 no causal mask 的步骤分开算
constexpr int n_masking_steps = 1;
int lds_offset = __builtin_amdgcn_readfirstlane(warp_id * 4 * 64);
int g_offset_v = tid * 4;
int g_offset_s = warp_id * 256;
inline_buffer_load_dwordx4_lds(index_lds, index_ptr_buffer, lds_offset, g_offset_s, g_offset_v);
int lds_offset = __builtin_amdgcn_readfirstlane(Has_extra ? warp_id * 64 : warp_id * 4 * 64);
int g_offset_v = Has_extra ? tid : tid * 4;
int g_offset_s = Has_extra ? warp_id * 64 : warp_id * 256;
if constexpr (Has_extra) {
inline_buffer_load_dword_lds(index_lds, index_ptr_buffer, lds_offset, g_offset_s, g_offset_v);
} else {
inline_buffer_load_dwordx4_lds(index_lds, index_ptr_buffer, lds_offset, g_offset_s, g_offset_v);
}
flash::wait_buffer_data_arrived<true>(0);
flash::wait_all_warp_arrived();
#pragma unroll
for (int i = 0; i < 4; ++i) {
const int local_index = warp_id * 256 + tid * 4 + i;
const int local_index = Has_extra ? warp_id * 64 + tid : warp_id * 256 + tid * 4 + i;
if (local_index >= main_topk_length) {
index_lds[local_index] = main_topk_length > 0 ? index_lds[main_topk_length - 1] : -1;
}
}
if (has_extra) {
const int local_index = warp_id * 64 + tid;
if (local_index >= extra_topk_length) {
extra_index_lds[local_index] = extra_topk_length > 0 ? extra_index_lds[extra_topk_length - 1] : -1;
if constexpr (Has_extra) {
#pragma unroll
for (int i = 0; i < 4; ++i) {
const int local_index = warp_id * 256 + tid * 4 + i;
if (local_index >= extra_topk_length) {
extra_index_lds[local_index] = extra_topk_length > 0 ? extra_index_lds[extra_topk_length - 1] : -1;
}
}
}
flash::wait_all_warp_arrived();
......@@ -3070,7 +3075,7 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa
auto PV_GEMM_FUNC_IN_MASK = &pv_gemm_prefetch_k_mls_ds_576_512_nopage_64_new_999<false, kHeadDim, kHeadDimV, kBlockM, 256, kBlockN, WARP_M, 256, STAGES, Element, ElementAccum, Is_even_MN>;
for (int logical_block = n_block_min; logical_block < n_block_max; ++logical_block) {
bool is_extra = logical_block >= main_num_blocks;
bool is_extra = Has_extra && logical_block >= main_num_blocks;
int rel_block = is_extra ? logical_block - main_num_blocks : logical_block;
int cur_topk_length = is_extra ? extra_topk_length : main_topk_length;
int* cur_index_lds = is_extra ? extra_index_lds : index_lds;
......@@ -3191,11 +3196,11 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa
constexpr bool Is_Interleave = true;
int lane_id = threadIdx.x & 63;
if (params.softmax_lse_ptr != nullptr) {
prefill_mla_epilogue_store_lse<WARP_M, Is_even_MN, false/*SplitD*/, Is_Interleave, ElementAccum, vec_Accum<ElementAccum>, 1/* M_MMAC_COUNT */>(lse, params.softmax_lse_ptr, row_offset_lse, warp_id, lane_id, 0, actual_block_m);
prefill_mla_epilogue_store_lse<WARP_M, Is_even_MN, false/*SplitD*/, Is_Interleave, ElementAccum, vec_Accum<ElementAccum>, 1/* M_MMAC_COUNT */>(lse, params.softmax_lse_ptr, split_lse_row_offset, warp_id, lane_id, 0, actual_block_m);
}
/**************************************************************************************************************************************/
Element* o_ptr = reinterpret_cast<Element *>(params.oaccum_ptr) + row_offset_o;
prefill_mla_epilogue_store_output<kHeadDimV, kBlockM, kBlockK, WARP_M, Is_even_MN, Is_Interleave, false/*TcpSwizzle*/, Element, ElementAccum, 1/* M_MMAC_COUNT */>(o_ptr, acc_o, 0, warp_id, lane_id, seqlen_o_stride, actual_block_m);
Element* o_ptr = reinterpret_cast<Element *>(params.oaccum_ptr) + split_accum_row_offset;
prefill_mla_epilogue_store_output<kHeadDimV, kBlockM, kBlockK, WARP_M, Is_even_MN, Is_Interleave, false/*TcpSwizzle*/, Element, ElementAccum, 1/* M_MMAC_COUNT */>(o_ptr, acc_o, 0, warp_id, lane_id, kHeadDimV, actual_block_m);
}
#endif
}
......
......@@ -37,6 +37,7 @@ void run_mla_splitkv_reduce(Params &params, hipStream_t stream) {
reduce_params.num_splits = params.num_splits;
reduce_params.partition_size = params.partition_size;
reduce_params.h = params.h;
reduce_params.ngroups = params.ngroups;
reduce_params.seqlen_q = params.seqlen_q;
reduce_params.layout = params.layout;
reduce_params.topk_length = params.topk_length;
......@@ -556,11 +557,14 @@ void run_mla_fwd_dispatch_dsa_prefill_nopage_64(Flash_fwd_mla_params_dsa &params
dimGrid.z = params.b;
constexpr bool IsEvenMNConst = false;
const bool has_extra = params.extra_sparse_indices != nullptr && params.extra_topk > 0;
if(params.num_splits == 1){
BOOL_SWITCH(params.mtp > 1/* is_mtp */, Is_MTP, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
flash::flash_fwd_mla_decode_kernel_gfx938_dsa_nopage_64<Kernel_traits, true/*Is_training*/, Is_dropout, false/* Is_prefix | flashmla */,Is_causal, IsEvenMNConst, /*Is_even_K*/true, /*Return_softmax*/false, Is_MTP, 0, Flash_fwd_mla_params_dsa>
BOOL_SWITCH(has_extra, Has_extra, [&] {
flash::flash_fwd_mla_decode_kernel_gfx938_dsa_nopage_64<Kernel_traits, true/*Is_training*/, Is_dropout, false/* Is_prefix | flashmla */,Is_causal, IsEvenMNConst, /*Is_even_K*/true, /*Return_softmax*/false, Is_MTP, 0, Has_extra, Flash_fwd_mla_params_dsa>
<<<dimGrid, dimBlock, 21 * 1024, stream>>>(params);
});
});
});
}
......@@ -570,8 +574,10 @@ void run_mla_fwd_dispatch_dsa_prefill_nopage_64(Flash_fwd_mla_params_dsa &params
dimGrid.z = params.b;
BOOL_SWITCH(params.mtp > 1/* is_mtp */, Is_MTP, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
flash::flash_fwd_mla_decode_kernel_gfx938_dsa_nopage_64_splitkv<Kernel_traits, true/*Is_training*/, Is_dropout, false/* Is_prefix | flashmla */,Is_causal, IsEvenMNConst, /*Is_even_K*/true, /*Return_softmax*/false, Is_MTP, 0, Flash_fwd_mla_params_dsa>
BOOL_SWITCH(has_extra, Has_extra, [&] {
flash::flash_fwd_mla_decode_kernel_gfx938_dsa_nopage_64_splitkv<Kernel_traits, true/*Is_training*/, Is_dropout, false/* Is_prefix | flashmla */,Is_causal, IsEvenMNConst, /*Is_even_K*/true, /*Return_softmax*/false, Is_MTP, 0, Has_extra, Flash_fwd_mla_params_dsa>
<<<dimGrid, dimBlock, 21 * 1024, stream>>>(params);
});
});
});
run_mla_splitkv_reduce<Kernel_traits, false/*Tail*/>(params, stream);
......
#pragma once
#include "numeric_types.h"
#include "splitkv.h"
#include "intrinsic.h"
////////////////////////////////////////////////////////////////////////////////////////////////////
......@@ -29,7 +30,7 @@ __global__ void flash_fwd_splitkv_reduce_kernel(
// compute partition_size when fix num_splits
int partition_size = params.partition_size > MLA_MAX_SPLITS ? splitkv_get_partitionsize_of_fix_numsplits(actual_seqlen_k, params.num_splits): params.partition_size;
const int true_num_splits = Tail ? max(1, floor_div(actual_seqlen_k, partition_size)): ceil_div(actual_seqlen_k, partition_size);
const int true_num_splits = Tail ? max(1, flash::floor_div(actual_seqlen_k, partition_size)): flash::ceil_div(actual_seqlen_k, partition_size);
// const int true_num_splits = num_splits;
bool exceed_split = (tx >= true_num_splits); // process boundary
......@@ -40,7 +41,7 @@ __global__ void flash_fwd_splitkv_reduce_kernel(
float s_max_tmp = s_max_load_ori;
#pragma unroll
for (int step = SPLIT_COUNT >> 1; step > 0; step = (step >> 1)) {
s_max_tmp = max(s_max_tmp, __shfl_xor_tmp(s_max_tmp, step));
s_max_tmp = max(s_max_tmp, flash::__shfl_xor_tmp(s_max_tmp, step));
}
// compute rescale coefficient for max (numerator)
float s_max_ratio = __expf(s_max_load_ori - s_max_tmp);
......@@ -50,7 +51,7 @@ __global__ void flash_fwd_splitkv_reduce_kernel(
float s_sum_tmp = s_sum_load_ori * s_max_ratio;
#pragma unroll
for (int step = SPLIT_COUNT >> 1; step > 0; step = (step >> 1)) {
s_sum_tmp = s_sum_tmp + __shfl_xor_tmp(s_sum_tmp, step);
s_sum_tmp = s_sum_tmp + flash::__shfl_xor_tmp(s_sum_tmp, step);
}
// max-rescale coefficient x sum-rescale coefficient
......@@ -81,18 +82,18 @@ __global__ void flash_fwd_splitkv_reduce_kernel(
// read ultimate scale value for current split
vec2_Element<accumType> load = *(vec2_Element<accumType>*)(oaccum_ptr + tx_offset + t);
// half -> float32, reduce precision loss
float a_f32 = within_splits? splitkv_upcast_to_f32<accumType>(load[0]): 0.f;
float b_f32 = within_splits? splitkv_upcast_to_f32<accumType>(load[1]): 0.f;
float a_f32 = within_splits? flash::splitkv_upcast_to_f32<accumType>(load[0]): 0.f;
float b_f32 = within_splits? flash::splitkv_upcast_to_f32<accumType>(load[1]): 0.f;
// do rescale and sum
tx_accum[t] = __llvm_fma_f32(a_f32, s_scale, tx_accum[t]);
tx_accum[t + 1] = __llvm_fma_f32(b_f32, s_scale, tx_accum[t + 1]);
tx_accum[t] = flash::__llvm_fma_f32(a_f32, s_scale, tx_accum[t]);
tx_accum[t + 1] = flash::__llvm_fma_f32(b_f32, s_scale, tx_accum[t + 1]);
} else if constexpr (kHeadDim == 64) {
// read ultimate scale value for current split
accumType load = *(accumType*)(oaccum_ptr + tx_offset + t);
// half -> float32, reduce precision loss
float load_f32 = within_splits? splitkv_upcast_to_f32<accumType>(load): 0.f;
float load_f32 = within_splits? flash::splitkv_upcast_to_f32<accumType>(load): 0.f;
// do rescale and sum
tx_accum[t] = __llvm_fma_f32(load_f32, s_scale, tx_accum[t]);
tx_accum[t] = flash::__llvm_fma_f32(load_f32, s_scale, tx_accum[t]);
}
}
// switch to next split
......@@ -103,14 +104,14 @@ __global__ void flash_fwd_splitkv_reduce_kernel(
if constexpr (kHeadDim % 128 == 0) {
vec2_Element<reduceType> accum_result;
#if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
accum_result = DownCastPairNoPack<float, reduceType>(tx_accum[t], tx_accum[t + 1]);
accum_result = flash::DownCastPairNoPack<float, reduceType>(tx_accum[t], tx_accum[t + 1]);
#else
accum_result[0] = DownCast<float, reduceType, true>(tx_accum[t]);
accum_result[1] = DownCast<float, reduceType, true>(tx_accum[t + 1]); // here, v_cvt_pkrtz can be used
accum_result[0] = flash::DownCast<float, reduceType, true>(tx_accum[t]);
accum_result[1] = flash::DownCast<float, reduceType, true>(tx_accum[t + 1]); // here, v_cvt_pkrtz can be used
#endif
*(vec2_Element<reduceType>*)(output_ptr + t) = accum_result;
} else if constexpr (kHeadDim == 64) {
reduceType accum_result = DownCast<float, reduceType, false>(tx_accum[t]);
reduceType accum_result = flash::DownCast<float, reduceType, false>(tx_accum[t]);
output_ptr[t] = accum_result;
}
}
......@@ -146,7 +147,7 @@ __global__ void flash_fwd_splitkv_reduce_kernel_split128(Params params) {
}
// compute partition_size when fix num_splits
int partition_size = params.partition_size > MLA_MAX_SPLITS ? splitkv_get_partitionsize_of_fix_numsplits(actual_seqlen_k, params.num_splits): params.partition_size;
const int true_num_splits = Tail ? max(1, floor_div(actual_seqlen_k, partition_size)): ceil_div(actual_seqlen_k, partition_size);
const int true_num_splits = Tail ? max(1, flash::floor_div(actual_seqlen_k, partition_size)): flash::ceil_div(actual_seqlen_k, partition_size);
// const int true_num_splits = num_splits;
bool exceed_split = (tx >= true_num_splits); // process boundary
......@@ -157,7 +158,7 @@ __global__ void flash_fwd_splitkv_reduce_kernel_split128(Params params) {
float s_max_tmp = s_max_load_ori;
#pragma unroll
for (int step = 64 >> 1; step > 0; step = (step >> 1)) {
s_max_tmp = max(s_max_tmp, __shfl_xor_tmp(s_max_tmp, step));
s_max_tmp = max(s_max_tmp, flash::__shfl_xor_tmp(s_max_tmp, step));
}
// for multiple waves, store the reduced max value to lds individually, and recompute max across multiple waves
int wave_id = (tx >> 6);
......@@ -181,7 +182,7 @@ __global__ void flash_fwd_splitkv_reduce_kernel_split128(Params params) {
float s_sum_tmp = s_sum_load_ori * s_max_ratio;
#pragma unroll
for (int step = 64 >> 1; step > 0; step = (step >> 1)) {
s_sum_tmp = s_sum_tmp + __shfl_xor_tmp(s_sum_tmp, step);
s_sum_tmp = s_sum_tmp + flash::__shfl_xor_tmp(s_sum_tmp, step);
}
// for multiple wave, store the reduced sum value to lds individually, and recompute sum across multiple waves
lds[LDS_ACCUM + wave_id] = s_sum_tmp;
......@@ -230,18 +231,18 @@ __global__ void flash_fwd_splitkv_reduce_kernel_split128(Params params) {
// read 2 halfs from current split of this threads
vec2_Element<accumType> load = *(vec2_Element<accumType>*)(oaccum_ptr + tx_offset + t);
// half -> float32, reduce precision loss
float a_f32 = within_splits? splitkv_upcast_to_f32<accumType>(load[0]): 0.f;
float b_f32 = within_splits? splitkv_upcast_to_f32<accumType>(load[1]): 0.f;
float a_f32 = within_splits? flash::splitkv_upcast_to_f32<accumType>(load[0]): 0.f;
float b_f32 = within_splits? flash::splitkv_upcast_to_f32<accumType>(load[1]): 0.f;
// do rescale and sum
tx_accum[t] = __llvm_fma_f32(a_f32, s_scale, tx_accum[t]);
tx_accum[t + 1] = __llvm_fma_f32(b_f32, s_scale, tx_accum[t + 1]);
tx_accum[t] = flash::__llvm_fma_f32(a_f32, s_scale, tx_accum[t]);
tx_accum[t + 1] = flash::__llvm_fma_f32(b_f32, s_scale, tx_accum[t + 1]);
} else if constexpr (kHeadDim == 64) {
// read 1 half from current split of this threads
accumType load = *(accumType*)(oaccum_ptr + tx_offset + t);
// half -> float32, reduce precision loss
float load_f32 = within_splits? splitkv_upcast_to_f32<accumType>(load): 0.f;
float load_f32 = within_splits? flash::splitkv_upcast_to_f32<accumType>(load): 0.f;
// do rescale and sum
tx_accum[t] = __llvm_fma_f32(load_f32, s_scale, tx_accum[t]);
tx_accum[t] = flash::__llvm_fma_f32(load_f32, s_scale, tx_accum[t]);
}
}
// switch to next split
......@@ -290,15 +291,15 @@ __global__ void flash_fwd_splitkv_reduce_kernel_split128(Params params) {
tx_accum[t + 1] = lds[tx * tx_float_count + t + 1];
vec2_Element<reduceType> accum_result;
#if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
accum_result = DownCastPairNoPack<float, reduceType>(tx_accum[t], tx_accum[t + 1]);
accum_result = flash::DownCastPairNoPack<float, reduceType>(tx_accum[t], tx_accum[t + 1]);
#else
accum_result[0] = DownCast<float, reduceType, true>(tx_accum[t]);
accum_result[1] = DownCast<float, reduceType, true>(tx_accum[t + 1]); // here, v_cvt_pkrtz can be used
accum_result[0] = flash::DownCast<float, reduceType, true>(tx_accum[t]);
accum_result[1] = flash::DownCast<float, reduceType, true>(tx_accum[t + 1]); // here, v_cvt_pkrtz can be used
#endif
*(vec2_Element<reduceType>*)(output_ptr + t) = accum_result;
} else if constexpr (kHeadDim == 64) {
tx_accum[t] = lds[tx * tx_float_count + t];
reduceType accum_result = DownCast<float, reduceType, false>(tx_accum[t]);
reduceType accum_result = flash::DownCast<float, reduceType, false>(tx_accum[t]);
*(reduceType*)(output_ptr + t) = accum_result;
}
}
......@@ -341,7 +342,7 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_splitkv_reduce_varlen_kernel
// compute partition_size when fix num_splits
int partition_size = splitkv_get_partitionsize_of_fix_numsplits(actual_seqlen_k, params.num_splits);
const int true_num_splits = Tail ? max(1, floor_div(actual_seqlen_k, partition_size)): ceil_div(actual_seqlen_k, partition_size);
const int true_num_splits = Tail ? max(1, flash::floor_div(actual_seqlen_k, partition_size)): flash::ceil_div(actual_seqlen_k, partition_size);
// const int true_num_splits = num_splits;
bool exceed_split = (tx >= true_num_splits); // process boundary
......@@ -399,13 +400,13 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_splitkv_reduce_varlen_kernel
float lse_max_local = lse_local;
#pragma unroll
for (int step = SPLIT_COUNT >> 1; step > 0; step = (step >> 1)) {
lse_max_local = max(lse_max_local, __shfl_xor_tmp(lse_max_local, step));
lse_max_local = max(lse_max_local, flash::__shfl_xor_tmp(lse_max_local, step));
}
// reduce sum lse
float lse_local_logsum = __expf(lse_local - lse_max_local);
#pragma unroll
for (int step = SPLIT_COUNT >> 1; step > 0; step = (step >> 1)) {
lse_local_logsum = lse_local_logsum + __shfl_xor_tmp(lse_local_logsum, step);
lse_local_logsum = lse_local_logsum + flash::__shfl_xor_tmp(lse_local_logsum, step);
}
lse_local_logsum = __logf(lse_local_logsum) + lse_max_local;
......@@ -427,16 +428,16 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_splitkv_reduce_varlen_kernel
for (int t = 0; t < tx_float_count; t += 2) {
if constexpr (kHeadDim % 128 == 0) {
// half -> float32, reduce precision loss
float a_f32 = within_splits? splitkv_upcast_to_f32<accumType>(load_vec[i][t >> 1][0]): 0.f;
float b_f32 = within_splits? splitkv_upcast_to_f32<accumType>(load_vec[i][t >> 1][1]): 0.f;
float a_f32 = within_splits? flash::splitkv_upcast_to_f32<accumType>(load_vec[i][t >> 1][0]): 0.f;
float b_f32 = within_splits? flash::splitkv_upcast_to_f32<accumType>(load_vec[i][t >> 1][1]): 0.f;
// do rescale and sum
tx_accum[t] = __llvm_fma_f32(a_f32, s_scale, tx_accum[t]);
tx_accum[t + 1] = __llvm_fma_f32(b_f32, s_scale, tx_accum[t + 1]);
tx_accum[t] = flash::__llvm_fma_f32(a_f32, s_scale, tx_accum[t]);
tx_accum[t + 1] = flash::__llvm_fma_f32(b_f32, s_scale, tx_accum[t + 1]);
} else if constexpr (kHeadDim == 64) {
// half -> float32, reduce precision loss
float load_f32 = within_splits? splitkv_upcast_to_f32<accumType>(load[i][t >> 1]): 0.f;
float load_f32 = within_splits? flash::splitkv_upcast_to_f32<accumType>(load[i][t >> 1]): 0.f;
// do rescale and sum
tx_accum[t] = __llvm_fma_f32(load_f32, s_scale, tx_accum[t]);
tx_accum[t] = flash::__llvm_fma_f32(load_f32, s_scale, tx_accum[t]);
}
}
}
......@@ -446,14 +447,14 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_splitkv_reduce_varlen_kernel
if constexpr (kHeadDim % 128 == 0) {
vec2_Element<reduceType> accum_result;
#if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
accum_result = DownCastPairNoPack<float, reduceType>(tx_accum[t], tx_accum[t + 1]);
accum_result = flash::DownCastPairNoPack<float, reduceType>(tx_accum[t], tx_accum[t + 1]);
#else
accum_result[0] = DownCast<float, reduceType, true>(tx_accum[t]);
accum_result[1] = DownCast<float, reduceType, true>(tx_accum[t + 1]); // here, v_cvt_pkrtz can be used
accum_result[0] = flash::DownCast<float, reduceType, true>(tx_accum[t]);
accum_result[1] = flash::DownCast<float, reduceType, true>(tx_accum[t + 1]); // here, v_cvt_pkrtz can be used
#endif
*(vec2_Element<reduceType>*)(output_ptr + t) = accum_result;
} else if constexpr (kHeadDim == 64) {
reduceType accum_result = DownCast<float, reduceType, false>(tx_accum[t]);
reduceType accum_result = flash::DownCast<float, reduceType, false>(tx_accum[t]);
output_ptr[t] = accum_result;
}
}
......@@ -501,15 +502,15 @@ __global__ void __launch_bounds__(256, 1) flash_mla_splitkv_reduce_kernel(
// int main_len = params.topk_length ? params.topk_length[row] : params.topk;
// int extra_len = params.extra_topk_length ? params.extra_topk_length[row] : params.extra_topk;
// int actual_seqlen_k = ceil_div(main_len, 64) * 64 + ceil_div(extra_len, 64) * 64;
// int actual_seqlen_k = flash::ceil_div(main_len, 64) * 64 + flash::ceil_div(extra_len, 64) * 64;
int row = block_x / 64;
int main_len = params.topk_length ? params.topk_length[row] : params.topk;
int extra_len = params.extra_topk_length ? params.extra_topk_length[row] : params.extra_topk;
int topk_length_row = h == 1 ? bidb : block_x / 64;
int main_len = params.topk_length ? params.topk_length[topk_length_row] : params.topk;
int extra_len = params.extra_topk_length ? params.extra_topk_length[topk_length_row] : params.extra_topk;
int total_blocks = ceil_div(main_len, 64) + ceil_div(extra_len, 64);
int blocks_per_split = ceil_div(params.partition_size, 64);
int true_num_splits = ceil_div(total_blocks, blocks_per_split);
int total_blocks = flash::ceil_div(main_len, 64) + flash::ceil_div(extra_len, 64);
int blocks_per_split = flash::ceil_div(params.partition_size, 64);
int true_num_splits = flash::ceil_div(total_blocks, blocks_per_split);
// for flashmla, 512 elements are engaged to 4 blocks
// within each block, num_splits / WARM_NUM load transactions are engaged to each wave
......@@ -539,7 +540,7 @@ __global__ void __launch_bounds__(256, 1) flash_mla_splitkv_reduce_kernel(
// compute partition_size when fix num_splits
// int partition_size = params_partition_size > MLA_MAX_SPLITS ? splitkv_get_partitionsize_of_fix_numsplits(actual_seqlen_k, params.num_splits): params_partition_size;
// const int true_num_splits = Tail ? max(1, floor_div(actual_seqlen_k, partition_size)): ceil_div(actual_seqlen_k, partition_size);
// const int true_num_splits = Tail ? max(1, flash::floor_div(actual_seqlen_k, partition_size)): flash::ceil_div(actual_seqlen_k, partition_size);
// const int true_num_splits = num_splits;
bool exceed_split = (tx >= true_num_splits); // process boundary
......@@ -552,20 +553,20 @@ __global__ void __launch_bounds__(256, 1) flash_mla_splitkv_reduce_kernel(
float lse_max_local = lse_local;
#pragma unroll
for (int step = SPLIT_COUNT >> 1; step > 0; step = (step >> 1)) {
lse_max_local = max(lse_max_local, __shfl_xor_tmp(lse_max_local, step));
lse_max_local = max(lse_max_local, flash::__shfl_xor_tmp(lse_max_local, step));
}
// reduce sum lse
float lse_local_logsum = __expf(lse_local - lse_max_local);
#pragma unroll
for (int step = SPLIT_COUNT >> 1; step > 0; step = (step >> 1)) {
lse_local_logsum = lse_local_logsum + __shfl_xor_tmp(lse_local_logsum, step);
lse_local_logsum = lse_local_logsum + flash::__shfl_xor_tmp(lse_local_logsum, step);
}
lse_local_logsum = __logf(lse_local_logsum) + lse_max_local;
float attn_sink_o_scale = 1.0f;
if (params.attn_sink != nullptr) {
// 当前 reduce kernel 的 block_x 是按 b,h,s 展开的,所以 bidh 就是 head id。
float rAttn_sink = params.attn_sink[block_x % 64];
int attn_sink_idx = h == 1 ? in_batch_offset % params.ngroups : block_x % 64;
float rAttn_sink = params.attn_sink[attn_sink_idx];
if (rAttn_sink == INFINITY) {
attn_sink_o_scale = 0.0f;
......@@ -588,11 +589,11 @@ __global__ void __launch_bounds__(256, 1) flash_mla_splitkv_reduce_kernel(
#pragma unroll
for (int t = 0; t < tx_float_count; t += 2) {
// half -> float32, reduce precision loss
float a_f32 = within_splits? splitkv_upcast_to_f32<accumType>(load[i >> 2][t >> 1][0]): 0.f;
float b_f32 = within_splits? splitkv_upcast_to_f32<accumType>(load[i >> 2][t >> 1][1]): 0.f;
float a_f32 = within_splits? flash::splitkv_upcast_to_f32<accumType>(load[i >> 2][t >> 1][0]): 0.f;
float b_f32 = within_splits? flash::splitkv_upcast_to_f32<accumType>(load[i >> 2][t >> 1][1]): 0.f;
// do rescale and sum
tx_accum[t] = __llvm_fma_f32(a_f32, s_scale, tx_accum[t]);
tx_accum[t + 1] = __llvm_fma_f32(b_f32, s_scale, tx_accum[t + 1]);
tx_accum[t] = flash::__llvm_fma_f32(a_f32, s_scale, tx_accum[t]);
tx_accum[t + 1] = flash::__llvm_fma_f32(b_f32, s_scale, tx_accum[t + 1]);
}
}
// reduce across 4 waves
......@@ -616,13 +617,13 @@ __global__ void __launch_bounds__(256, 1) flash_mla_splitkv_reduce_kernel(
// cvt
vec2_Element<reduceType> accum_result;
#if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
accum_result = DownCastPairNoPack<float, reduceType>(tx_accum[t], tx_accum[t + 1]);
accum_result = flash::DownCastPairNoPack<float, reduceType>(tx_accum[t], tx_accum[t + 1]);
#else
accum_result[0] = DownCast<float, reduceType, true>(tx_accum[t]);
accum_result[1] = DownCast<float, reduceType, true>(tx_accum[t + 1]);
accum_result[0] = flash::DownCast<float, reduceType, true>(tx_accum[t]);
accum_result[1] = flash::DownCast<float, reduceType, true>(tx_accum[t + 1]);
#endif
// storation
*(vec2_Element<reduceType>*)(output_ptr + t) = accum_result;
}
}
}
\ No newline at end of file
}
......@@ -151,7 +151,10 @@ def flash_mla_with_kvcache(
if topk is not None:
# Sparse attention
assert not causal, "causal must be False when sparse attention is enabled"
assert is_fp8_kvcache, "is_fp8_kvcache must be True when sparse attention is enabled"
if not is_fp8_kvcache:
assert k_cache.dtype == torch.bfloat16, "BF16 sparse attention requires k_cache dtype to be torch.bfloat16 when is_fp8_kvcache is False"
if extra_k_cache is not None:
assert extra_k_cache.dtype == torch.bfloat16, "BF16 sparse attention requires extra_k_cache dtype to be torch.bfloat16 when is_fp8_kvcache is False"
out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.sparse_decode_fwd(
q, k_cache, indices_in_kvcache, topk_length, attn_sink,
sched_meta.tile_scheduler_metadata, sched_meta.num_splits,
......@@ -640,4 +643,4 @@ def flash_mla_with_kvcache_quantization_q_nope_pe(
# )
# sched_meta.tile_scheduler_metadata = new_tile_scheduler_metadata
# sched_meta.num_splits = new_num_splits
# return (out, lse)
\ No newline at end of file
# return (out, lse)
......@@ -93,6 +93,7 @@ ext_modules.append(
"csrc/gfx93/decode/sparse_fp8/instantiations/v32_persistent_h16.cu",
"csrc/gfx93/decode/sparse_fp8/instantiations/v32_persistent_h64.cu",
"csrc/gfx93/decode/sparse_fp8/instantiations/v32_persistent_h128.cu",
"csrc/gfx93/decode/sparse_bf16_dsa/fwd.cu",
# # gfx93 sparse prefill
"csrc/gfx93/prefill/sparse/fwd.cu",
......
......@@ -14,6 +14,9 @@ class TestTarget(enum.Enum):
FWD = 0
DECODE = 1
def is_decode_bf16_kvcache() -> bool:
return os.environ.get("FLASH_MLA_DECODE_BF16", "").lower() in ["1", "true", "yes", "y", "bf16"]
@dataclasses.dataclass
class ExtraTestParamForDecode:
b: int
......@@ -42,6 +45,21 @@ class TestParam:
have_topk_length: bool = False
decode: Optional[ExtraTestParamForDecode] = None
def is_bf16_decode_supported_param(t: TestParam) -> bool:
if t.decode is None:
return False
if t.is_all_indices_invalid or t.decode.have_zero_seqlen_k:
return False
if t.h_kv != 1 or t.d_v != 512:
return False
if t.h_q not in [64, 128]:
return False
if t.d_qk not in [512, 576]:
return False
if t.decode.extra_topk is None:
return t.topk <= 1024
return t.topk <= 256 and t.decode.extra_topk <= 1024
@dataclasses.dataclass
class RawTestParamForDecode:
"""
......@@ -289,14 +307,16 @@ def generate_testcase_for_decode(t: TestParam) -> TestcaseForDecode:
return KVScope(t, cache_seqlens, block_table, blocked_k, abs_indices, indices_in_kvcache, topk_length)
kv_scope0 = generate_one_k_scope(t.s_kv, t.decode.block_size, t.topk, t.decode.is_varlen, t.decode.have_zero_seqlen_k, t.is_all_indices_invalid, t.have_topk_length)
kv_scope0.quant_and_dequant_()
if not is_decode_bf16_kvcache():
kv_scope0.quant_and_dequant_()
if t.decode.extra_topk is not None:
if t.decode.extra_s_k is None:
t.decode.extra_s_k = t.decode.extra_topk*2
if t.decode.extra_block_size is None:
t.decode.extra_block_size = t.decode.block_size
kv_scope1 = generate_one_k_scope(t.decode.extra_s_k, t.decode.extra_block_size, t.decode.extra_topk, t.decode.is_varlen, t.decode.have_zero_seqlen_k, t.is_all_indices_invalid, t.decode.have_extra_topk_length)
kv_scope1.quant_and_dequant_()
if not is_decode_bf16_kvcache():
kv_scope1.quant_and_dequant_()
else:
assert t.decode.extra_block_size is None and t.decode.extra_s_k is None and not t.decode.have_extra_topk_length
kv_scope1 = None
......@@ -318,16 +338,17 @@ def run_flash_mla_sparse_fwd(p: TestParam, t: Testcase, return_p_sum: bool):
def run_flash_mla_decode(p: TestParam, t: TestcaseForDecode, tile_scheduler_metadata, num_splits):
assert p.decode is not None
is_fp8_kvcache = not is_decode_bf16_kvcache()
return flash_mla.flash_mla_with_kvcache(
t.q,
t.kv_scope.get_kvcache_for_flash_mla(),
t.kv_scope.get_kvcache_for_flash_mla() if is_fp8_kvcache else t.kv_scope.blocked_k,
None, None, p.d_v,
tile_scheduler_metadata, num_splits,
t.sm_scale, False, True,
t.sm_scale, False, is_fp8_kvcache,
t.kv_scope.indices_in_kvcache,
t.attn_sink,
t.extra_kv_scope.get_kvcache_for_flash_mla() if t.extra_kv_scope is not None else None,
(t.extra_kv_scope.get_kvcache_for_flash_mla() if is_fp8_kvcache else t.extra_kv_scope.blocked_k) if t.extra_kv_scope is not None else None,
t.extra_kv_scope.indices_in_kvcache if t.extra_kv_scope is not None else None,
t.kv_scope.topk_length,
t.extra_kv_scope.topk_length if t.extra_kv_scope is not None and t.extra_kv_scope.topk_length is not None else None
......
......@@ -172,8 +172,12 @@ def test_flash_mla(p: TestParam) -> Result:
else:
result = kk.bench_kineto(run_decode, p.num_runs)
splitkv_kernel_name = "flash_fwd_splitkv_mla_fp8_sparse_kernel"
combine_kernel_name = "flash_fwd_mla_combine_kernel"
if lib.is_decode_bf16_kvcache():
splitkv_kernel_name = "flash_fwd_mla_decode_kernel_gfx938_dsa_nopage_64_splitkv"
combine_kernel_name = "flash_mla_splitkv_reduce_kernel"
else:
splitkv_kernel_name = "flash_fwd_splitkv_mla_fp8_sparse_kernel"
combine_kernel_name = "flash_fwd_mla_combine_kernel"
# Get individual kernel time usages
kernel_time_usages_us: Dict[str, Optional[float]] = {}
......@@ -226,8 +230,11 @@ def test_flash_mla(p: TestParam) -> Result:
out_ref, lse_ref = ref.ref_sparse_attn_decode(p, t)
is_out_correct = kk.check_is_allclose("out", out_ans, out_ref, abs_tol=1e-3, rel_tol=2.01/128, cos_diff_tol=5e-6)
is_lse_correct = kk.check_is_allclose("lse", lse_ans, lse_ref, abs_tol=1e-6, rel_tol=8.01/65536)
is_correct &= is_out_correct and is_lse_correct
if lib.is_decode_bf16_kvcache():
is_correct &= is_out_correct
else:
is_lse_correct = kk.check_is_allclose("lse", lse_ans, lse_ref, abs_tol=1e-6, rel_tol=8.01/65536)
is_correct &= is_out_correct and is_lse_correct
performance_result.is_correct = is_correct
return performance_result
......@@ -250,6 +257,22 @@ def main():
raw_testcases = gen_testcase()
testcases = [t.to_test_param() for t in raw_testcases]
if lib.is_decode_bf16_kvcache():
bf16_testcases = []
seen_bf16_cases = set()
for t in testcases:
if not lib.is_bf16_decode_supported_param(t):
continue
if t.num_runs > 0 and t.decode.b > 16:
t = dataclasses.replace(t, decode=dataclasses.replace(t.decode, b=16))
key = dataclasses.asdict(t)
key["decode"] = tuple(key["decode"].items()) if key["decode"] is not None else None
key = tuple(key.items())
if key in seen_bf16_cases:
continue
seen_bf16_cases.add(key)
bf16_testcases.append(t)
testcases = bf16_testcases
print(f"{kk.colors['CYAN_BG']}{len(testcases)} testcases to run{kk.colors['CLEAR']}")
......
......@@ -10,6 +10,26 @@ import ref
_counter = kk.Counter()
def is_dsa_mls_prefill_case(p: TestParam) -> bool:
if p.d_v != 512:
return False
if p.d_qk not in [512, 576]:
return False
if p.h_kv != 1:
return False
if p.h_q not in [64, 128]:
return False
if not (p.topk <= 1024 or p.topk == 2048):
return False
if p.topk == 2048 and (p.have_attn_sink or p.have_topk_length):
return False
if p.d_qk == 512 and ((p.h_q == 64 and p.topk == 512) or (p.h_q == 128 and p.topk == 1024)):
return True
if p.d_qk == 576 and p.h_q == 64 and p.topk == 2048 and p.s_kv >= 32768:
return True
return False
@torch.inference_mode()
def run_test(p: TestParam) -> bool:
if p.seed == -1:
......@@ -31,7 +51,12 @@ def run_test(p: TestParam) -> bool:
if p.num_runs > 0:
flops_and_mem_vol = lib.count_flop_and_mem_vol(p, t)
prefill_ans_time = kk.bench_kineto(run_prefill, num_tests=p.num_runs).get_kernel_time("sparse_attn_fwd")
bench_result = kk.bench_kineto(run_prefill, num_tests=p.num_runs)
kernel_names = bench_result.get_kernel_names()
prefill_kernel_name = "sparse_attn_fwd"
if not any(prefill_kernel_name in name for name in kernel_names):
prefill_kernel_name = "flash_fwd_mla_decode_kernel_gfx938_dsa_prefill_nopage_64"
prefill_ans_time = bench_result.get_kernel_time(prefill_kernel_name)
prefill_flops = flops_and_mem_vol.fwd_flop/prefill_ans_time/1e12
prefill_mem_bw = flops_and_mem_vol.fwd_mem_vol/prefill_ans_time/1e12
print(f"Prefill: {prefill_ans_time*1e6:4.0f} us, {prefill_flops:6.1f} TFlops, {prefill_mem_bw:4.2f} TBps")
......@@ -44,8 +69,9 @@ def run_test(p: TestParam) -> bool:
is_correct = True
is_correct &= kk.check_is_allclose("out", prefill_ans_out.float(), ref_out_fp32, abs_tol=8e-4, rel_tol=3.01/128, cos_diff_tol=7e-6)
is_correct &= kk.check_is_allclose("max_logits", prefill_ans_max_logits, ref_max_logits, abs_tol=1e-6, rel_tol=2.01/65536)
is_correct &= kk.check_is_allclose("lse", prefill_ans_lse, ref_lse, abs_tol=1e-6, rel_tol=2.01/65536)
if not is_dsa_mls_prefill_case(p):
is_correct &= kk.check_is_allclose("max_logits", prefill_ans_max_logits, ref_max_logits, abs_tol=1e-6, rel_tol=2.01/65536)
is_correct &= kk.check_is_allclose("lse", prefill_ans_lse, ref_lse, abs_tol=1e-6, rel_tol=2.01/65536)
return is_correct
else:
......@@ -187,4 +213,3 @@ if __name__ == '__main__':
sys.exit(1)
else:
print(f"\033[32m\033[1mAll {len(testcases)} cases passed!\033[0m")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment