Commit 3d3a277f authored by zhushuang's avatar zhushuang
Browse files

issue/1001 - feat: add paged attention decode for moore gpu referencing nvidia

parent c312f175
......@@ -17,6 +17,8 @@ using cuda_bfloat16 = mt_bfloat16;
using cuda_bfloat162 = mt_bfloat162;
using cuda_fp8_e4m3 = __mt_fp8_e4m3;
using __nv_bfloat16 = __mt_bfloat16;
namespace device::moore {
// get the memory offset of the given element in a tensor given its flat index
......
#include <musa_runtime.h>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include "../../../devices/moore/moore_common.h"
#include "../../../devices/moore/moore_kernel_common.h"
#include "../cuda/kernel_v2.cuh"
namespace op::paged_attention::moore {
namespace {
constexpr int kMaxSplits = 8;
constexpr size_t ceilDiv(size_t a, size_t b) {
return (a + b - 1) / b;
}
inline int getSmCount() {
int device = 0;
if (musaGetDevice(&device) != musaSuccess) {
return 0;
}
int sm_count = 0;
if (musaDeviceGetAttribute(&sm_count, musaDevAttrMultiProcessorCount, device) != musaSuccess) {
return 0;
}
return sm_count;
}
// A lightweight FA2-style "waves" heuristic.
//
// Important: our split-kv kernel shards the KV sequence length, so the main "work"
// dimension is tokens, not the number of pages. We use an upper bound for seqlen_k
// (max pages * page size), which matches common decode microbench where all seqs
// share the same cache length.
inline int chooseNumSplitsHeuristic(size_t num_heads, size_t num_seqs, size_t seqlen_k, int sm_count) {
if (sm_count <= 0) {
return 1;
}
if (num_heads == 0 || num_seqs == 0) {
return 1;
}
if (seqlen_k <= 256) {
return 1;
}
const size_t base_blocks = num_heads * num_seqs;
int best_splits = 1;
// Baseline: one kernel, base_blocks CTAs, each scanning seqlen_k tokens.
size_t best_score = (ceilDiv(base_blocks, static_cast<size_t>(sm_count)) * seqlen_k);
size_t prev_work_per_block = seqlen_k;
for (int s = 2; s <= kMaxSplits; ++s) {
const size_t blocks = base_blocks * static_cast<size_t>(s);
const size_t waves_split = ceilDiv(blocks, static_cast<size_t>(sm_count));
const size_t work_per_block = ceilDiv(seqlen_k, static_cast<size_t>(s));
// If this split count doesn't reduce per-block work vs the previous split, it's effectively redundant.
if (work_per_block == prev_work_per_block) {
continue;
}
prev_work_per_block = work_per_block;
// Combine is one extra kernel with base_blocks blocks; approximate as one more wave unit.
const size_t waves_combine = ceilDiv(base_blocks, static_cast<size_t>(sm_count));
const size_t score = waves_split * work_per_block + waves_combine;
if (score < best_score) {
best_score = score;
best_splits = s;
}
}
return best_splits;
}
} // namespace
inline bool envBool(const char *name) {
if (const char *env = std::getenv(name)) {
return (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0);
}
return false;
}
template <typename Tindex, typename Tdata>
INFINIOP_MOORE_KERNEL flashAttentionDecodeHd128Warp(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const Tindex *cache_lens,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride) {
op::paged_attention::cuda::flashAttentionDecodeWarpKernel<Tindex, Tdata, 128>(
out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride,
k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride,
v_head_stride, o_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_MOORE_KERNEL flashAttentionDecodeHd128Cta(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const Tindex *cache_lens,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride) {
// Default CTA variant (lower overhead).
op::paged_attention::cuda::flashAttentionDecodeCtaKernel<Tindex, Tdata, 128, 64, 8>(
out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride,
k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride,
v_head_stride, o_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_MOORE_KERNEL flashAttentionDecodeHd128CtaTile16(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const Tindex *cache_lens,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride) {
op::paged_attention::cuda::flashAttentionDecodeCtaKernel<Tindex, Tdata, 128, 64, 16>(
out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride,
k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride,
v_head_stride, o_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_MOORE_KERNEL flashAttentionDecodeHd128Cta32(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const Tindex *cache_lens,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride) {
// Experimental 1-warp CTA variant for head_dim=128 (kPack=4).
op::paged_attention::cuda::flashAttentionDecodeCtaKernel<Tindex, Tdata, 128, 32, 8>(
out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride,
k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride,
v_head_stride, o_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_MOORE_KERNEL flashAttentionDecodeHd128Cta32Tile16(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const Tindex *cache_lens,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride) {
op::paged_attention::cuda::flashAttentionDecodeCtaKernel<Tindex, Tdata, 128, 32, 16>(
out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride,
k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride,
v_head_stride, o_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_MOORE_KERNEL flashAttentionDecodeHd128CtaGqa4(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const Tindex *cache_lens,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride) {
// GQA fused kernel: CTA computes 4 query heads for one KV head (head_dim=128).
op::paged_attention::cuda::flashAttentionDecodeCtaGqaKernel<Tindex, Tdata, 128, 64, 8, 4>(
out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride,
k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride,
v_head_stride, o_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_MOORE_KERNEL flashAttentionDecodeHd128SplitKv(
float *partial_acc,
float *partial_m,
float *partial_l,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const Tindex *cache_lens,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
int num_splits) {
op::paged_attention::cuda::flashAttentionDecodeSplitKvWarpKernel<Tindex, Tdata, 128>(
partial_acc, partial_m, partial_l,
q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride,
k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride,
v_head_stride, num_splits);
}
template <typename Tindex, typename Tdata>
INFINIOP_MOORE_KERNEL flashAttentionDecodeHd128SplitKvCta(
float *partial_acc,
float *partial_m,
float *partial_l,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const Tindex *cache_lens,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
int num_splits) {
op::paged_attention::cuda::flashAttentionDecodeSplitKvCtaKernel<Tindex, Tdata, 128, 64, 8>(
partial_acc, partial_m, partial_l,
q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride,
k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride,
v_head_stride, num_splits);
}
template <typename Tindex, typename Tdata>
INFINIOP_MOORE_KERNEL flashAttentionDecodeHd128SplitKvCtaTile16(
float *partial_acc,
float *partial_m,
float *partial_l,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const Tindex *cache_lens,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
int num_splits) {
op::paged_attention::cuda::flashAttentionDecodeSplitKvCtaKernel<Tindex, Tdata, 128, 64, 16>(
partial_acc, partial_m, partial_l,
q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride,
k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride,
v_head_stride, num_splits);
}
template <typename Tindex, typename Tdata>
INFINIOP_MOORE_KERNEL flashAttentionDecodeHd128SplitKvCta32(
float *partial_acc,
float *partial_m,
float *partial_l,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const Tindex *cache_lens,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
int num_splits) {
op::paged_attention::cuda::flashAttentionDecodeSplitKvCtaKernel<Tindex, Tdata, 128, 32, 8>(
partial_acc, partial_m, partial_l,
q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride,
k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride,
v_head_stride, num_splits);
}
template <typename Tindex, typename Tdata>
INFINIOP_MOORE_KERNEL flashAttentionDecodeHd128SplitKvCta32Tile16(
float *partial_acc,
float *partial_m,
float *partial_l,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const Tindex *cache_lens,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
int num_splits) {
op::paged_attention::cuda::flashAttentionDecodeSplitKvCtaKernel<Tindex, Tdata, 128, 32, 16>(
partial_acc, partial_m, partial_l,
q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride,
k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride,
v_head_stride, num_splits);
}
template <typename Tdata>
INFINIOP_MOORE_KERNEL flashAttentionDecodeHd128SplitKvCombine(
Tdata *out,
const float *partial_acc,
const float *partial_m,
const float *partial_l,
int num_splits,
ptrdiff_t o_stride) {
op::paged_attention::cuda::flashAttentionDecodeSplitKvCombineWarpKernel<Tdata, 128>(
out, partial_acc, partial_m, partial_l, num_splits, o_stride);
}
template <typename Tindex>
infiniStatus_t launch_decode_hd128_impl(
void *workspace,
size_t workspace_size,
void *out,
const void *q,
const void *k_cache,
const void *v_cache,
infiniDtype_t dtype,
const Tindex *block_tables,
const Tindex *cache_lens,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
musaStream_t stream) {
// Default decode config (2026-01-22):
// decode_flash_cta8_64_gqa_splitkv_4
// Users can override any knob via the corresponding INFINIOP_FLASH_* env vars.
bool use_cta = true;
if (const char *env = std::getenv("INFINIOP_FLASH_DECODE_KERNEL")) {
// Backward-compatible: any non-"cta" value means "warp".
use_cta = (std::strcmp(env, "cta") == 0);
}
bool use_gqa_fused = true;
if (const char *env = std::getenv("INFINIOP_FLASH_GQA_FUSED")) {
if (std::strcmp(env, "0") == 0 || std::strcmp(env, "false") == 0) {
use_gqa_fused = false;
} else {
use_gqa_fused = (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0);
}
}
int cta_tile = 8;
if (const char *env = std::getenv("INFINIOP_FLASH_CTA_TILE")) {
const int v = std::atoi(env);
if (v == 8 || v == 16) {
cta_tile = v;
}
}
int cta_threads = 64;
if (const char *env = std::getenv("INFINIOP_FLASH_CTA_THREADS")) {
const int v = std::atoi(env);
if (v == 32 || v == 64) {
cta_threads = v;
}
}
dim3 block(use_cta ? static_cast<uint32_t>(cta_threads) : 32);
bool use_split = true;
bool use_split_auto = false;
if (const char *env = std::getenv("INFINIOP_FLASH_DECODE_SPLITKV")) {
if (std::strcmp(env, "auto") == 0) {
use_split_auto = true;
use_split = false;
} else {
if (std::strcmp(env, "0") == 0 || std::strcmp(env, "false") == 0) {
use_split = false;
} else {
use_split = (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0);
}
}
}
int num_splits = 4;
bool fixed_num_splits = true;
if (const char *env = std::getenv("INFINIOP_FLASH_NUM_SPLITS")) {
if (std::strcmp(env, "auto") == 0) {
fixed_num_splits = false;
} else {
num_splits = std::atoi(env);
fixed_num_splits = (num_splits > 0);
}
}
if (num_splits < 1) {
num_splits = 1;
}
if (num_splits > kMaxSplits) {
num_splits = kMaxSplits;
}
const bool debug_dispatch = envBool("INFINIOP_FLASH_DEBUG_DISPATCH");
auto dump_dispatch = [&](const char *path) {
if (!debug_dispatch) {
return;
}
// Avoid spamming: only print when the key dispatch signature changes.
struct Sig {
const char *path;
int dtype;
size_t heads;
size_t kv_heads;
size_t seqs;
size_t pbs;
size_t max_blocks;
int cta_tile;
int cta_threads;
int split;
int split_auto;
int num_splits;
int fixed;
int gqa_fused;
};
static Sig last{};
static bool has_last = false;
Sig cur{
path,
static_cast<int>(dtype),
num_heads,
num_kv_heads,
num_seqs,
page_block_size,
max_num_blocks_per_seq,
cta_tile,
cta_threads,
static_cast<int>(use_split),
static_cast<int>(use_split_auto),
num_splits,
static_cast<int>(fixed_num_splits),
static_cast<int>(use_gqa_fused),
};
if (has_last && cur.path == last.path && cur.dtype == last.dtype && cur.heads == last.heads && cur.kv_heads == last.kv_heads && cur.seqs == last.seqs && cur.pbs == last.pbs && cur.max_blocks == last.max_blocks && cur.cta_tile == last.cta_tile && cur.cta_threads == last.cta_threads && cur.split == last.split && cur.split_auto == last.split_auto && cur.num_splits == last.num_splits && cur.fixed == last.fixed && cur.gqa_fused == last.gqa_fused) {
return;
}
last = cur;
has_last = true;
fprintf(stderr,
"[INFINIOP][paged_attention][hd128] dispatch: path=%s dtype=%d heads=%zu kv_heads=%zu seqs=%zu "
"pbs=%zu max_blocks=%zu cta_tile=%d cta_threads=%d split=%d split_auto=%d num_splits=%d fixed=%d gqa_fused=%d\n",
path, static_cast<int>(dtype), num_heads, num_kv_heads, num_seqs,
page_block_size, max_num_blocks_per_seq, cta_tile, cta_threads,
static_cast<int>(use_split), static_cast<int>(use_split_auto), num_splits, static_cast<int>(fixed_num_splits),
static_cast<int>(use_gqa_fused));
};
// Split-kv auto mode: decide whether to split based on a heuristic.
if (use_split_auto) {
// Approximate seqlen_k by the per-seq KV capacity (paged KV upper bound).
const size_t seqlen_k = max_num_blocks_per_seq * page_block_size;
const int sm_count = getSmCount();
num_splits = chooseNumSplitsHeuristic(num_heads, num_seqs, seqlen_k, sm_count);
if (const char *dbg = std::getenv("INFINIOP_FLASH_DEBUG_SPLITS")) {
if (std::strcmp(dbg, "1") == 0 || std::strcmp(dbg, "true") == 0) {
static size_t last_seqlen_k = 0;
if (last_seqlen_k != seqlen_k) {
last_seqlen_k = seqlen_k;
fprintf(stderr,
"[INFINIOP][paged_attention] splitkv auto(mode): sm=%d heads=%zu seqs=%zu seqlen_k~%zu -> num_splits=%d\n",
sm_count, num_heads, num_seqs, seqlen_k, num_splits);
}
}
}
// If auto picks 1, fall back to non-split to avoid extra workspace and kernel overhead.
use_split = (num_splits > 1);
}
// const bool debug_dispatch = [] {
// if (const char *env = std::getenv("INFINIOP_FLASH_DEBUG_DISPATCH")) {
// return (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0);
// }
// return false;
// }();
// const char *selected_path = "unknown";
// Optional: fuse GQA groups (4) when seqlen_q=1 decode and alibi is disabled.
// This reuses K/V loads across query heads that share the same KV head.
// Controlled by INFINIOP_FLASH_GQA_FUSED (default: enabled).
if (use_gqa_fused && use_cta && !use_split && alibi_slopes == nullptr && num_kv_heads > 0 && num_heads == num_kv_heads * 4) {
dump_dispatch("cta_gqa_fused");
dim3 grid_gqa(static_cast<uint64_t>(num_kv_heads), static_cast<uint64_t>(num_seqs), 1);
if (dtype == INFINI_DTYPE_F16) {
flashAttentionDecodeHd128CtaGqa4<Tindex, half><<<grid_gqa, 64, 0, stream>>>(
static_cast<half *>(out),
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, cache_lens, nullptr,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
return INFINI_STATUS_SUCCESS;
}
if (dtype == INFINI_DTYPE_BF16) {
flashAttentionDecodeHd128CtaGqa4<Tindex, __mt_bfloat16><<<grid_gqa, 64, 0, stream>>>(
static_cast<__mt_bfloat16 *>(out),
static_cast<const __mt_bfloat16 *>(q),
static_cast<const __mt_bfloat16 *>(k_cache),
static_cast<const __mt_bfloat16 *>(v_cache),
block_tables, cache_lens, nullptr,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
return INFINI_STATUS_SUCCESS;
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
dim3 grid(static_cast<uint64_t>(num_heads), static_cast<uint64_t>(num_seqs), 1);
if (use_split) {
dump_dispatch(use_cta ? "splitkv_cta" : "splitkv_warp");
// }
if (!fixed_num_splits) {
// Approximate seqlen_k by the per-seq KV capacity (paged KV upper bound).
const size_t seqlen_k = max_num_blocks_per_seq * page_block_size;
const int sm_count = getSmCount();
num_splits = chooseNumSplitsHeuristic(num_heads, num_seqs, seqlen_k, sm_count);
if (const char *dbg = std::getenv("INFINIOP_FLASH_DEBUG_SPLITS")) {
if (std::strcmp(dbg, "1") == 0 || std::strcmp(dbg, "true") == 0) {
static size_t last_seqlen_k = 0;
if (last_seqlen_k != seqlen_k) {
last_seqlen_k = seqlen_k;
fprintf(stderr,
"[INFINIOP][paged_attention] splitkv auto: sm=%d heads=%zu seqs=%zu seqlen_k~%zu -> num_splits=%d\n",
sm_count, num_heads, num_seqs, seqlen_k, num_splits);
}
}
}
}
const size_t n = num_seqs * num_heads;
const size_t acc_elems = static_cast<size_t>(kMaxSplits) * n * 128;
const size_t m_elems = static_cast<size_t>(kMaxSplits) * n;
const size_t l_elems = static_cast<size_t>(kMaxSplits) * n;
const size_t needed_bytes = (acc_elems + m_elems + l_elems) * sizeof(float);
if (workspace == nullptr || workspace_size < needed_bytes) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
float *ws = static_cast<float *>(workspace);
float *partial_acc = ws;
float *partial_m = partial_acc + acc_elems;
float *partial_l = partial_m + m_elems;
dim3 grid_split(static_cast<uint64_t>(num_heads), static_cast<uint64_t>(num_seqs), static_cast<uint64_t>(num_splits));
dim3 block_split(use_cta ? static_cast<uint32_t>(cta_threads) : 32);
if (dtype == INFINI_DTYPE_F16) {
if (use_cta) {
if (cta_threads == 32) {
if (cta_tile == 16) {
flashAttentionDecodeHd128SplitKvCta32Tile16<Tindex, half><<<grid_split, block_split, 0, stream>>>(
partial_acc, partial_m, partial_l,
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, num_splits);
} else {
flashAttentionDecodeHd128SplitKvCta32<Tindex, half><<<grid_split, block_split, 0, stream>>>(
partial_acc, partial_m, partial_l,
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, num_splits);
}
} else {
if (cta_tile == 16) {
flashAttentionDecodeHd128SplitKvCtaTile16<Tindex, half><<<grid_split, block_split, 0, stream>>>(
partial_acc, partial_m, partial_l,
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, num_splits);
} else {
flashAttentionDecodeHd128SplitKvCta<Tindex, half><<<grid_split, block_split, 0, stream>>>(
partial_acc, partial_m, partial_l,
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, num_splits);
}
}
} else {
flashAttentionDecodeHd128SplitKv<Tindex, half><<<grid_split, block_split, 0, stream>>>(
partial_acc, partial_m, partial_l,
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, num_splits);
}
flashAttentionDecodeHd128SplitKvCombine<half><<<grid, 32, 0, stream>>>(
static_cast<half *>(out), partial_acc, partial_m, partial_l, num_splits, o_stride);
return INFINI_STATUS_SUCCESS;
}
if (dtype == INFINI_DTYPE_BF16) {
if (use_cta) {
if (cta_threads == 32) {
if (cta_tile == 16) {
flashAttentionDecodeHd128SplitKvCta32Tile16<Tindex, __mt_bfloat16><<<grid_split, block_split, 0, stream>>>(
partial_acc, partial_m, partial_l,
static_cast<const __mt_bfloat16 *>(q),
static_cast<const __mt_bfloat16 *>(k_cache),
static_cast<const __mt_bfloat16 *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, num_splits);
} else {
flashAttentionDecodeHd128SplitKvCta32<Tindex, __mt_bfloat16><<<grid_split, block_split, 0, stream>>>(
partial_acc, partial_m, partial_l,
static_cast<const __mt_bfloat16 *>(q),
static_cast<const __mt_bfloat16 *>(k_cache),
static_cast<const __mt_bfloat16 *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, num_splits);
}
} else {
if (cta_tile == 16) {
flashAttentionDecodeHd128SplitKvCtaTile16<Tindex, __mt_bfloat16><<<grid_split, block_split, 0, stream>>>(
partial_acc, partial_m, partial_l,
static_cast<const __mt_bfloat16 *>(q),
static_cast<const __mt_bfloat16 *>(k_cache),
static_cast<const __mt_bfloat16 *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, num_splits);
} else {
flashAttentionDecodeHd128SplitKvCta<Tindex, __mt_bfloat16><<<grid_split, block_split, 0, stream>>>(
partial_acc, partial_m, partial_l,
static_cast<const __mt_bfloat16 *>(q),
static_cast<const __mt_bfloat16 *>(k_cache),
static_cast<const __mt_bfloat16 *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, num_splits);
}
}
} else {
flashAttentionDecodeHd128SplitKv<Tindex, __mt_bfloat16><<<grid_split, block_split, 0, stream>>>(
partial_acc, partial_m, partial_l,
static_cast<const __mt_bfloat16 *>(q),
static_cast<const __mt_bfloat16 *>(k_cache),
static_cast<const __mt_bfloat16 *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, num_splits);
}
flashAttentionDecodeHd128SplitKvCombine<__mt_bfloat16><<<grid, 32, 0, stream>>>(
static_cast<__mt_bfloat16 *>(out), partial_acc, partial_m, partial_l, num_splits, o_stride);
return INFINI_STATUS_SUCCESS;
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
dump_dispatch(use_cta ? "cta_nosplit" : "warp_nosplit");
if (dtype == INFINI_DTYPE_F16) {
if (use_cta) {
if (cta_tile == 16) {
if (cta_threads == 32) {
flashAttentionDecodeHd128Cta32Tile16<Tindex, half><<<grid, block, 0, stream>>>(
static_cast<half *>(out),
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
} else {
flashAttentionDecodeHd128CtaTile16<Tindex, half><<<grid, block, 0, stream>>>(
static_cast<half *>(out),
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
}
} else {
if (cta_threads == 32) {
flashAttentionDecodeHd128Cta32<Tindex, half><<<grid, block, 0, stream>>>(
static_cast<half *>(out),
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
} else {
flashAttentionDecodeHd128Cta<Tindex, half><<<grid, block, 0, stream>>>(
static_cast<half *>(out),
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
}
}
} else {
flashAttentionDecodeHd128Warp<Tindex, half><<<grid, block, 0, stream>>>(
static_cast<half *>(out),
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
}
return INFINI_STATUS_SUCCESS;
}
if (dtype == INFINI_DTYPE_BF16) {
if (use_cta) {
if (cta_tile == 16) {
if (cta_threads == 32) {
flashAttentionDecodeHd128Cta32Tile16<Tindex, __mt_bfloat16><<<grid, block, 0, stream>>>(
static_cast<__mt_bfloat16 *>(out),
static_cast<const __mt_bfloat16 *>(q),
static_cast<const __mt_bfloat16 *>(k_cache),
static_cast<const __mt_bfloat16 *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
} else {
flashAttentionDecodeHd128CtaTile16<Tindex, __mt_bfloat16><<<grid, block, 0, stream>>>(
static_cast<__mt_bfloat16 *>(out),
static_cast<const __mt_bfloat16 *>(q),
static_cast<const __mt_bfloat16 *>(k_cache),
static_cast<const __mt_bfloat16 *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
}
} else {
if (cta_threads == 32) {
flashAttentionDecodeHd128Cta32<Tindex, __mt_bfloat16><<<grid, block, 0, stream>>>(
static_cast<__mt_bfloat16 *>(out),
static_cast<const __mt_bfloat16 *>(q),
static_cast<const __mt_bfloat16 *>(k_cache),
static_cast<const __mt_bfloat16 *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
} else {
flashAttentionDecodeHd128Cta<Tindex, __mt_bfloat16><<<grid, block, 0, stream>>>(
static_cast<__mt_bfloat16 *>(out),
static_cast<const __mt_bfloat16 *>(q),
static_cast<const __mt_bfloat16 *>(k_cache),
static_cast<const __mt_bfloat16 *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
}
}
} else {
flashAttentionDecodeHd128Warp<Tindex, __mt_bfloat16><<<grid, block, 0, stream>>>(
static_cast<__mt_bfloat16 *>(out),
static_cast<const __mt_bfloat16 *>(q),
static_cast<const __mt_bfloat16 *>(k_cache),
static_cast<const __mt_bfloat16 *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
}
return INFINI_STATUS_SUCCESS;
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
infiniStatus_t launch_decode_hd128_i64(
void *workspace,
size_t workspace_size,
void *out,
const void *q,
const void *k_cache,
const void *v_cache,
infiniDtype_t dtype,
const int64_t *block_tables,
const int64_t *cache_lens,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
musaStream_t stream) {
return launch_decode_hd128_impl<int64_t>(
workspace, workspace_size,
out, q, k_cache, v_cache, dtype, block_tables, cache_lens, alibi_slopes, num_heads, num_seqs,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, k_batch_stride, k_row_stride,
k_head_stride, v_batch_stride, v_row_stride, v_head_stride, o_stride, stream);
}
infiniStatus_t launch_decode_hd128_i32(
void *workspace,
size_t workspace_size,
void *out,
const void *q,
const void *k_cache,
const void *v_cache,
infiniDtype_t dtype,
const int32_t *block_tables,
const int32_t *cache_lens,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
musaStream_t stream) {
return launch_decode_hd128_impl<int32_t>(
workspace, workspace_size,
out, q, k_cache, v_cache, dtype, block_tables, cache_lens, alibi_slopes, num_heads, num_seqs,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, k_batch_stride, k_row_stride,
k_head_stride, v_batch_stride, v_row_stride, v_head_stride, o_stride, stream);
}
infiniStatus_t launch_decode_hd128_u32(
void *workspace,
size_t workspace_size,
void *out,
const void *q,
const void *k_cache,
const void *v_cache,
infiniDtype_t dtype,
const uint32_t *block_tables,
const uint32_t *cache_lens,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
musaStream_t stream) {
return launch_decode_hd128_impl<uint32_t>(
workspace, workspace_size,
out, q, k_cache, v_cache, dtype, block_tables, cache_lens, alibi_slopes, num_heads, num_seqs,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, k_batch_stride, k_row_stride,
k_head_stride, v_batch_stride, v_row_stride, v_head_stride, o_stride, stream);
}
} // namespace op::paged_attention::moore
#include <musa_runtime.h>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include "../../../devices/moore/moore_common.h"
#include "../../../devices/moore/moore_kernel_common.h"
#include "../cuda/kernel_v2.cuh"
namespace op::paged_attention::moore {
namespace {
constexpr int kMaxSplits = 8;
constexpr size_t ceilDiv(size_t a, size_t b) {
return (a + b - 1) / b;
}
inline int getSmCount() {
int device = 0;
if (musaGetDevice(&device) != musaSuccess) {
return 0;
}
int sm_count = 0;
if (musaDeviceGetAttribute(&sm_count, musaDevAttrMultiProcessorCount, device) != musaSuccess) {
return 0;
}
return sm_count;
}
// A lightweight FA2-style "waves" heuristic.
//
// Important: our split-kv kernel shards the KV sequence length, so the main "work"
// dimension is tokens, not the number of pages. We use an upper bound for seqlen_k
// (max pages * page size), which matches common decode microbench where all seqs
// share the same cache length.
inline int chooseNumSplitsHeuristic(size_t num_heads, size_t num_seqs, size_t seqlen_k, int sm_count) {
if (sm_count <= 0) {
return 1;
}
if (num_heads == 0 || num_seqs == 0) {
return 1;
}
if (seqlen_k <= 256) {
return 1;
}
const size_t base_blocks = num_heads * num_seqs;
int best_splits = 1;
// Baseline: one kernel, base_blocks CTAs, each scanning seqlen_k tokens.
size_t best_score = (ceilDiv(base_blocks, static_cast<size_t>(sm_count)) * seqlen_k);
size_t prev_work_per_block = seqlen_k;
for (int s = 2; s <= kMaxSplits; ++s) {
const size_t blocks = base_blocks * static_cast<size_t>(s);
const size_t waves_split = ceilDiv(blocks, static_cast<size_t>(sm_count));
const size_t work_per_block = ceilDiv(seqlen_k, static_cast<size_t>(s));
// If this split count doesn't reduce per-block work vs the previous split, it's effectively redundant.
if (work_per_block == prev_work_per_block) {
continue;
}
prev_work_per_block = work_per_block;
// Combine is one extra kernel with base_blocks blocks; approximate as one more wave unit.
const size_t waves_combine = ceilDiv(base_blocks, static_cast<size_t>(sm_count));
const size_t score = waves_split * work_per_block + waves_combine;
if (score < best_score) {
best_score = score;
best_splits = s;
}
}
return best_splits;
}
} // namespace
template <typename Tindex, typename Tdata>
INFINIOP_MOORE_KERNEL flashAttentionDecodeHd64Warp(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const Tindex *cache_lens,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride) {
op::paged_attention::cuda::flashAttentionDecodeWarpKernel<Tindex, Tdata, 64>(
out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride,
k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride,
v_head_stride, o_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_MOORE_KERNEL flashAttentionDecodeHd64Cta(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const Tindex *cache_lens,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride) {
// Default CTA variant (lower overhead).
op::paged_attention::cuda::flashAttentionDecodeCtaKernel<Tindex, Tdata, 64, 32, 8>(
out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride,
k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride,
v_head_stride, o_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_MOORE_KERNEL flashAttentionDecodeHd64CtaTile16(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const Tindex *cache_lens,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride) {
op::paged_attention::cuda::flashAttentionDecodeCtaKernel<Tindex, Tdata, 64, 32, 16>(
out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride,
k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride,
v_head_stride, o_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_MOORE_KERNEL flashAttentionDecodeHd64SplitKv(
float *partial_acc,
float *partial_m,
float *partial_l,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const Tindex *cache_lens,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
int num_splits) {
op::paged_attention::cuda::flashAttentionDecodeSplitKvWarpKernel<Tindex, Tdata, 64>(
partial_acc, partial_m, partial_l,
q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride,
k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride,
v_head_stride, num_splits);
}
template <typename Tdata>
INFINIOP_MOORE_KERNEL flashAttentionDecodeHd64SplitKvCombine(
Tdata *out,
const float *partial_acc,
const float *partial_m,
const float *partial_l,
int num_splits,
ptrdiff_t o_stride) {
op::paged_attention::cuda::flashAttentionDecodeSplitKvCombineWarpKernel<Tdata, 64>(
out, partial_acc, partial_m, partial_l, num_splits, o_stride);
}
template <typename Tindex>
infiniStatus_t launch_decode_hd64_impl(
void *workspace,
size_t workspace_size,
void *out,
const void *q,
const void *k_cache,
const void *v_cache,
infiniDtype_t dtype,
const Tindex *block_tables,
const Tindex *cache_lens,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
musaStream_t stream) {
dim3 grid(static_cast<uint64_t>(num_heads), static_cast<uint64_t>(num_seqs), 1);
bool use_cta = false;
if (const char *env = std::getenv("INFINIOP_FLASH_DECODE_KERNEL")) {
use_cta = (std::strcmp(env, "cta") == 0);
}
int cta_tile = 8;
if (const char *env = std::getenv("INFINIOP_FLASH_CTA_TILE")) {
const int v = std::atoi(env);
if (v == 8 || v == 16) {
cta_tile = v;
}
}
// For head_dim=64 we use a 1-warp CTA (32 threads) with packed loads.
dim3 block(32);
bool use_split = false;
if (const char *env = std::getenv("INFINIOP_FLASH_DECODE_SPLITKV")) {
use_split = (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0);
}
int num_splits = 4;
bool fixed_num_splits = false;
if (const char *env = std::getenv("INFINIOP_FLASH_NUM_SPLITS")) {
if (std::strcmp(env, "auto") == 0) {
fixed_num_splits = false;
} else {
num_splits = std::atoi(env);
fixed_num_splits = (num_splits > 0);
}
}
if (num_splits < 1) {
num_splits = 1;
}
if (num_splits > kMaxSplits) {
num_splits = kMaxSplits;
}
if (use_split) {
if (use_cta) {
// We currently only implement the split-kv path with warp kernels.
// The CTA kernel is a separate non-split implementation.
static bool warned = false;
if (!warned) {
warned = true;
fprintf(stderr,
"[INFINIOP][paged_attention] split-kv is enabled; ignoring INFINIOP_FLASH_DECODE_KERNEL=cta "
"(CTA split-kv not implemented yet)\n");
}
}
if (!fixed_num_splits) {
// Approximate seqlen_k by the per-seq KV capacity (paged KV upper bound).
const size_t seqlen_k = max_num_blocks_per_seq * page_block_size;
const int sm_count = getSmCount();
num_splits = chooseNumSplitsHeuristic(num_heads, num_seqs, seqlen_k, sm_count);
if (const char *dbg = std::getenv("INFINIOP_FLASH_DEBUG_SPLITS")) {
if (std::strcmp(dbg, "1") == 0 || std::strcmp(dbg, "true") == 0) {
static size_t last_seqlen_k = 0;
if (last_seqlen_k != seqlen_k) {
last_seqlen_k = seqlen_k;
fprintf(stderr,
"[INFINIOP][paged_attention] splitkv auto: sm=%d heads=%zu seqs=%zu seqlen_k~%zu -> num_splits=%d\n",
sm_count, num_heads, num_seqs, seqlen_k, num_splits);
}
}
}
}
const size_t n = num_seqs * num_heads;
const size_t acc_elems = static_cast<size_t>(kMaxSplits) * n * 64;
const size_t m_elems = static_cast<size_t>(kMaxSplits) * n;
const size_t l_elems = static_cast<size_t>(kMaxSplits) * n;
const size_t needed_bytes = (acc_elems + m_elems + l_elems) * sizeof(float);
if (workspace == nullptr || workspace_size < needed_bytes) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
float *ws = static_cast<float *>(workspace);
float *partial_acc = ws;
float *partial_m = partial_acc + acc_elems;
float *partial_l = partial_m + m_elems;
dim3 grid_split(static_cast<uint64_t>(num_heads), static_cast<uint64_t>(num_seqs), static_cast<uint64_t>(num_splits));
dim3 block_split(32);
if (dtype == INFINI_DTYPE_F16) {
flashAttentionDecodeHd64SplitKv<Tindex, half><<<grid_split, block_split, 0, stream>>>(
partial_acc, partial_m, partial_l,
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, num_splits);
flashAttentionDecodeHd64SplitKvCombine<half><<<grid, 32, 0, stream>>>(
static_cast<half *>(out), partial_acc, partial_m, partial_l, num_splits, o_stride);
return INFINI_STATUS_SUCCESS;
}
if (dtype == INFINI_DTYPE_BF16) {
flashAttentionDecodeHd64SplitKv<Tindex, __mt_bfloat16><<<grid_split, block_split, 0, stream>>>(
partial_acc, partial_m, partial_l,
static_cast<const __mt_bfloat16 *>(q),
static_cast<const __mt_bfloat16 *>(k_cache),
static_cast<const __mt_bfloat16 *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, num_splits);
flashAttentionDecodeHd64SplitKvCombine<__mt_bfloat16><<<grid, 32, 0, stream>>>(
static_cast<__mt_bfloat16 *>(out), partial_acc, partial_m, partial_l, num_splits, o_stride);
return INFINI_STATUS_SUCCESS;
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (dtype == INFINI_DTYPE_F16) {
if (use_cta) {
if (cta_tile == 16) {
flashAttentionDecodeHd64CtaTile16<Tindex, half><<<grid, block, 0, stream>>>(
static_cast<half *>(out),
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
} else {
flashAttentionDecodeHd64Cta<Tindex, half><<<grid, block, 0, stream>>>(
static_cast<half *>(out),
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
}
} else {
flashAttentionDecodeHd64Warp<Tindex, half><<<grid, block, 0, stream>>>(
static_cast<half *>(out),
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
}
return INFINI_STATUS_SUCCESS;
}
if (dtype == INFINI_DTYPE_BF16) {
if (use_cta) {
if (cta_tile == 16) {
flashAttentionDecodeHd64CtaTile16<Tindex, __mt_bfloat16><<<grid, block, 0, stream>>>(
static_cast<__mt_bfloat16 *>(out),
static_cast<const __mt_bfloat16 *>(q),
static_cast<const __mt_bfloat16 *>(k_cache),
static_cast<const __mt_bfloat16 *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
} else {
flashAttentionDecodeHd64Cta<Tindex, __mt_bfloat16><<<grid, block, 0, stream>>>(
static_cast<__mt_bfloat16 *>(out),
static_cast<const __mt_bfloat16 *>(q),
static_cast<const __mt_bfloat16 *>(k_cache),
static_cast<const __mt_bfloat16 *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
}
} else {
flashAttentionDecodeHd64Warp<Tindex, __mt_bfloat16><<<grid, block, 0, stream>>>(
static_cast<__mt_bfloat16 *>(out),
static_cast<const __mt_bfloat16 *>(q),
static_cast<const __mt_bfloat16 *>(k_cache),
static_cast<const __mt_bfloat16 *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
}
return INFINI_STATUS_SUCCESS;
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
infiniStatus_t launch_decode_hd64_i64(
void *workspace,
size_t workspace_size,
void *out,
const void *q,
const void *k_cache,
const void *v_cache,
infiniDtype_t dtype,
const int64_t *block_tables,
const int64_t *cache_lens,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
musaStream_t stream) {
return launch_decode_hd64_impl<int64_t>(
workspace, workspace_size,
out, q, k_cache, v_cache, dtype, block_tables, cache_lens, alibi_slopes, num_heads, num_seqs,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, k_batch_stride, k_row_stride,
k_head_stride, v_batch_stride, v_row_stride, v_head_stride, o_stride, stream);
}
infiniStatus_t launch_decode_hd64_i32(
void *workspace,
size_t workspace_size,
void *out,
const void *q,
const void *k_cache,
const void *v_cache,
infiniDtype_t dtype,
const int32_t *block_tables,
const int32_t *cache_lens,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
musaStream_t stream) {
return launch_decode_hd64_impl<int32_t>(
workspace, workspace_size,
out, q, k_cache, v_cache, dtype, block_tables, cache_lens, alibi_slopes, num_heads, num_seqs,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, k_batch_stride, k_row_stride,
k_head_stride, v_batch_stride, v_row_stride, v_head_stride, o_stride, stream);
}
infiniStatus_t launch_decode_hd64_u32(
void *workspace,
size_t workspace_size,
void *out,
const void *q,
const void *k_cache,
const void *v_cache,
infiniDtype_t dtype,
const uint32_t *block_tables,
const uint32_t *cache_lens,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
musaStream_t stream) {
return launch_decode_hd64_impl<uint32_t>(
workspace, workspace_size,
out, q, k_cache, v_cache, dtype, block_tables, cache_lens, alibi_slopes, num_heads, num_seqs,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, k_batch_stride, k_row_stride,
k_head_stride, v_batch_stride, v_row_stride, v_head_stride, o_stride, stream);
}
} // namespace op::paged_attention::moore
#ifndef __PAGED_ATTENTION_MOORE_H__
#define __PAGED_ATTENTION_MOORE_H__
#include "../paged_attention.h"
DESCRIPTOR(moore)
#endif // __PAGED_ATTENTION_MOORE_H__
#include <musa_runtime.h>
#include <cstdint>
#include <cstdlib>
#include <cstring>
#include "../../../devices/moore/moore_common.h"
#include "paged_attention_moore.h"
namespace op::paged_attention::moore {
infiniStatus_t launch_decode_hd64_i64(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
infiniDtype_t dtype, const int64_t *block_tables, const int64_t *cache_lens, const float *alibi_slopes,
size_t num_heads, size_t num_seqs, size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t page_block_size,
ptrdiff_t q_stride, ptrdiff_t k_batch_stride, ptrdiff_t k_row_stride, ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride, ptrdiff_t v_row_stride, ptrdiff_t v_head_stride, ptrdiff_t o_stride,
musaStream_t stream);
infiniStatus_t launch_decode_hd64_i32(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
infiniDtype_t dtype, const int32_t *block_tables, const int32_t *cache_lens, const float *alibi_slopes,
size_t num_heads, size_t num_seqs, size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t page_block_size,
ptrdiff_t q_stride, ptrdiff_t k_batch_stride, ptrdiff_t k_row_stride, ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride, ptrdiff_t v_row_stride, ptrdiff_t v_head_stride, ptrdiff_t o_stride,
musaStream_t stream);
infiniStatus_t launch_decode_hd64_u32(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
infiniDtype_t dtype, const uint32_t *block_tables, const uint32_t *cache_lens, const float *alibi_slopes,
size_t num_heads, size_t num_seqs, size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t page_block_size,
ptrdiff_t q_stride, ptrdiff_t k_batch_stride, ptrdiff_t k_row_stride, ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride, ptrdiff_t v_row_stride, ptrdiff_t v_head_stride, ptrdiff_t o_stride,
musaStream_t stream);
infiniStatus_t launch_decode_hd128_i64(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
infiniDtype_t dtype, const int64_t *block_tables, const int64_t *cache_lens, const float *alibi_slopes,
size_t num_heads, size_t num_seqs, size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t page_block_size,
ptrdiff_t q_stride, ptrdiff_t k_batch_stride, ptrdiff_t k_row_stride, ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride, ptrdiff_t v_row_stride, ptrdiff_t v_head_stride, ptrdiff_t o_stride,
musaStream_t stream);
infiniStatus_t launch_decode_hd128_i32(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
infiniDtype_t dtype, const int32_t *block_tables, const int32_t *cache_lens, const float *alibi_slopes,
size_t num_heads, size_t num_seqs, size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t page_block_size,
ptrdiff_t q_stride, ptrdiff_t k_batch_stride, ptrdiff_t k_row_stride, ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride, ptrdiff_t v_row_stride, ptrdiff_t v_head_stride, ptrdiff_t o_stride,
musaStream_t stream);
infiniStatus_t launch_decode_hd128_u32(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
infiniDtype_t dtype, const uint32_t *block_tables, const uint32_t *cache_lens, const float *alibi_slopes,
size_t num_heads, size_t num_seqs, size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t page_block_size,
ptrdiff_t q_stride, ptrdiff_t k_batch_stride, ptrdiff_t k_row_stride, ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride, ptrdiff_t v_row_stride, ptrdiff_t v_head_stride, ptrdiff_t o_stride,
musaStream_t stream);
struct Descriptor::Opaque {
std::shared_ptr<device::moore::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t cache_lens_desc,
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc,
float scale) {
auto info_res = PagedAttentionInfo::create(out_desc, q_desc, k_cache_desc, v_cache_desc, block_tables_desc, cache_lens_desc, alibi_slopes_desc, scale);
CHECK_RESULT(info_res);
auto info = info_res.take();
// Reserve workspace for optional split-kv decode (partial acc + m/l).
// Workspace is independent of runtime env toggles; kernels will clamp num_splits <= kMaxSplits.
constexpr size_t kMaxSplits = 8;
const size_t per_split = info.num_seqs * info.num_heads * (info.head_size + 2) * sizeof(float);
const size_t workspace_bytes = kMaxSplits * per_split;
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::moore::Handle *>(handle)->internal()},
info, workspace_bytes, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
const void *block_tables, const void *cache_lens, const void *alibi_slopes,
void *stream_) const {
bool need_workspace = false;
if (const char *env = std::getenv("INFINIOP_FLASH_DECODE_SPLITKV")) {
// "auto" may enable split-kv depending on the runtime heuristic.
need_workspace = (std::strcmp(env, "auto") == 0) || (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0);
} else {
// Keep hd64 behavior unchanged, but for hd128 we default to split-kv decode, which needs workspace.
need_workspace = (_info.head_size == 128);
}
if (need_workspace && workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
auto stream = static_cast<musaStream_t>(stream_);
const float *alibi_ptr = (alibi_slopes == nullptr) ? nullptr : static_cast<const float *>(alibi_slopes);
if (_info.index_dtype == INFINI_DTYPE_I64) {
const auto *block_table_i64 = static_cast<const int64_t *>(block_tables);
const auto *cache_lens_i64 = static_cast<const int64_t *>(cache_lens);
switch (_info.head_size) {
case 64:
return launch_decode_hd64_i64(
workspace, workspace_size,
out, q, k_cache, v_cache, _info.dtype,
block_table_i64, cache_lens_i64, alibi_ptr,
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.scale,
_info.max_num_blocks_per_seq, _info.page_block_size,
_info.q_stride, _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride,
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride,
_info.o_stride, stream);
case 128:
return launch_decode_hd128_i64(
workspace, workspace_size,
out, q, k_cache, v_cache, _info.dtype,
block_table_i64, cache_lens_i64, alibi_ptr,
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.scale,
_info.max_num_blocks_per_seq, _info.page_block_size,
_info.q_stride, _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride,
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride,
_info.o_stride, stream);
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
if (_info.index_dtype == INFINI_DTYPE_I32) {
const auto *block_table_i32 = static_cast<const int32_t *>(block_tables);
const auto *cache_lens_i32 = static_cast<const int32_t *>(cache_lens);
switch (_info.head_size) {
case 64:
return launch_decode_hd64_i32(
workspace, workspace_size,
out, q, k_cache, v_cache, _info.dtype,
block_table_i32, cache_lens_i32, alibi_ptr,
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.scale,
_info.max_num_blocks_per_seq, _info.page_block_size,
_info.q_stride, _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride,
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride,
_info.o_stride, stream);
case 128:
return launch_decode_hd128_i32(
workspace, workspace_size,
out, q, k_cache, v_cache, _info.dtype,
block_table_i32, cache_lens_i32, alibi_ptr,
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.scale,
_info.max_num_blocks_per_seq, _info.page_block_size,
_info.q_stride, _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride,
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride,
_info.o_stride, stream);
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
if (_info.index_dtype == INFINI_DTYPE_U32) {
const auto *block_table_u32 = static_cast<const uint32_t *>(block_tables);
const auto *cache_lens_u32 = static_cast<const uint32_t *>(cache_lens);
switch (_info.head_size) {
case 64:
return launch_decode_hd64_u32(
workspace, workspace_size,
out, q, k_cache, v_cache, _info.dtype,
block_table_u32, cache_lens_u32, alibi_ptr,
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.scale,
_info.max_num_blocks_per_seq, _info.page_block_size,
_info.q_stride, _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride,
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride,
_info.o_stride, stream);
case 128:
return launch_decode_hd128_u32(
workspace, workspace_size,
out, q, k_cache, v_cache, _info.dtype,
block_table_u32, cache_lens_u32, alibi_ptr,
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.scale,
_info.max_num_blocks_per_seq, _info.page_block_size,
_info.q_stride, _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride,
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride,
_info.o_stride, stream);
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} // namespace op::paged_attention::moore
......@@ -5,6 +5,9 @@
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API)
#include "nvidia/paged_attention_nvidia.cuh"
#endif
#ifdef ENABLE_MOORE_API
#include "moore/paged_attention_moore.h"
#endif
#ifdef ENABLE_METAX_API
#include "metax/paged_attention_metax.h"
#endif
......@@ -40,6 +43,9 @@ __C infiniStatus_t infiniopCreatePagedAttentionDescriptor(
#ifdef ENABLE_ALI_API
CREATE(INFINI_DEVICE_ALI, nvidia)
#endif
#ifdef ENABLE_MOORE_API
CREATE(INFINI_DEVICE_MOORE, moore)
#endif
#ifdef ENABLE_ILUVATAR_API
CREATE(INFINI_DEVICE_ILUVATAR, nvidia)
#endif
......@@ -67,6 +73,9 @@ __C infiniStatus_t infiniopGetPagedAttentionWorkspaceSize(
#ifdef ENABLE_ALI_API
GET(INFINI_DEVICE_ALI, nvidia)
#endif
#ifdef ENABLE_MOORE_API
GET(INFINI_DEVICE_MOORE, moore)
#endif
#ifdef ENABLE_ILUVATAR_API
GET(INFINI_DEVICE_ILUVATAR, nvidia)
#endif
......@@ -98,6 +107,9 @@ __C infiniStatus_t infiniopPagedAttention(
#ifdef ENABLE_ALI_API
CALCULATE(INFINI_DEVICE_ALI, nvidia)
#endif
#ifdef ENABLE_MOORE_API
CALCULATE(INFINI_DEVICE_MOORE, moore)
#endif
#ifdef ENABLE_ILUVATAR_API
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia)
#endif
......@@ -124,6 +136,9 @@ __C infiniStatus_t infiniopDestroyPagedAttentionDescriptor(
#ifdef ENABLE_ALI_API
DESTROY(INFINI_DEVICE_ALI, nvidia)
#endif
#ifdef ENABLE_MOORE_API
DESTROY(INFINI_DEVICE_MOORE, moore)
#endif
#ifdef ENABLE_ILUVATAR_API
DESTROY(INFINI_DEVICE_ILUVATAR, nvidia)
#endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment