Commit 7a18d241 authored by wooway777's avatar wooway777
Browse files

issue/983 - adapted the optimized paged attention to metax

parent 4cd1f688
......@@ -85,6 +85,9 @@
#define hcclSuccess mcclSuccess
#define hcclCommDestroy mcclCommDestroy
#define hcclAllReduce mcclAllReduce
#define hcGetDevice mcGetDevice
#define hcDeviceAttributeMultiProcessorCount mcDeviceAttributeMultiProcessorCount
#define hcDeviceGetAttribute mcDeviceGetAttribute
#define hcStreamCaptureMode mcStreamCaptureMode
#define hcStreamCaptureModeGlobal mcStreamCaptureModeGlobal
#define hcStreamCaptureModeThreadLocal mcStreamCaptureModeThreadLocal
......
......@@ -19,6 +19,12 @@ using cuda_bfloat16 = hpcc_bfloat16;
using cuda_bfloat162 = hpcc_bfloat162;
using cuda_fp8_e4m3 = __hpcc_fp8_e4m3;
#ifdef ENABLE_METAX_MC_API
using __nv_bfloat16 = __maca_bfloat16;
#else
using __nv_bfloat16 = __hpcc_bfloat16;
#endif
namespace device::metax {
// get the memory offset of the given element in a tensor given its flat index
......
#ifdef ENABLE_METAX_MC_API
#include <mc_runtime.h>
#else
#include <hc_runtime.h>
#endif
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include "../../../devices/metax/metax_common.h"
#include "../../../devices/metax/metax_kernel_common.h"
#include "../cuda/kernel_v2.cuh"
namespace op::paged_attention::metax {
namespace {
constexpr int kMaxSplits = 8;
constexpr size_t ceilDiv(size_t a, size_t b) {
return (a + b - 1) / b;
}
inline int getSmCount() {
int device = 0;
if (hcGetDevice(&device) != hcSuccess) {
return 0;
}
int sm_count = 0;
if (hcDeviceGetAttribute(&sm_count, hcDeviceAttributeMultiProcessorCount, device) != hcSuccess) {
return 0;
}
return sm_count;
}
// A lightweight FA2-style "waves" heuristic.
//
// Important: our split-kv kernel shards the KV sequence length, so the main "work"
// dimension is tokens, not the number of pages. We use an upper bound for seqlen_k
// (max pages * page size), which matches common decode microbench where all seqs
// share the same cache length.
inline int chooseNumSplitsHeuristic(size_t num_heads, size_t num_seqs, size_t seqlen_k, int sm_count) {
if (sm_count <= 0) {
return 1;
}
if (num_heads == 0 || num_seqs == 0) {
return 1;
}
if (seqlen_k <= 256) {
return 1;
}
const size_t base_blocks = num_heads * num_seqs;
int best_splits = 1;
// Baseline: one kernel, base_blocks CTAs, each scanning seqlen_k tokens.
size_t best_score = (ceilDiv(base_blocks, static_cast<size_t>(sm_count)) * seqlen_k);
size_t prev_work_per_block = seqlen_k;
for (int s = 2; s <= kMaxSplits; ++s) {
const size_t blocks = base_blocks * static_cast<size_t>(s);
const size_t waves_split = ceilDiv(blocks, static_cast<size_t>(sm_count));
const size_t work_per_block = ceilDiv(seqlen_k, static_cast<size_t>(s));
// If this split count doesn't reduce per-block work vs the previous split, it's effectively redundant.
if (work_per_block == prev_work_per_block) {
continue;
}
prev_work_per_block = work_per_block;
// Combine is one extra kernel with base_blocks blocks; approximate as one more wave unit.
const size_t waves_combine = ceilDiv(base_blocks, static_cast<size_t>(sm_count));
const size_t score = waves_split * work_per_block + waves_combine;
if (score < best_score) {
best_score = score;
best_splits = s;
}
}
return best_splits;
}
} // namespace
inline bool envBool(const char *name) {
if (const char *env = std::getenv(name)) {
return (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0);
}
return false;
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL flashAttentionDecodeHd128Warp(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const Tindex *cache_lens,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride) {
op::paged_attention::cuda::flashAttentionDecodeWarpKernel<Tindex, Tdata, 128>(
out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride,
k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride,
v_head_stride, o_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL flashAttentionDecodeHd128Cta(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const Tindex *cache_lens,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride) {
// Default CTA variant (lower overhead).
op::paged_attention::cuda::flashAttentionDecodeCtaKernel<Tindex, Tdata, 128, 64, 8>(
out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride,
k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride,
v_head_stride, o_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL flashAttentionDecodeHd128CtaTile16(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const Tindex *cache_lens,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride) {
op::paged_attention::cuda::flashAttentionDecodeCtaKernel<Tindex, Tdata, 128, 64, 16>(
out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride,
k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride,
v_head_stride, o_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL flashAttentionDecodeHd128Cta32(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const Tindex *cache_lens,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride) {
// Experimental 1-warp CTA variant for head_dim=128 (kPack=4).
op::paged_attention::cuda::flashAttentionDecodeCtaKernel<Tindex, Tdata, 128, 32, 8>(
out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride,
k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride,
v_head_stride, o_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL flashAttentionDecodeHd128Cta32Tile16(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const Tindex *cache_lens,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride) {
op::paged_attention::cuda::flashAttentionDecodeCtaKernel<Tindex, Tdata, 128, 32, 16>(
out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride,
k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride,
v_head_stride, o_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL flashAttentionDecodeHd128CtaGqa4(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const Tindex *cache_lens,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride) {
// GQA fused kernel: CTA computes 4 query heads for one KV head (head_dim=128).
op::paged_attention::cuda::flashAttentionDecodeCtaGqaKernel<Tindex, Tdata, 128, 64, 8, 4>(
out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride,
k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride,
v_head_stride, o_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL flashAttentionDecodeHd128SplitKv(
float *partial_acc,
float *partial_m,
float *partial_l,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const Tindex *cache_lens,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
int num_splits) {
op::paged_attention::cuda::flashAttentionDecodeSplitKvWarpKernel<Tindex, Tdata, 128>(
partial_acc, partial_m, partial_l,
q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride,
k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride,
v_head_stride, num_splits);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL flashAttentionDecodeHd128SplitKvCta(
float *partial_acc,
float *partial_m,
float *partial_l,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const Tindex *cache_lens,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
int num_splits) {
op::paged_attention::cuda::flashAttentionDecodeSplitKvCtaKernel<Tindex, Tdata, 128, 64, 8>(
partial_acc, partial_m, partial_l,
q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride,
k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride,
v_head_stride, num_splits);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL flashAttentionDecodeHd128SplitKvCtaTile16(
float *partial_acc,
float *partial_m,
float *partial_l,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const Tindex *cache_lens,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
int num_splits) {
op::paged_attention::cuda::flashAttentionDecodeSplitKvCtaKernel<Tindex, Tdata, 128, 64, 16>(
partial_acc, partial_m, partial_l,
q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride,
k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride,
v_head_stride, num_splits);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL flashAttentionDecodeHd128SplitKvCta32(
float *partial_acc,
float *partial_m,
float *partial_l,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const Tindex *cache_lens,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
int num_splits) {
op::paged_attention::cuda::flashAttentionDecodeSplitKvCtaKernel<Tindex, Tdata, 128, 32, 8>(
partial_acc, partial_m, partial_l,
q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride,
k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride,
v_head_stride, num_splits);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL flashAttentionDecodeHd128SplitKvCta32Tile16(
float *partial_acc,
float *partial_m,
float *partial_l,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const Tindex *cache_lens,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
int num_splits) {
op::paged_attention::cuda::flashAttentionDecodeSplitKvCtaKernel<Tindex, Tdata, 128, 32, 16>(
partial_acc, partial_m, partial_l,
q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride,
k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride,
v_head_stride, num_splits);
}
template <typename Tdata>
INFINIOP_METAX_KERNEL flashAttentionDecodeHd128SplitKvCombine(
Tdata *out,
const float *partial_acc,
const float *partial_m,
const float *partial_l,
int num_splits,
ptrdiff_t o_stride) {
op::paged_attention::cuda::flashAttentionDecodeSplitKvCombineWarpKernel<Tdata, 128>(
out, partial_acc, partial_m, partial_l, num_splits, o_stride);
}
template <typename Tindex>
infiniStatus_t launch_decode_hd128_impl(
void *workspace,
size_t workspace_size,
void *out,
const void *q,
const void *k_cache,
const void *v_cache,
infiniDtype_t dtype,
const Tindex *block_tables,
const Tindex *cache_lens,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
hcStream_t stream) {
// Default decode config (2026-01-22):
// decode_flash_cta8_64_gqa_splitkv_4
// Users can override any knob via the corresponding INFINIOP_FLASH_* env vars.
bool use_cta = true;
if (const char *env = std::getenv("INFINIOP_FLASH_DECODE_KERNEL")) {
// Backward-compatible: any non-"cta" value means "warp".
use_cta = (std::strcmp(env, "cta") == 0);
}
bool use_gqa_fused = true;
if (const char *env = std::getenv("INFINIOP_FLASH_GQA_FUSED")) {
if (std::strcmp(env, "0") == 0 || std::strcmp(env, "false") == 0) {
use_gqa_fused = false;
} else {
use_gqa_fused = (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0);
}
}
int cta_tile = 8;
if (const char *env = std::getenv("INFINIOP_FLASH_CTA_TILE")) {
const int v = std::atoi(env);
if (v == 8 || v == 16) {
cta_tile = v;
}
}
int cta_threads = 64;
if (const char *env = std::getenv("INFINIOP_FLASH_CTA_THREADS")) {
const int v = std::atoi(env);
if (v == 32 || v == 64) {
cta_threads = v;
}
}
dim3 block(use_cta ? static_cast<uint32_t>(cta_threads) : 32);
bool use_split = true;
bool use_split_auto = false;
if (const char *env = std::getenv("INFINIOP_FLASH_DECODE_SPLITKV")) {
if (std::strcmp(env, "auto") == 0) {
use_split_auto = true;
use_split = false;
} else {
if (std::strcmp(env, "0") == 0 || std::strcmp(env, "false") == 0) {
use_split = false;
} else {
use_split = (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0);
}
}
}
int num_splits = 4;
bool fixed_num_splits = true;
if (const char *env = std::getenv("INFINIOP_FLASH_NUM_SPLITS")) {
if (std::strcmp(env, "auto") == 0) {
fixed_num_splits = false;
} else {
num_splits = std::atoi(env);
fixed_num_splits = (num_splits > 0);
}
}
if (num_splits < 1) {
num_splits = 1;
}
if (num_splits > kMaxSplits) {
num_splits = kMaxSplits;
}
const bool debug_dispatch = envBool("INFINIOP_FLASH_DEBUG_DISPATCH");
auto dump_dispatch = [&](const char *path) {
if (!debug_dispatch) {
return;
}
// Avoid spamming: only print when the key dispatch signature changes.
struct Sig {
const char *path;
int dtype;
size_t heads;
size_t kv_heads;
size_t seqs;
size_t pbs;
size_t max_blocks;
int cta_tile;
int cta_threads;
int split;
int split_auto;
int num_splits;
int fixed;
int gqa_fused;
};
static Sig last{};
static bool has_last = false;
Sig cur{
path,
static_cast<int>(dtype),
num_heads,
num_kv_heads,
num_seqs,
page_block_size,
max_num_blocks_per_seq,
cta_tile,
cta_threads,
static_cast<int>(use_split),
static_cast<int>(use_split_auto),
num_splits,
static_cast<int>(fixed_num_splits),
static_cast<int>(use_gqa_fused),
};
if (has_last && cur.path == last.path && cur.dtype == last.dtype && cur.heads == last.heads && cur.kv_heads == last.kv_heads && cur.seqs == last.seqs && cur.pbs == last.pbs && cur.max_blocks == last.max_blocks && cur.cta_tile == last.cta_tile && cur.cta_threads == last.cta_threads && cur.split == last.split && cur.split_auto == last.split_auto && cur.num_splits == last.num_splits && cur.fixed == last.fixed && cur.gqa_fused == last.gqa_fused) {
return;
}
last = cur;
has_last = true;
fprintf(stderr,
"[INFINIOP][paged_attention][hd128] dispatch: path=%s dtype=%d heads=%zu kv_heads=%zu seqs=%zu "
"pbs=%zu max_blocks=%zu cta_tile=%d cta_threads=%d split=%d split_auto=%d num_splits=%d fixed=%d gqa_fused=%d\n",
path, static_cast<int>(dtype), num_heads, num_kv_heads, num_seqs,
page_block_size, max_num_blocks_per_seq, cta_tile, cta_threads,
static_cast<int>(use_split), static_cast<int>(use_split_auto), num_splits, static_cast<int>(fixed_num_splits),
static_cast<int>(use_gqa_fused));
};
// Split-kv auto mode: decide whether to split based on a heuristic.
if (use_split_auto) {
// Approximate seqlen_k by the per-seq KV capacity (paged KV upper bound).
const size_t seqlen_k = max_num_blocks_per_seq * page_block_size;
const int sm_count = getSmCount();
num_splits = chooseNumSplitsHeuristic(num_heads, num_seqs, seqlen_k, sm_count);
if (const char *dbg = std::getenv("INFINIOP_FLASH_DEBUG_SPLITS")) {
if (std::strcmp(dbg, "1") == 0 || std::strcmp(dbg, "true") == 0) {
static size_t last_seqlen_k = 0;
if (last_seqlen_k != seqlen_k) {
last_seqlen_k = seqlen_k;
fprintf(stderr,
"[INFINIOP][paged_attention] splitkv auto(mode): sm=%d heads=%zu seqs=%zu seqlen_k~%zu -> num_splits=%d\n",
sm_count, num_heads, num_seqs, seqlen_k, num_splits);
}
}
}
// If auto picks 1, fall back to non-split to avoid extra workspace and kernel overhead.
use_split = (num_splits > 1);
}
// const bool debug_dispatch = [] {
// if (const char *env = std::getenv("INFINIOP_FLASH_DEBUG_DISPATCH")) {
// return (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0);
// }
// return false;
// }();
// const char *selected_path = "unknown";
// Optional: fuse GQA groups (4) when seqlen_q=1 decode and alibi is disabled.
// This reuses K/V loads across query heads that share the same KV head.
// Controlled by INFINIOP_FLASH_GQA_FUSED (default: enabled).
if (use_gqa_fused && use_cta && !use_split && alibi_slopes == nullptr && num_kv_heads > 0 && num_heads == num_kv_heads * 4) {
dump_dispatch("cta_gqa_fused");
dim3 grid_gqa(static_cast<uint64_t>(num_kv_heads), static_cast<uint64_t>(num_seqs), 1);
if (dtype == INFINI_DTYPE_F16) {
flashAttentionDecodeHd128CtaGqa4<Tindex, half><<<grid_gqa, 64, 0, stream>>>(
static_cast<half *>(out),
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, cache_lens, nullptr,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
return INFINI_STATUS_SUCCESS;
}
if (dtype == INFINI_DTYPE_BF16) {
flashAttentionDecodeHd128CtaGqa4<Tindex, __nv_bfloat16><<<grid_gqa, 64, 0, stream>>>(
static_cast<__nv_bfloat16 *>(out),
static_cast<const __nv_bfloat16 *>(q),
static_cast<const __nv_bfloat16 *>(k_cache),
static_cast<const __nv_bfloat16 *>(v_cache),
block_tables, cache_lens, nullptr,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
return INFINI_STATUS_SUCCESS;
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
dim3 grid(static_cast<uint64_t>(num_heads), static_cast<uint64_t>(num_seqs), 1);
if (use_split) {
dump_dispatch(use_cta ? "splitkv_cta" : "splitkv_warp");
// }
if (!fixed_num_splits) {
// Approximate seqlen_k by the per-seq KV capacity (paged KV upper bound).
const size_t seqlen_k = max_num_blocks_per_seq * page_block_size;
const int sm_count = getSmCount();
num_splits = chooseNumSplitsHeuristic(num_heads, num_seqs, seqlen_k, sm_count);
if (const char *dbg = std::getenv("INFINIOP_FLASH_DEBUG_SPLITS")) {
if (std::strcmp(dbg, "1") == 0 || std::strcmp(dbg, "true") == 0) {
static size_t last_seqlen_k = 0;
if (last_seqlen_k != seqlen_k) {
last_seqlen_k = seqlen_k;
fprintf(stderr,
"[INFINIOP][paged_attention] splitkv auto: sm=%d heads=%zu seqs=%zu seqlen_k~%zu -> num_splits=%d\n",
sm_count, num_heads, num_seqs, seqlen_k, num_splits);
}
}
}
}
const size_t n = num_seqs * num_heads;
const size_t acc_elems = static_cast<size_t>(kMaxSplits) * n * 128;
const size_t m_elems = static_cast<size_t>(kMaxSplits) * n;
const size_t l_elems = static_cast<size_t>(kMaxSplits) * n;
const size_t needed_bytes = (acc_elems + m_elems + l_elems) * sizeof(float);
if (workspace == nullptr || workspace_size < needed_bytes) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
float *ws = static_cast<float *>(workspace);
float *partial_acc = ws;
float *partial_m = partial_acc + acc_elems;
float *partial_l = partial_m + m_elems;
dim3 grid_split(static_cast<uint64_t>(num_heads), static_cast<uint64_t>(num_seqs), static_cast<uint64_t>(num_splits));
dim3 block_split(use_cta ? static_cast<uint32_t>(cta_threads) : 32);
if (dtype == INFINI_DTYPE_F16) {
if (use_cta) {
if (cta_threads == 32) {
if (cta_tile == 16) {
flashAttentionDecodeHd128SplitKvCta32Tile16<Tindex, half><<<grid_split, block_split, 0, stream>>>(
partial_acc, partial_m, partial_l,
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, num_splits);
} else {
flashAttentionDecodeHd128SplitKvCta32<Tindex, half><<<grid_split, block_split, 0, stream>>>(
partial_acc, partial_m, partial_l,
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, num_splits);
}
} else {
if (cta_tile == 16) {
flashAttentionDecodeHd128SplitKvCtaTile16<Tindex, half><<<grid_split, block_split, 0, stream>>>(
partial_acc, partial_m, partial_l,
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, num_splits);
} else {
flashAttentionDecodeHd128SplitKvCta<Tindex, half><<<grid_split, block_split, 0, stream>>>(
partial_acc, partial_m, partial_l,
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, num_splits);
}
}
} else {
flashAttentionDecodeHd128SplitKv<Tindex, half><<<grid_split, block_split, 0, stream>>>(
partial_acc, partial_m, partial_l,
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, num_splits);
}
flashAttentionDecodeHd128SplitKvCombine<half><<<grid, 32, 0, stream>>>(
static_cast<half *>(out), partial_acc, partial_m, partial_l, num_splits, o_stride);
return INFINI_STATUS_SUCCESS;
}
if (dtype == INFINI_DTYPE_BF16) {
if (use_cta) {
if (cta_threads == 32) {
if (cta_tile == 16) {
flashAttentionDecodeHd128SplitKvCta32Tile16<Tindex, __nv_bfloat16><<<grid_split, block_split, 0, stream>>>(
partial_acc, partial_m, partial_l,
static_cast<const __nv_bfloat16 *>(q),
static_cast<const __nv_bfloat16 *>(k_cache),
static_cast<const __nv_bfloat16 *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, num_splits);
} else {
flashAttentionDecodeHd128SplitKvCta32<Tindex, __nv_bfloat16><<<grid_split, block_split, 0, stream>>>(
partial_acc, partial_m, partial_l,
static_cast<const __nv_bfloat16 *>(q),
static_cast<const __nv_bfloat16 *>(k_cache),
static_cast<const __nv_bfloat16 *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, num_splits);
}
} else {
if (cta_tile == 16) {
flashAttentionDecodeHd128SplitKvCtaTile16<Tindex, __nv_bfloat16><<<grid_split, block_split, 0, stream>>>(
partial_acc, partial_m, partial_l,
static_cast<const __nv_bfloat16 *>(q),
static_cast<const __nv_bfloat16 *>(k_cache),
static_cast<const __nv_bfloat16 *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, num_splits);
} else {
flashAttentionDecodeHd128SplitKvCta<Tindex, __nv_bfloat16><<<grid_split, block_split, 0, stream>>>(
partial_acc, partial_m, partial_l,
static_cast<const __nv_bfloat16 *>(q),
static_cast<const __nv_bfloat16 *>(k_cache),
static_cast<const __nv_bfloat16 *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, num_splits);
}
}
} else {
flashAttentionDecodeHd128SplitKv<Tindex, __nv_bfloat16><<<grid_split, block_split, 0, stream>>>(
partial_acc, partial_m, partial_l,
static_cast<const __nv_bfloat16 *>(q),
static_cast<const __nv_bfloat16 *>(k_cache),
static_cast<const __nv_bfloat16 *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, num_splits);
}
flashAttentionDecodeHd128SplitKvCombine<__nv_bfloat16><<<grid, 32, 0, stream>>>(
static_cast<__nv_bfloat16 *>(out), partial_acc, partial_m, partial_l, num_splits, o_stride);
return INFINI_STATUS_SUCCESS;
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
dump_dispatch(use_cta ? "cta_nosplit" : "warp_nosplit");
if (dtype == INFINI_DTYPE_F16) {
if (use_cta) {
if (cta_tile == 16) {
if (cta_threads == 32) {
flashAttentionDecodeHd128Cta32Tile16<Tindex, half><<<grid, block, 0, stream>>>(
static_cast<half *>(out),
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
} else {
flashAttentionDecodeHd128CtaTile16<Tindex, half><<<grid, block, 0, stream>>>(
static_cast<half *>(out),
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
}
} else {
if (cta_threads == 32) {
flashAttentionDecodeHd128Cta32<Tindex, half><<<grid, block, 0, stream>>>(
static_cast<half *>(out),
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
} else {
flashAttentionDecodeHd128Cta<Tindex, half><<<grid, block, 0, stream>>>(
static_cast<half *>(out),
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
}
}
} else {
flashAttentionDecodeHd128Warp<Tindex, half><<<grid, block, 0, stream>>>(
static_cast<half *>(out),
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
}
return INFINI_STATUS_SUCCESS;
}
if (dtype == INFINI_DTYPE_BF16) {
if (use_cta) {
if (cta_tile == 16) {
if (cta_threads == 32) {
flashAttentionDecodeHd128Cta32Tile16<Tindex, __nv_bfloat16><<<grid, block, 0, stream>>>(
static_cast<__nv_bfloat16 *>(out),
static_cast<const __nv_bfloat16 *>(q),
static_cast<const __nv_bfloat16 *>(k_cache),
static_cast<const __nv_bfloat16 *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
} else {
flashAttentionDecodeHd128CtaTile16<Tindex, __nv_bfloat16><<<grid, block, 0, stream>>>(
static_cast<__nv_bfloat16 *>(out),
static_cast<const __nv_bfloat16 *>(q),
static_cast<const __nv_bfloat16 *>(k_cache),
static_cast<const __nv_bfloat16 *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
}
} else {
if (cta_threads == 32) {
flashAttentionDecodeHd128Cta32<Tindex, __nv_bfloat16><<<grid, block, 0, stream>>>(
static_cast<__nv_bfloat16 *>(out),
static_cast<const __nv_bfloat16 *>(q),
static_cast<const __nv_bfloat16 *>(k_cache),
static_cast<const __nv_bfloat16 *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
} else {
flashAttentionDecodeHd128Cta<Tindex, __nv_bfloat16><<<grid, block, 0, stream>>>(
static_cast<__nv_bfloat16 *>(out),
static_cast<const __nv_bfloat16 *>(q),
static_cast<const __nv_bfloat16 *>(k_cache),
static_cast<const __nv_bfloat16 *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
}
}
} else {
flashAttentionDecodeHd128Warp<Tindex, __nv_bfloat16><<<grid, block, 0, stream>>>(
static_cast<__nv_bfloat16 *>(out),
static_cast<const __nv_bfloat16 *>(q),
static_cast<const __nv_bfloat16 *>(k_cache),
static_cast<const __nv_bfloat16 *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
}
return INFINI_STATUS_SUCCESS;
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
infiniStatus_t launch_decode_hd128_i64(
void *workspace,
size_t workspace_size,
void *out,
const void *q,
const void *k_cache,
const void *v_cache,
infiniDtype_t dtype,
const int64_t *block_tables,
const int64_t *cache_lens,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
hcStream_t stream) {
return launch_decode_hd128_impl<int64_t>(
workspace, workspace_size,
out, q, k_cache, v_cache, dtype, block_tables, cache_lens, alibi_slopes, num_heads, num_seqs,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, k_batch_stride, k_row_stride,
k_head_stride, v_batch_stride, v_row_stride, v_head_stride, o_stride, stream);
}
infiniStatus_t launch_decode_hd128_i32(
void *workspace,
size_t workspace_size,
void *out,
const void *q,
const void *k_cache,
const void *v_cache,
infiniDtype_t dtype,
const int32_t *block_tables,
const int32_t *cache_lens,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
hcStream_t stream) {
return launch_decode_hd128_impl<int32_t>(
workspace, workspace_size,
out, q, k_cache, v_cache, dtype, block_tables, cache_lens, alibi_slopes, num_heads, num_seqs,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, k_batch_stride, k_row_stride,
k_head_stride, v_batch_stride, v_row_stride, v_head_stride, o_stride, stream);
}
infiniStatus_t launch_decode_hd128_u32(
void *workspace,
size_t workspace_size,
void *out,
const void *q,
const void *k_cache,
const void *v_cache,
infiniDtype_t dtype,
const uint32_t *block_tables,
const uint32_t *cache_lens,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
hcStream_t stream) {
return launch_decode_hd128_impl<uint32_t>(
workspace, workspace_size,
out, q, k_cache, v_cache, dtype, block_tables, cache_lens, alibi_slopes, num_heads, num_seqs,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, k_batch_stride, k_row_stride,
k_head_stride, v_batch_stride, v_row_stride, v_head_stride, o_stride, stream);
}
} // namespace op::paged_attention::metax
#ifdef ENABLE_METAX_MC_API
#include <mc_runtime.h>
#else
#include <hc_runtime.h>
#endif
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include "../../../devices/metax/metax_common.h"
#include "../../../devices/metax/metax_kernel_common.h"
#include "../cuda/kernel_v2.cuh"
namespace op::paged_attention::metax {
namespace {
constexpr int kMaxSplits = 8;
constexpr size_t ceilDiv(size_t a, size_t b) {
return (a + b - 1) / b;
}
inline int getSmCount() {
int device = 0;
if (hcGetDevice(&device) != hcSuccess) {
return 0;
}
int sm_count = 0;
if (hcDeviceGetAttribute(&sm_count, hcDeviceAttributeMultiProcessorCount, device) != hcSuccess) {
return 0;
}
return sm_count;
}
// A lightweight FA2-style "waves" heuristic.
//
// Important: our split-kv kernel shards the KV sequence length, so the main "work"
// dimension is tokens, not the number of pages. We use an upper bound for seqlen_k
// (max pages * page size), which matches common decode microbench where all seqs
// share the same cache length.
inline int chooseNumSplitsHeuristic(size_t num_heads, size_t num_seqs, size_t seqlen_k, int sm_count) {
if (sm_count <= 0) {
return 1;
}
if (num_heads == 0 || num_seqs == 0) {
return 1;
}
if (seqlen_k <= 256) {
return 1;
}
const size_t base_blocks = num_heads * num_seqs;
int best_splits = 1;
// Baseline: one kernel, base_blocks CTAs, each scanning seqlen_k tokens.
size_t best_score = (ceilDiv(base_blocks, static_cast<size_t>(sm_count)) * seqlen_k);
size_t prev_work_per_block = seqlen_k;
for (int s = 2; s <= kMaxSplits; ++s) {
const size_t blocks = base_blocks * static_cast<size_t>(s);
const size_t waves_split = ceilDiv(blocks, static_cast<size_t>(sm_count));
const size_t work_per_block = ceilDiv(seqlen_k, static_cast<size_t>(s));
// If this split count doesn't reduce per-block work vs the previous split, it's effectively redundant.
if (work_per_block == prev_work_per_block) {
continue;
}
prev_work_per_block = work_per_block;
// Combine is one extra kernel with base_blocks blocks; approximate as one more wave unit.
const size_t waves_combine = ceilDiv(base_blocks, static_cast<size_t>(sm_count));
const size_t score = waves_split * work_per_block + waves_combine;
if (score < best_score) {
best_score = score;
best_splits = s;
}
}
return best_splits;
}
} // namespace
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL flashAttentionDecodeHd64Warp(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const Tindex *cache_lens,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride) {
op::paged_attention::cuda::flashAttentionDecodeWarpKernel<Tindex, Tdata, 64>(
out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride,
k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride,
v_head_stride, o_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL flashAttentionDecodeHd64Cta(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const Tindex *cache_lens,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride) {
// Default CTA variant (lower overhead).
op::paged_attention::cuda::flashAttentionDecodeCtaKernel<Tindex, Tdata, 64, 32, 8>(
out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride,
k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride,
v_head_stride, o_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL flashAttentionDecodeHd64CtaTile16(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const Tindex *cache_lens,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride) {
op::paged_attention::cuda::flashAttentionDecodeCtaKernel<Tindex, Tdata, 64, 32, 16>(
out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride,
k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride,
v_head_stride, o_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL flashAttentionDecodeHd64SplitKv(
float *partial_acc,
float *partial_m,
float *partial_l,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const Tindex *cache_lens,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
int num_splits) {
op::paged_attention::cuda::flashAttentionDecodeSplitKvWarpKernel<Tindex, Tdata, 64>(
partial_acc, partial_m, partial_l,
q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride,
k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride,
v_head_stride, num_splits);
}
template <typename Tdata>
INFINIOP_METAX_KERNEL flashAttentionDecodeHd64SplitKvCombine(
Tdata *out,
const float *partial_acc,
const float *partial_m,
const float *partial_l,
int num_splits,
ptrdiff_t o_stride) {
op::paged_attention::cuda::flashAttentionDecodeSplitKvCombineWarpKernel<Tdata, 64>(
out, partial_acc, partial_m, partial_l, num_splits, o_stride);
}
template <typename Tindex>
infiniStatus_t launch_decode_hd64_impl(
void *workspace,
size_t workspace_size,
void *out,
const void *q,
const void *k_cache,
const void *v_cache,
infiniDtype_t dtype,
const Tindex *block_tables,
const Tindex *cache_lens,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
hcStream_t stream) {
dim3 grid(static_cast<uint64_t>(num_heads), static_cast<uint64_t>(num_seqs), 1);
bool use_cta = false;
if (const char *env = std::getenv("INFINIOP_FLASH_DECODE_KERNEL")) {
use_cta = (std::strcmp(env, "cta") == 0);
}
int cta_tile = 8;
if (const char *env = std::getenv("INFINIOP_FLASH_CTA_TILE")) {
const int v = std::atoi(env);
if (v == 8 || v == 16) {
cta_tile = v;
}
}
// For head_dim=64 we use a 1-warp CTA (32 threads) with packed loads.
dim3 block(32);
bool use_split = false;
if (const char *env = std::getenv("INFINIOP_FLASH_DECODE_SPLITKV")) {
use_split = (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0);
}
int num_splits = 4;
bool fixed_num_splits = false;
if (const char *env = std::getenv("INFINIOP_FLASH_NUM_SPLITS")) {
if (std::strcmp(env, "auto") == 0) {
fixed_num_splits = false;
} else {
num_splits = std::atoi(env);
fixed_num_splits = (num_splits > 0);
}
}
if (num_splits < 1) {
num_splits = 1;
}
if (num_splits > kMaxSplits) {
num_splits = kMaxSplits;
}
if (use_split) {
if (use_cta) {
// We currently only implement the split-kv path with warp kernels.
// The CTA kernel is a separate non-split implementation.
static bool warned = false;
if (!warned) {
warned = true;
fprintf(stderr,
"[INFINIOP][paged_attention] split-kv is enabled; ignoring INFINIOP_FLASH_DECODE_KERNEL=cta "
"(CTA split-kv not implemented yet)\n");
}
}
if (!fixed_num_splits) {
// Approximate seqlen_k by the per-seq KV capacity (paged KV upper bound).
const size_t seqlen_k = max_num_blocks_per_seq * page_block_size;
const int sm_count = getSmCount();
num_splits = chooseNumSplitsHeuristic(num_heads, num_seqs, seqlen_k, sm_count);
if (const char *dbg = std::getenv("INFINIOP_FLASH_DEBUG_SPLITS")) {
if (std::strcmp(dbg, "1") == 0 || std::strcmp(dbg, "true") == 0) {
static size_t last_seqlen_k = 0;
if (last_seqlen_k != seqlen_k) {
last_seqlen_k = seqlen_k;
fprintf(stderr,
"[INFINIOP][paged_attention] splitkv auto: sm=%d heads=%zu seqs=%zu seqlen_k~%zu -> num_splits=%d\n",
sm_count, num_heads, num_seqs, seqlen_k, num_splits);
}
}
}
}
const size_t n = num_seqs * num_heads;
const size_t acc_elems = static_cast<size_t>(kMaxSplits) * n * 64;
const size_t m_elems = static_cast<size_t>(kMaxSplits) * n;
const size_t l_elems = static_cast<size_t>(kMaxSplits) * n;
const size_t needed_bytes = (acc_elems + m_elems + l_elems) * sizeof(float);
if (workspace == nullptr || workspace_size < needed_bytes) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
float *ws = static_cast<float *>(workspace);
float *partial_acc = ws;
float *partial_m = partial_acc + acc_elems;
float *partial_l = partial_m + m_elems;
dim3 grid_split(static_cast<uint64_t>(num_heads), static_cast<uint64_t>(num_seqs), static_cast<uint64_t>(num_splits));
dim3 block_split(32);
if (dtype == INFINI_DTYPE_F16) {
flashAttentionDecodeHd64SplitKv<Tindex, half><<<grid_split, block_split, 0, stream>>>(
partial_acc, partial_m, partial_l,
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, num_splits);
flashAttentionDecodeHd64SplitKvCombine<half><<<grid, 32, 0, stream>>>(
static_cast<half *>(out), partial_acc, partial_m, partial_l, num_splits, o_stride);
return INFINI_STATUS_SUCCESS;
}
if (dtype == INFINI_DTYPE_BF16) {
flashAttentionDecodeHd64SplitKv<Tindex, __nv_bfloat16><<<grid_split, block_split, 0, stream>>>(
partial_acc, partial_m, partial_l,
static_cast<const __nv_bfloat16 *>(q),
static_cast<const __nv_bfloat16 *>(k_cache),
static_cast<const __nv_bfloat16 *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, num_splits);
flashAttentionDecodeHd64SplitKvCombine<__nv_bfloat16><<<grid, 32, 0, stream>>>(
static_cast<__nv_bfloat16 *>(out), partial_acc, partial_m, partial_l, num_splits, o_stride);
return INFINI_STATUS_SUCCESS;
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (dtype == INFINI_DTYPE_F16) {
if (use_cta) {
if (cta_tile == 16) {
flashAttentionDecodeHd64CtaTile16<Tindex, half><<<grid, block, 0, stream>>>(
static_cast<half *>(out),
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
} else {
flashAttentionDecodeHd64Cta<Tindex, half><<<grid, block, 0, stream>>>(
static_cast<half *>(out),
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
}
} else {
flashAttentionDecodeHd64Warp<Tindex, half><<<grid, block, 0, stream>>>(
static_cast<half *>(out),
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
}
return INFINI_STATUS_SUCCESS;
}
if (dtype == INFINI_DTYPE_BF16) {
if (use_cta) {
if (cta_tile == 16) {
flashAttentionDecodeHd64CtaTile16<Tindex, __nv_bfloat16><<<grid, block, 0, stream>>>(
static_cast<__nv_bfloat16 *>(out),
static_cast<const __nv_bfloat16 *>(q),
static_cast<const __nv_bfloat16 *>(k_cache),
static_cast<const __nv_bfloat16 *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
} else {
flashAttentionDecodeHd64Cta<Tindex, __nv_bfloat16><<<grid, block, 0, stream>>>(
static_cast<__nv_bfloat16 *>(out),
static_cast<const __nv_bfloat16 *>(q),
static_cast<const __nv_bfloat16 *>(k_cache),
static_cast<const __nv_bfloat16 *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
}
} else {
flashAttentionDecodeHd64Warp<Tindex, __nv_bfloat16><<<grid, block, 0, stream>>>(
static_cast<__nv_bfloat16 *>(out),
static_cast<const __nv_bfloat16 *>(q),
static_cast<const __nv_bfloat16 *>(k_cache),
static_cast<const __nv_bfloat16 *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
}
return INFINI_STATUS_SUCCESS;
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
infiniStatus_t launch_decode_hd64_i64(
void *workspace,
size_t workspace_size,
void *out,
const void *q,
const void *k_cache,
const void *v_cache,
infiniDtype_t dtype,
const int64_t *block_tables,
const int64_t *cache_lens,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
hcStream_t stream) {
return launch_decode_hd64_impl<int64_t>(
workspace, workspace_size,
out, q, k_cache, v_cache, dtype, block_tables, cache_lens, alibi_slopes, num_heads, num_seqs,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, k_batch_stride, k_row_stride,
k_head_stride, v_batch_stride, v_row_stride, v_head_stride, o_stride, stream);
}
infiniStatus_t launch_decode_hd64_i32(
void *workspace,
size_t workspace_size,
void *out,
const void *q,
const void *k_cache,
const void *v_cache,
infiniDtype_t dtype,
const int32_t *block_tables,
const int32_t *cache_lens,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
hcStream_t stream) {
return launch_decode_hd64_impl<int32_t>(
workspace, workspace_size,
out, q, k_cache, v_cache, dtype, block_tables, cache_lens, alibi_slopes, num_heads, num_seqs,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, k_batch_stride, k_row_stride,
k_head_stride, v_batch_stride, v_row_stride, v_head_stride, o_stride, stream);
}
infiniStatus_t launch_decode_hd64_u32(
void *workspace,
size_t workspace_size,
void *out,
const void *q,
const void *k_cache,
const void *v_cache,
infiniDtype_t dtype,
const uint32_t *block_tables,
const uint32_t *cache_lens,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
hcStream_t stream) {
return launch_decode_hd64_impl<uint32_t>(
workspace, workspace_size,
out, q, k_cache, v_cache, dtype, block_tables, cache_lens, alibi_slopes, num_heads, num_seqs,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, k_batch_stride, k_row_stride,
k_head_stride, v_batch_stride, v_row_stride, v_head_stride, o_stride, stream);
}
} // namespace op::paged_attention::metax
#ifndef __PAGED_ATTENTION_METAX_H__
#define __PAGED_ATTENTION_METAX_H__
#include "../paged_attention.h"
DESCRIPTOR(metax)
#endif // __PAGED_ATTENTION_METAX_H__
#ifdef ENABLE_METAX_MC_API
#include <mc_runtime.h>
#else
#include <hc_runtime.h>
#endif
#include <cstdint>
#include <cstdlib>
#include <cstring>
#include "../../../devices/metax/metax_common.h"
#include "paged_attention_metax.h"
namespace op::paged_attention::metax {
infiniStatus_t launch_decode_hd64_i64(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
infiniDtype_t dtype, const int64_t *block_tables, const int64_t *cache_lens, const float *alibi_slopes,
size_t num_heads, size_t num_seqs, size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t page_block_size,
ptrdiff_t q_stride, ptrdiff_t k_batch_stride, ptrdiff_t k_row_stride, ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride, ptrdiff_t v_row_stride, ptrdiff_t v_head_stride, ptrdiff_t o_stride,
hcStream_t stream);
infiniStatus_t launch_decode_hd64_i32(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
infiniDtype_t dtype, const int32_t *block_tables, const int32_t *cache_lens, const float *alibi_slopes,
size_t num_heads, size_t num_seqs, size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t page_block_size,
ptrdiff_t q_stride, ptrdiff_t k_batch_stride, ptrdiff_t k_row_stride, ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride, ptrdiff_t v_row_stride, ptrdiff_t v_head_stride, ptrdiff_t o_stride,
hcStream_t stream);
infiniStatus_t launch_decode_hd64_u32(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
infiniDtype_t dtype, const uint32_t *block_tables, const uint32_t *cache_lens, const float *alibi_slopes,
size_t num_heads, size_t num_seqs, size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t page_block_size,
ptrdiff_t q_stride, ptrdiff_t k_batch_stride, ptrdiff_t k_row_stride, ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride, ptrdiff_t v_row_stride, ptrdiff_t v_head_stride, ptrdiff_t o_stride,
hcStream_t stream);
infiniStatus_t launch_decode_hd128_i64(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
infiniDtype_t dtype, const int64_t *block_tables, const int64_t *cache_lens, const float *alibi_slopes,
size_t num_heads, size_t num_seqs, size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t page_block_size,
ptrdiff_t q_stride, ptrdiff_t k_batch_stride, ptrdiff_t k_row_stride, ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride, ptrdiff_t v_row_stride, ptrdiff_t v_head_stride, ptrdiff_t o_stride,
hcStream_t stream);
infiniStatus_t launch_decode_hd128_i32(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
infiniDtype_t dtype, const int32_t *block_tables, const int32_t *cache_lens, const float *alibi_slopes,
size_t num_heads, size_t num_seqs, size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t page_block_size,
ptrdiff_t q_stride, ptrdiff_t k_batch_stride, ptrdiff_t k_row_stride, ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride, ptrdiff_t v_row_stride, ptrdiff_t v_head_stride, ptrdiff_t o_stride,
hcStream_t stream);
infiniStatus_t launch_decode_hd128_u32(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
infiniDtype_t dtype, const uint32_t *block_tables, const uint32_t *cache_lens, const float *alibi_slopes,
size_t num_heads, size_t num_seqs, size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t page_block_size,
ptrdiff_t q_stride, ptrdiff_t k_batch_stride, ptrdiff_t k_row_stride, ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride, ptrdiff_t v_row_stride, ptrdiff_t v_head_stride, ptrdiff_t o_stride,
hcStream_t stream);
struct Descriptor::Opaque {
std::shared_ptr<device::metax::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t cache_lens_desc,
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc,
float scale) {
auto info_res = PagedAttentionInfo::create(out_desc, q_desc, k_cache_desc, v_cache_desc, block_tables_desc, cache_lens_desc, alibi_slopes_desc, scale);
CHECK_RESULT(info_res);
auto info = info_res.take();
// Reserve workspace for optional split-kv decode (partial acc + m/l).
// Workspace is independent of runtime env toggles; kernels will clamp num_splits <= kMaxSplits.
constexpr size_t kMaxSplits = 8;
const size_t per_split = info.num_seqs * info.num_heads * (info.head_size + 2) * sizeof(float);
const size_t workspace_bytes = kMaxSplits * per_split;
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
info, workspace_bytes, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
const void *block_tables, const void *cache_lens, const void *alibi_slopes,
void *stream_) const {
bool need_workspace = false;
if (const char *env = std::getenv("INFINIOP_FLASH_DECODE_SPLITKV")) {
// "auto" may enable split-kv depending on the runtime heuristic.
need_workspace = (std::strcmp(env, "auto") == 0) || (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0);
} else {
// Keep hd64 behavior unchanged, but for hd128 we default to split-kv decode, which needs workspace.
need_workspace = (_info.head_size == 128);
}
if (need_workspace && workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
auto stream = static_cast<hcStream_t>(stream_);
const float *alibi_ptr = (alibi_slopes == nullptr) ? nullptr : static_cast<const float *>(alibi_slopes);
if (_info.index_dtype == INFINI_DTYPE_I64) {
const auto *block_table_i64 = static_cast<const int64_t *>(block_tables);
const auto *cache_lens_i64 = static_cast<const int64_t *>(cache_lens);
switch (_info.head_size) {
case 64:
return launch_decode_hd64_i64(
workspace, workspace_size,
out, q, k_cache, v_cache, _info.dtype,
block_table_i64, cache_lens_i64, alibi_ptr,
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.scale,
_info.max_num_blocks_per_seq, _info.page_block_size,
_info.q_stride, _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride,
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride,
_info.o_stride, stream);
case 128:
return launch_decode_hd128_i64(
workspace, workspace_size,
out, q, k_cache, v_cache, _info.dtype,
block_table_i64, cache_lens_i64, alibi_ptr,
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.scale,
_info.max_num_blocks_per_seq, _info.page_block_size,
_info.q_stride, _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride,
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride,
_info.o_stride, stream);
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
if (_info.index_dtype == INFINI_DTYPE_I32) {
const auto *block_table_i32 = static_cast<const int32_t *>(block_tables);
const auto *cache_lens_i32 = static_cast<const int32_t *>(cache_lens);
switch (_info.head_size) {
case 64:
return launch_decode_hd64_i32(
workspace, workspace_size,
out, q, k_cache, v_cache, _info.dtype,
block_table_i32, cache_lens_i32, alibi_ptr,
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.scale,
_info.max_num_blocks_per_seq, _info.page_block_size,
_info.q_stride, _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride,
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride,
_info.o_stride, stream);
case 128:
return launch_decode_hd128_i32(
workspace, workspace_size,
out, q, k_cache, v_cache, _info.dtype,
block_table_i32, cache_lens_i32, alibi_ptr,
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.scale,
_info.max_num_blocks_per_seq, _info.page_block_size,
_info.q_stride, _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride,
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride,
_info.o_stride, stream);
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
if (_info.index_dtype == INFINI_DTYPE_U32) {
const auto *block_table_u32 = static_cast<const uint32_t *>(block_tables);
const auto *cache_lens_u32 = static_cast<const uint32_t *>(cache_lens);
switch (_info.head_size) {
case 64:
return launch_decode_hd64_u32(
workspace, workspace_size,
out, q, k_cache, v_cache, _info.dtype,
block_table_u32, cache_lens_u32, alibi_ptr,
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.scale,
_info.max_num_blocks_per_seq, _info.page_block_size,
_info.q_stride, _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride,
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride,
_info.o_stride, stream);
case 128:
return launch_decode_hd128_u32(
workspace, workspace_size,
out, q, k_cache, v_cache, _info.dtype,
block_table_u32, cache_lens_u32, alibi_ptr,
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.scale,
_info.max_num_blocks_per_seq, _info.page_block_size,
_info.q_stride, _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride,
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride,
_info.o_stride, stream);
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} // namespace op::paged_attention::metax
......@@ -5,9 +5,9 @@
#ifdef ENABLE_NVIDIA_API
#include "nvidia/paged_attention_nvidia.cuh"
#endif
// #ifdef ENABLE_METAX_API
// #include "metax/paged_attention_metax.h"
// #endif
#ifdef ENABLE_METAX_API
#include "metax/paged_attention_metax.h"
#endif
__C infiniStatus_t infiniopCreatePagedAttentionDescriptor(
infiniopHandle_t handle,
......@@ -34,9 +34,9 @@ __C infiniStatus_t infiniopCreatePagedAttentionDescriptor(
#ifdef ENABLE_NVIDIA_API
CREATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif
// #ifdef ENABLE_METAX_API
// CREATE(INFINI_DEVICE_METAX, metax)
// #endif
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
......@@ -55,9 +55,9 @@ __C infiniStatus_t infiniopGetPagedAttentionWorkspaceSize(
#ifdef ENABLE_NVIDIA_API
GET(INFINI_DEVICE_NVIDIA, nvidia)
#endif
// #ifdef ENABLE_METAX_API
// GET(INFINI_DEVICE_METAX, metax)
// #endif
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
......@@ -80,9 +80,9 @@ __C infiniStatus_t infiniopPagedAttention(
#ifdef ENABLE_NVIDIA_API
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif
// #ifdef ENABLE_METAX_API
// CALCULATE(INFINI_DEVICE_METAX, metax)
// #endif
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
......@@ -100,9 +100,9 @@ __C infiniStatus_t infiniopDestroyPagedAttentionDescriptor(
#ifdef ENABLE_NVIDIA_API
DESTROY(INFINI_DEVICE_NVIDIA, nvidia)
#endif
// #ifdef ENABLE_METAX_API
// DESTROY(INFINI_DEVICE_METAX, metax)
// #endif
#ifdef ENABLE_METAX_API
DESTROY(INFINI_DEVICE_METAX, metax)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
......
#ifndef __PAGED_ATTENTION_PREFILL_KERNEL_V2_CUH__
#define __PAGED_ATTENTION_PREFILL_KERNEL_V2_CUH__
#ifdef ENABLE_NVIDIA_API
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <mma.h>
#endif
#include <cstdint>
#include <type_traits>
......
#ifndef __PAGED_ATTENTION_PREFILL_METAX_H__
#define __PAGED_ATTENTION_PREFILL_METAX_H__
#include "../paged_attention_prefill.h"
DESCRIPTOR(metax)
#endif // __PAGED_ATTENTION_PREFILL_METAX_H__
#ifdef ENABLE_METAX_MC_API
#include <mc_runtime.h>
#else
#include <hc_runtime.h>
#endif
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include "../../../devices/metax/metax_common.h"
#include "../../../devices/metax/metax_kernel_common.h"
// #include "paged_attention_prefill_fa2.cuh"
#include "paged_attention_prefill_metax.h"
#include "../cuda/kernel_v2.cuh"
namespace op::paged_attention_prefill::metax {
namespace {
constexpr size_t ceilDiv(size_t a, size_t b) {
return (a + b - 1) / b;
}
inline const char *default_prefill_kernel(const PagedAttentionPrefillInfo &info) {
// Heuristic auto-dispatch (v0.4):
// - Prefer the pipelined + tile-wise softmax kernel on FA2-compatible block_size=256.
// - Keep a conservative fallback for other shapes / older GPUs (cp.async is a no-op below SM80).
//
// Users can always override via INFINIOP_FLASH_PREFILL_KERNEL.
if (info.page_block_size == 256 && (info.dtype == INFINI_DTYPE_F16 || info.dtype == INFINI_DTYPE_BF16)) {
if (info.head_size == 128) {
return "warpcta8pipe";
}
// For head_size=64 we keep the previous default until we have broader perf coverage.
}
return "warpcta8";
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128Warp(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// Legacy per-seq launch (kept only as a wrapper; current "warp" impl uses a global-token kernel).
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpKernel<Tindex, Tdata, 128>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, block_table_batch_stride,
q_stride, q_head_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64Warp(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// Legacy per-seq launch (kept only as a wrapper; current "warp" impl uses a global-token kernel).
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpKernel<Tindex, Tdata, 64>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, block_table_batch_stride,
q_stride, q_head_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 4 warps per CTA, one warp per query token.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel<Tindex, Tdata, 128, 4, 64>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64WarpCta(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 4 warps per CTA, one warp per query token.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel<Tindex, Tdata, 64, 4, 128>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta8(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 8 warps per CTA, one warp per query token.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel<Tindex, Tdata, 128, 8, 64>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta8N128(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 8 warps per CTA, one warp per query token, tile_n=128 for fewer K stages.
// Note: we keep K in shared memory but load V from global to stay within the per-block shared limit.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernelKOnly<Tindex, Tdata, 128, 8, 128>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64WarpCta8(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 8 warps per CTA, one warp per query token.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel<Tindex, Tdata, 64, 8, 128>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta8Pipe(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 8 warps per CTA, one warp per query token, with cp.async pipelining.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernelPipelined<Tindex, Tdata, 128, 8, 32, 2>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta8Mma(
half *out,
const half *q,
const half *k_cache,
const half *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCta8MmaHd128Kernel<Tindex>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64WarpCta8Pipe(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 8 warps per CTA, one warp per query token, with cp.async pipelining.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernelPipelined<Tindex, Tdata, 64, 8, 32, 2>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta8PipeSplitKv(
float *partial_acc,
float *partial_m,
float *partial_l,
int num_splits,
size_t total_q_tokens,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride) {
// Encode (split_idx, m_block) into blockIdx.z to allow a single kernel launch:
// blockIdx.z in [0, num_splits * num_m_blocks).
const int num_m_blocks = static_cast<int>((total_q_tokens + 8 - 1) / 8);
const int bz = static_cast<int>(blockIdx.z);
const int split_idx = bz / num_m_blocks;
const int m_block = bz - split_idx * num_m_blocks;
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernelPipelinedSplitKv<Tindex, Tdata, 128, 8, 32, 2>(
partial_acc, partial_m, partial_l, split_idx, num_splits, m_block, total_q_tokens,
q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64WarpCta8PipeSplitKv(
float *partial_acc,
float *partial_m,
float *partial_l,
int num_splits,
size_t total_q_tokens,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride) {
const int num_m_blocks = static_cast<int>((total_q_tokens + 8 - 1) / 8);
const int bz = static_cast<int>(blockIdx.z);
const int split_idx = bz / num_m_blocks;
const int m_block = bz - split_idx * num_m_blocks;
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernelPipelinedSplitKv<Tindex, Tdata, 64, 8, 32, 2>(
partial_acc, partial_m, partial_l, split_idx, num_splits, m_block, total_q_tokens,
q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride);
}
template <typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128SplitKvCombine(
Tdata *out,
const float *partial_acc,
const float *partial_m,
const float *partial_l,
int num_splits,
size_t total_q_tokens,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
op::paged_attention_prefill::cuda::PagedAttentionPrefillSplitKvCombineWarpKernel<Tdata, 128>(
out, partial_acc, partial_m, partial_l, num_splits, total_q_tokens, o_stride, o_head_stride);
}
template <typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64SplitKvCombine(
Tdata *out,
const float *partial_acc,
const float *partial_m,
const float *partial_l,
int num_splits,
size_t total_q_tokens,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
op::paged_attention_prefill::cuda::PagedAttentionPrefillSplitKvCombineWarpKernel<Tdata, 64>(
out, partial_acc, partial_m, partial_l, num_splits, total_q_tokens, o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta16(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 16 warps per CTA, one warp per query token.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel<Tindex, Tdata, 128, 16, 64>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64WarpCta16(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 16 warps per CTA, one warp per query token.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel<Tindex, Tdata, 64, 16, 128>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata, typename Tcompute>
infiniStatus_t launch_prefill_ref(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
hcStream_t stream) {
const dim3 grid(static_cast<uint32_t>(total_q_tokens), static_cast<uint32_t>(num_heads), 1);
const dim3 block(static_cast<uint32_t>(head_size), 1, 1);
if (head_size == 64) {
op::paged_attention_prefill::cuda::PagedAttentionPrefillReferenceKernel<Tindex, Tdata, Tcompute, 64>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_heads, num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride, q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride, num_seqs);
return INFINI_STATUS_SUCCESS;
}
if (head_size == 128) {
op::paged_attention_prefill::cuda::PagedAttentionPrefillReferenceKernel<Tindex, Tdata, Tcompute, 128>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_heads, num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride, q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride, num_seqs);
return INFINI_STATUS_SUCCESS;
}
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill_warp(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
hcStream_t stream) {
const dim3 block(32, 1, 1);
// Global-token launch:
// - dramatically reduces grid size vs the legacy (num_seqs * total_q_tokens) launch
// - matches PagedAttention varlen (cu_seqlens) mental model better
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(total_q_tokens),
1);
switch (head_size) {
case 64:
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpGlobalKernel<Tindex, Tdata, 64>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_heads, num_seqs, num_kv_heads, total_q_tokens, scale, max_num_blocks_per_seq,
page_block_size, block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
case 128:
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpGlobalKernel<Tindex, Tdata, 128>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_heads, num_seqs, num_kv_heads, total_q_tokens, scale, max_num_blocks_per_seq,
page_block_size, block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
hcStream_t stream) {
constexpr int kWarps = 4;
constexpr int kThreads = kWarps * 32;
const dim3 block(kThreads);
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(num_seqs),
static_cast<uint32_t>(ceilDiv(total_q_tokens, static_cast<size_t>(kWarps))));
switch (head_size) {
case 64:
PagedAttentionPrefillHd64WarpCta<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
case 128:
PagedAttentionPrefillHd128WarpCta<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill_warpcta8(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
hcStream_t stream) {
constexpr int kWarps = 8;
constexpr int kThreads = kWarps * 32;
const dim3 block(kThreads);
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(num_seqs),
static_cast<uint32_t>(ceilDiv(total_q_tokens, static_cast<size_t>(kWarps))));
switch (head_size) {
case 64:
PagedAttentionPrefillHd64WarpCta8<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
case 128:
PagedAttentionPrefillHd128WarpCta8<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill_warpcta8pipe(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
hcStream_t stream) {
constexpr int kWarps = 8;
constexpr int kThreads = kWarps * 32;
const dim3 block(kThreads);
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(num_seqs),
static_cast<uint32_t>(ceilDiv(total_q_tokens, static_cast<size_t>(kWarps))));
switch (head_size) {
case 64:
PagedAttentionPrefillHd64WarpCta8Pipe<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
case 128:
PagedAttentionPrefillHd128WarpCta8Pipe<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill_warpcta8mma(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
hcStream_t stream) {
// Current WMMA kernel only supports fp16 + head_dim=128.
if constexpr (!std::is_same_v<Tdata, half>) {
return launch_prefill_warpcta8pipe<Tindex, Tdata>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_heads, num_seqs, num_kv_heads, total_q_tokens, head_size, scale,
max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride, stream);
}
if (head_size != 128) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
// Guardrail: the current WMMA-score kernel is correctness-first and can be extremely slow on long prompts.
// Allow power users to force it via INFINIOP_FLASH_PREFILL_MMA_FORCE=1.
const char *force_env = std::getenv("INFINIOP_FLASH_PREFILL_MMA_FORCE");
const bool force_mma = (force_env != nullptr) && (std::strcmp(force_env, "1") == 0);
const size_t seqlen_k_est = max_num_blocks_per_seq * page_block_size;
if (!force_mma && seqlen_k_est > 4096) {
static bool warned = false;
if (!warned) {
std::fprintf(stderr,
"[infiniop][paged_attention_prefill] warpcta8mma is experimental and very slow for long seqlen_k (est=%zu). "
"Falling back to warpcta8pipe. Set INFINIOP_FLASH_PREFILL_MMA_FORCE=1 to override.\n",
seqlen_k_est);
warned = true;
}
return launch_prefill_warpcta8pipe<Tindex, Tdata>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_heads, num_seqs, num_kv_heads, total_q_tokens, head_size, scale,
max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride, stream);
}
// WMMA requires SM70+. If not supported (or if we can't query), fall back to the pipelined SIMT kernel.
int device = 0;
hcDeviceProp_t prop{};
if (hcGetDevice(&device) == hcSuccess && hcGetDeviceProperties(&prop, device) == hcSuccess) {
if (prop.major < 7) {
return launch_prefill_warpcta8pipe<Tindex, Tdata>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_heads, num_seqs, num_kv_heads, total_q_tokens, head_size, scale,
max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride, stream);
}
}
constexpr int kWarps = 8;
constexpr int kThreads = kWarps * 32;
const dim3 block(kThreads);
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(num_seqs),
static_cast<uint32_t>(ceilDiv(total_q_tokens, static_cast<size_t>(16))));
PagedAttentionPrefillHd128WarpCta8Mma<Tindex>
<<<grid, block, 0, stream>>>(
static_cast<half *>(out),
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill_warpcta8pipe_splitkv(
float *partial_acc,
float *partial_m,
float *partial_l,
int num_splits,
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
hcStream_t stream) {
constexpr int kMaxSplits = 8;
if (num_splits < 1) {
num_splits = 1;
}
if (num_splits > kMaxSplits) {
num_splits = kMaxSplits;
}
constexpr int kWarps = 8;
constexpr int kThreads = kWarps * 32;
const dim3 block(kThreads);
const size_t num_m_blocks = ceilDiv(total_q_tokens, static_cast<size_t>(kWarps));
// Single kernel launch with split_idx encoded in grid.z:
// blockIdx.z in [0, num_splits * num_m_blocks).
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(num_seqs),
static_cast<uint32_t>(num_m_blocks * static_cast<size_t>(num_splits)));
switch (head_size) {
case 64:
PagedAttentionPrefillHd64WarpCta8PipeSplitKv<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
partial_acc, partial_m, partial_l, num_splits, total_q_tokens,
q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride);
break;
case 128:
PagedAttentionPrefillHd128WarpCta8PipeSplitKv<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
partial_acc, partial_m, partial_l, num_splits, total_q_tokens,
q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride);
break;
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
// Combine: one warp per (token, head).
const dim3 block2(32);
const dim3 grid2(static_cast<uint32_t>(num_heads), static_cast<uint32_t>(total_q_tokens), 1);
switch (head_size) {
case 64:
PagedAttentionPrefillHd64SplitKvCombine<Tdata>
<<<grid2, block2, 0, stream>>>(
out, partial_acc, partial_m, partial_l, num_splits, total_q_tokens, o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
case 128:
PagedAttentionPrefillHd128SplitKvCombine<Tdata>
<<<grid2, block2, 0, stream>>>(
out, partial_acc, partial_m, partial_l, num_splits, total_q_tokens, o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill_warpcta8n128(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
hcStream_t stream) {
constexpr int kWarps = 8;
constexpr int kThreads = kWarps * 32;
const dim3 block(kThreads);
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(num_seqs),
static_cast<uint32_t>(ceilDiv(total_q_tokens, static_cast<size_t>(kWarps))));
// Only meaningful for head_dim=128.
if (head_size != 128) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
PagedAttentionPrefillHd128WarpCta8N128<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill_warpcta16(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
hcStream_t stream) {
constexpr int kWarps = 16;
constexpr int kThreads = kWarps * 32;
const dim3 block(kThreads);
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(num_seqs),
static_cast<uint32_t>(ceilDiv(total_q_tokens, static_cast<size_t>(kWarps))));
switch (head_size) {
case 64:
PagedAttentionPrefillHd64WarpCta16<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
case 128:
PagedAttentionPrefillHd128WarpCta16<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
} // namespace
struct Descriptor::Opaque {
std::shared_ptr<device::metax::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t total_kv_lens_desc,
infiniopTensorDescriptor_t cum_seqlens_q_desc,
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc,
float scale) {
auto info = PagedAttentionPrefillInfo::create(
out_desc, q_desc, k_cache_desc, v_cache_desc,
block_tables_desc, total_kv_lens_desc, cum_seqlens_q_desc,
alibi_slopes_desc, scale);
CHECK_RESULT(info);
// Optional split-kv prefill requires workspace for partial (m, l, acc).
// IMPORTANT: Unlike decode, prefill's total_q_tokens can be very large, so we must NOT reserve
// a huge workspace unless the user explicitly enables split-kv.
bool use_splitkv = false;
if (const char *env = std::getenv("INFINIOP_FLASH_PREFILL_SPLITKV")) {
use_splitkv = (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0);
}
int num_splits = 1;
if (use_splitkv) {
if (const char *env = std::getenv("INFINIOP_FLASH_PREFILL_NUM_SPLITS")) {
const int v = std::atoi(env);
if (v > 0) {
num_splits = v;
}
} else {
num_splits = 4;
}
constexpr int kMaxSplits = 8;
if (num_splits > kMaxSplits) {
num_splits = kMaxSplits;
}
}
const size_t n = info->total_q_tokens * info->num_heads;
const size_t splitkv_workspace_bytes = use_splitkv ? (static_cast<size_t>(num_splits) * n * (info->head_size + 2) * sizeof(float)) : 0;
const size_t workspace_bytes = splitkv_workspace_bytes;
// const size_t workspace_bytes = splitkv_workspace_bytes + fa2_workspace_bytes;
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
info.take(), workspace_bytes, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
const void *block_tables,
const void *total_kv_lens,
const void *cum_seqlens_q,
const void *alibi_slopes,
void *stream_) const {
auto stream = static_cast<hcStream_t>(stream_);
const float *alibi_ptr = (alibi_slopes == nullptr) ? nullptr : static_cast<const float *>(alibi_slopes);
const auto *total_kv_lens_i64 = static_cast<const int64_t *>(total_kv_lens);
const auto *cu_seqlens_q_i64 = static_cast<const int64_t *>(cum_seqlens_q);
bool use_splitkv = false;
if (const char *env = std::getenv("INFINIOP_FLASH_PREFILL_SPLITKV")) {
use_splitkv = (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0);
}
int num_splits = 1;
if (use_splitkv) {
if (const char *env = std::getenv("INFINIOP_FLASH_PREFILL_NUM_SPLITS")) {
const int v = std::atoi(env);
if (v > 0) {
num_splits = v;
}
} else {
// Conservative default; users can override.
num_splits = 4;
}
constexpr int kMaxSplits = 8;
if (num_splits > kMaxSplits) {
num_splits = kMaxSplits;
}
const size_t n = _info.total_q_tokens * _info.num_heads;
const size_t required = static_cast<size_t>(num_splits) * n * (_info.head_size + 2) * sizeof(float);
if (workspace_size < required) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
}
if (use_splitkv) {
const size_t n = _info.total_q_tokens * _info.num_heads;
float *partial_acc = static_cast<float *>(workspace);
float *partial_m = partial_acc + static_cast<size_t>(num_splits) * n * _info.head_size;
float *partial_l = partial_m + static_cast<size_t>(num_splits) * n;
// Dispatch by (Tdata, Tindex). total_kv_lens + cu_seqlens_q are currently always int64.
#define DISPATCH_SPLITKV(Tindex, Tdata, BT_PTR) \
return launch_prefill_warpcta8pipe_splitkv<Tindex, Tdata>( \
partial_acc, partial_m, partial_l, num_splits, \
static_cast<Tdata *>(out), \
static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), \
static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(BT_PTR), \
total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream)
if (_info.dtype == INFINI_DTYPE_F16) {
if (_info.index_dtype == INFINI_DTYPE_I64) {
DISPATCH_SPLITKV(int64_t, half, block_tables);
}
if (_info.index_dtype == INFINI_DTYPE_I32) {
DISPATCH_SPLITKV(int32_t, half, block_tables);
}
if (_info.index_dtype == INFINI_DTYPE_U32) {
DISPATCH_SPLITKV(uint32_t, half, block_tables);
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (_info.dtype == INFINI_DTYPE_BF16) {
if (_info.index_dtype == INFINI_DTYPE_I64) {
DISPATCH_SPLITKV(int64_t, __nv_bfloat16, block_tables);
}
if (_info.index_dtype == INFINI_DTYPE_I32) {
DISPATCH_SPLITKV(int32_t, __nv_bfloat16, block_tables);
}
if (_info.index_dtype == INFINI_DTYPE_U32) {
DISPATCH_SPLITKV(uint32_t, __nv_bfloat16, block_tables);
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
#undef DISPATCH_SPLITKV
}
// Default to the fastest validated kernel for supported shapes.
// "ref" is still available for debugging/correctness bisecting.
#define DISPATCH_KERNEL(Tindex, Tdata, Tcompute) \
do { \
const char *k_env = std::getenv("INFINIOP_FLASH_PREFILL_KERNEL"); \
const char *k = (k_env == nullptr) ? default_prefill_kernel(_info) : k_env; \
if (k_env != nullptr) { \
const bool known = (std::strcmp(k, "warp") == 0) || (std::strcmp(k, "warpcta") == 0) || (std::strcmp(k, "warpcta8") == 0) || (std::strcmp(k, "warpcta8pipe") == 0) || (std::strcmp(k, "warpcta8mma") == 0) || (std::strcmp(k, "warpcta8n128") == 0) || (std::strcmp(k, "warpcta16") == 0) || (std::strcmp(k, "ref") == 0); \
if (!known) { \
const char *fallback = default_prefill_kernel(_info); \
std::fprintf(stderr, \
"[infiniop][paged_attention_prefill] WARNING: unknown kernel '%s', falling back to '%s'\n", \
k, fallback); \
k = fallback; \
} \
} \
const char *dbg = std::getenv("INFINIOP_DEBUG_PREFILL_DISPATCH"); \
static bool printed_dispatch = false; \
if (!printed_dispatch && dbg != nullptr && std::strcmp(dbg, "1") == 0) { \
std::fprintf(stderr, \
"[infiniop][paged_attention_prefill] kernel=%s (override=%s head_size=%zu block=%zu dtype=%zu)\n", \
k, \
(k_env == nullptr ? "auto" : "env"), \
static_cast<size_t>(_info.head_size), \
static_cast<size_t>(_info.page_block_size), \
static_cast<size_t>(_info.dtype)); \
printed_dispatch = true; \
} \
if (std::strcmp(k, "warp") == 0) { \
return launch_prefill_warp<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if (std::strcmp(k, "warpcta") == 0) { \
return launch_prefill<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if (std::strcmp(k, "warpcta8") == 0) { \
return launch_prefill_warpcta8<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if (std::strcmp(k, "warpcta8pipe") == 0) { \
return launch_prefill_warpcta8pipe<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if constexpr (std::is_same_v<Tdata, half>) { \
if (std::strcmp(k, "warpcta8mma") == 0) { \
return launch_prefill_warpcta8mma<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
} \
if (std::strcmp(k, "warpcta8n128") == 0) { \
return launch_prefill_warpcta8n128<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if (std::strcmp(k, "warpcta16") == 0) { \
return launch_prefill_warpcta16<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if (std::strcmp(k, "ref") == 0) { \
return launch_prefill_ref<Tindex, Tdata, Tcompute>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
return INFINI_STATUS_BAD_PARAM; \
} while (false)
#define DISPATCH_INDEX(Tindex) \
do { \
if (_info.dtype == INFINI_DTYPE_F16) { \
DISPATCH_KERNEL(Tindex, half, float); \
} \
if (_info.dtype == INFINI_DTYPE_BF16) { \
DISPATCH_KERNEL(Tindex, __nv_bfloat16, float); \
} \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
} while (false)
if (_info.index_dtype == INFINI_DTYPE_I64) {
DISPATCH_INDEX(int64_t);
} else if (_info.index_dtype == INFINI_DTYPE_I32) {
DISPATCH_INDEX(int32_t);
} else if (_info.index_dtype == INFINI_DTYPE_U32) {
DISPATCH_INDEX(uint32_t);
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} // namespace op::paged_attention_prefill::nvidia
......@@ -5,6 +5,9 @@
#ifdef ENABLE_NVIDIA_API
#include "nvidia/paged_attention_prefill_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
#include "metax/paged_attention_prefill_metax.h"
#endif
__C infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor(
infiniopHandle_t handle,
......@@ -32,6 +35,9 @@ __C infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor(
switch (handle->device) {
#ifdef ENABLE_NVIDIA_API
CREATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -50,6 +56,9 @@ __C infiniStatus_t infiniopGetPagedAttentionPrefillWorkspaceSize(
switch (desc->device_type) {
#ifdef ENABLE_NVIDIA_API
GET(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -75,6 +84,9 @@ __C infiniStatus_t infiniopPagedAttentionPrefill(
switch (desc->device_type) {
#ifdef ENABLE_NVIDIA_API
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -92,6 +104,9 @@ __C infiniStatus_t infiniopDestroyPagedAttentionPrefillDescriptor(
switch (desc->device_type) {
#ifdef ENABLE_NVIDIA_API
DESTROY(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_METAX_API
DESTROY(INFINI_DEVICE_METAX, metax)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......
#ifndef __PAGED_CACHING_METAX_H__
#define __PAGED_CACHING_METAX_H__
#include "../paged_caching.h"
DESCRIPTOR(metax)
#endif // __PAGED_CACHING_METAX_H__
#include "../../../devices/metax/metax_common.h"
#include "../../../devices/metax/metax_kernel_common.h"
#include "../cuda/kernel.cuh"
#include "paged_caching_metax.h"
template <typename Tdata, int NUM_THREADS>
INFINIOP_METAX_KERNEL pagedCaching(
Tdata *k_cache, Tdata *v_cache,
const Tdata *k, const Tdata *v,
const int64_t *slot_mapping,
const size_t head_size, const size_t block_size,
const ptrdiff_t k_src_stride, const ptrdiff_t v_src_stride,
const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride) {
op::paged_caching::cuda::pagedCachingKernel<Tdata, NUM_THREADS>(
k_cache, v_cache, k, v, slot_mapping, head_size,
block_size, k_src_stride, v_src_stride, k_cache_block_stride, v_cache_block_stride);
}
namespace op::paged_caching::metax {
// PIMPL struct definition
struct Descriptor::Opaque {
std::shared_ptr<device::metax::Handle::Internal> internal;
};
// Destructor implementation
Descriptor::~Descriptor() {
delete _opaque;
}
// Static factory method implementation
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t slot_mapping_desc) {
auto info = PagedCachingInfo::create(k_cache_desc, v_cache_desc, k_desc, v_desc, slot_mapping_desc);
CHECK_RESULT(info);
// Create and return the Descriptor instance.
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
info.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
// The launchKernel function is a templated helper to encapsulate the kernel launch.
// It sets up grid/block dimensions and calls the device-side kernel.
template <int NUM_THREADS>
infiniStatus_t launchKernel(const PagedCachingInfo &info,
void *k_cache, void *v_cache,
infiniDtype_t dtype,
const void *k, const void *v,
const void *slot_mapping,
size_t num_tokens, size_t num_kv_heads, size_t head_size, size_t block_size,
ptrdiff_t k_src_stride, ptrdiff_t v_src_stride,
ptrdiff_t k_cache_block_stride, ptrdiff_t v_cache_block_stride,
hcStream_t stream) {
// Grid dimension is 1D, with one block per token, as we decided.
dim3 grid(uint64_t(num_kv_heads), uint64_t(num_tokens), 1);
// Block dimension is 1D, using the number of threads specified at compile time.
dim3 block(NUM_THREADS);
// This kernel does not require dynamic shared memory.
size_t shared_mem_size = 0;
// Launch the device-side kernel.
if (dtype == INFINI_DTYPE_F16) {
pagedCaching<half, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
(half *)k_cache,
(half *)v_cache,
(const half *)k,
(const half *)v,
(const int64_t *)slot_mapping,
head_size,
block_size,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
} else if (dtype == INFINI_DTYPE_BF16) {
pagedCaching<cuda_bfloat16, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
(cuda_bfloat16 *)k_cache,
(cuda_bfloat16 *)v_cache,
(const cuda_bfloat16 *)k,
(const cuda_bfloat16 *)v,
(const int64_t *)slot_mapping,
head_size,
block_size,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
} else if (dtype == INFINI_DTYPE_F32) {
pagedCaching<float, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
(float *)k_cache,
(float *)v_cache,
(const float *)k,
(const float *)v,
(const int64_t *)slot_mapping,
head_size,
block_size,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
// Execution method implementation
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *k_cache, void *v_cache,
const void *k, const void *v,
const void *slot_mapping,
void *stream_) const {
hcStream_t stream = (hcStream_t)stream_;
// Dispatch logic based on the device's maximum threads per block.
// This allows selecting the largest, most efficient block size the hardware supports.
int max_threads = _opaque->internal->maxThreadsPerBlock();
if (max_threads >= METAX_BLOCK_SIZE_1024) {
// Dispatch based on data type for a 1024-thread block.
launchKernel<METAX_BLOCK_SIZE_1024>(
_info, k_cache, v_cache, _info.dtype, k, v, slot_mapping,
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride,
stream);
} else if (max_threads >= METAX_BLOCK_SIZE_512) {
launchKernel<METAX_BLOCK_SIZE_512>(
_info, k_cache, v_cache, _info.dtype, k, v, slot_mapping,
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride,
stream);
} else {
// If the device supports fewer threads, return an error.
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::paged_caching::metax
......@@ -5,9 +5,9 @@
#ifdef ENABLE_NVIDIA_API
#include "nvidia/paged_caching_nvidia.cuh"
#endif
// #ifdef ENABLE_METAX_API
// #include "metax/paged_caching_metax.h"
// #endif
#ifdef ENABLE_METAX_API
#include "metax/paged_caching_metax.h"
#endif
__C infiniStatus_t infiniopCreatePagedCachingDescriptor(
infiniopHandle_t handle,
......@@ -29,9 +29,9 @@ __C infiniStatus_t infiniopCreatePagedCachingDescriptor(
#ifdef ENABLE_NVIDIA_API
CREATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif
// #ifdef ENABLE_METAX_API
// CREATE(INFINI_DEVICE_METAX, metax)
// #endif
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
......@@ -50,9 +50,9 @@ __C infiniStatus_t infiniopGetPagedCachingWorkspaceSize(
#ifdef ENABLE_NVIDIA_API
GET(INFINI_DEVICE_NVIDIA, nvidia)
#endif
// #ifdef ENABLE_METAX_API
// GET(INFINI_DEVICE_METAX, metax)
// #endif
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
......@@ -75,9 +75,9 @@ __C infiniStatus_t infiniopPagedCaching(
#ifdef ENABLE_NVIDIA_API
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif
// #ifdef ENABLE_METAX_API
// CALCULATE(INFINI_DEVICE_METAX, metax)
// #endif
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
......@@ -95,9 +95,9 @@ __C infiniStatus_t infiniopDestroyPagedCachingDescriptor(
#ifdef ENABLE_NVIDIA_API
DESTROY(INFINI_DEVICE_NVIDIA, nvidia)
#endif
// #ifdef ENABLE_METAX_API
// DESTROY(INFINI_DEVICE_METAX, metax)
// #endif
#ifdef ENABLE_METAX_API
DESTROY(INFINI_DEVICE_METAX, metax)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment