Unverified Commit 8d09630a authored by gongchensu's avatar gongchensu Committed by GitHub
Browse files

Merge branch 'demo131' into Issue/862

parents ab52dead 012df56c
#ifdef ENABLE_METAX_MC_API
#include <mc_runtime.h>
#else
#include <hc_runtime.h>
#endif
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include "../../../devices/metax/metax_common.h"
#include "../../../devices/metax/metax_kernel_common.h"
// #include "paged_attention_prefill_fa2.cuh"
#include "paged_attention_prefill_metax.h"
#include "../cuda/kernel_v2.cuh"
namespace op::paged_attention_prefill::metax {
namespace {
constexpr size_t ceilDiv(size_t a, size_t b) {
return (a + b - 1) / b;
}
inline const char *default_prefill_kernel(const PagedAttentionPrefillInfo &info) {
// Heuristic auto-dispatch (v0.4):
// - Prefer the pipelined + tile-wise softmax kernel on FA2-compatible block_size=256.
// - Keep a conservative fallback for other shapes / older GPUs (cp.async is a no-op below SM80).
//
// Users can always override via INFINIOP_FLASH_PREFILL_KERNEL.
if (info.page_block_size == 256 && (info.dtype == INFINI_DTYPE_F16 || info.dtype == INFINI_DTYPE_BF16)) {
if (info.head_size == 128) {
return "warpcta8pipe";
}
// For head_size=64 we keep the previous default until we have broader perf coverage.
}
return "warpcta8";
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128Warp(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// Legacy per-seq launch (kept only as a wrapper; current "warp" impl uses a global-token kernel).
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpKernel<Tindex, Tdata, 128>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, block_table_batch_stride,
q_stride, q_head_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64Warp(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// Legacy per-seq launch (kept only as a wrapper; current "warp" impl uses a global-token kernel).
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpKernel<Tindex, Tdata, 64>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, block_table_batch_stride,
q_stride, q_head_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 4 warps per CTA, one warp per query token.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel<Tindex, Tdata, 128, 4, 64>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64WarpCta(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 4 warps per CTA, one warp per query token.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel<Tindex, Tdata, 64, 4, 128>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta8(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 8 warps per CTA, one warp per query token.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel<Tindex, Tdata, 128, 8, 64>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta8N128(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 8 warps per CTA, one warp per query token, tile_n=128 for fewer K stages.
// Note: we keep K in shared memory but load V from global to stay within the per-block shared limit.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernelKOnly<Tindex, Tdata, 128, 8, 128>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64WarpCta8(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 8 warps per CTA, one warp per query token.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel<Tindex, Tdata, 64, 8, 128>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta8Pipe(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 8 warps per CTA, one warp per query token, with cp.async pipelining.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernelPipelined<Tindex, Tdata, 128, 8, 32, 2>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta8Mma(
half *out,
const half *q,
const half *k_cache,
const half *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCta8MmaHd128Kernel<Tindex>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64WarpCta8Pipe(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 8 warps per CTA, one warp per query token, with cp.async pipelining.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernelPipelined<Tindex, Tdata, 64, 8, 32, 2>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta8PipeSplitKv(
float *partial_acc,
float *partial_m,
float *partial_l,
int num_splits,
size_t total_q_tokens,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride) {
// Encode (split_idx, m_block) into blockIdx.z to allow a single kernel launch:
// blockIdx.z in [0, num_splits * num_m_blocks).
const int num_m_blocks = static_cast<int>((total_q_tokens + 8 - 1) / 8);
const int bz = static_cast<int>(blockIdx.z);
const int split_idx = bz / num_m_blocks;
const int m_block = bz - split_idx * num_m_blocks;
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernelPipelinedSplitKv<Tindex, Tdata, 128, 8, 32, 2>(
partial_acc, partial_m, partial_l, split_idx, num_splits, m_block, total_q_tokens,
q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64WarpCta8PipeSplitKv(
float *partial_acc,
float *partial_m,
float *partial_l,
int num_splits,
size_t total_q_tokens,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride) {
const int num_m_blocks = static_cast<int>((total_q_tokens + 8 - 1) / 8);
const int bz = static_cast<int>(blockIdx.z);
const int split_idx = bz / num_m_blocks;
const int m_block = bz - split_idx * num_m_blocks;
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernelPipelinedSplitKv<Tindex, Tdata, 64, 8, 32, 2>(
partial_acc, partial_m, partial_l, split_idx, num_splits, m_block, total_q_tokens,
q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride);
}
template <typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128SplitKvCombine(
Tdata *out,
const float *partial_acc,
const float *partial_m,
const float *partial_l,
int num_splits,
size_t total_q_tokens,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
op::paged_attention_prefill::cuda::PagedAttentionPrefillSplitKvCombineWarpKernel<Tdata, 128>(
out, partial_acc, partial_m, partial_l, num_splits, total_q_tokens, o_stride, o_head_stride);
}
template <typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64SplitKvCombine(
Tdata *out,
const float *partial_acc,
const float *partial_m,
const float *partial_l,
int num_splits,
size_t total_q_tokens,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
op::paged_attention_prefill::cuda::PagedAttentionPrefillSplitKvCombineWarpKernel<Tdata, 64>(
out, partial_acc, partial_m, partial_l, num_splits, total_q_tokens, o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta16(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 16 warps per CTA, one warp per query token.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel<Tindex, Tdata, 128, 16, 64>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64WarpCta16(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 16 warps per CTA, one warp per query token.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel<Tindex, Tdata, 64, 16, 128>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata, typename Tcompute>
infiniStatus_t launch_prefill_ref(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
hcStream_t stream) {
const dim3 grid(static_cast<uint32_t>(total_q_tokens), static_cast<uint32_t>(num_heads), 1);
const dim3 block(static_cast<uint32_t>(head_size), 1, 1);
if (head_size == 64) {
op::paged_attention_prefill::cuda::PagedAttentionPrefillReferenceKernel<Tindex, Tdata, Tcompute, 64>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_heads, num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride, q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride, num_seqs);
return INFINI_STATUS_SUCCESS;
}
if (head_size == 128) {
op::paged_attention_prefill::cuda::PagedAttentionPrefillReferenceKernel<Tindex, Tdata, Tcompute, 128>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_heads, num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride, q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride, num_seqs);
return INFINI_STATUS_SUCCESS;
}
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill_warp(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
hcStream_t stream) {
const dim3 block(32, 1, 1);
// Global-token launch:
// - dramatically reduces grid size vs the legacy (num_seqs * total_q_tokens) launch
// - matches PagedAttention varlen (cu_seqlens) mental model better
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(total_q_tokens),
1);
switch (head_size) {
case 64:
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpGlobalKernel<Tindex, Tdata, 64>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_heads, num_seqs, num_kv_heads, total_q_tokens, scale, max_num_blocks_per_seq,
page_block_size, block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
case 128:
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpGlobalKernel<Tindex, Tdata, 128>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_heads, num_seqs, num_kv_heads, total_q_tokens, scale, max_num_blocks_per_seq,
page_block_size, block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
hcStream_t stream) {
constexpr int kWarps = 4;
constexpr int kThreads = kWarps * 32;
const dim3 block(kThreads);
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(num_seqs),
static_cast<uint32_t>(ceilDiv(total_q_tokens, static_cast<size_t>(kWarps))));
switch (head_size) {
case 64:
PagedAttentionPrefillHd64WarpCta<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
case 128:
PagedAttentionPrefillHd128WarpCta<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill_warpcta8(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
hcStream_t stream) {
constexpr int kWarps = 8;
constexpr int kThreads = kWarps * 32;
const dim3 block(kThreads);
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(num_seqs),
static_cast<uint32_t>(ceilDiv(total_q_tokens, static_cast<size_t>(kWarps))));
switch (head_size) {
case 64:
PagedAttentionPrefillHd64WarpCta8<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
case 128:
PagedAttentionPrefillHd128WarpCta8<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill_warpcta8pipe(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
hcStream_t stream) {
constexpr int kWarps = 8;
constexpr int kThreads = kWarps * 32;
const dim3 block(kThreads);
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(num_seqs),
static_cast<uint32_t>(ceilDiv(total_q_tokens, static_cast<size_t>(kWarps))));
switch (head_size) {
case 64:
PagedAttentionPrefillHd64WarpCta8Pipe<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
case 128:
PagedAttentionPrefillHd128WarpCta8Pipe<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill_warpcta8mma(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
hcStream_t stream) {
// Current WMMA kernel only supports fp16 + head_dim=128.
if constexpr (!std::is_same_v<Tdata, half>) {
return launch_prefill_warpcta8pipe<Tindex, Tdata>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_heads, num_seqs, num_kv_heads, total_q_tokens, head_size, scale,
max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride, stream);
}
if (head_size != 128) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
// Guardrail: the current WMMA-score kernel is correctness-first and can be extremely slow on long prompts.
// Allow power users to force it via INFINIOP_FLASH_PREFILL_MMA_FORCE=1.
const char *force_env = std::getenv("INFINIOP_FLASH_PREFILL_MMA_FORCE");
const bool force_mma = (force_env != nullptr) && (std::strcmp(force_env, "1") == 0);
const size_t seqlen_k_est = max_num_blocks_per_seq * page_block_size;
if (!force_mma && seqlen_k_est > 4096) {
static bool warned = false;
if (!warned) {
std::fprintf(stderr,
"[infiniop][paged_attention_prefill] warpcta8mma is experimental and very slow for long seqlen_k (est=%zu). "
"Falling back to warpcta8pipe. Set INFINIOP_FLASH_PREFILL_MMA_FORCE=1 to override.\n",
seqlen_k_est);
warned = true;
}
return launch_prefill_warpcta8pipe<Tindex, Tdata>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_heads, num_seqs, num_kv_heads, total_q_tokens, head_size, scale,
max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride, stream);
}
// WMMA requires SM70+. If not supported (or if we can't query), fall back to the pipelined SIMT kernel.
int device = 0;
hcDeviceProp_t prop{};
if (hcGetDevice(&device) == hcSuccess && hcGetDeviceProperties(&prop, device) == hcSuccess) {
if (prop.major < 7) {
return launch_prefill_warpcta8pipe<Tindex, Tdata>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_heads, num_seqs, num_kv_heads, total_q_tokens, head_size, scale,
max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride, stream);
}
}
constexpr int kWarps = 8;
constexpr int kThreads = kWarps * 32;
const dim3 block(kThreads);
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(num_seqs),
static_cast<uint32_t>(ceilDiv(total_q_tokens, static_cast<size_t>(16))));
PagedAttentionPrefillHd128WarpCta8Mma<Tindex>
<<<grid, block, 0, stream>>>(
static_cast<half *>(out),
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill_warpcta8pipe_splitkv(
float *partial_acc,
float *partial_m,
float *partial_l,
int num_splits,
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
hcStream_t stream) {
constexpr int kMaxSplits = 8;
if (num_splits < 1) {
num_splits = 1;
}
if (num_splits > kMaxSplits) {
num_splits = kMaxSplits;
}
constexpr int kWarps = 8;
constexpr int kThreads = kWarps * 32;
const dim3 block(kThreads);
const size_t num_m_blocks = ceilDiv(total_q_tokens, static_cast<size_t>(kWarps));
// Single kernel launch with split_idx encoded in grid.z:
// blockIdx.z in [0, num_splits * num_m_blocks).
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(num_seqs),
static_cast<uint32_t>(num_m_blocks * static_cast<size_t>(num_splits)));
switch (head_size) {
case 64:
PagedAttentionPrefillHd64WarpCta8PipeSplitKv<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
partial_acc, partial_m, partial_l, num_splits, total_q_tokens,
q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride);
break;
case 128:
PagedAttentionPrefillHd128WarpCta8PipeSplitKv<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
partial_acc, partial_m, partial_l, num_splits, total_q_tokens,
q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride);
break;
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
// Combine: one warp per (token, head).
const dim3 block2(32);
const dim3 grid2(static_cast<uint32_t>(num_heads), static_cast<uint32_t>(total_q_tokens), 1);
switch (head_size) {
case 64:
PagedAttentionPrefillHd64SplitKvCombine<Tdata>
<<<grid2, block2, 0, stream>>>(
out, partial_acc, partial_m, partial_l, num_splits, total_q_tokens, o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
case 128:
PagedAttentionPrefillHd128SplitKvCombine<Tdata>
<<<grid2, block2, 0, stream>>>(
out, partial_acc, partial_m, partial_l, num_splits, total_q_tokens, o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill_warpcta8n128(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
hcStream_t stream) {
constexpr int kWarps = 8;
constexpr int kThreads = kWarps * 32;
const dim3 block(kThreads);
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(num_seqs),
static_cast<uint32_t>(ceilDiv(total_q_tokens, static_cast<size_t>(kWarps))));
// Only meaningful for head_dim=128.
if (head_size != 128) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
PagedAttentionPrefillHd128WarpCta8N128<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill_warpcta16(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
hcStream_t stream) {
constexpr int kWarps = 16;
constexpr int kThreads = kWarps * 32;
const dim3 block(kThreads);
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(num_seqs),
static_cast<uint32_t>(ceilDiv(total_q_tokens, static_cast<size_t>(kWarps))));
switch (head_size) {
case 64:
PagedAttentionPrefillHd64WarpCta16<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
case 128:
PagedAttentionPrefillHd128WarpCta16<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
} // namespace
struct Descriptor::Opaque {
std::shared_ptr<device::metax::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t total_kv_lens_desc,
infiniopTensorDescriptor_t cum_seqlens_q_desc,
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc,
float scale) {
auto info = PagedAttentionPrefillInfo::create(
out_desc, q_desc, k_cache_desc, v_cache_desc,
block_tables_desc, total_kv_lens_desc, cum_seqlens_q_desc,
alibi_slopes_desc, scale);
CHECK_RESULT(info);
// Optional split-kv prefill requires workspace for partial (m, l, acc).
// IMPORTANT: Unlike decode, prefill's total_q_tokens can be very large, so we must NOT reserve
// a huge workspace unless the user explicitly enables split-kv.
bool use_splitkv = false;
if (const char *env = std::getenv("INFINIOP_FLASH_PREFILL_SPLITKV")) {
use_splitkv = (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0);
}
int num_splits = 1;
if (use_splitkv) {
if (const char *env = std::getenv("INFINIOP_FLASH_PREFILL_NUM_SPLITS")) {
const int v = std::atoi(env);
if (v > 0) {
num_splits = v;
}
} else {
num_splits = 4;
}
constexpr int kMaxSplits = 8;
if (num_splits > kMaxSplits) {
num_splits = kMaxSplits;
}
}
const size_t n = info->total_q_tokens * info->num_heads;
const size_t splitkv_workspace_bytes = use_splitkv ? (static_cast<size_t>(num_splits) * n * (info->head_size + 2) * sizeof(float)) : 0;
const size_t workspace_bytes = splitkv_workspace_bytes;
// const size_t workspace_bytes = splitkv_workspace_bytes + fa2_workspace_bytes;
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
info.take(), workspace_bytes, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
const void *block_tables,
const void *total_kv_lens,
const void *cum_seqlens_q,
const void *alibi_slopes,
void *stream_) const {
auto stream = static_cast<hcStream_t>(stream_);
const float *alibi_ptr = (alibi_slopes == nullptr) ? nullptr : static_cast<const float *>(alibi_slopes);
const auto *total_kv_lens_i64 = static_cast<const int64_t *>(total_kv_lens);
const auto *cu_seqlens_q_i64 = static_cast<const int64_t *>(cum_seqlens_q);
bool use_splitkv = false;
if (const char *env = std::getenv("INFINIOP_FLASH_PREFILL_SPLITKV")) {
use_splitkv = (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0);
}
int num_splits = 1;
if (use_splitkv) {
if (const char *env = std::getenv("INFINIOP_FLASH_PREFILL_NUM_SPLITS")) {
const int v = std::atoi(env);
if (v > 0) {
num_splits = v;
}
} else {
// Conservative default; users can override.
num_splits = 4;
}
constexpr int kMaxSplits = 8;
if (num_splits > kMaxSplits) {
num_splits = kMaxSplits;
}
const size_t n = _info.total_q_tokens * _info.num_heads;
const size_t required = static_cast<size_t>(num_splits) * n * (_info.head_size + 2) * sizeof(float);
if (workspace_size < required) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
}
if (use_splitkv) {
const size_t n = _info.total_q_tokens * _info.num_heads;
float *partial_acc = static_cast<float *>(workspace);
float *partial_m = partial_acc + static_cast<size_t>(num_splits) * n * _info.head_size;
float *partial_l = partial_m + static_cast<size_t>(num_splits) * n;
// Dispatch by (Tdata, Tindex). total_kv_lens + cu_seqlens_q are currently always int64.
#define DISPATCH_SPLITKV(Tindex, Tdata, BT_PTR) \
return launch_prefill_warpcta8pipe_splitkv<Tindex, Tdata>( \
partial_acc, partial_m, partial_l, num_splits, \
static_cast<Tdata *>(out), \
static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), \
static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(BT_PTR), \
total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream)
if (_info.dtype == INFINI_DTYPE_F16) {
if (_info.index_dtype == INFINI_DTYPE_I64) {
DISPATCH_SPLITKV(int64_t, half, block_tables);
}
if (_info.index_dtype == INFINI_DTYPE_I32) {
DISPATCH_SPLITKV(int32_t, half, block_tables);
}
if (_info.index_dtype == INFINI_DTYPE_U32) {
DISPATCH_SPLITKV(uint32_t, half, block_tables);
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (_info.dtype == INFINI_DTYPE_BF16) {
if (_info.index_dtype == INFINI_DTYPE_I64) {
DISPATCH_SPLITKV(int64_t, __nv_bfloat16, block_tables);
}
if (_info.index_dtype == INFINI_DTYPE_I32) {
DISPATCH_SPLITKV(int32_t, __nv_bfloat16, block_tables);
}
if (_info.index_dtype == INFINI_DTYPE_U32) {
DISPATCH_SPLITKV(uint32_t, __nv_bfloat16, block_tables);
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
#undef DISPATCH_SPLITKV
}
// Default to the fastest validated kernel for supported shapes.
// "ref" is still available for debugging/correctness bisecting.
#define DISPATCH_KERNEL(Tindex, Tdata, Tcompute) \
do { \
const char *k_env = std::getenv("INFINIOP_FLASH_PREFILL_KERNEL"); \
const char *k = (k_env == nullptr) ? default_prefill_kernel(_info) : k_env; \
if (k_env != nullptr) { \
const bool known = (std::strcmp(k, "warp") == 0) || (std::strcmp(k, "warpcta") == 0) || (std::strcmp(k, "warpcta8") == 0) || (std::strcmp(k, "warpcta8pipe") == 0) || (std::strcmp(k, "warpcta8mma") == 0) || (std::strcmp(k, "warpcta8n128") == 0) || (std::strcmp(k, "warpcta16") == 0) || (std::strcmp(k, "ref") == 0); \
if (!known) { \
const char *fallback = default_prefill_kernel(_info); \
std::fprintf(stderr, \
"[infiniop][paged_attention_prefill] WARNING: unknown kernel '%s', falling back to '%s'\n", \
k, fallback); \
k = fallback; \
} \
} \
const char *dbg = std::getenv("INFINIOP_DEBUG_PREFILL_DISPATCH"); \
static bool printed_dispatch = false; \
if (!printed_dispatch && dbg != nullptr && std::strcmp(dbg, "1") == 0) { \
std::fprintf(stderr, \
"[infiniop][paged_attention_prefill] kernel=%s (override=%s head_size=%zu block=%zu dtype=%zu)\n", \
k, \
(k_env == nullptr ? "auto" : "env"), \
static_cast<size_t>(_info.head_size), \
static_cast<size_t>(_info.page_block_size), \
static_cast<size_t>(_info.dtype)); \
printed_dispatch = true; \
} \
if (std::strcmp(k, "warp") == 0) { \
return launch_prefill_warp<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if (std::strcmp(k, "warpcta") == 0) { \
return launch_prefill<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if (std::strcmp(k, "warpcta8") == 0) { \
return launch_prefill_warpcta8<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if (std::strcmp(k, "warpcta8pipe") == 0) { \
return launch_prefill_warpcta8pipe<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if constexpr (std::is_same_v<Tdata, half>) { \
if (std::strcmp(k, "warpcta8mma") == 0) { \
return launch_prefill_warpcta8mma<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
} \
if (std::strcmp(k, "warpcta8n128") == 0) { \
return launch_prefill_warpcta8n128<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if (std::strcmp(k, "warpcta16") == 0) { \
return launch_prefill_warpcta16<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if (std::strcmp(k, "ref") == 0) { \
return launch_prefill_ref<Tindex, Tdata, Tcompute>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
return INFINI_STATUS_BAD_PARAM; \
} while (false)
#define DISPATCH_INDEX(Tindex) \
do { \
if (_info.dtype == INFINI_DTYPE_F16) { \
DISPATCH_KERNEL(Tindex, half, float); \
} \
if (_info.dtype == INFINI_DTYPE_BF16) { \
DISPATCH_KERNEL(Tindex, __nv_bfloat16, float); \
} \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
} while (false)
if (_info.index_dtype == INFINI_DTYPE_I64) {
DISPATCH_INDEX(int64_t);
} else if (_info.index_dtype == INFINI_DTYPE_I32) {
DISPATCH_INDEX(int32_t);
} else if (_info.index_dtype == INFINI_DTYPE_U32) {
DISPATCH_INDEX(uint32_t);
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} // namespace op::paged_attention_prefill::nvidia
#ifndef __PAGED_ATTENTION_PREFILL_KERNEL_CUH__
#define __PAGED_ATTENTION_PREFILL_KERNEL_CUH__
namespace op::paged_attention_prefill::cuda {
__device__ __forceinline__ size_t find_seq_id(size_t token_idx, const int64_t *cum_seq_lens_q, size_t num_seqs) {
size_t low = 0, high = num_seqs - 1;
while (low <= high) {
size_t mid = (low + high) >> 1;
if (token_idx >= (size_t)cum_seq_lens_q[mid] && token_idx < (size_t)cum_seq_lens_q[mid + 1]) {
return mid;
} else if (token_idx < (size_t)cum_seq_lens_q[mid]) {
high = mid - 1;
} else {
low = mid + 1;
}
}
return 0;
}
// Warp-level sum reduction with an explicit active mask (safe for partial warps).
__device__ __forceinline__ float warpReduceSum(float val, unsigned mask) {
for (int offset = 16; offset > 0; offset >>= 1) {
val += __shfl_down_sync(mask, val, offset);
}
return val;
}
// Block-level sum reduction. Returns the sum to all threads in the block.
// Supports blockDim.x up to 1024.
__device__ __forceinline__ float blockReduceSum(float val) {
__shared__ float shared[32]; // max 32 warps per block
const int lane = threadIdx.x & 31;
const int wid = threadIdx.x >> 5;
const unsigned mask = __activemask();
val = warpReduceSum(val, mask);
if (lane == 0) {
shared[wid] = val;
}
__syncthreads();
const int num_warps = (blockDim.x + 31) >> 5;
float sum = 0.0f;
if (wid == 0) {
sum = (lane < num_warps) ? shared[lane] : 0.0f;
const unsigned mask0 = (num_warps >= 32) ? 0xffffffffu : ((1u << num_warps) - 1u);
sum = warpReduceSum(sum, mask0);
if (lane == 0) {
shared[0] = sum;
}
}
__syncthreads();
return shared[0];
}
template <typename Tdata, typename Tcompute>
__global__ void pagedAttentionPrefillKernel(
Tdata *out_, const Tdata *q_, const Tdata *k_cache_, const Tdata *v_cache_,
const int64_t *block_tables_,
const int64_t *total_kv_lens_,
const int64_t *cum_seq_lens_q_,
const float *alibi_slopes_,
const size_t num_heads, const size_t num_kv_heads, const float scale,
const size_t max_num_blocks_per_seq, const size_t block_size,
const ptrdiff_t kv_block_stride, const ptrdiff_t kv_head_stride,
const ptrdiff_t q_stride, const ptrdiff_t q_head_stride,
const size_t head_size,
const size_t num_seqs) {
// Grid : x -> token, y -> head
const size_t global_token_idx = blockIdx.x;
const size_t head_idx = blockIdx.y;
const size_t dim_idx = threadIdx.x;
if (dim_idx >= head_size) {
return;
}
__shared__ size_t sh_seq_idx;
__shared__ size_t sh_causal_limit;
__shared__ size_t sh_kv_head_idx;
__shared__ float sh_scale_acc;
__shared__ float sh_w;
__shared__ float sh_inv_l;
if (dim_idx == 0) {
sh_seq_idx = find_seq_id(global_token_idx, cum_seq_lens_q_, num_seqs);
const size_t q_token_idx = global_token_idx - static_cast<size_t>(cum_seq_lens_q_[sh_seq_idx]);
const size_t total_kv_len = static_cast<size_t>(total_kv_lens_[sh_seq_idx]);
const size_t q_len = static_cast<size_t>(cum_seq_lens_q_[sh_seq_idx + 1] - cum_seq_lens_q_[sh_seq_idx]);
const size_t history_len = total_kv_len - q_len;
sh_causal_limit = history_len + q_token_idx;
const size_t num_queries_per_kv = num_heads / num_kv_heads;
sh_kv_head_idx = head_idx / num_queries_per_kv;
}
__syncthreads();
const size_t seq_idx = sh_seq_idx;
const size_t causal_limit = sh_causal_limit;
const size_t kv_head_idx = sh_kv_head_idx;
const Tdata *q_vec = q_ + global_token_idx * q_stride + head_idx * q_head_stride;
Tdata *out_ptr = out_ + global_token_idx * num_heads * head_size + head_idx * head_size;
const int64_t *block_table = block_tables_ + seq_idx * max_num_blocks_per_seq;
const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx];
const float qv = static_cast<float>(q_vec[dim_idx]);
Tcompute acc = 0.0f;
float m = -FLT_MAX;
float l = 0.0f;
for (size_t t = 0; t <= causal_limit; ++t) {
const size_t b_idx = t / block_size;
const size_t t_off = t % block_size;
const ptrdiff_t physical_block_id = block_table[b_idx];
const Tdata *k_vec = k_cache_ + physical_block_id * kv_block_stride + kv_head_idx * kv_head_stride + t_off * head_size;
const float dot = blockReduceSum(qv * static_cast<float>(k_vec[dim_idx]));
if (dim_idx == 0) {
float score = dot * static_cast<float>(scale);
if (alibi_slope != 0.0f) {
score += alibi_slope * static_cast<float>(t - causal_limit);
}
const float m_new = fmaxf(m, score);
const float scale_acc = expf(m - m_new);
const float w = expf(score - m_new);
l = l * scale_acc + w;
m = m_new;
sh_scale_acc = scale_acc;
sh_w = w;
}
__syncthreads();
const float scale_acc = sh_scale_acc;
const float w = sh_w;
const Tdata *v_vec = v_cache_ + physical_block_id * kv_block_stride + kv_head_idx * kv_head_stride + t_off * head_size;
acc = acc * static_cast<Tcompute>(scale_acc) + static_cast<Tcompute>(w) * static_cast<Tcompute>(v_vec[dim_idx]);
__syncthreads();
}
if (dim_idx == 0) {
sh_inv_l = 1.0f / (l + 1e-6f);
}
__syncthreads();
out_ptr[dim_idx] = static_cast<Tdata>(acc * static_cast<Tcompute>(sh_inv_l));
}
} // namespace op::paged_attention_prefill::cuda
#endif
#ifndef __PAGED_ATTENTION_PREFILL_MOORE_H__
#define __PAGED_ATTENTION_PREFILL_MOORE_H__
#include "../paged_attention_prefill.h"
DESCRIPTOR(moore)
#endif // __PAGED_ATTENTION_PREFILL_MOORE_H__
#include <musa_fp16.h>
#include <float.h>
#include <math.h>
#include <stdint.h>
#include "../../../devices/moore/moore_common.h"
#include "../../../devices/moore/moore_kernel_common.h"
#include "paged_attention_prefill_kernel.h"
#include "paged_attention_prefill_moore.h"
template <typename Tdata, typename Tcompute>
infiniStatus_t launchPagedAttentionPrefill(
Tdata *out, const Tdata *q, const Tdata *k_cache, const Tdata *v_cache,
const int64_t *block_tables,
const int64_t *seq_lens,
const int64_t *cum_seq_lens_q,
const float *alibi_slopes,
const size_t num_heads,
const size_t num_seqs,
const size_t num_kv_heads,
const float scale,
const size_t max_num_blocks_per_seq,
const size_t page_block_size,
const size_t total_q_tokens,
const size_t head_size,
const ptrdiff_t k_batch_stride,
const ptrdiff_t k_head_stride,
const ptrdiff_t q_stride,
const ptrdiff_t q_head_stride,
musaStream_t stream) {
if (total_q_tokens == 0 || num_heads == 0) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
dim3 grid(total_q_tokens, num_heads);
dim3 block(head_size);
op::paged_attention_prefill::cuda::pagedAttentionPrefillKernel<Tdata, Tcompute>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache,
block_tables, seq_lens, cum_seq_lens_q, alibi_slopes,
num_heads, num_kv_heads, scale,
max_num_blocks_per_seq, page_block_size,
k_batch_stride, k_head_stride,
q_stride, q_head_stride,
head_size,
num_seqs);
return INFINI_STATUS_SUCCESS;
}
namespace op::paged_attention_prefill::moore {
struct Descriptor::Opaque {
std::shared_ptr<device::moore::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t seq_lens_desc,
infiniopTensorDescriptor_t cum_seq_lens_q_desc,
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc,
float scale) {
auto info = PagedAttentionPrefillInfo::create(
out_desc, q_desc, k_cache_desc, v_cache_desc,
block_tables_desc, seq_lens_desc,
cum_seq_lens_q_desc,
alibi_slopes_desc, scale);
CHECK_RESULT(info);
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::moore::Handle *>(handle)->internal()},
info.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
const void *block_tables,
const void *seq_lens,
const void *cum_seq_lens_q,
const void *alibi_slopes,
void *stream_) const {
musaStream_t stream = (musaStream_t)stream_;
#define LAUNCH_KERNEL(Tdata, Tcompute) \
launchPagedAttentionPrefill<Tdata, Tcompute>( \
(Tdata *)out, (const Tdata *)q, (const Tdata *)k_cache, (const Tdata *)v_cache, \
(const int64_t *)block_tables, (const int64_t *)seq_lens, (const int64_t *)cum_seq_lens_q, \
(const float *)alibi_slopes, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, \
_info.scale, _info.max_num_blocks_per_seq, \
_info.page_block_size, _info.total_q_tokens, \
_info.head_size, \
_info.k_batch_stride, _info.k_head_stride, \
_info.q_stride, _info.q_head_stride, \
stream)
if (_info.dtype == INFINI_DTYPE_F16) {
return LAUNCH_KERNEL(half, float);
} else if (_info.dtype == INFINI_DTYPE_BF16) {
return LAUNCH_KERNEL(__mt_bfloat16, float);
} else if (_info.dtype == INFINI_DTYPE_F32) {
return LAUNCH_KERNEL(float, float);
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} // namespace op::paged_attention_prefill::moore
#include <cuda_runtime.h>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include "../../../devices/nvidia/nvidia_common.cuh"
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
// #include "paged_attention_prefill_fa2.cuh"
#include "paged_attention_prefill_nvidia.cuh"
#include "../cuda/kernel_v2.cuh"
namespace op::paged_attention_prefill::nvidia {
namespace {
constexpr size_t ceilDiv(size_t a, size_t b) {
return (a + b - 1) / b;
}
inline const char *default_prefill_kernel(const PagedAttentionPrefillInfo &info) {
// Heuristic auto-dispatch (v0.4):
// - Prefer the pipelined + tile-wise softmax kernel on FA2-compatible block_size=256.
// - Keep a conservative fallback for other shapes / older GPUs (cp.async is a no-op below SM80).
//
// Users can always override via INFINIOP_FLASH_PREFILL_KERNEL.
if (info.page_block_size == 256 && (info.dtype == INFINI_DTYPE_F16 || info.dtype == INFINI_DTYPE_BF16)) {
if (info.head_size == 128) {
return "warpcta8pipe";
}
// For head_size=64 we keep the previous default until we have broader perf coverage.
}
return "warpcta8";
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128Warp(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// Legacy per-seq launch (kept only as a wrapper; current "warp" impl uses a global-token kernel).
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpKernel<Tindex, Tdata, 128>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, block_table_batch_stride,
q_stride, q_head_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd64Warp(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// Legacy per-seq launch (kept only as a wrapper; current "warp" impl uses a global-token kernel).
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpKernel<Tindex, Tdata, 64>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, block_table_batch_stride,
q_stride, q_head_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128WarpCta(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 4 warps per CTA, one warp per query token.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel<Tindex, Tdata, 128, 4, 64>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd64WarpCta(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 4 warps per CTA, one warp per query token.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel<Tindex, Tdata, 64, 4, 128>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128WarpCta8(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 8 warps per CTA, one warp per query token.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel<Tindex, Tdata, 128, 8, 64>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128WarpCta8N128(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 8 warps per CTA, one warp per query token, tile_n=128 for fewer K stages.
// Note: we keep K in shared memory but load V from global to stay within the per-block shared limit.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernelKOnly<Tindex, Tdata, 128, 8, 128>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd64WarpCta8(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 8 warps per CTA, one warp per query token.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel<Tindex, Tdata, 64, 8, 128>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128WarpCta8Pipe(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 8 warps per CTA, one warp per query token, with cp.async pipelining.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernelPipelined<Tindex, Tdata, 128, 8, 32, 2>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex>
INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128WarpCta8Mma(
half *out,
const half *q,
const half *k_cache,
const half *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCta8MmaHd128Kernel<Tindex>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd64WarpCta8Pipe(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 8 warps per CTA, one warp per query token, with cp.async pipelining.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernelPipelined<Tindex, Tdata, 64, 8, 32, 2>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128WarpCta8PipeSplitKv(
float *partial_acc,
float *partial_m,
float *partial_l,
int num_splits,
size_t total_q_tokens,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride) {
// Encode (split_idx, m_block) into blockIdx.z to allow a single kernel launch:
// blockIdx.z in [0, num_splits * num_m_blocks).
const int num_m_blocks = static_cast<int>((total_q_tokens + 8 - 1) / 8);
const int bz = static_cast<int>(blockIdx.z);
const int split_idx = bz / num_m_blocks;
const int m_block = bz - split_idx * num_m_blocks;
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernelPipelinedSplitKv<Tindex, Tdata, 128, 8, 32, 2>(
partial_acc, partial_m, partial_l, split_idx, num_splits, m_block, total_q_tokens,
q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd64WarpCta8PipeSplitKv(
float *partial_acc,
float *partial_m,
float *partial_l,
int num_splits,
size_t total_q_tokens,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride) {
const int num_m_blocks = static_cast<int>((total_q_tokens + 8 - 1) / 8);
const int bz = static_cast<int>(blockIdx.z);
const int split_idx = bz / num_m_blocks;
const int m_block = bz - split_idx * num_m_blocks;
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernelPipelinedSplitKv<Tindex, Tdata, 64, 8, 32, 2>(
partial_acc, partial_m, partial_l, split_idx, num_splits, m_block, total_q_tokens,
q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride);
}
template <typename Tdata>
INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128SplitKvCombine(
Tdata *out,
const float *partial_acc,
const float *partial_m,
const float *partial_l,
int num_splits,
size_t total_q_tokens,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
op::paged_attention_prefill::cuda::PagedAttentionPrefillSplitKvCombineWarpKernel<Tdata, 128>(
out, partial_acc, partial_m, partial_l, num_splits, total_q_tokens, o_stride, o_head_stride);
}
template <typename Tdata>
INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd64SplitKvCombine(
Tdata *out,
const float *partial_acc,
const float *partial_m,
const float *partial_l,
int num_splits,
size_t total_q_tokens,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
op::paged_attention_prefill::cuda::PagedAttentionPrefillSplitKvCombineWarpKernel<Tdata, 64>(
out, partial_acc, partial_m, partial_l, num_splits, total_q_tokens, o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128WarpCta16(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 16 warps per CTA, one warp per query token.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel<Tindex, Tdata, 128, 16, 64>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd64WarpCta16(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 16 warps per CTA, one warp per query token.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel<Tindex, Tdata, 64, 16, 128>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata, typename Tcompute>
infiniStatus_t launch_prefill_ref(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
cudaStream_t stream) {
const dim3 grid(static_cast<uint32_t>(total_q_tokens), static_cast<uint32_t>(num_heads), 1);
const dim3 block(static_cast<uint32_t>(head_size), 1, 1);
if (head_size == 64) {
op::paged_attention_prefill::cuda::PagedAttentionPrefillReferenceKernel<Tindex, Tdata, Tcompute, 64>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_heads, num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride, q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride, num_seqs);
return INFINI_STATUS_SUCCESS;
}
if (head_size == 128) {
op::paged_attention_prefill::cuda::PagedAttentionPrefillReferenceKernel<Tindex, Tdata, Tcompute, 128>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_heads, num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride, q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride, num_seqs);
return INFINI_STATUS_SUCCESS;
}
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill_warp(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
cudaStream_t stream) {
const dim3 block(32, 1, 1);
// Global-token launch:
// - dramatically reduces grid size vs the legacy (num_seqs * total_q_tokens) launch
// - matches PagedAttention varlen (cu_seqlens) mental model better
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(total_q_tokens),
1);
switch (head_size) {
case 64:
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpGlobalKernel<Tindex, Tdata, 64>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_heads, num_seqs, num_kv_heads, total_q_tokens, scale, max_num_blocks_per_seq,
page_block_size, block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
case 128:
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpGlobalKernel<Tindex, Tdata, 128>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_heads, num_seqs, num_kv_heads, total_q_tokens, scale, max_num_blocks_per_seq,
page_block_size, block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
cudaStream_t stream) {
constexpr int kWarps = 4;
constexpr int kThreads = kWarps * 32;
const dim3 block(kThreads);
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(num_seqs),
static_cast<uint32_t>(ceilDiv(total_q_tokens, static_cast<size_t>(kWarps))));
switch (head_size) {
case 64:
PagedAttentionPrefillHd64WarpCta<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
case 128:
PagedAttentionPrefillHd128WarpCta<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill_warpcta8(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
cudaStream_t stream) {
constexpr int kWarps = 8;
constexpr int kThreads = kWarps * 32;
const dim3 block(kThreads);
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(num_seqs),
static_cast<uint32_t>(ceilDiv(total_q_tokens, static_cast<size_t>(kWarps))));
switch (head_size) {
case 64:
PagedAttentionPrefillHd64WarpCta8<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
case 128:
PagedAttentionPrefillHd128WarpCta8<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill_warpcta8pipe(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
cudaStream_t stream) {
constexpr int kWarps = 8;
constexpr int kThreads = kWarps * 32;
const dim3 block(kThreads);
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(num_seqs),
static_cast<uint32_t>(ceilDiv(total_q_tokens, static_cast<size_t>(kWarps))));
switch (head_size) {
case 64:
PagedAttentionPrefillHd64WarpCta8Pipe<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
case 128:
PagedAttentionPrefillHd128WarpCta8Pipe<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill_warpcta8mma(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
cudaStream_t stream) {
// Current WMMA kernel only supports fp16 + head_dim=128.
if constexpr (!std::is_same_v<Tdata, half>) {
return launch_prefill_warpcta8pipe<Tindex, Tdata>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_heads, num_seqs, num_kv_heads, total_q_tokens, head_size, scale,
max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride, stream);
}
if (head_size != 128) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
// Guardrail: the current WMMA-score kernel is correctness-first and can be extremely slow on long prompts.
// Allow power users to force it via INFINIOP_FLASH_PREFILL_MMA_FORCE=1.
const char *force_env = std::getenv("INFINIOP_FLASH_PREFILL_MMA_FORCE");
const bool force_mma = (force_env != nullptr) && (std::strcmp(force_env, "1") == 0);
const size_t seqlen_k_est = max_num_blocks_per_seq * page_block_size;
if (!force_mma && seqlen_k_est > 4096) {
static bool warned = false;
if (!warned) {
std::fprintf(stderr,
"[infiniop][paged_attention_prefill] warpcta8mma is experimental and very slow for long seqlen_k (est=%zu). "
"Falling back to warpcta8pipe. Set INFINIOP_FLASH_PREFILL_MMA_FORCE=1 to override.\n",
seqlen_k_est);
warned = true;
}
return launch_prefill_warpcta8pipe<Tindex, Tdata>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_heads, num_seqs, num_kv_heads, total_q_tokens, head_size, scale,
max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride, stream);
}
// WMMA requires SM70+. If not supported (or if we can't query), fall back to the pipelined SIMT kernel.
int device = 0;
cudaDeviceProp prop{};
if (cudaGetDevice(&device) == cudaSuccess && cudaGetDeviceProperties(&prop, device) == cudaSuccess) {
if (prop.major < 7) {
return launch_prefill_warpcta8pipe<Tindex, Tdata>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_heads, num_seqs, num_kv_heads, total_q_tokens, head_size, scale,
max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride, stream);
}
}
constexpr int kWarps = 8;
constexpr int kThreads = kWarps * 32;
const dim3 block(kThreads);
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(num_seqs),
static_cast<uint32_t>(ceilDiv(total_q_tokens, static_cast<size_t>(16))));
PagedAttentionPrefillHd128WarpCta8Mma<Tindex>
<<<grid, block, 0, stream>>>(
static_cast<half *>(out),
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill_warpcta8pipe_splitkv(
float *partial_acc,
float *partial_m,
float *partial_l,
int num_splits,
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
cudaStream_t stream) {
constexpr int kMaxSplits = 8;
if (num_splits < 1) {
num_splits = 1;
}
if (num_splits > kMaxSplits) {
num_splits = kMaxSplits;
}
constexpr int kWarps = 8;
constexpr int kThreads = kWarps * 32;
const dim3 block(kThreads);
const size_t num_m_blocks = ceilDiv(total_q_tokens, static_cast<size_t>(kWarps));
// Single kernel launch with split_idx encoded in grid.z:
// blockIdx.z in [0, num_splits * num_m_blocks).
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(num_seqs),
static_cast<uint32_t>(num_m_blocks * static_cast<size_t>(num_splits)));
switch (head_size) {
case 64:
PagedAttentionPrefillHd64WarpCta8PipeSplitKv<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
partial_acc, partial_m, partial_l, num_splits, total_q_tokens,
q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride);
break;
case 128:
PagedAttentionPrefillHd128WarpCta8PipeSplitKv<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
partial_acc, partial_m, partial_l, num_splits, total_q_tokens,
q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride);
break;
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
// Combine: one warp per (token, head).
const dim3 block2(32);
const dim3 grid2(static_cast<uint32_t>(num_heads), static_cast<uint32_t>(total_q_tokens), 1);
switch (head_size) {
case 64:
PagedAttentionPrefillHd64SplitKvCombine<Tdata>
<<<grid2, block2, 0, stream>>>(
out, partial_acc, partial_m, partial_l, num_splits, total_q_tokens, o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
case 128:
PagedAttentionPrefillHd128SplitKvCombine<Tdata>
<<<grid2, block2, 0, stream>>>(
out, partial_acc, partial_m, partial_l, num_splits, total_q_tokens, o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill_warpcta8n128(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
cudaStream_t stream) {
constexpr int kWarps = 8;
constexpr int kThreads = kWarps * 32;
const dim3 block(kThreads);
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(num_seqs),
static_cast<uint32_t>(ceilDiv(total_q_tokens, static_cast<size_t>(kWarps))));
// Only meaningful for head_dim=128.
if (head_size != 128) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
PagedAttentionPrefillHd128WarpCta8N128<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill_warpcta16(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
cudaStream_t stream) {
constexpr int kWarps = 16;
constexpr int kThreads = kWarps * 32;
const dim3 block(kThreads);
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(num_seqs),
static_cast<uint32_t>(ceilDiv(total_q_tokens, static_cast<size_t>(kWarps))));
switch (head_size) {
case 64:
PagedAttentionPrefillHd64WarpCta16<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
case 128:
PagedAttentionPrefillHd128WarpCta16<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
} // namespace
struct Descriptor::Opaque {
std::shared_ptr<device::nvidia::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t total_kv_lens_desc,
infiniopTensorDescriptor_t cum_seqlens_q_desc,
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc,
float scale) {
auto info = PagedAttentionPrefillInfo::create(
out_desc, q_desc, k_cache_desc, v_cache_desc,
block_tables_desc, total_kv_lens_desc, cum_seqlens_q_desc,
alibi_slopes_desc, scale);
CHECK_RESULT(info);
// Optional split-kv prefill requires workspace for partial (m, l, acc).
// IMPORTANT: Unlike decode, prefill's total_q_tokens can be very large, so we must NOT reserve
// a huge workspace unless the user explicitly enables split-kv.
bool use_splitkv = false;
if (const char *env = std::getenv("INFINIOP_FLASH_PREFILL_SPLITKV")) {
use_splitkv = (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0);
}
int num_splits = 1;
if (use_splitkv) {
if (const char *env = std::getenv("INFINIOP_FLASH_PREFILL_NUM_SPLITS")) {
const int v = std::atoi(env);
if (v > 0) {
num_splits = v;
}
} else {
num_splits = 4;
}
constexpr int kMaxSplits = 8;
if (num_splits > kMaxSplits) {
num_splits = kMaxSplits;
}
}
const size_t n = info->total_q_tokens * info->num_heads;
const size_t splitkv_workspace_bytes = use_splitkv ? (static_cast<size_t>(num_splits) * n * (info->head_size + 2) * sizeof(float)) : 0;
const size_t workspace_bytes = splitkv_workspace_bytes;
// const size_t workspace_bytes = splitkv_workspace_bytes + fa2_workspace_bytes;
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::nvidia::Handle *>(handle)->internal()},
info.take(), workspace_bytes, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
const void *block_tables,
const void *total_kv_lens,
const void *cum_seqlens_q,
const void *alibi_slopes,
void *stream_) const {
auto stream = static_cast<cudaStream_t>(stream_);
const float *alibi_ptr = (alibi_slopes == nullptr) ? nullptr : static_cast<const float *>(alibi_slopes);
const auto *total_kv_lens_i64 = static_cast<const int64_t *>(total_kv_lens);
const auto *cu_seqlens_q_i64 = static_cast<const int64_t *>(cum_seqlens_q);
bool use_splitkv = false;
if (const char *env = std::getenv("INFINIOP_FLASH_PREFILL_SPLITKV")) {
use_splitkv = (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0);
}
int num_splits = 1;
if (use_splitkv) {
if (const char *env = std::getenv("INFINIOP_FLASH_PREFILL_NUM_SPLITS")) {
const int v = std::atoi(env);
if (v > 0) {
num_splits = v;
}
} else {
// Conservative default; users can override.
num_splits = 4;
}
constexpr int kMaxSplits = 8;
if (num_splits > kMaxSplits) {
num_splits = kMaxSplits;
}
const size_t n = _info.total_q_tokens * _info.num_heads;
const size_t required = static_cast<size_t>(num_splits) * n * (_info.head_size + 2) * sizeof(float);
if (workspace_size < required) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
}
if (use_splitkv) {
const size_t n = _info.total_q_tokens * _info.num_heads;
float *partial_acc = static_cast<float *>(workspace);
float *partial_m = partial_acc + static_cast<size_t>(num_splits) * n * _info.head_size;
float *partial_l = partial_m + static_cast<size_t>(num_splits) * n;
// Dispatch by (Tdata, Tindex). total_kv_lens + cu_seqlens_q are currently always int64.
#define DISPATCH_SPLITKV(Tindex, Tdata, BT_PTR) \
return launch_prefill_warpcta8pipe_splitkv<Tindex, Tdata>( \
partial_acc, partial_m, partial_l, num_splits, \
static_cast<Tdata *>(out), \
static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), \
static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(BT_PTR), \
total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream)
if (_info.dtype == INFINI_DTYPE_F16) {
if (_info.index_dtype == INFINI_DTYPE_I64) {
DISPATCH_SPLITKV(int64_t, half, block_tables);
}
if (_info.index_dtype == INFINI_DTYPE_I32) {
DISPATCH_SPLITKV(int32_t, half, block_tables);
}
if (_info.index_dtype == INFINI_DTYPE_U32) {
DISPATCH_SPLITKV(uint32_t, half, block_tables);
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (_info.dtype == INFINI_DTYPE_BF16) {
if (_info.index_dtype == INFINI_DTYPE_I64) {
DISPATCH_SPLITKV(int64_t, __nv_bfloat16, block_tables);
}
if (_info.index_dtype == INFINI_DTYPE_I32) {
DISPATCH_SPLITKV(int32_t, __nv_bfloat16, block_tables);
}
if (_info.index_dtype == INFINI_DTYPE_U32) {
DISPATCH_SPLITKV(uint32_t, __nv_bfloat16, block_tables);
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
#undef DISPATCH_SPLITKV
}
// Default to the fastest validated kernel for supported shapes.
// "ref" is still available for debugging/correctness bisecting.
#define DISPATCH_KERNEL(Tindex, Tdata, Tcompute) \
do { \
const char *k_env = std::getenv("INFINIOP_FLASH_PREFILL_KERNEL"); \
const char *k = (k_env == nullptr) ? default_prefill_kernel(_info) : k_env; \
if (k_env != nullptr) { \
const bool known = (std::strcmp(k, "warp") == 0) || (std::strcmp(k, "warpcta") == 0) || (std::strcmp(k, "warpcta8") == 0) || (std::strcmp(k, "warpcta8pipe") == 0) || (std::strcmp(k, "warpcta8mma") == 0) || (std::strcmp(k, "warpcta8n128") == 0) || (std::strcmp(k, "warpcta16") == 0) || (std::strcmp(k, "ref") == 0); \
if (!known) { \
const char *fallback = default_prefill_kernel(_info); \
std::fprintf(stderr, \
"[infiniop][paged_attention_prefill] WARNING: unknown kernel '%s', falling back to '%s'\n", \
k, fallback); \
k = fallback; \
} \
} \
const char *dbg = std::getenv("INFINIOP_DEBUG_PREFILL_DISPATCH"); \
static bool printed_dispatch = false; \
if (!printed_dispatch && dbg != nullptr && std::strcmp(dbg, "1") == 0) { \
std::fprintf(stderr, \
"[infiniop][paged_attention_prefill] kernel=%s (override=%s head_size=%zu block=%zu dtype=%zu)\n", \
k, \
(k_env == nullptr ? "auto" : "env"), \
static_cast<size_t>(_info.head_size), \
static_cast<size_t>(_info.page_block_size), \
static_cast<size_t>(_info.dtype)); \
printed_dispatch = true; \
} \
if (std::strcmp(k, "warp") == 0) { \
return launch_prefill_warp<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if (std::strcmp(k, "warpcta") == 0) { \
return launch_prefill<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if (std::strcmp(k, "warpcta8") == 0) { \
return launch_prefill_warpcta8<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if (std::strcmp(k, "warpcta8pipe") == 0) { \
return launch_prefill_warpcta8pipe<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if constexpr (std::is_same_v<Tdata, half>) { \
if (std::strcmp(k, "warpcta8mma") == 0) { \
return launch_prefill_warpcta8mma<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
} \
if (std::strcmp(k, "warpcta8n128") == 0) { \
return launch_prefill_warpcta8n128<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if (std::strcmp(k, "warpcta16") == 0) { \
return launch_prefill_warpcta16<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if (std::strcmp(k, "ref") == 0) { \
return launch_prefill_ref<Tindex, Tdata, Tcompute>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
return INFINI_STATUS_BAD_PARAM; \
} while (false)
#define DISPATCH_INDEX(Tindex) \
do { \
if (_info.dtype == INFINI_DTYPE_F16) { \
DISPATCH_KERNEL(Tindex, half, float); \
} \
if (_info.dtype == INFINI_DTYPE_BF16) { \
DISPATCH_KERNEL(Tindex, __nv_bfloat16, float); \
} \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
} while (false)
if (_info.index_dtype == INFINI_DTYPE_I64) {
DISPATCH_INDEX(int64_t);
} else if (_info.index_dtype == INFINI_DTYPE_I32) {
DISPATCH_INDEX(int32_t);
} else if (_info.index_dtype == INFINI_DTYPE_U32) {
DISPATCH_INDEX(uint32_t);
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} // namespace op::paged_attention_prefill::nvidia
#ifndef __PAGED_ATTENTION_PREFILL_NVIDIA_H__
#define __PAGED_ATTENTION_PREFILL_NVIDIA_H__
#include "../paged_attention_prefill.h"
DESCRIPTOR(nvidia)
#endif // __PAGED_ATTENTION_PREFILL_NVIDIA_H__
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/paged_attention_prefill.h"
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API)
#include "nvidia/paged_attention_prefill_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
#include "metax/paged_attention_prefill_metax.h"
#endif
#ifdef ENABLE_MOORE_API
#include "moore/paged_attention_prefill_moore.h"
#endif
__C infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor(
infiniopHandle_t handle,
infiniopPagedAttentionPrefillDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t seq_lens_desc,
infiniopTensorDescriptor_t cum_seq_lens_q_desc,
infiniopTensorDescriptor_t alibi_slopes_desc,
float scale) {
infiniopTensorDescriptor_t alibi_opt = (alibi_slopes_desc == nullptr) ? nullptr : alibi_slopes_desc;
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::paged_attention_prefill::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::paged_attention_prefill::NAMESPACE::Descriptor **>(desc_ptr), \
out_desc, q_desc, k_cache_desc, v_cache_desc, block_tables_desc, \
seq_lens_desc, cum_seq_lens_q_desc, alibi_opt, scale);
switch (handle->device) {
#ifdef ENABLE_NVIDIA_API
CREATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax)
#endif
#ifdef ENABLE_ALI_API
CREATE(INFINI_DEVICE_ALI, nvidia)
#endif
#ifdef ENABLE_ILUVATAR_API
CREATE(INFINI_DEVICE_ILUVATAR, nvidia)
#endif
#ifdef ENABLE_MOORE_API
CREATE(INFINI_DEVICE_MOORE, moore)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
}
__C infiniStatus_t infiniopGetPagedAttentionPrefillWorkspaceSize(
infiniopPagedAttentionPrefillDescriptor_t desc,
size_t *size) {
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<op::paged_attention_prefill::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) {
#ifdef ENABLE_NVIDIA_API
GET(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax)
#endif
#ifdef ENABLE_ALI_API
GET(INFINI_DEVICE_ALI, nvidia)
#endif
#ifdef ENABLE_ILUVATAR_API
GET(INFINI_DEVICE_ILUVATAR, nvidia)
#endif
#ifdef ENABLE_MOORE_API
GET(INFINI_DEVICE_MOORE, moore)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
}
__C infiniStatus_t infiniopPagedAttentionPrefill(
infiniopPagedAttentionPrefillDescriptor_t desc,
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
const void *block_tables,
const void *seq_lens,
const void *cum_seq_lens_q,
const void *alibi_slopes,
void *stream) {
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<op::paged_attention_prefill::NAMESPACE::Descriptor *>(desc)->calculate( \
workspace, workspace_size, out, q, k_cache, v_cache, block_tables, \
seq_lens, cum_seq_lens_q, alibi_slopes, stream);
switch (desc->device_type) {
#ifdef ENABLE_NVIDIA_API
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax)
#endif
#ifdef ENABLE_ALI_API
CALCULATE(INFINI_DEVICE_ALI, nvidia)
#endif
#ifdef ENABLE_ILUVATAR_API
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia)
#endif
#ifdef ENABLE_MOORE_API
CALCULATE(INFINI_DEVICE_MOORE, moore)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
}
__C infiniStatus_t infiniopDestroyPagedAttentionPrefillDescriptor(
infiniopPagedAttentionPrefillDescriptor_t desc) {
#define DESTROY(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<op::paged_attention_prefill::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) {
#ifdef ENABLE_NVIDIA_API
DESTROY(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_METAX_API
DESTROY(INFINI_DEVICE_METAX, metax)
#endif
#ifdef ENABLE_ALI_API
DESTROY(INFINI_DEVICE_ALI, nvidia)
#endif
#ifdef ENABLE_ILUVATAR_API
DESTROY(INFINI_DEVICE_ILUVATAR, nvidia)
#endif
#ifdef ENABLE_MOORE_API
DESTROY(INFINI_DEVICE_MOORE, moore)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
}
#ifndef PAGED_ATTENTION_PREFILL_H
#define PAGED_ATTENTION_PREFILL_H
#include "../../operator.h"
#include "info.h"
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::paged_attention_prefill::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
PagedAttentionPrefillInfo _info; \
size_t _workspace_size; \
\
Descriptor( \
Opaque *opaque, \
PagedAttentionPrefillInfo info, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t out_desc, \
infiniopTensorDescriptor_t q_desc, \
infiniopTensorDescriptor_t k_cache_desc, \
infiniopTensorDescriptor_t v_cache_desc, \
infiniopTensorDescriptor_t block_tables_desc, \
infiniopTensorDescriptor_t seq_lens_desc, \
infiniopTensorDescriptor_t cum_seq_lens_q_desc, \
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc, \
float scale); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *out, const void *q, const void *k_cache, const void *v_cache, \
const void *block_tables, \
const void *seq_lens, \
const void *cum_seq_lens_q, \
const void *alibi_slopes, \
void *stream) const; \
}; \
}
#endif // PAGED_ATTENTION_PREFILL_H
#ifndef __PAGED_CACHING_KERNEL_CUH__
#define __PAGED_CACHING_KERNEL_CUH__
//================================================================================
// Paged Caching Operator CUDA Kernel
//
// This kernel implements the "paged_caching" operation, which copies Key and Value
// vectors from a contiguous source tensor into a paged, non-contiguous KV Cache.
//
// Design Principles:
// 1. Token-Centric Parallelism: A 1D grid of `num_tokens` is launched. Each CUDA
// block is responsible for caching one full token (all its heads).
// 2. Coalesced Memory Access: This grid strategy ensures that threads within a
// block read a large, contiguous chunk of memory from the source tensors,
// maximizing memory bandwidth utilization.
// 3. Vectorization: The copy operation is vectorized to further enhance memory
// throughput, processing multiple data elements in a single instruction.
//================================================================================
namespace op::paged_caching::cuda {
template <
typename Tdata, // Data type of the tensors (e.g., half, __nv_bfloat16)
int NUM_THREADS // Number of threads per block, configured at launch time
>
__device__ void pagedCachingKernel(
// ----- Output Tensors -----
Tdata *k_cache_ptr, // Pointer to the destination K cache pool [num_blocks, nkvh, block_size, dh]
Tdata *v_cache_ptr, // Pointer to the destination V cache pool [num_blocks, nkvh, block_size, dh]
// ----- Input Tensors -----
const Tdata *k_ptr, // Pointer to the source Keys, shape [ntok, nkvh, dh]
const Tdata *v_ptr, // Pointer to the source Values, shape [ntok, nkvh, dh]
const int64_t *slot_mapping_ptr, // Pointer to the slot mapping, shape [ntok]
// ----- Metadata -----
const size_t head_size, // Dimension of each head (dh)
const size_t block_size, // Number of tokens per block in the KV cache
// ----- Stride Information -----
const ptrdiff_t k_src_stride, // Stride between tokens in the source K tensor
const ptrdiff_t v_src_stride, // Stride between tokens in the source V tensor
const ptrdiff_t k_cache_block_stride, // Stride between blocks in the K cache pool
const ptrdiff_t v_cache_block_stride // Stride between blocks in the V cache pool
) {
//================================================================================
// 1. Identify Work Unit & Calculate Addresses
//================================================================================
// Each block processes one token.
const int token_idx = blockIdx.y;
const int head_idx = blockIdx.x;
// const int num_kv_heads = gridDim.y;
// Retrieve the destination slot for the current token.
const int64_t slot_idx = slot_mapping_ptr[token_idx];
// Handle padding: if slot_idx is negative, this token is padding and should be ignored.
if (slot_idx < 0) {
return;
}
// Calculate the physical block index and the offset within that block.
const int64_t physical_block_idx = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
// Calculate base pointers for source and destination for this specific token.
const Tdata *k_src_head_ptr = k_ptr + token_idx * k_src_stride + head_idx * head_size;
const Tdata *v_src_head_ptr = v_ptr + token_idx * v_src_stride + head_idx * head_size;
// Destination pointer calculation assumes a [num_blocks, block_size, num_heads, head_size] layout.
// We point to the beginning of the memory region for this token's slot.
const ptrdiff_t cache_head_stride = block_size * head_size;
Tdata *k_cache_block_base_ptr = k_cache_ptr + physical_block_idx * k_cache_block_stride;
Tdata *k_dst_head_ptr = k_cache_block_base_ptr + head_idx * cache_head_stride + block_offset * head_size;
Tdata *v_cache_block_base_ptr = v_cache_ptr + physical_block_idx * v_cache_block_stride;
Tdata *v_dst_head_ptr = v_cache_block_base_ptr + head_idx * cache_head_stride + block_offset * head_size;
//================================================================================
// 2. Perform Element-wise Data Copy (Safe, Non-Vectorized)
//================================================================================
for (int i = threadIdx.x; i < head_size; i += NUM_THREADS) {
k_dst_head_ptr[i] = k_src_head_ptr[i];
v_dst_head_ptr[i] = v_src_head_ptr[i];
}
}
} // namespace op::paged_caching::cuda
#endif // __PAGED_CACHING_KERNEL_CUH__
#ifndef __PAGED_CACHING_INFO_H__
#define __PAGED_CACHING_INFO_H__
#include "../../../utils.h"
#include "../../tensor.h"
#include <optional>
#include <vector>
namespace op::paged_caching {
class PagedCachingInfo {
PagedCachingInfo() = default;
public:
// --- Data Type ---
infiniDtype_t dtype;
// --- Shape Dimensions ---
size_t num_tokens;
size_t num_kv_heads;
size_t head_size;
size_t block_size;
// --- Strides for Memory Layout ---
ptrdiff_t k_src_stride;
ptrdiff_t v_src_stride;
ptrdiff_t k_cache_block_stride;
ptrdiff_t v_cache_block_stride;
static utils::Result<PagedCachingInfo> create(
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t slot_mapping_desc) {
auto dtype = k_desc->dtype();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32);
if (v_desc->dtype() != dtype || k_cache_desc->dtype() != dtype || v_cache_desc->dtype() != dtype) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (slot_mapping_desc->dtype() != INFINI_DTYPE_I64) {
printf("slot_mapping must be int64_t.\n");
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (k_desc->ndim() != 3 || v_desc->ndim() != 3 || k_cache_desc->ndim() < 4 || v_cache_desc->ndim() < 4 || slot_mapping_desc->ndim() != 1) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
// PagedCachingInfo info;
// --- Extract shape dimensions ---
auto k_shape = k_desc->shape();
auto k_cache_shape = k_cache_desc->shape();
size_t num_tokens = slot_mapping_desc->shape()[0];
size_t num_kv_heads = k_shape[1];
size_t head_size = k_shape[2];
size_t block_size = k_cache_shape[2]; // Assuming [num_blocks, num_heads, block_size, head_size] layout
// --- Extract strides for memory access ---
ptrdiff_t k_src_stride = k_desc->stride(0);
ptrdiff_t v_src_stride = v_desc->stride(0);
ptrdiff_t k_cache_block_stride = k_cache_desc->stride(0);
ptrdiff_t v_cache_block_stride = v_cache_desc->stride(0);
return utils::Result<PagedCachingInfo>(PagedCachingInfo{
dtype,
num_tokens,
num_kv_heads,
head_size,
block_size,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride});
}
};
} // namespace op::paged_caching
#endif // __PAGED_CACHING_INFO_H__
#ifndef __PAGED_CACHING_METAX_H__
#define __PAGED_CACHING_METAX_H__
#include "../paged_caching.h"
DESCRIPTOR(metax)
#endif // __PAGED_CACHING_METAX_H__
#include "../../../devices/metax/metax_common.h"
#include "../../../devices/metax/metax_kernel_common.h"
#include "../cuda/kernel.cuh"
#include "paged_caching_metax.h"
template <typename Tdata, int NUM_THREADS>
INFINIOP_METAX_KERNEL pagedCaching(
Tdata *k_cache, Tdata *v_cache,
const Tdata *k, const Tdata *v,
const int64_t *slot_mapping,
const size_t head_size, const size_t block_size,
const ptrdiff_t k_src_stride, const ptrdiff_t v_src_stride,
const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride) {
op::paged_caching::cuda::pagedCachingKernel<Tdata, NUM_THREADS>(
k_cache, v_cache, k, v, slot_mapping, head_size,
block_size, k_src_stride, v_src_stride, k_cache_block_stride, v_cache_block_stride);
}
namespace op::paged_caching::metax {
// PIMPL struct definition
struct Descriptor::Opaque {
std::shared_ptr<device::metax::Handle::Internal> internal;
};
// Destructor implementation
Descriptor::~Descriptor() {
delete _opaque;
}
// Static factory method implementation
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t slot_mapping_desc) {
auto info = PagedCachingInfo::create(k_cache_desc, v_cache_desc, k_desc, v_desc, slot_mapping_desc);
CHECK_RESULT(info);
// Create and return the Descriptor instance.
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
info.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
// The launchKernel function is a templated helper to encapsulate the kernel launch.
// It sets up grid/block dimensions and calls the device-side kernel.
template <int NUM_THREADS>
infiniStatus_t launchKernel(const PagedCachingInfo &info,
void *k_cache, void *v_cache,
infiniDtype_t dtype,
const void *k, const void *v,
const void *slot_mapping,
size_t num_tokens, size_t num_kv_heads, size_t head_size, size_t block_size,
ptrdiff_t k_src_stride, ptrdiff_t v_src_stride,
ptrdiff_t k_cache_block_stride, ptrdiff_t v_cache_block_stride,
hcStream_t stream) {
// Grid dimension is 1D, with one block per token, as we decided.
dim3 grid(uint64_t(num_kv_heads), uint64_t(num_tokens), 1);
// Block dimension is 1D, using the number of threads specified at compile time.
dim3 block(NUM_THREADS);
// This kernel does not require dynamic shared memory.
size_t shared_mem_size = 0;
// Launch the device-side kernel.
if (dtype == INFINI_DTYPE_F16) {
pagedCaching<half, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
(half *)k_cache,
(half *)v_cache,
(const half *)k,
(const half *)v,
(const int64_t *)slot_mapping,
head_size,
block_size,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
} else if (dtype == INFINI_DTYPE_BF16) {
pagedCaching<cuda_bfloat16, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
(cuda_bfloat16 *)k_cache,
(cuda_bfloat16 *)v_cache,
(const cuda_bfloat16 *)k,
(const cuda_bfloat16 *)v,
(const int64_t *)slot_mapping,
head_size,
block_size,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
} else if (dtype == INFINI_DTYPE_F32) {
pagedCaching<float, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
(float *)k_cache,
(float *)v_cache,
(const float *)k,
(const float *)v,
(const int64_t *)slot_mapping,
head_size,
block_size,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
// Execution method implementation
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *k_cache, void *v_cache,
const void *k, const void *v,
const void *slot_mapping,
void *stream_) const {
hcStream_t stream = (hcStream_t)stream_;
// Dispatch logic based on the device's maximum threads per block.
// This allows selecting the largest, most efficient block size the hardware supports.
int max_threads = _opaque->internal->maxThreadsPerBlock();
if (max_threads >= METAX_BLOCK_SIZE_1024) {
// Dispatch based on data type for a 1024-thread block.
launchKernel<METAX_BLOCK_SIZE_1024>(
_info, k_cache, v_cache, _info.dtype, k, v, slot_mapping,
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride,
stream);
} else if (max_threads >= METAX_BLOCK_SIZE_512) {
launchKernel<METAX_BLOCK_SIZE_512>(
_info, k_cache, v_cache, _info.dtype, k, v, slot_mapping,
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride,
stream);
} else {
// If the device supports fewer threads, return an error.
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::paged_caching::metax
#ifndef __PAGED_CACHING_MOORE_H__
#define __PAGED_CACHING_MOORE_H__
#include "../paged_caching.h"
DESCRIPTOR(moore)
#endif // __PAGED_CACHING_MOORE_H__
#include "../../../devices/moore/moore_common.h"
#include "../../../devices/moore/moore_kernel_common.h"
#include "../cuda/kernel.cuh"
#include "paged_caching_moore.h"
template <typename Tdata, int NUM_THREADS>
INFINIOP_MOORE_KERNEL pagedCaching(
Tdata *k_cache, Tdata *v_cache,
const Tdata *k, const Tdata *v,
const int64_t *slot_mapping,
const size_t head_size, const size_t block_size,
const ptrdiff_t k_src_stride, const ptrdiff_t v_src_stride,
const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride) {
op::paged_caching::cuda::pagedCachingKernel<Tdata, NUM_THREADS>(
k_cache, v_cache, k, v, slot_mapping, head_size,
block_size, k_src_stride, v_src_stride, k_cache_block_stride, v_cache_block_stride);
}
namespace op::paged_caching::moore {
// PIMPL struct definition
struct Descriptor::Opaque {
std::shared_ptr<device::moore::Handle::Internal> internal;
};
// Destructor implementation
Descriptor::~Descriptor() {
delete _opaque;
}
// Static factory method implementation
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t slot_mapping_desc) {
auto info = PagedCachingInfo::create(k_cache_desc, v_cache_desc, k_desc, v_desc, slot_mapping_desc);
CHECK_RESULT(info);
// Create and return the Descriptor instance.
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::moore::Handle *>(handle)->internal()},
info.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
// The launchKernel function is a templated helper to encapsulate the MUSA kernel launch.
// It sets up grid/block dimensions and calls the device-side kernel.
template <int NUM_THREADS>
infiniStatus_t launchKernel(const PagedCachingInfo &info,
void *k_cache, void *v_cache,
infiniDtype_t dtype,
const void *k, const void *v,
const void *slot_mapping,
size_t num_tokens, size_t num_kv_heads, size_t head_size, size_t block_size,
ptrdiff_t k_src_stride, ptrdiff_t v_src_stride,
ptrdiff_t k_cache_block_stride, ptrdiff_t v_cache_block_stride,
musaStream_t stream) {
// Grid dimension is 1D, with one block per token, as we decided.
dim3 grid(uint64_t(num_kv_heads), uint64_t(num_tokens), 1);
// Block dimension is 1D, using the number of threads specified at compile time.
dim3 block(NUM_THREADS);
// This kernel does not require dynamic shared memory.
size_t shared_mem_size = 0;
// Launch the device-side MUSA kernel.
if (dtype == INFINI_DTYPE_F16) {
pagedCaching<half, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
(half *)k_cache,
(half *)v_cache,
(const half *)k,
(const half *)v,
(const int64_t *)slot_mapping,
head_size,
block_size,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
} else if (dtype == INFINI_DTYPE_BF16) {
pagedCaching<__mt_bfloat16, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
(__mt_bfloat16 *)k_cache,
(__mt_bfloat16 *)v_cache,
(const __mt_bfloat16 *)k,
(const __mt_bfloat16 *)v,
(const int64_t *)slot_mapping,
head_size,
block_size,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
} else if (dtype == INFINI_DTYPE_F32) {
pagedCaching<float, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
(float *)k_cache,
(float *)v_cache,
(const float *)k,
(const float *)v,
(const int64_t *)slot_mapping,
head_size,
block_size,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
// Execution method implementation
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *k_cache, void *v_cache,
const void *k, const void *v,
const void *slot_mapping,
void *stream_) const {
musaStream_t stream = (musaStream_t)stream_;
// Dispatch logic based on the GPU's maximum threads per block.
// This allows selecting the largest, most efficient block size the hardware supports.
if (_opaque->internal->maxThreadsPerBlock() >= MOORE_BLOCK_SIZE_1024) {
// Dispatch based on data type for a 1024-thread block.
launchKernel<MOORE_BLOCK_SIZE_1024>(
_info, k_cache, v_cache, _info.dtype, k, v, slot_mapping,
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride,
stream);
} else if (_opaque->internal->maxThreadsPerBlock() >= MOORE_BLOCK_SIZE_512) {
launchKernel<MOORE_BLOCK_SIZE_512>(
_info, k_cache, v_cache, _info.dtype, k, v, slot_mapping,
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride,
stream);
} else {
// If the GPU is older and supports fewer threads, return an error.
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::paged_caching::moore
#include "../../../devices/nvidia/nvidia_common.cuh"
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
#include "../cuda/kernel.cuh"
#include "paged_caching_nvidia.cuh"
template <typename Tdata, int NUM_THREADS>
INFINIOP_CUDA_KERNEL pagedCaching(
Tdata *k_cache, Tdata *v_cache,
const Tdata *k, const Tdata *v,
const int64_t *slot_mapping,
const size_t head_size, const size_t block_size,
const ptrdiff_t k_src_stride, const ptrdiff_t v_src_stride,
const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride) {
op::paged_caching::cuda::pagedCachingKernel<Tdata, NUM_THREADS>(
k_cache, v_cache, k, v, slot_mapping, head_size,
block_size, k_src_stride, v_src_stride, k_cache_block_stride, v_cache_block_stride);
}
namespace op::paged_caching::nvidia {
// PIMPL struct definition
struct Descriptor::Opaque {
std::shared_ptr<device::nvidia::Handle::Internal> internal;
};
// Destructor implementation
Descriptor::~Descriptor() {
delete _opaque;
}
// Static factory method implementation
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t slot_mapping_desc) {
auto info = PagedCachingInfo::create(k_cache_desc, v_cache_desc, k_desc, v_desc, slot_mapping_desc);
CHECK_RESULT(info);
// Create and return the Descriptor instance.
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::nvidia::Handle *>(handle)->internal()},
info.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
// The launchKernel function is a templated helper to encapsulate the CUDA kernel launch.
// It sets up grid/block dimensions and calls the device-side kernel.
template <int NUM_THREADS>
infiniStatus_t launchKernel(const PagedCachingInfo &info,
void *k_cache, void *v_cache,
infiniDtype_t dtype,
const void *k, const void *v,
const void *slot_mapping,
size_t num_tokens, size_t num_kv_heads, size_t head_size, size_t block_size,
ptrdiff_t k_src_stride, ptrdiff_t v_src_stride,
ptrdiff_t k_cache_block_stride, ptrdiff_t v_cache_block_stride,
cudaStream_t stream) {
// Grid dimension is 1D, with one block per token, as we decided.
dim3 grid(uint64_t(num_kv_heads), uint64_t(num_tokens), 1);
// Block dimension is 1D, using the number of threads specified at compile time.
dim3 block(NUM_THREADS);
// This kernel does not require dynamic shared memory.
size_t shared_mem_size = 0;
// Launch the device-side CUDA kernel.
if (dtype == INFINI_DTYPE_F16) {
pagedCaching<half, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
(half *)k_cache,
(half *)v_cache,
(const half *)k,
(const half *)v,
(const int64_t *)slot_mapping,
head_size,
block_size,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
} else if (dtype == INFINI_DTYPE_BF16) {
pagedCaching<__nv_bfloat16, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
(__nv_bfloat16 *)k_cache,
(__nv_bfloat16 *)v_cache,
(const __nv_bfloat16 *)k,
(const __nv_bfloat16 *)v,
(const int64_t *)slot_mapping,
head_size,
block_size,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
} else if (dtype == INFINI_DTYPE_F32) {
pagedCaching<float, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
(float *)k_cache,
(float *)v_cache,
(const float *)k,
(const float *)v,
(const int64_t *)slot_mapping,
head_size,
block_size,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
// Execution method implementation
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *k_cache, void *v_cache,
const void *k, const void *v,
const void *slot_mapping,
void *stream_) const {
cudaStream_t stream = (cudaStream_t)stream_;
// Dispatch logic based on the GPU's maximum threads per block.
// This allows selecting the largest, most efficient block size the hardware supports.
if (_opaque->internal->maxThreadsPerBlock() >= CUDA_BLOCK_SIZE_1024) {
// Dispatch based on data type for a 1024-thread block.
launchKernel<CUDA_BLOCK_SIZE_1024>(
_info, k_cache, v_cache, _info.dtype, k, v, slot_mapping,
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride,
stream);
} else if (_opaque->internal->maxThreadsPerBlock() >= CUDA_BLOCK_SIZE_512) {
launchKernel<CUDA_BLOCK_SIZE_512>(
_info, k_cache, v_cache, _info.dtype, k, v, slot_mapping,
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride,
stream);
} else if (_opaque->internal->maxThreadsPerBlock() >= CUDA_BLOCK_SIZE_4096) {
launchKernel<CUDA_BLOCK_SIZE_4096>(
_info, k_cache, v_cache, _info.dtype, k, v, slot_mapping,
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride,
stream);
} else {
// If the GPU is older and supports fewer threads, return an error.
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::paged_caching::nvidia
#ifndef __PAGED_CACHING_NVIDIA_H__
#define __PAGED_CACHING_NVIDIA_H__
#include "../paged_caching.h"
DESCRIPTOR(nvidia)
#endif // __PAGED_CACHING_NVIDIA_H__
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/paged_caching.h"
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API)
#include "nvidia/paged_caching_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
#include "metax/paged_caching_metax.h"
#endif
#ifdef ENABLE_MOORE_API
#include "moore/paged_caching_moore.h"
#endif
__C infiniStatus_t infiniopCreatePagedCachingDescriptor(
infiniopHandle_t handle,
infiniopPagedCachingDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t slot_mapping_desc) {
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::paged_caching::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::paged_caching::NAMESPACE::Descriptor **>(desc_ptr), \
k_cache_desc, v_cache_desc, k_desc, v_desc, slot_mapping_desc);
switch (handle->device) {
#ifdef ENABLE_NVIDIA_API
CREATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax)
#endif
#ifdef ENABLE_ALI_API
CREATE(INFINI_DEVICE_ALI, nvidia)
#endif
#ifdef ENABLE_ILUVATAR_API
CREATE(INFINI_DEVICE_ILUVATAR, nvidia)
#endif
#ifdef ENABLE_MOORE_API
CREATE(INFINI_DEVICE_MOORE, moore)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
}
__C infiniStatus_t infiniopGetPagedCachingWorkspaceSize(
infiniopPagedCachingDescriptor_t desc,
size_t *size) {
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<op::paged_caching::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) {
#ifdef ENABLE_NVIDIA_API
GET(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax)
#endif
#ifdef ENABLE_ALI_API
GET(INFINI_DEVICE_ALI, nvidia)
#endif
#ifdef ENABLE_ILUVATAR_API
GET(INFINI_DEVICE_ILUVATAR, nvidia)
#endif
#ifdef ENABLE_MOORE_API
GET(INFINI_DEVICE_MOORE, moore)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
}
__C infiniStatus_t infiniopPagedCaching(
infiniopPagedCachingDescriptor_t desc,
void *workspace, size_t workspace_size,
void *k_cache, void *v_cache,
const void *k, const void *v,
const void *slot_mapping,
void *stream) {
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<op::paged_caching::NAMESPACE::Descriptor *>(desc)->calculate( \
workspace, workspace_size, k_cache, v_cache, k, v, slot_mapping, stream);
switch (desc->device_type) {
#ifdef ENABLE_NVIDIA_API
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax)
#endif
#ifdef ENABLE_ALI_API
CALCULATE(INFINI_DEVICE_ALI, nvidia)
#endif
#ifdef ENABLE_ILUVATAR_API
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia)
#endif
#ifdef ENABLE_MOORE_API
CALCULATE(INFINI_DEVICE_MOORE, moore)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
}
__C infiniStatus_t infiniopDestroyPagedCachingDescriptor(
infiniopPagedCachingDescriptor_t desc) {
#define DESTROY(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<op::paged_caching::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) {
#ifdef ENABLE_NVIDIA_API
DESTROY(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_METAX_API
DESTROY(INFINI_DEVICE_METAX, metax)
#endif
#ifdef ENABLE_ALI_API
DESTROY(INFINI_DEVICE_ALI, nvidia)
#endif
#ifdef ENABLE_ILUVATAR_API
DESTROY(INFINI_DEVICE_ILUVATAR, nvidia)
#endif
#ifdef ENABLE_MOORE_API
DESTROY(INFINI_DEVICE_MOORE, moore)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
}
#ifndef PAGED_CACHING_H
#define PAGED_CACHING_H
#include "../../operator.h"
#include "info.h"
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::paged_caching::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
PagedCachingInfo _info; \
size_t _workspace_size; \
\
Descriptor( \
Opaque *opaque, \
PagedCachingInfo info, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t k_cache_desc, \
infiniopTensorDescriptor_t v_cache_desc, \
infiniopTensorDescriptor_t k_desc, \
infiniopTensorDescriptor_t v_desc, \
infiniopTensorDescriptor_t slot_mapping_desc); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *k_cache, void *v_cache, \
const void *k, const void *v, \
const void *slot_mapping, \
void *stream) const; \
}; \
}
#endif // PAGED_CACHING_H
#ifndef __PERCHANNEL_QUANTINT8_KERNEL_CUH__
#define __PERCHANNEL_QUANTINT8_KERNEL_CUH__
#include <cub/block/block_reduce.cuh>
__device__ inline int round_half_away_from_zero(float x) {
float ax = fabsf(x);
float r = floorf(ax + 0.5f);
return (x >= 0.0f) ? (int)r : -(int)r;
}
template <typename Tdata, unsigned int BLOCK_SIZE>
__device__ void blockPerChannelQuantI8Kernel(
int8_t *x_packed, float *x_scale, float *x_zero, const Tdata *x,
int M, int K) {
int row = blockIdx.x;
int tid = row * K;
// ---- 1. reduce max ----
float local_max = op::common_cuda::reduce_op::max<BLOCK_SIZE, Tdata>(
x + tid, K);
__shared__ float global_max_f;
if (threadIdx.x == 0) {
global_max_f = local_max;
}
__syncthreads();
typedef cub::BlockReduce<float, BLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
// ---- 2. reduce min ----
float thread_min = __FLT_MAX__;
for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE) {
thread_min = fminf(thread_min, (float)x[tid + ind]);
}
#if CUDART_VERSION >= 12090
float local_min = BlockReduce(temp_storage).Reduce(thread_min, ::cuda::minimum());
#else
float local_min = BlockReduce(temp_storage).Reduce(thread_min, cub::Min());
#endif
__shared__ float global_min_f;
if (threadIdx.x == 0) {
global_min_f = local_min;
}
__syncthreads();
float global_max = global_max_f;
float global_min = global_min_f;
float scale = (global_max - global_min) / 255.0f;
if (scale < 1e-8f) {
scale = 1e-8f;
}
float inv_scale = 1.0f / scale;
float zero = -global_min * inv_scale - 128.0f;
x_scale[row] = (Tdata)scale;
x_zero[row] = (Tdata)zero;
for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE) {
float v = (float)x[tid + ind];
float qf = v * inv_scale + zero;
int q = round_half_away_from_zero(qf);
if (q > 127) {
q = 127;
}
if (q < -128) {
q = -128;
}
x_packed[tid + ind] = (int8_t)q;
}
}
template <typename Tdata, unsigned int BLOCK_SIZE>
__device__ void blockPerChannelQuantI8SymKernel(
int8_t *x_packed, float *x_scale, const Tdata *x,
int M, int K) {
int row = blockIdx.x;
int tid = row * K;
typedef cub::BlockReduce<float, BLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
// ---- 2. reduce min ----
float thread_max = -__FLT_MAX__;
for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE) {
thread_max = fmaxf(thread_max, fabs((float)x[tid + ind]));
}
#if CUDART_VERSION >= 12090
float local_max = BlockReduce(temp_storage).Reduce(thread_max, ::cuda::maximum());
#else
float local_max = BlockReduce(temp_storage).Reduce(thread_max, cub::Max());
#endif
__shared__ float global_max_f;
if (threadIdx.x == 0) {
global_max_f = local_max;
}
__syncthreads();
float global_max = global_max_f;
float scale = global_max / 127.0f;
if (scale < 1e-8f) {
scale = 1e-8f;
}
float inv_scale = 1.0f / scale;
x_scale[row] = (Tdata)scale;
for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE) {
float v = (float)x[tid + ind];
float qf = v * inv_scale;
int q = round_half_away_from_zero(qf);
if (q > 127) {
q = 127;
}
if (q < -127) {
q = -127;
}
x_packed[tid + ind] = (int8_t)q;
}
}
template <typename T>
struct MaxOp {
__device__ __forceinline__ T operator()(const T &a, const T &b) const {
return max(a, b);
}
};
template <typename T>
struct MinOp {
__device__ __forceinline__ T operator()(const T &a, const T &b) const {
return min(a, b);
}
};
template <template <typename> class ReductionOp, typename T,
int thread_group_width>
__inline__ __device__ T WarpAllReduce(T val) {
for (int mask = thread_group_width / 2; mask > 0; mask /= 2) {
val = ReductionOp<T>()(val, __shfl_xor_sync(0xffffffff, val, mask));
}
return val;
}
template <typename Tdata, unsigned int BLOCK_SIZE_x, unsigned int BLOCK_SIZE_y>
__device__ void warpPerChannelQuantI8Kernel(
int8_t *x_packed, float *x_scale, float *x_zero, const Tdata *x,
int M, int K) {
int otherIdx = blockIdx.x * blockDim.y + threadIdx.y;
int tid = otherIdx * K;
if (otherIdx < M) {
__shared__ float max_total[BLOCK_SIZE_y];
__shared__ float min_total[BLOCK_SIZE_y];
float max_data = -__FLT_MAX__;
float min_data = __FLT_MAX__;
// ---- reduce max/min ----
for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE_x) {
float v = (float)x[tid + ind];
max_data = fmaxf(max_data, v);
min_data = fminf(min_data, v);
}
max_data = WarpAllReduce<MaxOp, float, BLOCK_SIZE_x>(max_data);
min_data = WarpAllReduce<MinOp, float, BLOCK_SIZE_x>(min_data);
if (threadIdx.x == 0) {
max_total[threadIdx.y] = max_data;
min_total[threadIdx.y] = min_data;
}
__syncthreads();
float max_f = max_total[threadIdx.y];
float min_f = min_total[threadIdx.y];
float scale = (max_f - min_f) / 255.0f;
if (scale < 1e-8f) {
scale = 1e-8f;
}
float inv_scale = 1.0f / scale;
float zero = -min_f * inv_scale - 128.0f;
x_scale[otherIdx] = scale;
x_zero[otherIdx] = zero;
for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE_x) {
float v = (float)x[tid + ind];
float qf = v * inv_scale + zero;
int q = round_half_away_from_zero(qf);
if (q > 127) {
q = 127;
}
if (q < -128) {
q = -128;
}
x_packed[tid + ind] = (int8_t)q;
}
}
}
template <typename Tdata, unsigned int BLOCK_SIZE_x, unsigned int BLOCK_SIZE_y>
__device__ void warpPerChannelQuantI8SymKernel(
int8_t *x_packed, float *x_scale, const Tdata *x,
int M, int K) {
int otherIdx = blockIdx.x * blockDim.y + threadIdx.y;
int tid = otherIdx * K;
if (otherIdx < M) {
__shared__ float max_total[BLOCK_SIZE_y];
float max_data = -__FLT_MAX__;
// ---- reduce max/min ----
for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE_x) {
float v = fabs((float)x[tid + ind]);
max_data = fmaxf(max_data, v);
}
max_data = WarpAllReduce<MaxOp, float, BLOCK_SIZE_x>(max_data);
if (threadIdx.x == 0) {
max_total[threadIdx.y] = max_data;
}
__syncthreads();
float max_f = max_total[threadIdx.y];
float scale = max_f / 127.0f;
if (scale < 1e-8f) {
scale = 1e-8f;
}
float inv_scale = 1.0f / scale;
x_scale[otherIdx] = scale;
for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE_x) {
float v = (float)x[tid + ind];
float qf = v * inv_scale;
int q = round_half_away_from_zero(qf);
if (q > 127) {
q = 127;
}
if (q < -127) {
q = -127;
}
x_packed[tid + ind] = (int8_t)q;
}
}
}
#endif // __PERCHANNEL_QUANTINT8_KERNEL_CUH__
#ifndef __PER_CHANNEL_QUANT_INT8_INFO_H__
#define __PER_CHANNEL_QUANT_INT8_INFO_H__
#include "../../../../utils.h"
#include "../../../operator.h"
#include "../../../tensor.h"
namespace op::per_channel_quant_int8 {
class PerChannelQuantI8Info {
private:
PerChannelQuantI8Info() = default;
public:
infiniDtype_t dtype, packed_type;
size_t M, K;
static utils::Result<PerChannelQuantI8Info> createPerChannelQuantI8Info(
infiniopTensorDescriptor_t x_packed_desc,
infiniopTensorDescriptor_t x_scale_desc,
infiniopTensorDescriptor_t x_zero_desc,
infiniopTensorDescriptor_t x_desc) {
CHECK_OR_RETURN(
x_packed_desc != nullptr && x_scale_desc != nullptr && x_desc != nullptr,
INFINI_STATUS_NULL_POINTER);
const infiniDtype_t dtype = x_desc->dtype();
const infiniDtype_t packed_type = x_packed_desc->dtype();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32);
CHECK_DTYPE(packed_type, INFINI_DTYPE_I8);
CHECK_OR_RETURN(x_desc->ndim() == 2
&& x_packed_desc->ndim() == 2
&& x_scale_desc->ndim() == 2,
INFINI_STATUS_BAD_TENSOR_SHAPE);
size_t M = x_desc->dim(0);
size_t K = x_desc->dim(1);
CHECK_OR_RETURN(M == x_packed_desc->dim(0)
|| K == x_packed_desc->dim(1)
|| M == x_scale_desc->dim(0)
|| 1 == x_scale_desc->dim(1),
INFINI_STATUS_BAD_TENSOR_SHAPE);
return utils::Result<PerChannelQuantI8Info>(PerChannelQuantI8Info{
dtype,
packed_type,
M,
K,
});
}
};
} // namespace op::per_channel_quant_int8
#endif // __PER_CHANNEL_QUANT_INT8_INFO_H__
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment