Commit 2c35de66 authored by shenzhe's avatar shenzhe Committed by zhanghj2
Browse files

Add DSA BF16 sparse decode support

parent a1eef562
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "params.h" #include "params.h"
#include "gfx93/decode/sparse_fp8/splitkv_mla.h" #include "gfx93/decode/sparse_fp8/splitkv_mla.h"
#include "gfx93/decode/sparse_bf16_dsa/fwd.h"
#include "gfx9/decode/get_decoding_sched_meta/get_decoding_sched_meta.h" #include "gfx9/decode/get_decoding_sched_meta/get_decoding_sched_meta.h"
#include "gfx9/decode/combine/combine.h" #include "gfx9/decode/combine/combine.h"
...@@ -123,6 +124,14 @@ sparse_attn_decode_interface( ...@@ -123,6 +124,14 @@ sparse_attn_decode_interface(
bool have_extra_topk_length = extra_topk_length.has_value(); bool have_extra_topk_length = extra_topk_length.has_value();
bool have_attn_sink = attn_sink.has_value(); bool have_attn_sink = attn_sink.has_value();
if (kv.dtype() == torch::kBFloat16) {
return gfx93::decode::sparse_bf16_dsa::run(
q, kv, indices, topk_length, attn_sink,
tile_scheduler_metadata, num_splits,
extra_kv, extra_indices, extra_topk_length,
d_v, sm_scale);
}
int extra_num_blocks = 0, extra_page_block_size = 0, extra_topk = 0; int extra_num_blocks = 0, extra_page_block_size = 0, extra_topk = 0;
if (have_extra_kcache) { if (have_extra_kcache) {
extra_num_blocks = extra_kv->size(0); extra_num_blocks = extra_kv->size(0);
......
#include "fwd.h"
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/extension.h>
#include <algorithm>
#include <cstring>
#include <limits>
#include <optional>
#include <tuple>
#include "kerutils/supplemental/torch_tensors.h"
#include "gfx93/prefill/sparse/dsa_mls/dispatch.h"
namespace gfx93::decode::sparse_bf16_dsa {
static constexpr float LOG_2_E = 1.44269504f;
struct LocalArch {
int num_sms;
std::string arch_name;
LocalArch() {
auto* props = at::cuda::getCurrentDeviceProperties();
num_sms = props->multiProcessorCount;
arch_name = props->gcnArchName;
}
bool is_gfx93x() const {
const auto base = arch_name.substr(0, arch_name.find(':'));
return base == "gfx936" || base == "gfx938";
}
};
static int int64_stride_to_int(int64_t stride) {
TORCH_CHECK(stride <= std::numeric_limits<int>::max(), "DSA BF16 sparse decode stride exceeds int32 limit: ", stride);
return static_cast<int>(stride);
}
static int default_num_splits(int topk, int extra_topk) {
if (extra_topk > 0) {
return 2;
}
if (topk == 1024) return 16;
if (topk == 512) return 8;
return 1;
}
static void check_optional_extra(
const std::optional<at::Tensor>& extra_kv,
const std::optional<at::Tensor>& extra_indices,
const std::optional<at::Tensor>& extra_topk_length) {
if (extra_kv.has_value()) {
TORCH_CHECK(extra_indices.has_value(), "extra_indices_in_kvcache must be provided when extra_k_cache is provided");
} else {
TORCH_CHECK(!extra_indices.has_value(), "extra_indices_in_kvcache must not be provided when extra_k_cache is not provided");
TORCH_CHECK(!extra_topk_length.has_value(), "extra_topk_length must not be provided when extra_k_cache is not provided");
}
}
std::tuple<at::Tensor, at::Tensor, std::optional<at::Tensor>, std::optional<at::Tensor>>
run(
const at::Tensor& q,
const at::Tensor& kv,
const at::Tensor& indices,
const std::optional<at::Tensor>& topk_length,
const std::optional<at::Tensor>& attn_sink,
std::optional<at::Tensor>& tile_scheduler_metadata,
std::optional<at::Tensor>& num_splits,
const std::optional<at::Tensor>& extra_kv,
const std::optional<at::Tensor>& extra_indices,
const std::optional<at::Tensor>& extra_topk_length,
int d_v,
float sm_scale) {
LocalArch arch;
TORCH_CHECK(arch.is_gfx93x(), "DSA BF16 sparse decode is only supported on gfx936/gfx938");
KU_CHECK_NDIM(q, 4);
KU_CHECK_NDIM(kv, 4);
KU_CHECK_NDIM(indices, 3);
if (extra_kv.has_value()) KU_CHECK_NDIM(extra_kv, 4);
if (extra_indices.has_value()) KU_CHECK_NDIM(extra_indices, 3);
const int b = q.size(0);
const int s_q = q.size(1);
const int h_q = q.size(2);
const int d_qk = q.size(3);
const int page_block_size = kv.size(1);
const int h_kv = kv.size(2);
const int topk = indices.size(2);
const bool has_extra = extra_kv.has_value() && extra_indices.has_value() &&
extra_kv->numel() > 0 && extra_indices->numel() > 0 &&
extra_indices->size(2) > 0;
const int extra_topk = has_extra ? extra_indices->size(2) : 0;
TORCH_CHECK(b > 0 && s_q > 0 && h_q > 0, "Invalid q shape for DSA BF16 sparse decode");
TORCH_CHECK(h_kv == 1, "DSA BF16 sparse decode only supports h_kv == 1");
TORCH_CHECK(h_q == 64 || h_q == 128, "DSA BF16 sparse decode only supports h_q == 64 or 128");
TORCH_CHECK(d_qk == 512 || d_qk == 576, "DSA BF16 sparse decode only supports d_qk == 512 or 576");
TORCH_CHECK(d_v == 512, "DSA BF16 sparse decode only supports d_v == 512");
TORCH_CHECK(topk > 0, "topk must be positive");
if (has_extra) {
TORCH_CHECK(topk <= 256, "DSA BF16 sparse decode with extra_kv supports topk <= 256");
TORCH_CHECK(extra_topk <= 1024, "DSA BF16 sparse decode supports extra_topk <= 1024");
TORCH_CHECK(extra_kv->size(1) > 0, "extra page_block_size must be positive");
TORCH_CHECK(extra_kv->size(2) == h_kv, "extra_kv h_kv must match kv h_kv");
TORCH_CHECK(extra_kv->size(3) == d_qk, "extra_kv d_qk must match q d_qk");
} else {
TORCH_CHECK(topk <= 1024, "DSA BF16 sparse decode supports topk <= 1024");
}
check_optional_extra(extra_kv, extra_indices, extra_topk_length);
KU_CHECK_DEVICE(q);
KU_CHECK_DEVICE(kv);
KU_CHECK_DEVICE(indices);
KU_CHECK_DEVICE(topk_length);
KU_CHECK_DEVICE(attn_sink);
KU_CHECK_DEVICE(tile_scheduler_metadata);
KU_CHECK_DEVICE(num_splits);
KU_CHECK_DEVICE(extra_kv);
KU_CHECK_DEVICE(extra_indices);
KU_CHECK_DEVICE(extra_topk_length);
KU_CHECK_DTYPE(q, torch::kBFloat16);
KU_CHECK_DTYPE(kv, torch::kBFloat16);
KU_CHECK_DTYPE(indices, torch::kInt32);
KU_CHECK_DTYPE(topk_length, torch::kInt32);
KU_CHECK_DTYPE(attn_sink, torch::kFloat32);
KU_CHECK_DTYPE(tile_scheduler_metadata, torch::kInt32);
KU_CHECK_DTYPE(num_splits, torch::kInt32);
KU_CHECK_DTYPE(extra_kv, torch::kBFloat16);
KU_CHECK_DTYPE(extra_indices, torch::kInt32);
KU_CHECK_DTYPE(extra_topk_length, torch::kInt32);
KU_CHECK_LAST_DIM_CONTIGUOUS(q);
KU_CHECK_LAST_DIM_CONTIGUOUS(kv);
KU_CHECK_LAST_DIM_CONTIGUOUS(indices);
KU_CHECK_CONTIGUOUS(topk_length);
KU_CHECK_CONTIGUOUS(attn_sink);
KU_CHECK_LAST_DIM_CONTIGUOUS(extra_kv);
KU_CHECK_LAST_DIM_CONTIGUOUS(extra_indices);
KU_CHECK_CONTIGUOUS(extra_topk_length);
KU_CHECK_SHAPE(q, b, s_q, h_q, d_qk);
KU_CHECK_SHAPE(kv, kv.size(0), page_block_size, h_kv, d_qk);
KU_CHECK_SHAPE(indices, b, s_q, topk);
KU_CHECK_SHAPE(topk_length, b);
KU_CHECK_SHAPE(attn_sink, h_q);
if (has_extra) {
KU_CHECK_SHAPE(extra_indices, b, s_q, extra_topk);
KU_CHECK_SHAPE(extra_topk_length, b);
}
at::Tensor indices_for_dsa = indices.unsqueeze(2);
at::Tensor extra_indices_for_dsa;
if (has_extra) {
extra_indices_for_dsa = extra_indices->unsqueeze(2);
}
c10::cuda::CUDAGuard device_guard{q.device()};
auto opts = q.options();
at::Tensor out = torch::empty({b, s_q, h_q, d_v}, opts);
at::Tensor lse = torch::empty({b, h_q, s_q}, opts.dtype(at::kFloat));
at::Tensor scores_memory = torch::empty({2, b, h_kv, s_q * h_q}, opts.dtype(at::kFloat));
at::Tensor scores_max = scores_memory.select(0, 0);
at::Tensor scores_sum = scores_memory.select(0, 1);
if (!num_splits.has_value()) {
const int split = default_num_splits(topk, extra_topk);
num_splits = torch::empty({1}, opts.dtype(torch::kInt32));
num_splits->fill_(split);
}
KU_CHECK_DTYPE(num_splits, torch::kInt32);
KU_CHECK_DEVICE(num_splits);
KU_CHECK_CONTIGUOUS(num_splits);
TORCH_CHECK(num_splits->numel() == 1, "DSA BF16 sparse decode expects num_splits to be a scalar tensor");
const int requested_num_splits = num_splits->item<int>();
TORCH_CHECK(requested_num_splits >= 1 && requested_num_splits <= 64, "DSA BF16 sparse decode requires 1 <= num_splits <= 64");
Flash_fwd_mla_params_dsa params;
std::memset(&params, 0, sizeof(params));
params.layout = 1;
params.b = b;
params.h = h_kv;
params.h_k = h_kv;
params.h_h_k_ratio = 1;
params.mtp = 1;
params.ngroups = h_q / h_kv;
params.topk = topk;
params.extra_topk = has_extra ? extra_topk : 0;
params.d = d_qk;
params.d_v = d_v;
params.scale_softmax = sm_scale;
params.scale_softmax_log2 = sm_scale * LOG_2_E;
params.topk_length = ku::get_optional_tensor_ptr<int>(topk_length);
params.extra_topk_length = ku::get_optional_tensor_ptr<int>(extra_topk_length);
params.attn_sink = ku::get_optional_tensor_ptr<float>(attn_sink);
params.q_ptr = q.data_ptr();
params.k_ptr = kv.data_ptr();
params.v_ptr = kv.data_ptr();
params.extra_k_ptr = has_extra ? extra_kv->data_ptr() : nullptr;
params.extra_v_ptr = has_extra ? extra_kv->data_ptr() : nullptr;
params.o_ptr = out.data_ptr();
params.sparse_indices = reinterpret_cast<int*>(indices_for_dsa.data_ptr());
params.extra_sparse_indices = has_extra ? reinterpret_cast<int*>(extra_indices_for_dsa.data_ptr()) : nullptr;
params.softmax_lse_ptr = lse.data_ptr<float>();
params.scores_max_ptr = scores_max.data_ptr<float>();
params.scores_sum_ptr = scores_sum.data_ptr<float>();
params.page_block_size = page_block_size;
params.extra_page_block_size = has_extra ? extra_kv->size(1) : 0;
params.is_causal = false;
params.q_batch_stride = int64_stride_to_int(q.stride(0));
params.q_token_stride = int64_stride_to_int(q.stride(1));
params.q_row_stride = int64_stride_to_int(q.stride(2));
params.q_head_stride = int64_stride_to_int(q.stride(2));
params.k_batch_stride = int64_stride_to_int(kv.stride(0));
params.k_row_stride = int64_stride_to_int(kv.stride(1));
params.k_head_stride = int64_stride_to_int(kv.stride(2));
params.v_batch_stride = params.k_batch_stride;
params.v_row_stride = params.k_row_stride;
params.v_head_stride = params.k_head_stride;
params.extra_k_batch_stride = has_extra ? int64_stride_to_int(extra_kv->stride(0)) : 0;
params.extra_k_row_stride = has_extra ? int64_stride_to_int(extra_kv->stride(1)) : 0;
params.extra_v_batch_stride = params.extra_k_batch_stride;
params.extra_v_row_stride = params.extra_k_row_stride;
params.sparse_indices_batch_stride = int64_stride_to_int(indices_for_dsa.stride(0));
params.sparse_indices_row_stride = int64_stride_to_int(indices_for_dsa.stride(1));
params.sparse_indices_head_stride = int64_stride_to_int(indices_for_dsa.stride(2));
params.sparse_indices_topk_stride = int64_stride_to_int(indices_for_dsa.stride(3));
params.extra_sparse_indices_batch_stride = has_extra ? int64_stride_to_int(extra_indices_for_dsa.stride(0)) : 0;
params.extra_sparse_indices_row_stride = has_extra ? int64_stride_to_int(extra_indices_for_dsa.stride(1)) : 0;
params.extra_sparse_indices_head_stride = has_extra ? int64_stride_to_int(extra_indices_for_dsa.stride(2)) : 0;
params.extra_sparse_indices_topk_stride = has_extra ? int64_stride_to_int(extra_indices_for_dsa.stride(3)) : 0;
params.o_batch_stride = int64_stride_to_int(out.stride(0));
params.o_row_stride = int64_stride_to_int(out.stride(1));
params.o_head_stride = int64_stride_to_int(out.stride(2));
params.seqlen_q = s_q * params.ngroups;
params.seqlen_k = kv.size(0) * kv.size(1);
params.max_seqlen = s_q;
params.is_bf16 = true;
params.is_e4m3 = false;
params.is_int8 = false;
params.cu_count = arch.num_sms;
params.seqlenq_ngroups_swapped = true;
params.is_seqlens_k_cumulative = false;
params.splitkv_use_fp32_as_accum = false;
params.num_splits = requested_num_splits;
params.partition_size = topk + params.extra_topk;
if (params.num_splits > 1) {
params.partition_size = std::max(64, (params.partition_size + params.num_splits - 1) / params.num_splits);
params.partition_size = ((params.partition_size + 63) / 64) * 64;
}
at::Tensor out_accum;
at::Tensor lse_accum;
if (params.num_splits > 1) {
lse_accum = torch::empty({params.num_splits, b, h_kv, params.seqlen_q}, opts.dtype(at::kFloat));
out_accum = torch::empty({params.num_splits, b, s_q, h_q, d_v}, opts);
params.softmax_lse_ptr = lse_accum.data_ptr<float>();
params.oaccum_ptr = out_accum.data_ptr();
}
hipStream_t stream = reinterpret_cast<hipStream_t>(at::cuda::getCurrentCUDAStream().stream());
if (d_qk == 512) {
gfx93::fwd::dsa_mls::run_dsa_prefill_nopage_64_dispatch<BFloat16, 512, 512>(params, stream);
} else {
gfx93::fwd::dsa_mls::run_dsa_prefill_nopage_64_dispatch<BFloat16, 576, 512>(params, stream);
}
return {out, lse, tile_scheduler_metadata, num_splits};
}
} // namespace gfx93::decode::sparse_bf16_dsa
#pragma once
#include <ATen/core/Tensor.h>
#include <optional>
#include <tuple>
namespace gfx93::decode::sparse_bf16_dsa {
std::tuple<at::Tensor, at::Tensor, std::optional<at::Tensor>, std::optional<at::Tensor>>
run(
const at::Tensor& q,
const at::Tensor& kv,
const at::Tensor& indices,
const std::optional<at::Tensor>& topk_length,
const std::optional<at::Tensor>& attn_sink,
std::optional<at::Tensor>& tile_scheduler_metadata,
std::optional<at::Tensor>& num_splits,
const std::optional<at::Tensor>& extra_kv,
const std::optional<at::Tensor>& extra_indices,
const std::optional<at::Tensor>& extra_topk_length,
int d_v,
float sm_scale);
} // namespace gfx93::decode::sparse_bf16_dsa
...@@ -9,8 +9,68 @@ ...@@ -9,8 +9,68 @@
#include "legacy/include/static_switch.h" #include "legacy/include/static_switch.h"
#include "legacy/src/flash_fwd_b16_mla.h" #include "legacy/src/flash_fwd_b16_mla.h"
#include "legacy/src/flash_fwd_reduce.h"
namespace gfx93::fwd::dsa_mls { namespace gfx93::fwd::dsa_mls {
template<typename Kernel_traits, const bool Tail, typename Params>
void run_dsa_mla_splitkv_reduce(Params& params, hipStream_t stream) {
static_assert(Kernel_traits::kHeadDimV == 512,
"run_dsa_mla_splitkv_reduce only supports hdimv == 512");
using Element = typename Kernel_traits::Element;
using SplitkvAccumType = typename Kernel_traits::SplitkvAccumType;
Flash_fwd_mla_reduce_params reduce_params;
reduce_params.softmax_lse_ptr = params.softmax_lse_ptr;
reduce_params.oaccum_ptr = params.oaccum_ptr;
reduce_params.o_ptr = params.o_ptr;
reduce_params.cu_seqlens_k = params.cu_seqlens_k;
reduce_params.num_splits = params.num_splits;
reduce_params.partition_size = params.partition_size;
reduce_params.h = params.h;
reduce_params.ngroups = params.ngroups;
reduce_params.seqlen_q = params.seqlen_q;
reduce_params.layout = params.layout;
reduce_params.topk_length = params.topk_length;
reduce_params.attn_sink = params.attn_sink;
reduce_params.extra_topk_length = params.extra_topk_length;
reduce_params.topk = params.topk;
reduce_params.extra_topk = params.extra_topk;
if (params.num_splits > 1) {
dim3 block(256);
dim3 grid(params.b * params.h * params.seqlen_q, 4);
constexpr int MAX_NUM_SPLITS = 64;
if (params.num_splits > MAX_NUM_SPLITS) {
printf("\x1b[31mnum_splits %d is larger than limit %d, and thus won't execute the kernel\033[0m\n",
params.num_splits, MAX_NUM_SPLITS);
return;
}
if (params.num_splits == 2) {
::flash_mla_splitkv_reduce_kernel<SplitkvAccumType, Element, 2, true, Tail, Kernel_traits::kHeadDimV>
<<<grid, block, 0, stream>>>(reduce_params);
} else if (params.num_splits == 4) {
::flash_mla_splitkv_reduce_kernel<SplitkvAccumType, Element, 4, true, Tail, Kernel_traits::kHeadDimV>
<<<grid, block, 0, stream>>>(reduce_params);
} else if (params.num_splits == 8) {
::flash_mla_splitkv_reduce_kernel<SplitkvAccumType, Element, 8, true, Tail, Kernel_traits::kHeadDimV>
<<<grid, block, 0, stream>>>(reduce_params);
} else if (params.num_splits == 16) {
::flash_mla_splitkv_reduce_kernel<SplitkvAccumType, Element, 16, true, Tail, Kernel_traits::kHeadDimV>
<<<grid, block, 0, stream>>>(reduce_params);
} else if (params.num_splits == 32) {
::flash_mla_splitkv_reduce_kernel<SplitkvAccumType, Element, 32, true, Tail, Kernel_traits::kHeadDimV>
<<<grid, block, 0, stream>>>(reduce_params);
} else if (params.num_splits == 64) {
::flash_mla_splitkv_reduce_kernel<SplitkvAccumType, Element, 64, true, Tail, Kernel_traits::kHeadDimV>
<<<grid, block, 0, stream>>>(reduce_params);
} else {
printf("\x1b[31mnum_splits %d is not supported yet, and thus won't execute the kernel\033[0m\n",
params.num_splits);
}
}
}
template<typename T, int Headdim, int HeaddimV> template<typename T, int Headdim, int HeaddimV>
void run_dsa_prefill_nopage_64_dispatch(Flash_fwd_mla_params_dsa& params, hipStream_t stream) { void run_dsa_prefill_nopage_64_dispatch(Flash_fwd_mla_params_dsa& params, hipStream_t stream) {
constexpr int kBlockM = 64; constexpr int kBlockM = 64;
...@@ -34,6 +94,33 @@ void run_dsa_prefill_nopage_64_dispatch(Flash_fwd_mla_params_dsa& params, hipStr ...@@ -34,6 +94,33 @@ void run_dsa_prefill_nopage_64_dispatch(Flash_fwd_mla_params_dsa& params, hipStr
constexpr bool Is_dropout = false; constexpr bool Is_dropout = false;
constexpr bool IsEvenMNConst = false; constexpr bool IsEvenMNConst = false;
constexpr int REUSE_KV = 1;
const bool has_extra = params.extra_sparse_indices != nullptr && params.extra_topk > 0;
if (params.num_splits == 1) {
BOOL_SWITCH(params.mtp > 1, Is_MTP, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
BOOL_SWITCH(has_extra, Has_extra, [&] {
flash::flash_fwd_mla_decode_kernel_gfx938_dsa_nopage_64<
Kernel_traits, true, Is_dropout, false, Is_causal,
IsEvenMNConst, true, false, Is_MTP, 0, Has_extra, Flash_fwd_mla_params_dsa>
<<<dimGrid, dimBlock, 21 * 1024, stream>>>(params);
});
});
});
} else if (params.num_splits != 0) {
dimGrid.y = params.num_splits;
BOOL_SWITCH(params.mtp > 1, Is_MTP, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
BOOL_SWITCH(has_extra, Has_extra, [&] {
flash::flash_fwd_mla_decode_kernel_gfx938_dsa_nopage_64_splitkv<
Kernel_traits, true, Is_dropout, false, Is_causal,
IsEvenMNConst, true, false, Is_MTP, 0, Has_extra, Flash_fwd_mla_params_dsa>
<<<dimGrid, dimBlock, 21 * 1024, stream>>>(params);
});
});
});
run_dsa_mla_splitkv_reduce<Kernel_traits, false>(params, stream);
} else {
BOOL_SWITCH(params.mtp > 1, Is_MTP, [&] { BOOL_SWITCH(params.mtp > 1, Is_MTP, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if (params.topk == 2048) { if (params.topk == 2048) {
...@@ -49,6 +136,7 @@ void run_dsa_prefill_nopage_64_dispatch(Flash_fwd_mla_params_dsa& params, hipStr ...@@ -49,6 +136,7 @@ void run_dsa_prefill_nopage_64_dispatch(Flash_fwd_mla_params_dsa& params, hipStr
} }
}); });
}); });
}
} }
} // namespace gfx93::fwd::dsa_mls } // namespace gfx93::fwd::dsa_mls
...@@ -450,6 +450,7 @@ struct Flash_fwd_mla_reduce_params { ...@@ -450,6 +450,7 @@ struct Flash_fwd_mla_reduce_params {
int num_splits; int num_splits;
int partition_size; int partition_size;
int h; int h;
int ngroups;
int seqlen_q; int seqlen_q;
int layout; int layout;
float* attn_sink; float* attn_sink;
......
...@@ -2503,25 +2503,9 @@ __forceinline__ __device__ void flash_fwd_mla_decode_kernel_gfx938_dsa_prefill_n ...@@ -2503,25 +2503,9 @@ __forceinline__ __device__ void flash_fwd_mla_decode_kernel_gfx938_dsa_prefill_n
/**************************************************************************************************************************************/ /**************************************************************************************************************************************/
constexpr bool Is_Interleave = true; constexpr bool Is_Interleave = true;
int lane_id = threadIdx.x & 63; int lane_id = threadIdx.x & 63;
#pragma unroll
for (int mi = 0; mi < WARP_M / 16; ++mi) {
if (real_topk == 0) {
lse[mi].f32[0] = __builtin_inff();
}
}
if (params.softmax_lse_ptr != nullptr) { if (params.softmax_lse_ptr != nullptr) {
prefill_mla_epilogue_store_lse<WARP_M, Is_even_MN, false/*SplitD*/, Is_Interleave, ElementAccum, vec_Accum<ElementAccum>, 1/* M_MMAC_COUNT */>(lse, params.softmax_lse_ptr, row_offset_lse, warp_id, lane_id, 0, actual_seqlen_q - m_block * kBlockM); prefill_mla_epilogue_store_lse<WARP_M, Is_even_MN, false/*SplitD*/, Is_Interleave, ElementAccum, vec_Accum<ElementAccum>, 1/* M_MMAC_COUNT */>(lse, params.softmax_lse_ptr, row_offset_lse, warp_id, lane_id, 0, actual_seqlen_q - m_block * kBlockM);
} }
if (params.scores_max_ptr != nullptr) {
float* scores_max_ptr = params.scores_max_ptr + row_offset_lse;
const bool write_scores_max = (lane_id >> 4) == 0;
if (write_scores_max) {
const int row = warp_id * WARP_M + (lane_id & 15);
if (row < actual_seqlen_q - m_block * kBlockM) {
scores_max_ptr[row] = scores_max[0].f32[0] * params.scale_softmax;
}
}
}
/**************************************************************************************************************************************/ /**************************************************************************************************************************************/
Element* o_ptr = reinterpret_cast<Element *>(params.o_ptr) + row_offset_o; Element* o_ptr = reinterpret_cast<Element *>(params.o_ptr) + row_offset_o;
prefill_mla_epilogue_store_output<kHeadDimV, kBlockM, kBlockK, WARP_M, Is_even_MN, Is_Interleave, false/*TcpSwizzle*/, Element, ElementAccum, 1/* M_MMAC_COUNT */>(o_ptr, acc_o, m_block, warp_id, lane_id, seqlen_o_stride, actual_seqlen_q); prefill_mla_epilogue_store_output<kHeadDimV, kBlockM, kBlockK, WARP_M, Is_even_MN, Is_Interleave, false/*TcpSwizzle*/, Element, ElementAccum, 1/* M_MMAC_COUNT */>(o_ptr, acc_o, m_block, warp_id, lane_id, seqlen_o_stride, actual_seqlen_q);
...@@ -2539,7 +2523,7 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa ...@@ -2539,7 +2523,7 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa
flash_fwd_mla_decode_kernel_gfx938_dsa_prefill_nopage_64_impl<Kernel_traits, Is_training, Is_dropout, Is_prefix, Is_causal, Is_even_MN, Is_even_K, Return_softmax, Is_MTP, Layout, false, Params>(params); flash_fwd_mla_decode_kernel_gfx938_dsa_prefill_nopage_64_impl<Kernel_traits, Is_training, Is_dropout, Is_prefix, Is_causal, Is_even_MN, Is_even_K, Return_softmax, Is_MTP, Layout, false, Params>(params);
} }
template<typename Kernel_traits, bool Is_training, bool Is_dropout, bool Is_prefix, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Return_softmax, bool Is_MTP, int Layout, typename Params> template<typename Kernel_traits, bool Is_training, bool Is_dropout, bool Is_prefix, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Return_softmax, bool Is_MTP, int Layout, bool Has_extra, typename Params>
__global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa_nopage_64(const Params params) { __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa_nopage_64(const Params params) {
#if defined(__gfx938__) #if defined(__gfx938__)
using Element = typename Kernel_traits::Element; using Element = typename Kernel_traits::Element;
...@@ -2576,7 +2560,7 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa ...@@ -2576,7 +2560,7 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa
const int warp_offset_in_seq_q = m_block * kBlockM + warp_id * WARP_M; const int warp_offset_in_seq_q = m_block * kBlockM + warp_id * WARP_M;
const int warp_seqq_limit = Is_even_MN ? 0: actual_seqlen_q - m_block * kBlockM; const int warp_seqq_limit = Is_even_MN ? 0: actual_seqlen_q - m_block * kBlockM;
const bool has_extra = params.extra_sparse_indices != nullptr && params.extra_topk > 0; constexpr bool has_extra = Has_extra;
// 分配 lds Q/P same place, K/V same place; // 分配 lds Q/P same place, K/V same place;
// extern __shared__ Element smem[]; // extern __shared__ Element smem[];
// int* index_lds = (int *)&(smem); // int* index_lds = (int *)&(smem);
...@@ -2588,7 +2572,7 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa ...@@ -2588,7 +2572,7 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa
Element* k_lds = q_lds; // 16KB Element* k_lds = q_lds; // 16KB
Element* v_lds = q_lds; Element* v_lds = q_lds;
int* index_lds = (int *)(q_lds + 8 * 1024); int* index_lds = (int *)(q_lds + 8 * 1024);
int* extra_index_lds = (int *)(index_lds + 1024); int* extra_index_lds = (int *)(index_lds + 256);
// int* sIndices = (int *)(q_lds + 8192); // int* sIndices = (int *)(q_lds + 8192);
// 计算当前任务沿着 seqlenKV 方向的 block 起始位置和终止位置 // 计算当前任务沿着 seqlenKV 方向的 block 起始位置和终止位置
...@@ -2601,7 +2585,7 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa ...@@ -2601,7 +2585,7 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa
int extra_topk_length = params.extra_topk_length ? params.extra_topk_length[bidb] : params.extra_topk; int extra_topk_length = params.extra_topk_length ? params.extra_topk_length[bidb] : params.extra_topk;
int main_num_blocks = ceil_div(main_topk_length, kBlockN); int main_num_blocks = ceil_div(main_topk_length, kBlockN);
int extra_num_blocks = has_extra ? ceil_div(extra_topk_length, kBlockN) : 0; int extra_num_blocks = Has_extra ? ceil_div(extra_topk_length, kBlockN) : 0;
// 计算数据跨度 // 计算数据跨度
int seqlen_q_stride = params.q_head_stride; int seqlen_q_stride = params.q_head_stride;
...@@ -2619,12 +2603,13 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa ...@@ -2619,12 +2603,13 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa
+ bidb * params.sparse_indices_batch_stride + bidb * params.sparse_indices_batch_stride
+ query_idx * params.sparse_indices_row_stride + query_idx * params.sparse_indices_row_stride
+ (q_head_start / params.ngroups) * params.sparse_indices_head_stride; + (q_head_start / params.ngroups) * params.sparse_indices_head_stride;
int* extra_index_ptr = has_extra int* extra_index_ptr = nullptr;
? params.extra_sparse_indices if constexpr (Has_extra) {
extra_index_ptr = params.extra_sparse_indices
+ bidb * params.extra_sparse_indices_batch_stride + bidb * params.extra_sparse_indices_batch_stride
+ query_idx * params.extra_sparse_indices_row_stride + query_idx * params.extra_sparse_indices_row_stride
+ (q_head_start / params.ngroups) * params.extra_sparse_indices_head_stride + (q_head_start / params.ngroups) * params.extra_sparse_indices_head_stride;
: nullptr; }
// const int block_table_idx = 0; // const int block_table_idx = 0;
// const int block_table_offset = 0; // const int block_table_offset = 0;
...@@ -2645,45 +2630,52 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa ...@@ -2645,45 +2630,52 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa
qv_ptr = prepare_for_matrix_load<kHeadDim>(reinterpret_cast<Element*>(params.qv_ptr) + row_offset_qv); qv_ptr = prepare_for_matrix_load<kHeadDim>(reinterpret_cast<Element*>(params.qv_ptr) + row_offset_qv);
auto k_ptr = prepare_for_matrix_load<kHeadDim>(reinterpret_cast<Element*>(params.k_ptr) + row_offset_k); auto k_ptr = prepare_for_matrix_load<kHeadDim>(reinterpret_cast<Element*>(params.k_ptr) + row_offset_k);
auto k_ptr_buffer = tcp_cache_swizzle_func<64, Element>(reinterpret_cast<Element*>(params.k_ptr) + row_offset_k); auto k_ptr_buffer = tcp_cache_swizzle_func<64, Element>(reinterpret_cast<Element*>(params.k_ptr) + row_offset_k);
auto extra_k_ptr_buffer = has_extra ? tcp_cache_swizzle_func<64, Element>(reinterpret_cast<Element*>(params.extra_k_ptr)) : k_ptr_buffer; auto extra_k_ptr_buffer = tcp_cache_swizzle_func<64, Element>(reinterpret_cast<Element*>(Has_extra ? params.extra_k_ptr : params.k_ptr));
auto v_ptr = prepare_for_matrix_load<kHeadDimV>(reinterpret_cast<Element*>(params.v_ptr) + row_offset_v); auto v_ptr = prepare_for_matrix_load<kHeadDimV>(reinterpret_cast<Element*>(params.v_ptr) + row_offset_v);
auto v_ptr_buffer = tcp_cache_swizzle_func<64, Element>(reinterpret_cast<Element*>(params.v_ptr) + row_offset_v); auto v_ptr_buffer = tcp_cache_swizzle_func<64, Element>(reinterpret_cast<Element*>(params.v_ptr) + row_offset_v);
auto extra_v_ptr_buffer = has_extra ? tcp_cache_swizzle_func<64, Element>(reinterpret_cast<Element*>(params.extra_v_ptr)) : v_ptr_buffer; auto extra_v_ptr_buffer = tcp_cache_swizzle_func<64, Element>(reinterpret_cast<Element*>(Has_extra ? params.extra_v_ptr : params.v_ptr));
auto index_ptr_buffer = tcp_cache_swizzle_func<0, int>(reinterpret_cast<int*>(index_ptr)); auto index_ptr_buffer = tcp_cache_swizzle_func<0, int>(reinterpret_cast<int*>(index_ptr));
int tid = threadIdx.x % 64; int tid = threadIdx.x % 64;
if (has_extra) { if constexpr (Has_extra) {
auto extra_index_ptr_buffer = tcp_cache_swizzle_func<0, int>(extra_index_ptr); auto extra_index_ptr_buffer = tcp_cache_swizzle_func<0, int>(extra_index_ptr);
int lds_offset = __builtin_amdgcn_readfirstlane(warp_id * 64); int lds_offset = __builtin_amdgcn_readfirstlane(warp_id * 4 * 64);
int g_offset_v = tid; int g_offset_v = tid * 4;
int g_offset_s = warp_id * 64; int g_offset_s = warp_id * 256;
inline_buffer_load_dword_lds(extra_index_lds, extra_index_ptr_buffer, lds_offset, g_offset_s, g_offset_v); inline_buffer_load_dwordx4_lds(extra_index_lds, extra_index_ptr_buffer, lds_offset, g_offset_s, g_offset_v);
} }
auto k_faker = reinterpret_cast<Element*>(params.k_ptr) + row_offset_k; auto k_faker = reinterpret_cast<Element*>(params.k_ptr) + row_offset_k;
auto v_faker = reinterpret_cast<Element*>(params.v_ptr) + row_offset_v; auto v_faker = reinterpret_cast<Element*>(params.v_ptr) + row_offset_v;
// apply causal mask 的步骤和 no causal mask 的步骤分开算 // apply causal mask 的步骤和 no causal mask 的步骤分开算
constexpr int n_masking_steps = 1; constexpr int n_masking_steps = 1;
int lds_offset = __builtin_amdgcn_readfirstlane(warp_id * 4 * 64); int lds_offset = __builtin_amdgcn_readfirstlane(Has_extra ? warp_id * 64 : warp_id * 4 * 64);
int g_offset_v = tid * 4; int g_offset_v = Has_extra ? tid : tid * 4;
int g_offset_s = warp_id * 256; int g_offset_s = Has_extra ? warp_id * 64 : warp_id * 256;
if constexpr (Has_extra) {
inline_buffer_load_dword_lds(index_lds, index_ptr_buffer, lds_offset, g_offset_s, g_offset_v);
} else {
inline_buffer_load_dwordx4_lds(index_lds, index_ptr_buffer, lds_offset, g_offset_s, g_offset_v); inline_buffer_load_dwordx4_lds(index_lds, index_ptr_buffer, lds_offset, g_offset_s, g_offset_v);
}
flash::wait_buffer_data_arrived<true>(0); flash::wait_buffer_data_arrived<true>(0);
flash::wait_all_warp_arrived(); flash::wait_all_warp_arrived();
#pragma unroll #pragma unroll
for (int i = 0; i < 4; ++i) { for (int i = 0; i < 4; ++i) {
const int local_index = warp_id * 256 + tid * 4 + i; const int local_index = Has_extra ? warp_id * 64 + tid : warp_id * 256 + tid * 4 + i;
if (local_index >= main_topk_length) { if (local_index >= main_topk_length) {
index_lds[local_index] = main_topk_length > 0 ? index_lds[main_topk_length - 1] : -1; index_lds[local_index] = main_topk_length > 0 ? index_lds[main_topk_length - 1] : -1;
} }
} }
if (has_extra) { if constexpr (Has_extra) {
const int local_index = warp_id * 64 + tid; #pragma unroll
for (int i = 0; i < 4; ++i) {
const int local_index = warp_id * 256 + tid * 4 + i;
if (local_index >= extra_topk_length) { if (local_index >= extra_topk_length) {
extra_index_lds[local_index] = extra_topk_length > 0 ? extra_index_lds[extra_topk_length - 1] : -1; extra_index_lds[local_index] = extra_topk_length > 0 ? extra_index_lds[extra_topk_length - 1] : -1;
} }
} }
}
flash::wait_all_warp_arrived(); flash::wait_all_warp_arrived();
// 是否做 prefetch K, PV 结束后, prefetch K 有风险 // 是否做 prefetch K, PV 结束后, prefetch K 有风险
...@@ -2862,7 +2854,7 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa ...@@ -2862,7 +2854,7 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa
#endif #endif
} }
template<typename Kernel_traits, bool Is_training, bool Is_dropout, bool Is_prefix, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Return_softmax, bool Is_MTP, int Layout, typename Params> template<typename Kernel_traits, bool Is_training, bool Is_dropout, bool Is_prefix, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Return_softmax, bool Is_MTP, int Layout, bool Has_extra, typename Params>
__global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa_nopage_64_splitkv(const Params params) { __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa_nopage_64_splitkv(const Params params) {
#if defined(__gfx938__) #if defined(__gfx938__)
using Element = typename Kernel_traits::Element; using Element = typename Kernel_traits::Element;
...@@ -2899,7 +2891,7 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa ...@@ -2899,7 +2891,7 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa
const int warp_offset_in_seq_q = m_block * kBlockM + warp_id * WARP_M; const int warp_offset_in_seq_q = m_block * kBlockM + warp_id * WARP_M;
const int warp_seqq_limit = Is_even_MN ? 0: actual_seqlen_q - m_block * kBlockM; const int warp_seqq_limit = Is_even_MN ? 0: actual_seqlen_q - m_block * kBlockM;
const bool has_extra = params.extra_sparse_indices != nullptr && params.extra_topk > 0; constexpr bool has_extra = Has_extra;
// 分配 lds Q/P same place, K/V same place; // 分配 lds Q/P same place, K/V same place;
// extern __shared__ Element smem[]; // extern __shared__ Element smem[];
// int* index_lds = (int *)&(smem); // int* index_lds = (int *)&(smem);
...@@ -2911,7 +2903,7 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa ...@@ -2911,7 +2903,7 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa
Element* k_lds = q_lds; // 16KB Element* k_lds = q_lds; // 16KB
Element* v_lds = q_lds; Element* v_lds = q_lds;
int* index_lds = (int *)(q_lds + 8 * 1024); int* index_lds = (int *)(q_lds + 8 * 1024);
int* extra_index_lds = (int *)(index_lds + 1024); int* extra_index_lds = (int *)(index_lds + 256);
// int* sIndices = (int *)(q_lds + 8192); // int* sIndices = (int *)(q_lds + 8192);
// 计算当前任务沿着 seqlenKV 方向的 block 起始位置和终止位置 // 计算当前任务沿着 seqlenKV 方向的 block 起始位置和终止位置
int split_id = blockIdx.y; int split_id = blockIdx.y;
...@@ -2925,7 +2917,7 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa ...@@ -2925,7 +2917,7 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa
int extra_topk_length = params.extra_topk_length ? params.extra_topk_length[bidb] : params.extra_topk; int extra_topk_length = params.extra_topk_length ? params.extra_topk_length[bidb] : params.extra_topk;
int main_num_blocks = ceil_div(main_topk_length, kBlockN); int main_num_blocks = ceil_div(main_topk_length, kBlockN);
int extra_num_blocks = params.extra_sparse_indices ? ceil_div(extra_topk_length, kBlockN) : 0; int extra_num_blocks = Has_extra ? ceil_div(extra_topk_length, kBlockN) : 0;
int total_num_blocks = main_num_blocks + extra_num_blocks; int total_num_blocks = main_num_blocks + extra_num_blocks;
int blocks_per_split = ceil_div(params.partition_size, kBlockN); int blocks_per_split = ceil_div(params.partition_size, kBlockN);
...@@ -2949,12 +2941,13 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa ...@@ -2949,12 +2941,13 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa
+ bidb * params.sparse_indices_batch_stride + bidb * params.sparse_indices_batch_stride
+ query_idx * params.sparse_indices_row_stride + query_idx * params.sparse_indices_row_stride
+ (q_head_start / params.ngroups) * params.sparse_indices_head_stride; + (q_head_start / params.ngroups) * params.sparse_indices_head_stride;
int* extra_index_ptr = has_extra int* extra_index_ptr = nullptr;
? params.extra_sparse_indices if constexpr (Has_extra) {
extra_index_ptr = params.extra_sparse_indices
+ bidb * params.extra_sparse_indices_batch_stride + bidb * params.extra_sparse_indices_batch_stride
+ query_idx * params.extra_sparse_indices_row_stride + query_idx * params.extra_sparse_indices_row_stride
+ (q_head_start / params.ngroups) * params.extra_sparse_indices_head_stride + (q_head_start / params.ngroups) * params.extra_sparse_indices_head_stride;
: nullptr; }
// const int block_table_idx = 0; // const int block_table_idx = 0;
// const int block_table_offset = 0; // const int block_table_offset = 0;
...@@ -2963,6 +2956,11 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa ...@@ -2963,6 +2956,11 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa
row_offset_v = 0; row_offset_v = 0;
row_offset_o = bidb * int64_t(params.o_batch_stride) + query_idx * int64_t(params.o_row_stride) + q_head_start * int64_t(params.o_head_stride); row_offset_o = bidb * int64_t(params.o_batch_stride) + query_idx * int64_t(params.o_row_stride) + q_head_start * int64_t(params.o_head_stride);
row_offset_lse = bidb * params.seqlen_q + m_block * kBlockM; row_offset_lse = bidb * params.seqlen_q + m_block * kBlockM;
const int64_t split_block_base = int64_t(bidb) * params.h * params.seqlen_q + m_block * kBlockM;
const int64_t split_accum_row_offset = int64_t(split_id) * params.b * params.h * params.seqlen_q * kHeadDimV
+ split_block_base * kHeadDimV;
const int64_t split_lse_row_offset = int64_t(split_id) * params.b * params.h * params.seqlen_q
+ split_block_base;
// row_offset_k = bidb * int64_t(params.k_batch_stride) + block_table_offset * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; // row_offset_k = bidb * int64_t(params.k_batch_stride) + block_table_offset * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
// row_offset_v = bidb * int64_t(params.v_batch_stride) + block_table_offset * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; // row_offset_v = bidb * int64_t(params.v_batch_stride) + block_table_offset * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
...@@ -2975,44 +2973,51 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa ...@@ -2975,44 +2973,51 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa
qv_ptr = prepare_for_matrix_load<kHeadDim>(reinterpret_cast<Element*>(params.qv_ptr) + row_offset_qv); qv_ptr = prepare_for_matrix_load<kHeadDim>(reinterpret_cast<Element*>(params.qv_ptr) + row_offset_qv);
auto k_ptr = prepare_for_matrix_load<kHeadDim>(reinterpret_cast<Element*>(params.k_ptr) + row_offset_k); auto k_ptr = prepare_for_matrix_load<kHeadDim>(reinterpret_cast<Element*>(params.k_ptr) + row_offset_k);
auto k_ptr_buffer = tcp_cache_swizzle_func<64, Element>(reinterpret_cast<Element*>(params.k_ptr) + row_offset_k); auto k_ptr_buffer = tcp_cache_swizzle_func<64, Element>(reinterpret_cast<Element*>(params.k_ptr) + row_offset_k);
auto extra_k_ptr_buffer = has_extra ? tcp_cache_swizzle_func<64, Element>(reinterpret_cast<Element*>(params.extra_k_ptr)) : k_ptr_buffer; auto extra_k_ptr_buffer = tcp_cache_swizzle_func<64, Element>(reinterpret_cast<Element*>(Has_extra ? params.extra_k_ptr : params.k_ptr));
auto v_ptr = prepare_for_matrix_load<kHeadDimV>(reinterpret_cast<Element*>(params.v_ptr) + row_offset_v); auto v_ptr = prepare_for_matrix_load<kHeadDimV>(reinterpret_cast<Element*>(params.v_ptr) + row_offset_v);
auto v_ptr_buffer = tcp_cache_swizzle_func<64, Element>(reinterpret_cast<Element*>(params.v_ptr) + row_offset_v); auto v_ptr_buffer = tcp_cache_swizzle_func<64, Element>(reinterpret_cast<Element*>(params.v_ptr) + row_offset_v);
auto extra_v_ptr_buffer = has_extra ? tcp_cache_swizzle_func<64, Element>(reinterpret_cast<Element*>(params.extra_v_ptr)) : v_ptr_buffer; auto extra_v_ptr_buffer = tcp_cache_swizzle_func<64, Element>(reinterpret_cast<Element*>(Has_extra ? params.extra_v_ptr : params.v_ptr));
auto index_ptr_buffer = tcp_cache_swizzle_func<0, int>(reinterpret_cast<int*>(index_ptr)); auto index_ptr_buffer = tcp_cache_swizzle_func<0, int>(reinterpret_cast<int*>(index_ptr));
int tid = threadIdx.x % 64; int tid = threadIdx.x % 64;
if (has_extra) { if constexpr (Has_extra) {
auto extra_index_ptr_buffer = tcp_cache_swizzle_func<0, int>(extra_index_ptr); auto extra_index_ptr_buffer = tcp_cache_swizzle_func<0, int>(extra_index_ptr);
int lds_offset = __builtin_amdgcn_readfirstlane(warp_id * 64); int lds_offset = __builtin_amdgcn_readfirstlane(warp_id * 4 * 64);
int g_offset_v = tid; int g_offset_v = tid * 4;
int g_offset_s = warp_id * 64; int g_offset_s = warp_id * 256;
inline_buffer_load_dword_lds(extra_index_lds, extra_index_ptr_buffer, lds_offset, g_offset_s, g_offset_v); inline_buffer_load_dwordx4_lds(extra_index_lds, extra_index_ptr_buffer, lds_offset, g_offset_s, g_offset_v);
} }
auto k_faker = reinterpret_cast<Element*>(params.k_ptr) + row_offset_k; auto k_faker = reinterpret_cast<Element*>(params.k_ptr) + row_offset_k;
auto v_faker = reinterpret_cast<Element*>(params.v_ptr) + row_offset_v; auto v_faker = reinterpret_cast<Element*>(params.v_ptr) + row_offset_v;
// apply causal mask 的步骤和 no causal mask 的步骤分开算 // apply causal mask 的步骤和 no causal mask 的步骤分开算
constexpr int n_masking_steps = 1; constexpr int n_masking_steps = 1;
int lds_offset = __builtin_amdgcn_readfirstlane(warp_id * 4 * 64); int lds_offset = __builtin_amdgcn_readfirstlane(Has_extra ? warp_id * 64 : warp_id * 4 * 64);
int g_offset_v = tid * 4; int g_offset_v = Has_extra ? tid : tid * 4;
int g_offset_s = warp_id * 256; int g_offset_s = Has_extra ? warp_id * 64 : warp_id * 256;
if constexpr (Has_extra) {
inline_buffer_load_dword_lds(index_lds, index_ptr_buffer, lds_offset, g_offset_s, g_offset_v);
} else {
inline_buffer_load_dwordx4_lds(index_lds, index_ptr_buffer, lds_offset, g_offset_s, g_offset_v); inline_buffer_load_dwordx4_lds(index_lds, index_ptr_buffer, lds_offset, g_offset_s, g_offset_v);
}
flash::wait_buffer_data_arrived<true>(0); flash::wait_buffer_data_arrived<true>(0);
flash::wait_all_warp_arrived(); flash::wait_all_warp_arrived();
#pragma unroll #pragma unroll
for (int i = 0; i < 4; ++i) { for (int i = 0; i < 4; ++i) {
const int local_index = warp_id * 256 + tid * 4 + i; const int local_index = Has_extra ? warp_id * 64 + tid : warp_id * 256 + tid * 4 + i;
if (local_index >= main_topk_length) { if (local_index >= main_topk_length) {
index_lds[local_index] = main_topk_length > 0 ? index_lds[main_topk_length - 1] : -1; index_lds[local_index] = main_topk_length > 0 ? index_lds[main_topk_length - 1] : -1;
} }
} }
if (has_extra) { if constexpr (Has_extra) {
const int local_index = warp_id * 64 + tid; #pragma unroll
for (int i = 0; i < 4; ++i) {
const int local_index = warp_id * 256 + tid * 4 + i;
if (local_index >= extra_topk_length) { if (local_index >= extra_topk_length) {
extra_index_lds[local_index] = extra_topk_length > 0 ? extra_index_lds[extra_topk_length - 1] : -1; extra_index_lds[local_index] = extra_topk_length > 0 ? extra_index_lds[extra_topk_length - 1] : -1;
} }
} }
}
flash::wait_all_warp_arrived(); flash::wait_all_warp_arrived();
// 是否做 prefetch K, PV 结束后, prefetch K 有风险 // 是否做 prefetch K, PV 结束后, prefetch K 有风险
...@@ -3070,7 +3075,7 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa ...@@ -3070,7 +3075,7 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa
auto PV_GEMM_FUNC_IN_MASK = &pv_gemm_prefetch_k_mls_ds_576_512_nopage_64_new_999<false, kHeadDim, kHeadDimV, kBlockM, 256, kBlockN, WARP_M, 256, STAGES, Element, ElementAccum, Is_even_MN>; auto PV_GEMM_FUNC_IN_MASK = &pv_gemm_prefetch_k_mls_ds_576_512_nopage_64_new_999<false, kHeadDim, kHeadDimV, kBlockM, 256, kBlockN, WARP_M, 256, STAGES, Element, ElementAccum, Is_even_MN>;
for (int logical_block = n_block_min; logical_block < n_block_max; ++logical_block) { for (int logical_block = n_block_min; logical_block < n_block_max; ++logical_block) {
bool is_extra = logical_block >= main_num_blocks; bool is_extra = Has_extra && logical_block >= main_num_blocks;
int rel_block = is_extra ? logical_block - main_num_blocks : logical_block; int rel_block = is_extra ? logical_block - main_num_blocks : logical_block;
int cur_topk_length = is_extra ? extra_topk_length : main_topk_length; int cur_topk_length = is_extra ? extra_topk_length : main_topk_length;
int* cur_index_lds = is_extra ? extra_index_lds : index_lds; int* cur_index_lds = is_extra ? extra_index_lds : index_lds;
...@@ -3191,11 +3196,11 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa ...@@ -3191,11 +3196,11 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa
constexpr bool Is_Interleave = true; constexpr bool Is_Interleave = true;
int lane_id = threadIdx.x & 63; int lane_id = threadIdx.x & 63;
if (params.softmax_lse_ptr != nullptr) { if (params.softmax_lse_ptr != nullptr) {
prefill_mla_epilogue_store_lse<WARP_M, Is_even_MN, false/*SplitD*/, Is_Interleave, ElementAccum, vec_Accum<ElementAccum>, 1/* M_MMAC_COUNT */>(lse, params.softmax_lse_ptr, row_offset_lse, warp_id, lane_id, 0, actual_block_m); prefill_mla_epilogue_store_lse<WARP_M, Is_even_MN, false/*SplitD*/, Is_Interleave, ElementAccum, vec_Accum<ElementAccum>, 1/* M_MMAC_COUNT */>(lse, params.softmax_lse_ptr, split_lse_row_offset, warp_id, lane_id, 0, actual_block_m);
} }
/**************************************************************************************************************************************/ /**************************************************************************************************************************************/
Element* o_ptr = reinterpret_cast<Element *>(params.oaccum_ptr) + row_offset_o; Element* o_ptr = reinterpret_cast<Element *>(params.oaccum_ptr) + split_accum_row_offset;
prefill_mla_epilogue_store_output<kHeadDimV, kBlockM, kBlockK, WARP_M, Is_even_MN, Is_Interleave, false/*TcpSwizzle*/, Element, ElementAccum, 1/* M_MMAC_COUNT */>(o_ptr, acc_o, 0, warp_id, lane_id, seqlen_o_stride, actual_block_m); prefill_mla_epilogue_store_output<kHeadDimV, kBlockM, kBlockK, WARP_M, Is_even_MN, Is_Interleave, false/*TcpSwizzle*/, Element, ElementAccum, 1/* M_MMAC_COUNT */>(o_ptr, acc_o, 0, warp_id, lane_id, kHeadDimV, actual_block_m);
} }
#endif #endif
} }
......
...@@ -37,6 +37,7 @@ void run_mla_splitkv_reduce(Params &params, hipStream_t stream) { ...@@ -37,6 +37,7 @@ void run_mla_splitkv_reduce(Params &params, hipStream_t stream) {
reduce_params.num_splits = params.num_splits; reduce_params.num_splits = params.num_splits;
reduce_params.partition_size = params.partition_size; reduce_params.partition_size = params.partition_size;
reduce_params.h = params.h; reduce_params.h = params.h;
reduce_params.ngroups = params.ngroups;
reduce_params.seqlen_q = params.seqlen_q; reduce_params.seqlen_q = params.seqlen_q;
reduce_params.layout = params.layout; reduce_params.layout = params.layout;
reduce_params.topk_length = params.topk_length; reduce_params.topk_length = params.topk_length;
...@@ -556,13 +557,16 @@ void run_mla_fwd_dispatch_dsa_prefill_nopage_64(Flash_fwd_mla_params_dsa &params ...@@ -556,13 +557,16 @@ void run_mla_fwd_dispatch_dsa_prefill_nopage_64(Flash_fwd_mla_params_dsa &params
dimGrid.z = params.b; dimGrid.z = params.b;
constexpr bool IsEvenMNConst = false; constexpr bool IsEvenMNConst = false;
const bool has_extra = params.extra_sparse_indices != nullptr && params.extra_topk > 0;
if(params.num_splits == 1){ if(params.num_splits == 1){
BOOL_SWITCH(params.mtp > 1/* is_mtp */, Is_MTP, [&] { BOOL_SWITCH(params.mtp > 1/* is_mtp */, Is_MTP, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] {
flash::flash_fwd_mla_decode_kernel_gfx938_dsa_nopage_64<Kernel_traits, true/*Is_training*/, Is_dropout, false/* Is_prefix | flashmla */,Is_causal, IsEvenMNConst, /*Is_even_K*/true, /*Return_softmax*/false, Is_MTP, 0, Flash_fwd_mla_params_dsa> BOOL_SWITCH(has_extra, Has_extra, [&] {
flash::flash_fwd_mla_decode_kernel_gfx938_dsa_nopage_64<Kernel_traits, true/*Is_training*/, Is_dropout, false/* Is_prefix | flashmla */,Is_causal, IsEvenMNConst, /*Is_even_K*/true, /*Return_softmax*/false, Is_MTP, 0, Has_extra, Flash_fwd_mla_params_dsa>
<<<dimGrid, dimBlock, 21 * 1024, stream>>>(params); <<<dimGrid, dimBlock, 21 * 1024, stream>>>(params);
}); });
}); });
});
} }
else if(params.num_splits != 0){ else if(params.num_splits != 0){
dimGrid.x = (params.seqlen_q + 1 * kBlockM - 1) / (1 * kBlockM); dimGrid.x = (params.seqlen_q + 1 * kBlockM - 1) / (1 * kBlockM);
...@@ -570,10 +574,12 @@ void run_mla_fwd_dispatch_dsa_prefill_nopage_64(Flash_fwd_mla_params_dsa &params ...@@ -570,10 +574,12 @@ void run_mla_fwd_dispatch_dsa_prefill_nopage_64(Flash_fwd_mla_params_dsa &params
dimGrid.z = params.b; dimGrid.z = params.b;
BOOL_SWITCH(params.mtp > 1/* is_mtp */, Is_MTP, [&] { BOOL_SWITCH(params.mtp > 1/* is_mtp */, Is_MTP, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] {
flash::flash_fwd_mla_decode_kernel_gfx938_dsa_nopage_64_splitkv<Kernel_traits, true/*Is_training*/, Is_dropout, false/* Is_prefix | flashmla */,Is_causal, IsEvenMNConst, /*Is_even_K*/true, /*Return_softmax*/false, Is_MTP, 0, Flash_fwd_mla_params_dsa> BOOL_SWITCH(has_extra, Has_extra, [&] {
flash::flash_fwd_mla_decode_kernel_gfx938_dsa_nopage_64_splitkv<Kernel_traits, true/*Is_training*/, Is_dropout, false/* Is_prefix | flashmla */,Is_causal, IsEvenMNConst, /*Is_even_K*/true, /*Return_softmax*/false, Is_MTP, 0, Has_extra, Flash_fwd_mla_params_dsa>
<<<dimGrid, dimBlock, 21 * 1024, stream>>>(params); <<<dimGrid, dimBlock, 21 * 1024, stream>>>(params);
}); });
}); });
});
run_mla_splitkv_reduce<Kernel_traits, false/*Tail*/>(params, stream); run_mla_splitkv_reduce<Kernel_traits, false/*Tail*/>(params, stream);
} }
else{ else{
......
#pragma once #pragma once
#include "numeric_types.h" #include "numeric_types.h"
#include "splitkv.h" #include "splitkv.h"
#include "intrinsic.h"
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
...@@ -29,7 +30,7 @@ __global__ void flash_fwd_splitkv_reduce_kernel( ...@@ -29,7 +30,7 @@ __global__ void flash_fwd_splitkv_reduce_kernel(
// compute partition_size when fix num_splits // compute partition_size when fix num_splits
int partition_size = params.partition_size > MLA_MAX_SPLITS ? splitkv_get_partitionsize_of_fix_numsplits(actual_seqlen_k, params.num_splits): params.partition_size; int partition_size = params.partition_size > MLA_MAX_SPLITS ? splitkv_get_partitionsize_of_fix_numsplits(actual_seqlen_k, params.num_splits): params.partition_size;
const int true_num_splits = Tail ? max(1, floor_div(actual_seqlen_k, partition_size)): ceil_div(actual_seqlen_k, partition_size); const int true_num_splits = Tail ? max(1, flash::floor_div(actual_seqlen_k, partition_size)): flash::ceil_div(actual_seqlen_k, partition_size);
// const int true_num_splits = num_splits; // const int true_num_splits = num_splits;
bool exceed_split = (tx >= true_num_splits); // process boundary bool exceed_split = (tx >= true_num_splits); // process boundary
...@@ -40,7 +41,7 @@ __global__ void flash_fwd_splitkv_reduce_kernel( ...@@ -40,7 +41,7 @@ __global__ void flash_fwd_splitkv_reduce_kernel(
float s_max_tmp = s_max_load_ori; float s_max_tmp = s_max_load_ori;
#pragma unroll #pragma unroll
for (int step = SPLIT_COUNT >> 1; step > 0; step = (step >> 1)) { for (int step = SPLIT_COUNT >> 1; step > 0; step = (step >> 1)) {
s_max_tmp = max(s_max_tmp, __shfl_xor_tmp(s_max_tmp, step)); s_max_tmp = max(s_max_tmp, flash::__shfl_xor_tmp(s_max_tmp, step));
} }
// compute rescale coefficient for max (numerator) // compute rescale coefficient for max (numerator)
float s_max_ratio = __expf(s_max_load_ori - s_max_tmp); float s_max_ratio = __expf(s_max_load_ori - s_max_tmp);
...@@ -50,7 +51,7 @@ __global__ void flash_fwd_splitkv_reduce_kernel( ...@@ -50,7 +51,7 @@ __global__ void flash_fwd_splitkv_reduce_kernel(
float s_sum_tmp = s_sum_load_ori * s_max_ratio; float s_sum_tmp = s_sum_load_ori * s_max_ratio;
#pragma unroll #pragma unroll
for (int step = SPLIT_COUNT >> 1; step > 0; step = (step >> 1)) { for (int step = SPLIT_COUNT >> 1; step > 0; step = (step >> 1)) {
s_sum_tmp = s_sum_tmp + __shfl_xor_tmp(s_sum_tmp, step); s_sum_tmp = s_sum_tmp + flash::__shfl_xor_tmp(s_sum_tmp, step);
} }
// max-rescale coefficient x sum-rescale coefficient // max-rescale coefficient x sum-rescale coefficient
...@@ -81,18 +82,18 @@ __global__ void flash_fwd_splitkv_reduce_kernel( ...@@ -81,18 +82,18 @@ __global__ void flash_fwd_splitkv_reduce_kernel(
// read ultimate scale value for current split // read ultimate scale value for current split
vec2_Element<accumType> load = *(vec2_Element<accumType>*)(oaccum_ptr + tx_offset + t); vec2_Element<accumType> load = *(vec2_Element<accumType>*)(oaccum_ptr + tx_offset + t);
// half -> float32, reduce precision loss // half -> float32, reduce precision loss
float a_f32 = within_splits? splitkv_upcast_to_f32<accumType>(load[0]): 0.f; float a_f32 = within_splits? flash::splitkv_upcast_to_f32<accumType>(load[0]): 0.f;
float b_f32 = within_splits? splitkv_upcast_to_f32<accumType>(load[1]): 0.f; float b_f32 = within_splits? flash::splitkv_upcast_to_f32<accumType>(load[1]): 0.f;
// do rescale and sum // do rescale and sum
tx_accum[t] = __llvm_fma_f32(a_f32, s_scale, tx_accum[t]); tx_accum[t] = flash::__llvm_fma_f32(a_f32, s_scale, tx_accum[t]);
tx_accum[t + 1] = __llvm_fma_f32(b_f32, s_scale, tx_accum[t + 1]); tx_accum[t + 1] = flash::__llvm_fma_f32(b_f32, s_scale, tx_accum[t + 1]);
} else if constexpr (kHeadDim == 64) { } else if constexpr (kHeadDim == 64) {
// read ultimate scale value for current split // read ultimate scale value for current split
accumType load = *(accumType*)(oaccum_ptr + tx_offset + t); accumType load = *(accumType*)(oaccum_ptr + tx_offset + t);
// half -> float32, reduce precision loss // half -> float32, reduce precision loss
float load_f32 = within_splits? splitkv_upcast_to_f32<accumType>(load): 0.f; float load_f32 = within_splits? flash::splitkv_upcast_to_f32<accumType>(load): 0.f;
// do rescale and sum // do rescale and sum
tx_accum[t] = __llvm_fma_f32(load_f32, s_scale, tx_accum[t]); tx_accum[t] = flash::__llvm_fma_f32(load_f32, s_scale, tx_accum[t]);
} }
} }
// switch to next split // switch to next split
...@@ -103,14 +104,14 @@ __global__ void flash_fwd_splitkv_reduce_kernel( ...@@ -103,14 +104,14 @@ __global__ void flash_fwd_splitkv_reduce_kernel(
if constexpr (kHeadDim % 128 == 0) { if constexpr (kHeadDim % 128 == 0) {
vec2_Element<reduceType> accum_result; vec2_Element<reduceType> accum_result;
#if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__) #if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
accum_result = DownCastPairNoPack<float, reduceType>(tx_accum[t], tx_accum[t + 1]); accum_result = flash::DownCastPairNoPack<float, reduceType>(tx_accum[t], tx_accum[t + 1]);
#else #else
accum_result[0] = DownCast<float, reduceType, true>(tx_accum[t]); accum_result[0] = flash::DownCast<float, reduceType, true>(tx_accum[t]);
accum_result[1] = DownCast<float, reduceType, true>(tx_accum[t + 1]); // here, v_cvt_pkrtz can be used accum_result[1] = flash::DownCast<float, reduceType, true>(tx_accum[t + 1]); // here, v_cvt_pkrtz can be used
#endif #endif
*(vec2_Element<reduceType>*)(output_ptr + t) = accum_result; *(vec2_Element<reduceType>*)(output_ptr + t) = accum_result;
} else if constexpr (kHeadDim == 64) { } else if constexpr (kHeadDim == 64) {
reduceType accum_result = DownCast<float, reduceType, false>(tx_accum[t]); reduceType accum_result = flash::DownCast<float, reduceType, false>(tx_accum[t]);
output_ptr[t] = accum_result; output_ptr[t] = accum_result;
} }
} }
...@@ -146,7 +147,7 @@ __global__ void flash_fwd_splitkv_reduce_kernel_split128(Params params) { ...@@ -146,7 +147,7 @@ __global__ void flash_fwd_splitkv_reduce_kernel_split128(Params params) {
} }
// compute partition_size when fix num_splits // compute partition_size when fix num_splits
int partition_size = params.partition_size > MLA_MAX_SPLITS ? splitkv_get_partitionsize_of_fix_numsplits(actual_seqlen_k, params.num_splits): params.partition_size; int partition_size = params.partition_size > MLA_MAX_SPLITS ? splitkv_get_partitionsize_of_fix_numsplits(actual_seqlen_k, params.num_splits): params.partition_size;
const int true_num_splits = Tail ? max(1, floor_div(actual_seqlen_k, partition_size)): ceil_div(actual_seqlen_k, partition_size); const int true_num_splits = Tail ? max(1, flash::floor_div(actual_seqlen_k, partition_size)): flash::ceil_div(actual_seqlen_k, partition_size);
// const int true_num_splits = num_splits; // const int true_num_splits = num_splits;
bool exceed_split = (tx >= true_num_splits); // process boundary bool exceed_split = (tx >= true_num_splits); // process boundary
...@@ -157,7 +158,7 @@ __global__ void flash_fwd_splitkv_reduce_kernel_split128(Params params) { ...@@ -157,7 +158,7 @@ __global__ void flash_fwd_splitkv_reduce_kernel_split128(Params params) {
float s_max_tmp = s_max_load_ori; float s_max_tmp = s_max_load_ori;
#pragma unroll #pragma unroll
for (int step = 64 >> 1; step > 0; step = (step >> 1)) { for (int step = 64 >> 1; step > 0; step = (step >> 1)) {
s_max_tmp = max(s_max_tmp, __shfl_xor_tmp(s_max_tmp, step)); s_max_tmp = max(s_max_tmp, flash::__shfl_xor_tmp(s_max_tmp, step));
} }
// for multiple waves, store the reduced max value to lds individually, and recompute max across multiple waves // for multiple waves, store the reduced max value to lds individually, and recompute max across multiple waves
int wave_id = (tx >> 6); int wave_id = (tx >> 6);
...@@ -181,7 +182,7 @@ __global__ void flash_fwd_splitkv_reduce_kernel_split128(Params params) { ...@@ -181,7 +182,7 @@ __global__ void flash_fwd_splitkv_reduce_kernel_split128(Params params) {
float s_sum_tmp = s_sum_load_ori * s_max_ratio; float s_sum_tmp = s_sum_load_ori * s_max_ratio;
#pragma unroll #pragma unroll
for (int step = 64 >> 1; step > 0; step = (step >> 1)) { for (int step = 64 >> 1; step > 0; step = (step >> 1)) {
s_sum_tmp = s_sum_tmp + __shfl_xor_tmp(s_sum_tmp, step); s_sum_tmp = s_sum_tmp + flash::__shfl_xor_tmp(s_sum_tmp, step);
} }
// for multiple wave, store the reduced sum value to lds individually, and recompute sum across multiple waves // for multiple wave, store the reduced sum value to lds individually, and recompute sum across multiple waves
lds[LDS_ACCUM + wave_id] = s_sum_tmp; lds[LDS_ACCUM + wave_id] = s_sum_tmp;
...@@ -230,18 +231,18 @@ __global__ void flash_fwd_splitkv_reduce_kernel_split128(Params params) { ...@@ -230,18 +231,18 @@ __global__ void flash_fwd_splitkv_reduce_kernel_split128(Params params) {
// read 2 halfs from current split of this threads // read 2 halfs from current split of this threads
vec2_Element<accumType> load = *(vec2_Element<accumType>*)(oaccum_ptr + tx_offset + t); vec2_Element<accumType> load = *(vec2_Element<accumType>*)(oaccum_ptr + tx_offset + t);
// half -> float32, reduce precision loss // half -> float32, reduce precision loss
float a_f32 = within_splits? splitkv_upcast_to_f32<accumType>(load[0]): 0.f; float a_f32 = within_splits? flash::splitkv_upcast_to_f32<accumType>(load[0]): 0.f;
float b_f32 = within_splits? splitkv_upcast_to_f32<accumType>(load[1]): 0.f; float b_f32 = within_splits? flash::splitkv_upcast_to_f32<accumType>(load[1]): 0.f;
// do rescale and sum // do rescale and sum
tx_accum[t] = __llvm_fma_f32(a_f32, s_scale, tx_accum[t]); tx_accum[t] = flash::__llvm_fma_f32(a_f32, s_scale, tx_accum[t]);
tx_accum[t + 1] = __llvm_fma_f32(b_f32, s_scale, tx_accum[t + 1]); tx_accum[t + 1] = flash::__llvm_fma_f32(b_f32, s_scale, tx_accum[t + 1]);
} else if constexpr (kHeadDim == 64) { } else if constexpr (kHeadDim == 64) {
// read 1 half from current split of this threads // read 1 half from current split of this threads
accumType load = *(accumType*)(oaccum_ptr + tx_offset + t); accumType load = *(accumType*)(oaccum_ptr + tx_offset + t);
// half -> float32, reduce precision loss // half -> float32, reduce precision loss
float load_f32 = within_splits? splitkv_upcast_to_f32<accumType>(load): 0.f; float load_f32 = within_splits? flash::splitkv_upcast_to_f32<accumType>(load): 0.f;
// do rescale and sum // do rescale and sum
tx_accum[t] = __llvm_fma_f32(load_f32, s_scale, tx_accum[t]); tx_accum[t] = flash::__llvm_fma_f32(load_f32, s_scale, tx_accum[t]);
} }
} }
// switch to next split // switch to next split
...@@ -290,15 +291,15 @@ __global__ void flash_fwd_splitkv_reduce_kernel_split128(Params params) { ...@@ -290,15 +291,15 @@ __global__ void flash_fwd_splitkv_reduce_kernel_split128(Params params) {
tx_accum[t + 1] = lds[tx * tx_float_count + t + 1]; tx_accum[t + 1] = lds[tx * tx_float_count + t + 1];
vec2_Element<reduceType> accum_result; vec2_Element<reduceType> accum_result;
#if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__) #if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
accum_result = DownCastPairNoPack<float, reduceType>(tx_accum[t], tx_accum[t + 1]); accum_result = flash::DownCastPairNoPack<float, reduceType>(tx_accum[t], tx_accum[t + 1]);
#else #else
accum_result[0] = DownCast<float, reduceType, true>(tx_accum[t]); accum_result[0] = flash::DownCast<float, reduceType, true>(tx_accum[t]);
accum_result[1] = DownCast<float, reduceType, true>(tx_accum[t + 1]); // here, v_cvt_pkrtz can be used accum_result[1] = flash::DownCast<float, reduceType, true>(tx_accum[t + 1]); // here, v_cvt_pkrtz can be used
#endif #endif
*(vec2_Element<reduceType>*)(output_ptr + t) = accum_result; *(vec2_Element<reduceType>*)(output_ptr + t) = accum_result;
} else if constexpr (kHeadDim == 64) { } else if constexpr (kHeadDim == 64) {
tx_accum[t] = lds[tx * tx_float_count + t]; tx_accum[t] = lds[tx * tx_float_count + t];
reduceType accum_result = DownCast<float, reduceType, false>(tx_accum[t]); reduceType accum_result = flash::DownCast<float, reduceType, false>(tx_accum[t]);
*(reduceType*)(output_ptr + t) = accum_result; *(reduceType*)(output_ptr + t) = accum_result;
} }
} }
...@@ -341,7 +342,7 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_splitkv_reduce_varlen_kernel ...@@ -341,7 +342,7 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_splitkv_reduce_varlen_kernel
// compute partition_size when fix num_splits // compute partition_size when fix num_splits
int partition_size = splitkv_get_partitionsize_of_fix_numsplits(actual_seqlen_k, params.num_splits); int partition_size = splitkv_get_partitionsize_of_fix_numsplits(actual_seqlen_k, params.num_splits);
const int true_num_splits = Tail ? max(1, floor_div(actual_seqlen_k, partition_size)): ceil_div(actual_seqlen_k, partition_size); const int true_num_splits = Tail ? max(1, flash::floor_div(actual_seqlen_k, partition_size)): flash::ceil_div(actual_seqlen_k, partition_size);
// const int true_num_splits = num_splits; // const int true_num_splits = num_splits;
bool exceed_split = (tx >= true_num_splits); // process boundary bool exceed_split = (tx >= true_num_splits); // process boundary
...@@ -399,13 +400,13 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_splitkv_reduce_varlen_kernel ...@@ -399,13 +400,13 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_splitkv_reduce_varlen_kernel
float lse_max_local = lse_local; float lse_max_local = lse_local;
#pragma unroll #pragma unroll
for (int step = SPLIT_COUNT >> 1; step > 0; step = (step >> 1)) { for (int step = SPLIT_COUNT >> 1; step > 0; step = (step >> 1)) {
lse_max_local = max(lse_max_local, __shfl_xor_tmp(lse_max_local, step)); lse_max_local = max(lse_max_local, flash::__shfl_xor_tmp(lse_max_local, step));
} }
// reduce sum lse // reduce sum lse
float lse_local_logsum = __expf(lse_local - lse_max_local); float lse_local_logsum = __expf(lse_local - lse_max_local);
#pragma unroll #pragma unroll
for (int step = SPLIT_COUNT >> 1; step > 0; step = (step >> 1)) { for (int step = SPLIT_COUNT >> 1; step > 0; step = (step >> 1)) {
lse_local_logsum = lse_local_logsum + __shfl_xor_tmp(lse_local_logsum, step); lse_local_logsum = lse_local_logsum + flash::__shfl_xor_tmp(lse_local_logsum, step);
} }
lse_local_logsum = __logf(lse_local_logsum) + lse_max_local; lse_local_logsum = __logf(lse_local_logsum) + lse_max_local;
...@@ -427,16 +428,16 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_splitkv_reduce_varlen_kernel ...@@ -427,16 +428,16 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_splitkv_reduce_varlen_kernel
for (int t = 0; t < tx_float_count; t += 2) { for (int t = 0; t < tx_float_count; t += 2) {
if constexpr (kHeadDim % 128 == 0) { if constexpr (kHeadDim % 128 == 0) {
// half -> float32, reduce precision loss // half -> float32, reduce precision loss
float a_f32 = within_splits? splitkv_upcast_to_f32<accumType>(load_vec[i][t >> 1][0]): 0.f; float a_f32 = within_splits? flash::splitkv_upcast_to_f32<accumType>(load_vec[i][t >> 1][0]): 0.f;
float b_f32 = within_splits? splitkv_upcast_to_f32<accumType>(load_vec[i][t >> 1][1]): 0.f; float b_f32 = within_splits? flash::splitkv_upcast_to_f32<accumType>(load_vec[i][t >> 1][1]): 0.f;
// do rescale and sum // do rescale and sum
tx_accum[t] = __llvm_fma_f32(a_f32, s_scale, tx_accum[t]); tx_accum[t] = flash::__llvm_fma_f32(a_f32, s_scale, tx_accum[t]);
tx_accum[t + 1] = __llvm_fma_f32(b_f32, s_scale, tx_accum[t + 1]); tx_accum[t + 1] = flash::__llvm_fma_f32(b_f32, s_scale, tx_accum[t + 1]);
} else if constexpr (kHeadDim == 64) { } else if constexpr (kHeadDim == 64) {
// half -> float32, reduce precision loss // half -> float32, reduce precision loss
float load_f32 = within_splits? splitkv_upcast_to_f32<accumType>(load[i][t >> 1]): 0.f; float load_f32 = within_splits? flash::splitkv_upcast_to_f32<accumType>(load[i][t >> 1]): 0.f;
// do rescale and sum // do rescale and sum
tx_accum[t] = __llvm_fma_f32(load_f32, s_scale, tx_accum[t]); tx_accum[t] = flash::__llvm_fma_f32(load_f32, s_scale, tx_accum[t]);
} }
} }
} }
...@@ -446,14 +447,14 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_splitkv_reduce_varlen_kernel ...@@ -446,14 +447,14 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_splitkv_reduce_varlen_kernel
if constexpr (kHeadDim % 128 == 0) { if constexpr (kHeadDim % 128 == 0) {
vec2_Element<reduceType> accum_result; vec2_Element<reduceType> accum_result;
#if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__) #if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
accum_result = DownCastPairNoPack<float, reduceType>(tx_accum[t], tx_accum[t + 1]); accum_result = flash::DownCastPairNoPack<float, reduceType>(tx_accum[t], tx_accum[t + 1]);
#else #else
accum_result[0] = DownCast<float, reduceType, true>(tx_accum[t]); accum_result[0] = flash::DownCast<float, reduceType, true>(tx_accum[t]);
accum_result[1] = DownCast<float, reduceType, true>(tx_accum[t + 1]); // here, v_cvt_pkrtz can be used accum_result[1] = flash::DownCast<float, reduceType, true>(tx_accum[t + 1]); // here, v_cvt_pkrtz can be used
#endif #endif
*(vec2_Element<reduceType>*)(output_ptr + t) = accum_result; *(vec2_Element<reduceType>*)(output_ptr + t) = accum_result;
} else if constexpr (kHeadDim == 64) { } else if constexpr (kHeadDim == 64) {
reduceType accum_result = DownCast<float, reduceType, false>(tx_accum[t]); reduceType accum_result = flash::DownCast<float, reduceType, false>(tx_accum[t]);
output_ptr[t] = accum_result; output_ptr[t] = accum_result;
} }
} }
...@@ -501,15 +502,15 @@ __global__ void __launch_bounds__(256, 1) flash_mla_splitkv_reduce_kernel( ...@@ -501,15 +502,15 @@ __global__ void __launch_bounds__(256, 1) flash_mla_splitkv_reduce_kernel(
// int main_len = params.topk_length ? params.topk_length[row] : params.topk; // int main_len = params.topk_length ? params.topk_length[row] : params.topk;
// int extra_len = params.extra_topk_length ? params.extra_topk_length[row] : params.extra_topk; // int extra_len = params.extra_topk_length ? params.extra_topk_length[row] : params.extra_topk;
// int actual_seqlen_k = ceil_div(main_len, 64) * 64 + ceil_div(extra_len, 64) * 64; // int actual_seqlen_k = flash::ceil_div(main_len, 64) * 64 + flash::ceil_div(extra_len, 64) * 64;
int row = block_x / 64; int topk_length_row = h == 1 ? bidb : block_x / 64;
int main_len = params.topk_length ? params.topk_length[row] : params.topk; int main_len = params.topk_length ? params.topk_length[topk_length_row] : params.topk;
int extra_len = params.extra_topk_length ? params.extra_topk_length[row] : params.extra_topk; int extra_len = params.extra_topk_length ? params.extra_topk_length[topk_length_row] : params.extra_topk;
int total_blocks = ceil_div(main_len, 64) + ceil_div(extra_len, 64); int total_blocks = flash::ceil_div(main_len, 64) + flash::ceil_div(extra_len, 64);
int blocks_per_split = ceil_div(params.partition_size, 64); int blocks_per_split = flash::ceil_div(params.partition_size, 64);
int true_num_splits = ceil_div(total_blocks, blocks_per_split); int true_num_splits = flash::ceil_div(total_blocks, blocks_per_split);
// for flashmla, 512 elements are engaged to 4 blocks // for flashmla, 512 elements are engaged to 4 blocks
// within each block, num_splits / WARM_NUM load transactions are engaged to each wave // within each block, num_splits / WARM_NUM load transactions are engaged to each wave
...@@ -539,7 +540,7 @@ __global__ void __launch_bounds__(256, 1) flash_mla_splitkv_reduce_kernel( ...@@ -539,7 +540,7 @@ __global__ void __launch_bounds__(256, 1) flash_mla_splitkv_reduce_kernel(
// compute partition_size when fix num_splits // compute partition_size when fix num_splits
// int partition_size = params_partition_size > MLA_MAX_SPLITS ? splitkv_get_partitionsize_of_fix_numsplits(actual_seqlen_k, params.num_splits): params_partition_size; // int partition_size = params_partition_size > MLA_MAX_SPLITS ? splitkv_get_partitionsize_of_fix_numsplits(actual_seqlen_k, params.num_splits): params_partition_size;
// const int true_num_splits = Tail ? max(1, floor_div(actual_seqlen_k, partition_size)): ceil_div(actual_seqlen_k, partition_size); // const int true_num_splits = Tail ? max(1, flash::floor_div(actual_seqlen_k, partition_size)): flash::ceil_div(actual_seqlen_k, partition_size);
// const int true_num_splits = num_splits; // const int true_num_splits = num_splits;
bool exceed_split = (tx >= true_num_splits); // process boundary bool exceed_split = (tx >= true_num_splits); // process boundary
...@@ -552,20 +553,20 @@ __global__ void __launch_bounds__(256, 1) flash_mla_splitkv_reduce_kernel( ...@@ -552,20 +553,20 @@ __global__ void __launch_bounds__(256, 1) flash_mla_splitkv_reduce_kernel(
float lse_max_local = lse_local; float lse_max_local = lse_local;
#pragma unroll #pragma unroll
for (int step = SPLIT_COUNT >> 1; step > 0; step = (step >> 1)) { for (int step = SPLIT_COUNT >> 1; step > 0; step = (step >> 1)) {
lse_max_local = max(lse_max_local, __shfl_xor_tmp(lse_max_local, step)); lse_max_local = max(lse_max_local, flash::__shfl_xor_tmp(lse_max_local, step));
} }
// reduce sum lse // reduce sum lse
float lse_local_logsum = __expf(lse_local - lse_max_local); float lse_local_logsum = __expf(lse_local - lse_max_local);
#pragma unroll #pragma unroll
for (int step = SPLIT_COUNT >> 1; step > 0; step = (step >> 1)) { for (int step = SPLIT_COUNT >> 1; step > 0; step = (step >> 1)) {
lse_local_logsum = lse_local_logsum + __shfl_xor_tmp(lse_local_logsum, step); lse_local_logsum = lse_local_logsum + flash::__shfl_xor_tmp(lse_local_logsum, step);
} }
lse_local_logsum = __logf(lse_local_logsum) + lse_max_local; lse_local_logsum = __logf(lse_local_logsum) + lse_max_local;
float attn_sink_o_scale = 1.0f; float attn_sink_o_scale = 1.0f;
if (params.attn_sink != nullptr) { if (params.attn_sink != nullptr) {
// 当前 reduce kernel 的 block_x 是按 b,h,s 展开的,所以 bidh 就是 head id。 int attn_sink_idx = h == 1 ? in_batch_offset % params.ngroups : block_x % 64;
float rAttn_sink = params.attn_sink[block_x % 64]; float rAttn_sink = params.attn_sink[attn_sink_idx];
if (rAttn_sink == INFINITY) { if (rAttn_sink == INFINITY) {
attn_sink_o_scale = 0.0f; attn_sink_o_scale = 0.0f;
...@@ -588,11 +589,11 @@ __global__ void __launch_bounds__(256, 1) flash_mla_splitkv_reduce_kernel( ...@@ -588,11 +589,11 @@ __global__ void __launch_bounds__(256, 1) flash_mla_splitkv_reduce_kernel(
#pragma unroll #pragma unroll
for (int t = 0; t < tx_float_count; t += 2) { for (int t = 0; t < tx_float_count; t += 2) {
// half -> float32, reduce precision loss // half -> float32, reduce precision loss
float a_f32 = within_splits? splitkv_upcast_to_f32<accumType>(load[i >> 2][t >> 1][0]): 0.f; float a_f32 = within_splits? flash::splitkv_upcast_to_f32<accumType>(load[i >> 2][t >> 1][0]): 0.f;
float b_f32 = within_splits? splitkv_upcast_to_f32<accumType>(load[i >> 2][t >> 1][1]): 0.f; float b_f32 = within_splits? flash::splitkv_upcast_to_f32<accumType>(load[i >> 2][t >> 1][1]): 0.f;
// do rescale and sum // do rescale and sum
tx_accum[t] = __llvm_fma_f32(a_f32, s_scale, tx_accum[t]); tx_accum[t] = flash::__llvm_fma_f32(a_f32, s_scale, tx_accum[t]);
tx_accum[t + 1] = __llvm_fma_f32(b_f32, s_scale, tx_accum[t + 1]); tx_accum[t + 1] = flash::__llvm_fma_f32(b_f32, s_scale, tx_accum[t + 1]);
} }
} }
// reduce across 4 waves // reduce across 4 waves
...@@ -616,10 +617,10 @@ __global__ void __launch_bounds__(256, 1) flash_mla_splitkv_reduce_kernel( ...@@ -616,10 +617,10 @@ __global__ void __launch_bounds__(256, 1) flash_mla_splitkv_reduce_kernel(
// cvt // cvt
vec2_Element<reduceType> accum_result; vec2_Element<reduceType> accum_result;
#if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__) #if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
accum_result = DownCastPairNoPack<float, reduceType>(tx_accum[t], tx_accum[t + 1]); accum_result = flash::DownCastPairNoPack<float, reduceType>(tx_accum[t], tx_accum[t + 1]);
#else #else
accum_result[0] = DownCast<float, reduceType, true>(tx_accum[t]); accum_result[0] = flash::DownCast<float, reduceType, true>(tx_accum[t]);
accum_result[1] = DownCast<float, reduceType, true>(tx_accum[t + 1]); accum_result[1] = flash::DownCast<float, reduceType, true>(tx_accum[t + 1]);
#endif #endif
// storation // storation
*(vec2_Element<reduceType>*)(output_ptr + t) = accum_result; *(vec2_Element<reduceType>*)(output_ptr + t) = accum_result;
......
...@@ -151,7 +151,10 @@ def flash_mla_with_kvcache( ...@@ -151,7 +151,10 @@ def flash_mla_with_kvcache(
if topk is not None: if topk is not None:
# Sparse attention # Sparse attention
assert not causal, "causal must be False when sparse attention is enabled" assert not causal, "causal must be False when sparse attention is enabled"
assert is_fp8_kvcache, "is_fp8_kvcache must be True when sparse attention is enabled" if not is_fp8_kvcache:
assert k_cache.dtype == torch.bfloat16, "BF16 sparse attention requires k_cache dtype to be torch.bfloat16 when is_fp8_kvcache is False"
if extra_k_cache is not None:
assert extra_k_cache.dtype == torch.bfloat16, "BF16 sparse attention requires extra_k_cache dtype to be torch.bfloat16 when is_fp8_kvcache is False"
out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.sparse_decode_fwd( out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.sparse_decode_fwd(
q, k_cache, indices_in_kvcache, topk_length, attn_sink, q, k_cache, indices_in_kvcache, topk_length, attn_sink,
sched_meta.tile_scheduler_metadata, sched_meta.num_splits, sched_meta.tile_scheduler_metadata, sched_meta.num_splits,
......
...@@ -93,6 +93,7 @@ ext_modules.append( ...@@ -93,6 +93,7 @@ ext_modules.append(
"csrc/gfx93/decode/sparse_fp8/instantiations/v32_persistent_h16.cu", "csrc/gfx93/decode/sparse_fp8/instantiations/v32_persistent_h16.cu",
"csrc/gfx93/decode/sparse_fp8/instantiations/v32_persistent_h64.cu", "csrc/gfx93/decode/sparse_fp8/instantiations/v32_persistent_h64.cu",
"csrc/gfx93/decode/sparse_fp8/instantiations/v32_persistent_h128.cu", "csrc/gfx93/decode/sparse_fp8/instantiations/v32_persistent_h128.cu",
"csrc/gfx93/decode/sparse_bf16_dsa/fwd.cu",
# # gfx93 sparse prefill # # gfx93 sparse prefill
"csrc/gfx93/prefill/sparse/fwd.cu", "csrc/gfx93/prefill/sparse/fwd.cu",
......
...@@ -14,6 +14,9 @@ class TestTarget(enum.Enum): ...@@ -14,6 +14,9 @@ class TestTarget(enum.Enum):
FWD = 0 FWD = 0
DECODE = 1 DECODE = 1
def is_decode_bf16_kvcache() -> bool:
return os.environ.get("FLASH_MLA_DECODE_BF16", "").lower() in ["1", "true", "yes", "y", "bf16"]
@dataclasses.dataclass @dataclasses.dataclass
class ExtraTestParamForDecode: class ExtraTestParamForDecode:
b: int b: int
...@@ -42,6 +45,21 @@ class TestParam: ...@@ -42,6 +45,21 @@ class TestParam:
have_topk_length: bool = False have_topk_length: bool = False
decode: Optional[ExtraTestParamForDecode] = None decode: Optional[ExtraTestParamForDecode] = None
def is_bf16_decode_supported_param(t: TestParam) -> bool:
if t.decode is None:
return False
if t.is_all_indices_invalid or t.decode.have_zero_seqlen_k:
return False
if t.h_kv != 1 or t.d_v != 512:
return False
if t.h_q not in [64, 128]:
return False
if t.d_qk not in [512, 576]:
return False
if t.decode.extra_topk is None:
return t.topk <= 1024
return t.topk <= 256 and t.decode.extra_topk <= 1024
@dataclasses.dataclass @dataclasses.dataclass
class RawTestParamForDecode: class RawTestParamForDecode:
""" """
...@@ -289,6 +307,7 @@ def generate_testcase_for_decode(t: TestParam) -> TestcaseForDecode: ...@@ -289,6 +307,7 @@ def generate_testcase_for_decode(t: TestParam) -> TestcaseForDecode:
return KVScope(t, cache_seqlens, block_table, blocked_k, abs_indices, indices_in_kvcache, topk_length) return KVScope(t, cache_seqlens, block_table, blocked_k, abs_indices, indices_in_kvcache, topk_length)
kv_scope0 = generate_one_k_scope(t.s_kv, t.decode.block_size, t.topk, t.decode.is_varlen, t.decode.have_zero_seqlen_k, t.is_all_indices_invalid, t.have_topk_length) kv_scope0 = generate_one_k_scope(t.s_kv, t.decode.block_size, t.topk, t.decode.is_varlen, t.decode.have_zero_seqlen_k, t.is_all_indices_invalid, t.have_topk_length)
if not is_decode_bf16_kvcache():
kv_scope0.quant_and_dequant_() kv_scope0.quant_and_dequant_()
if t.decode.extra_topk is not None: if t.decode.extra_topk is not None:
if t.decode.extra_s_k is None: if t.decode.extra_s_k is None:
...@@ -296,6 +315,7 @@ def generate_testcase_for_decode(t: TestParam) -> TestcaseForDecode: ...@@ -296,6 +315,7 @@ def generate_testcase_for_decode(t: TestParam) -> TestcaseForDecode:
if t.decode.extra_block_size is None: if t.decode.extra_block_size is None:
t.decode.extra_block_size = t.decode.block_size t.decode.extra_block_size = t.decode.block_size
kv_scope1 = generate_one_k_scope(t.decode.extra_s_k, t.decode.extra_block_size, t.decode.extra_topk, t.decode.is_varlen, t.decode.have_zero_seqlen_k, t.is_all_indices_invalid, t.decode.have_extra_topk_length) kv_scope1 = generate_one_k_scope(t.decode.extra_s_k, t.decode.extra_block_size, t.decode.extra_topk, t.decode.is_varlen, t.decode.have_zero_seqlen_k, t.is_all_indices_invalid, t.decode.have_extra_topk_length)
if not is_decode_bf16_kvcache():
kv_scope1.quant_and_dequant_() kv_scope1.quant_and_dequant_()
else: else:
assert t.decode.extra_block_size is None and t.decode.extra_s_k is None and not t.decode.have_extra_topk_length assert t.decode.extra_block_size is None and t.decode.extra_s_k is None and not t.decode.have_extra_topk_length
...@@ -318,16 +338,17 @@ def run_flash_mla_sparse_fwd(p: TestParam, t: Testcase, return_p_sum: bool): ...@@ -318,16 +338,17 @@ def run_flash_mla_sparse_fwd(p: TestParam, t: Testcase, return_p_sum: bool):
def run_flash_mla_decode(p: TestParam, t: TestcaseForDecode, tile_scheduler_metadata, num_splits): def run_flash_mla_decode(p: TestParam, t: TestcaseForDecode, tile_scheduler_metadata, num_splits):
assert p.decode is not None assert p.decode is not None
is_fp8_kvcache = not is_decode_bf16_kvcache()
return flash_mla.flash_mla_with_kvcache( return flash_mla.flash_mla_with_kvcache(
t.q, t.q,
t.kv_scope.get_kvcache_for_flash_mla(), t.kv_scope.get_kvcache_for_flash_mla() if is_fp8_kvcache else t.kv_scope.blocked_k,
None, None, p.d_v, None, None, p.d_v,
tile_scheduler_metadata, num_splits, tile_scheduler_metadata, num_splits,
t.sm_scale, False, True, t.sm_scale, False, is_fp8_kvcache,
t.kv_scope.indices_in_kvcache, t.kv_scope.indices_in_kvcache,
t.attn_sink, t.attn_sink,
t.extra_kv_scope.get_kvcache_for_flash_mla() if t.extra_kv_scope is not None else None, (t.extra_kv_scope.get_kvcache_for_flash_mla() if is_fp8_kvcache else t.extra_kv_scope.blocked_k) if t.extra_kv_scope is not None else None,
t.extra_kv_scope.indices_in_kvcache if t.extra_kv_scope is not None else None, t.extra_kv_scope.indices_in_kvcache if t.extra_kv_scope is not None else None,
t.kv_scope.topk_length, t.kv_scope.topk_length,
t.extra_kv_scope.topk_length if t.extra_kv_scope is not None and t.extra_kv_scope.topk_length is not None else None t.extra_kv_scope.topk_length if t.extra_kv_scope is not None and t.extra_kv_scope.topk_length is not None else None
......
...@@ -172,6 +172,10 @@ def test_flash_mla(p: TestParam) -> Result: ...@@ -172,6 +172,10 @@ def test_flash_mla(p: TestParam) -> Result:
else: else:
result = kk.bench_kineto(run_decode, p.num_runs) result = kk.bench_kineto(run_decode, p.num_runs)
if lib.is_decode_bf16_kvcache():
splitkv_kernel_name = "flash_fwd_mla_decode_kernel_gfx938_dsa_nopage_64_splitkv"
combine_kernel_name = "flash_mla_splitkv_reduce_kernel"
else:
splitkv_kernel_name = "flash_fwd_splitkv_mla_fp8_sparse_kernel" splitkv_kernel_name = "flash_fwd_splitkv_mla_fp8_sparse_kernel"
combine_kernel_name = "flash_fwd_mla_combine_kernel" combine_kernel_name = "flash_fwd_mla_combine_kernel"
...@@ -226,6 +230,9 @@ def test_flash_mla(p: TestParam) -> Result: ...@@ -226,6 +230,9 @@ def test_flash_mla(p: TestParam) -> Result:
out_ref, lse_ref = ref.ref_sparse_attn_decode(p, t) out_ref, lse_ref = ref.ref_sparse_attn_decode(p, t)
is_out_correct = kk.check_is_allclose("out", out_ans, out_ref, abs_tol=1e-3, rel_tol=2.01/128, cos_diff_tol=5e-6) is_out_correct = kk.check_is_allclose("out", out_ans, out_ref, abs_tol=1e-3, rel_tol=2.01/128, cos_diff_tol=5e-6)
if lib.is_decode_bf16_kvcache():
is_correct &= is_out_correct
else:
is_lse_correct = kk.check_is_allclose("lse", lse_ans, lse_ref, abs_tol=1e-6, rel_tol=8.01/65536) is_lse_correct = kk.check_is_allclose("lse", lse_ans, lse_ref, abs_tol=1e-6, rel_tol=8.01/65536)
is_correct &= is_out_correct and is_lse_correct is_correct &= is_out_correct and is_lse_correct
...@@ -250,6 +257,22 @@ def main(): ...@@ -250,6 +257,22 @@ def main():
raw_testcases = gen_testcase() raw_testcases = gen_testcase()
testcases = [t.to_test_param() for t in raw_testcases] testcases = [t.to_test_param() for t in raw_testcases]
if lib.is_decode_bf16_kvcache():
bf16_testcases = []
seen_bf16_cases = set()
for t in testcases:
if not lib.is_bf16_decode_supported_param(t):
continue
if t.num_runs > 0 and t.decode.b > 16:
t = dataclasses.replace(t, decode=dataclasses.replace(t.decode, b=16))
key = dataclasses.asdict(t)
key["decode"] = tuple(key["decode"].items()) if key["decode"] is not None else None
key = tuple(key.items())
if key in seen_bf16_cases:
continue
seen_bf16_cases.add(key)
bf16_testcases.append(t)
testcases = bf16_testcases
print(f"{kk.colors['CYAN_BG']}{len(testcases)} testcases to run{kk.colors['CLEAR']}") print(f"{kk.colors['CYAN_BG']}{len(testcases)} testcases to run{kk.colors['CLEAR']}")
......
...@@ -10,6 +10,26 @@ import ref ...@@ -10,6 +10,26 @@ import ref
_counter = kk.Counter() _counter = kk.Counter()
def is_dsa_mls_prefill_case(p: TestParam) -> bool:
if p.d_v != 512:
return False
if p.d_qk not in [512, 576]:
return False
if p.h_kv != 1:
return False
if p.h_q not in [64, 128]:
return False
if not (p.topk <= 1024 or p.topk == 2048):
return False
if p.topk == 2048 and (p.have_attn_sink or p.have_topk_length):
return False
if p.d_qk == 512 and ((p.h_q == 64 and p.topk == 512) or (p.h_q == 128 and p.topk == 1024)):
return True
if p.d_qk == 576 and p.h_q == 64 and p.topk == 2048 and p.s_kv >= 32768:
return True
return False
@torch.inference_mode() @torch.inference_mode()
def run_test(p: TestParam) -> bool: def run_test(p: TestParam) -> bool:
if p.seed == -1: if p.seed == -1:
...@@ -31,7 +51,12 @@ def run_test(p: TestParam) -> bool: ...@@ -31,7 +51,12 @@ def run_test(p: TestParam) -> bool:
if p.num_runs > 0: if p.num_runs > 0:
flops_and_mem_vol = lib.count_flop_and_mem_vol(p, t) flops_and_mem_vol = lib.count_flop_and_mem_vol(p, t)
prefill_ans_time = kk.bench_kineto(run_prefill, num_tests=p.num_runs).get_kernel_time("sparse_attn_fwd") bench_result = kk.bench_kineto(run_prefill, num_tests=p.num_runs)
kernel_names = bench_result.get_kernel_names()
prefill_kernel_name = "sparse_attn_fwd"
if not any(prefill_kernel_name in name for name in kernel_names):
prefill_kernel_name = "flash_fwd_mla_decode_kernel_gfx938_dsa_prefill_nopage_64"
prefill_ans_time = bench_result.get_kernel_time(prefill_kernel_name)
prefill_flops = flops_and_mem_vol.fwd_flop/prefill_ans_time/1e12 prefill_flops = flops_and_mem_vol.fwd_flop/prefill_ans_time/1e12
prefill_mem_bw = flops_and_mem_vol.fwd_mem_vol/prefill_ans_time/1e12 prefill_mem_bw = flops_and_mem_vol.fwd_mem_vol/prefill_ans_time/1e12
print(f"Prefill: {prefill_ans_time*1e6:4.0f} us, {prefill_flops:6.1f} TFlops, {prefill_mem_bw:4.2f} TBps") print(f"Prefill: {prefill_ans_time*1e6:4.0f} us, {prefill_flops:6.1f} TFlops, {prefill_mem_bw:4.2f} TBps")
...@@ -44,6 +69,7 @@ def run_test(p: TestParam) -> bool: ...@@ -44,6 +69,7 @@ def run_test(p: TestParam) -> bool:
is_correct = True is_correct = True
is_correct &= kk.check_is_allclose("out", prefill_ans_out.float(), ref_out_fp32, abs_tol=8e-4, rel_tol=3.01/128, cos_diff_tol=7e-6) is_correct &= kk.check_is_allclose("out", prefill_ans_out.float(), ref_out_fp32, abs_tol=8e-4, rel_tol=3.01/128, cos_diff_tol=7e-6)
if not is_dsa_mls_prefill_case(p):
is_correct &= kk.check_is_allclose("max_logits", prefill_ans_max_logits, ref_max_logits, abs_tol=1e-6, rel_tol=2.01/65536) is_correct &= kk.check_is_allclose("max_logits", prefill_ans_max_logits, ref_max_logits, abs_tol=1e-6, rel_tol=2.01/65536)
is_correct &= kk.check_is_allclose("lse", prefill_ans_lse, ref_lse, abs_tol=1e-6, rel_tol=2.01/65536) is_correct &= kk.check_is_allclose("lse", prefill_ans_lse, ref_lse, abs_tol=1e-6, rel_tol=2.01/65536)
...@@ -187,4 +213,3 @@ if __name__ == '__main__': ...@@ -187,4 +213,3 @@ if __name__ == '__main__':
sys.exit(1) sys.exit(1)
else: else:
print(f"\033[32m\033[1mAll {len(testcases)} cases passed!\033[0m") print(f"\033[32m\033[1mAll {len(testcases)} cases passed!\033[0m")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment