"googlemock/test/gmock-spec-builders_test.cc" did not exist on "e5121b5a828c13588d7aa4fc328b348e92ee4abb"
Unverified Commit 784139b9 authored by thatPepe's avatar thatPepe Committed by GitHub
Browse files

Merge pull request #990 from InfiniTensor/demo131

Demo-131 Cuda graph with optimized paged attention
parents 3c8fb3c0 1d6527cb
#ifdef ENABLE_METAX_MC_API
#include <mc_runtime.h>
#else
#include <hc_runtime.h>
#endif
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include "../../../devices/metax/metax_common.h"
#include "../../../devices/metax/metax_kernel_common.h"
// #include "paged_attention_prefill_fa2.cuh"
#include "paged_attention_prefill_metax.h"
#include "../cuda/kernel_v2.cuh"
namespace op::paged_attention_prefill::metax {
namespace {
constexpr size_t ceilDiv(size_t a, size_t b) {
return (a + b - 1) / b;
}
inline const char *default_prefill_kernel(const PagedAttentionPrefillInfo &info) {
// Heuristic auto-dispatch (v0.4):
// - Prefer the pipelined + tile-wise softmax kernel on FA2-compatible block_size=256.
// - Keep a conservative fallback for other shapes / older GPUs (cp.async is a no-op below SM80).
//
// Users can always override via INFINIOP_FLASH_PREFILL_KERNEL.
if (info.page_block_size == 256 && (info.dtype == INFINI_DTYPE_F16 || info.dtype == INFINI_DTYPE_BF16)) {
if (info.head_size == 128) {
return "warpcta8pipe";
}
// For head_size=64 we keep the previous default until we have broader perf coverage.
}
return "warpcta8";
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128Warp(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// Legacy per-seq launch (kept only as a wrapper; current "warp" impl uses a global-token kernel).
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpKernel<Tindex, Tdata, 128>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, block_table_batch_stride,
q_stride, q_head_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64Warp(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// Legacy per-seq launch (kept only as a wrapper; current "warp" impl uses a global-token kernel).
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpKernel<Tindex, Tdata, 64>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, block_table_batch_stride,
q_stride, q_head_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 4 warps per CTA, one warp per query token.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel<Tindex, Tdata, 128, 4, 64>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64WarpCta(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 4 warps per CTA, one warp per query token.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel<Tindex, Tdata, 64, 4, 128>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta8(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 8 warps per CTA, one warp per query token.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel<Tindex, Tdata, 128, 8, 64>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta8N128(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 8 warps per CTA, one warp per query token, tile_n=128 for fewer K stages.
// Note: we keep K in shared memory but load V from global to stay within the per-block shared limit.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernelKOnly<Tindex, Tdata, 128, 8, 128>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64WarpCta8(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 8 warps per CTA, one warp per query token.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel<Tindex, Tdata, 64, 8, 128>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta8Pipe(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 8 warps per CTA, one warp per query token, with cp.async pipelining.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernelPipelined<Tindex, Tdata, 128, 8, 32, 2>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta8Mma(
half *out,
const half *q,
const half *k_cache,
const half *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCta8MmaHd128Kernel<Tindex>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64WarpCta8Pipe(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 8 warps per CTA, one warp per query token, with cp.async pipelining.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernelPipelined<Tindex, Tdata, 64, 8, 32, 2>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta8PipeSplitKv(
float *partial_acc,
float *partial_m,
float *partial_l,
int num_splits,
size_t total_q_tokens,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride) {
// Encode (split_idx, m_block) into blockIdx.z to allow a single kernel launch:
// blockIdx.z in [0, num_splits * num_m_blocks).
const int num_m_blocks = static_cast<int>((total_q_tokens + 8 - 1) / 8);
const int bz = static_cast<int>(blockIdx.z);
const int split_idx = bz / num_m_blocks;
const int m_block = bz - split_idx * num_m_blocks;
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernelPipelinedSplitKv<Tindex, Tdata, 128, 8, 32, 2>(
partial_acc, partial_m, partial_l, split_idx, num_splits, m_block, total_q_tokens,
q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64WarpCta8PipeSplitKv(
float *partial_acc,
float *partial_m,
float *partial_l,
int num_splits,
size_t total_q_tokens,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride) {
const int num_m_blocks = static_cast<int>((total_q_tokens + 8 - 1) / 8);
const int bz = static_cast<int>(blockIdx.z);
const int split_idx = bz / num_m_blocks;
const int m_block = bz - split_idx * num_m_blocks;
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernelPipelinedSplitKv<Tindex, Tdata, 64, 8, 32, 2>(
partial_acc, partial_m, partial_l, split_idx, num_splits, m_block, total_q_tokens,
q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride);
}
template <typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128SplitKvCombine(
Tdata *out,
const float *partial_acc,
const float *partial_m,
const float *partial_l,
int num_splits,
size_t total_q_tokens,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
op::paged_attention_prefill::cuda::PagedAttentionPrefillSplitKvCombineWarpKernel<Tdata, 128>(
out, partial_acc, partial_m, partial_l, num_splits, total_q_tokens, o_stride, o_head_stride);
}
template <typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64SplitKvCombine(
Tdata *out,
const float *partial_acc,
const float *partial_m,
const float *partial_l,
int num_splits,
size_t total_q_tokens,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
op::paged_attention_prefill::cuda::PagedAttentionPrefillSplitKvCombineWarpKernel<Tdata, 64>(
out, partial_acc, partial_m, partial_l, num_splits, total_q_tokens, o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta16(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 16 warps per CTA, one warp per query token.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel<Tindex, Tdata, 128, 16, 64>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64WarpCta16(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 16 warps per CTA, one warp per query token.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel<Tindex, Tdata, 64, 16, 128>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata, typename Tcompute>
infiniStatus_t launch_prefill_ref(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
hcStream_t stream) {
const dim3 grid(static_cast<uint32_t>(total_q_tokens), static_cast<uint32_t>(num_heads), 1);
const dim3 block(static_cast<uint32_t>(head_size), 1, 1);
if (head_size == 64) {
op::paged_attention_prefill::cuda::PagedAttentionPrefillReferenceKernel<Tindex, Tdata, Tcompute, 64>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_heads, num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride, q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride, num_seqs);
return INFINI_STATUS_SUCCESS;
}
if (head_size == 128) {
op::paged_attention_prefill::cuda::PagedAttentionPrefillReferenceKernel<Tindex, Tdata, Tcompute, 128>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_heads, num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride, q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride, num_seqs);
return INFINI_STATUS_SUCCESS;
}
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill_warp(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
hcStream_t stream) {
const dim3 block(32, 1, 1);
// Global-token launch:
// - dramatically reduces grid size vs the legacy (num_seqs * total_q_tokens) launch
// - matches PagedAttention varlen (cu_seqlens) mental model better
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(total_q_tokens),
1);
switch (head_size) {
case 64:
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpGlobalKernel<Tindex, Tdata, 64>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_heads, num_seqs, num_kv_heads, total_q_tokens, scale, max_num_blocks_per_seq,
page_block_size, block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
case 128:
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpGlobalKernel<Tindex, Tdata, 128>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_heads, num_seqs, num_kv_heads, total_q_tokens, scale, max_num_blocks_per_seq,
page_block_size, block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
hcStream_t stream) {
constexpr int kWarps = 4;
constexpr int kThreads = kWarps * 32;
const dim3 block(kThreads);
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(num_seqs),
static_cast<uint32_t>(ceilDiv(total_q_tokens, static_cast<size_t>(kWarps))));
switch (head_size) {
case 64:
PagedAttentionPrefillHd64WarpCta<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
case 128:
PagedAttentionPrefillHd128WarpCta<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill_warpcta8(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
hcStream_t stream) {
constexpr int kWarps = 8;
constexpr int kThreads = kWarps * 32;
const dim3 block(kThreads);
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(num_seqs),
static_cast<uint32_t>(ceilDiv(total_q_tokens, static_cast<size_t>(kWarps))));
switch (head_size) {
case 64:
PagedAttentionPrefillHd64WarpCta8<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
case 128:
PagedAttentionPrefillHd128WarpCta8<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill_warpcta8pipe(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
hcStream_t stream) {
constexpr int kWarps = 8;
constexpr int kThreads = kWarps * 32;
const dim3 block(kThreads);
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(num_seqs),
static_cast<uint32_t>(ceilDiv(total_q_tokens, static_cast<size_t>(kWarps))));
switch (head_size) {
case 64:
PagedAttentionPrefillHd64WarpCta8Pipe<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
case 128:
PagedAttentionPrefillHd128WarpCta8Pipe<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill_warpcta8mma(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
hcStream_t stream) {
// Current WMMA kernel only supports fp16 + head_dim=128.
if constexpr (!std::is_same_v<Tdata, half>) {
return launch_prefill_warpcta8pipe<Tindex, Tdata>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_heads, num_seqs, num_kv_heads, total_q_tokens, head_size, scale,
max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride, stream);
}
if (head_size != 128) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
// Guardrail: the current WMMA-score kernel is correctness-first and can be extremely slow on long prompts.
// Allow power users to force it via INFINIOP_FLASH_PREFILL_MMA_FORCE=1.
const char *force_env = std::getenv("INFINIOP_FLASH_PREFILL_MMA_FORCE");
const bool force_mma = (force_env != nullptr) && (std::strcmp(force_env, "1") == 0);
const size_t seqlen_k_est = max_num_blocks_per_seq * page_block_size;
if (!force_mma && seqlen_k_est > 4096) {
static bool warned = false;
if (!warned) {
std::fprintf(stderr,
"[infiniop][paged_attention_prefill] warpcta8mma is experimental and very slow for long seqlen_k (est=%zu). "
"Falling back to warpcta8pipe. Set INFINIOP_FLASH_PREFILL_MMA_FORCE=1 to override.\n",
seqlen_k_est);
warned = true;
}
return launch_prefill_warpcta8pipe<Tindex, Tdata>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_heads, num_seqs, num_kv_heads, total_q_tokens, head_size, scale,
max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride, stream);
}
// WMMA requires SM70+. If not supported (or if we can't query), fall back to the pipelined SIMT kernel.
int device = 0;
hcDeviceProp_t prop{};
if (hcGetDevice(&device) == hcSuccess && hcGetDeviceProperties(&prop, device) == hcSuccess) {
if (prop.major < 7) {
return launch_prefill_warpcta8pipe<Tindex, Tdata>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_heads, num_seqs, num_kv_heads, total_q_tokens, head_size, scale,
max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride, stream);
}
}
constexpr int kWarps = 8;
constexpr int kThreads = kWarps * 32;
const dim3 block(kThreads);
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(num_seqs),
static_cast<uint32_t>(ceilDiv(total_q_tokens, static_cast<size_t>(16))));
PagedAttentionPrefillHd128WarpCta8Mma<Tindex>
<<<grid, block, 0, stream>>>(
static_cast<half *>(out),
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill_warpcta8pipe_splitkv(
float *partial_acc,
float *partial_m,
float *partial_l,
int num_splits,
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
hcStream_t stream) {
constexpr int kMaxSplits = 8;
if (num_splits < 1) {
num_splits = 1;
}
if (num_splits > kMaxSplits) {
num_splits = kMaxSplits;
}
constexpr int kWarps = 8;
constexpr int kThreads = kWarps * 32;
const dim3 block(kThreads);
const size_t num_m_blocks = ceilDiv(total_q_tokens, static_cast<size_t>(kWarps));
// Single kernel launch with split_idx encoded in grid.z:
// blockIdx.z in [0, num_splits * num_m_blocks).
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(num_seqs),
static_cast<uint32_t>(num_m_blocks * static_cast<size_t>(num_splits)));
switch (head_size) {
case 64:
PagedAttentionPrefillHd64WarpCta8PipeSplitKv<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
partial_acc, partial_m, partial_l, num_splits, total_q_tokens,
q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride);
break;
case 128:
PagedAttentionPrefillHd128WarpCta8PipeSplitKv<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
partial_acc, partial_m, partial_l, num_splits, total_q_tokens,
q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride);
break;
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
// Combine: one warp per (token, head).
const dim3 block2(32);
const dim3 grid2(static_cast<uint32_t>(num_heads), static_cast<uint32_t>(total_q_tokens), 1);
switch (head_size) {
case 64:
PagedAttentionPrefillHd64SplitKvCombine<Tdata>
<<<grid2, block2, 0, stream>>>(
out, partial_acc, partial_m, partial_l, num_splits, total_q_tokens, o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
case 128:
PagedAttentionPrefillHd128SplitKvCombine<Tdata>
<<<grid2, block2, 0, stream>>>(
out, partial_acc, partial_m, partial_l, num_splits, total_q_tokens, o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill_warpcta8n128(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
hcStream_t stream) {
constexpr int kWarps = 8;
constexpr int kThreads = kWarps * 32;
const dim3 block(kThreads);
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(num_seqs),
static_cast<uint32_t>(ceilDiv(total_q_tokens, static_cast<size_t>(kWarps))));
// Only meaningful for head_dim=128.
if (head_size != 128) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
PagedAttentionPrefillHd128WarpCta8N128<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill_warpcta16(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
hcStream_t stream) {
constexpr int kWarps = 16;
constexpr int kThreads = kWarps * 32;
const dim3 block(kThreads);
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(num_seqs),
static_cast<uint32_t>(ceilDiv(total_q_tokens, static_cast<size_t>(kWarps))));
switch (head_size) {
case 64:
PagedAttentionPrefillHd64WarpCta16<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
case 128:
PagedAttentionPrefillHd128WarpCta16<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
} // namespace
struct Descriptor::Opaque {
std::shared_ptr<device::metax::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t total_kv_lens_desc,
infiniopTensorDescriptor_t cum_seqlens_q_desc,
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc,
float scale) {
auto info = PagedAttentionPrefillInfo::create(
out_desc, q_desc, k_cache_desc, v_cache_desc,
block_tables_desc, total_kv_lens_desc, cum_seqlens_q_desc,
alibi_slopes_desc, scale);
CHECK_RESULT(info);
// Optional split-kv prefill requires workspace for partial (m, l, acc).
// IMPORTANT: Unlike decode, prefill's total_q_tokens can be very large, so we must NOT reserve
// a huge workspace unless the user explicitly enables split-kv.
bool use_splitkv = false;
if (const char *env = std::getenv("INFINIOP_FLASH_PREFILL_SPLITKV")) {
use_splitkv = (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0);
}
int num_splits = 1;
if (use_splitkv) {
if (const char *env = std::getenv("INFINIOP_FLASH_PREFILL_NUM_SPLITS")) {
const int v = std::atoi(env);
if (v > 0) {
num_splits = v;
}
} else {
num_splits = 4;
}
constexpr int kMaxSplits = 8;
if (num_splits > kMaxSplits) {
num_splits = kMaxSplits;
}
}
const size_t n = info->total_q_tokens * info->num_heads;
const size_t splitkv_workspace_bytes = use_splitkv ? (static_cast<size_t>(num_splits) * n * (info->head_size + 2) * sizeof(float)) : 0;
const size_t workspace_bytes = splitkv_workspace_bytes;
// const size_t workspace_bytes = splitkv_workspace_bytes + fa2_workspace_bytes;
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
info.take(), workspace_bytes, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
const void *block_tables,
const void *total_kv_lens,
const void *cum_seqlens_q,
const void *alibi_slopes,
void *stream_) const {
auto stream = static_cast<hcStream_t>(stream_);
const float *alibi_ptr = (alibi_slopes == nullptr) ? nullptr : static_cast<const float *>(alibi_slopes);
const auto *total_kv_lens_i64 = static_cast<const int64_t *>(total_kv_lens);
const auto *cu_seqlens_q_i64 = static_cast<const int64_t *>(cum_seqlens_q);
bool use_splitkv = false;
if (const char *env = std::getenv("INFINIOP_FLASH_PREFILL_SPLITKV")) {
use_splitkv = (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0);
}
int num_splits = 1;
if (use_splitkv) {
if (const char *env = std::getenv("INFINIOP_FLASH_PREFILL_NUM_SPLITS")) {
const int v = std::atoi(env);
if (v > 0) {
num_splits = v;
}
} else {
// Conservative default; users can override.
num_splits = 4;
}
constexpr int kMaxSplits = 8;
if (num_splits > kMaxSplits) {
num_splits = kMaxSplits;
}
const size_t n = _info.total_q_tokens * _info.num_heads;
const size_t required = static_cast<size_t>(num_splits) * n * (_info.head_size + 2) * sizeof(float);
if (workspace_size < required) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
}
if (use_splitkv) {
const size_t n = _info.total_q_tokens * _info.num_heads;
float *partial_acc = static_cast<float *>(workspace);
float *partial_m = partial_acc + static_cast<size_t>(num_splits) * n * _info.head_size;
float *partial_l = partial_m + static_cast<size_t>(num_splits) * n;
// Dispatch by (Tdata, Tindex). total_kv_lens + cu_seqlens_q are currently always int64.
#define DISPATCH_SPLITKV(Tindex, Tdata, BT_PTR) \
return launch_prefill_warpcta8pipe_splitkv<Tindex, Tdata>( \
partial_acc, partial_m, partial_l, num_splits, \
static_cast<Tdata *>(out), \
static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), \
static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(BT_PTR), \
total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream)
if (_info.dtype == INFINI_DTYPE_F16) {
if (_info.index_dtype == INFINI_DTYPE_I64) {
DISPATCH_SPLITKV(int64_t, half, block_tables);
}
if (_info.index_dtype == INFINI_DTYPE_I32) {
DISPATCH_SPLITKV(int32_t, half, block_tables);
}
if (_info.index_dtype == INFINI_DTYPE_U32) {
DISPATCH_SPLITKV(uint32_t, half, block_tables);
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (_info.dtype == INFINI_DTYPE_BF16) {
if (_info.index_dtype == INFINI_DTYPE_I64) {
DISPATCH_SPLITKV(int64_t, __nv_bfloat16, block_tables);
}
if (_info.index_dtype == INFINI_DTYPE_I32) {
DISPATCH_SPLITKV(int32_t, __nv_bfloat16, block_tables);
}
if (_info.index_dtype == INFINI_DTYPE_U32) {
DISPATCH_SPLITKV(uint32_t, __nv_bfloat16, block_tables);
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
#undef DISPATCH_SPLITKV
}
// Default to the fastest validated kernel for supported shapes.
// "ref" is still available for debugging/correctness bisecting.
#define DISPATCH_KERNEL(Tindex, Tdata, Tcompute) \
do { \
const char *k_env = std::getenv("INFINIOP_FLASH_PREFILL_KERNEL"); \
const char *k = (k_env == nullptr) ? default_prefill_kernel(_info) : k_env; \
if (k_env != nullptr) { \
const bool known = (std::strcmp(k, "warp") == 0) || (std::strcmp(k, "warpcta") == 0) || (std::strcmp(k, "warpcta8") == 0) || (std::strcmp(k, "warpcta8pipe") == 0) || (std::strcmp(k, "warpcta8mma") == 0) || (std::strcmp(k, "warpcta8n128") == 0) || (std::strcmp(k, "warpcta16") == 0) || (std::strcmp(k, "ref") == 0); \
if (!known) { \
const char *fallback = default_prefill_kernel(_info); \
std::fprintf(stderr, \
"[infiniop][paged_attention_prefill] WARNING: unknown kernel '%s', falling back to '%s'\n", \
k, fallback); \
k = fallback; \
} \
} \
const char *dbg = std::getenv("INFINIOP_DEBUG_PREFILL_DISPATCH"); \
static bool printed_dispatch = false; \
if (!printed_dispatch && dbg != nullptr && std::strcmp(dbg, "1") == 0) { \
std::fprintf(stderr, \
"[infiniop][paged_attention_prefill] kernel=%s (override=%s head_size=%zu block=%zu dtype=%zu)\n", \
k, \
(k_env == nullptr ? "auto" : "env"), \
static_cast<size_t>(_info.head_size), \
static_cast<size_t>(_info.page_block_size), \
static_cast<size_t>(_info.dtype)); \
printed_dispatch = true; \
} \
if (std::strcmp(k, "warp") == 0) { \
return launch_prefill_warp<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if (std::strcmp(k, "warpcta") == 0) { \
return launch_prefill<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if (std::strcmp(k, "warpcta8") == 0) { \
return launch_prefill_warpcta8<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if (std::strcmp(k, "warpcta8pipe") == 0) { \
return launch_prefill_warpcta8pipe<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if constexpr (std::is_same_v<Tdata, half>) { \
if (std::strcmp(k, "warpcta8mma") == 0) { \
return launch_prefill_warpcta8mma<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
} \
if (std::strcmp(k, "warpcta8n128") == 0) { \
return launch_prefill_warpcta8n128<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if (std::strcmp(k, "warpcta16") == 0) { \
return launch_prefill_warpcta16<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if (std::strcmp(k, "ref") == 0) { \
return launch_prefill_ref<Tindex, Tdata, Tcompute>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
return INFINI_STATUS_BAD_PARAM; \
} while (false)
#define DISPATCH_INDEX(Tindex) \
do { \
if (_info.dtype == INFINI_DTYPE_F16) { \
DISPATCH_KERNEL(Tindex, half, float); \
} \
if (_info.dtype == INFINI_DTYPE_BF16) { \
DISPATCH_KERNEL(Tindex, __nv_bfloat16, float); \
} \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
} while (false)
if (_info.index_dtype == INFINI_DTYPE_I64) {
DISPATCH_INDEX(int64_t);
} else if (_info.index_dtype == INFINI_DTYPE_I32) {
DISPATCH_INDEX(int32_t);
} else if (_info.index_dtype == INFINI_DTYPE_U32) {
DISPATCH_INDEX(uint32_t);
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} // namespace op::paged_attention_prefill::nvidia
#ifndef __PAGED_ATTENTION_PREFILL_KERNEL_CUH__
#define __PAGED_ATTENTION_PREFILL_KERNEL_CUH__
namespace op::paged_attention_prefill::cuda {
__device__ __forceinline__ size_t find_seq_id(size_t token_idx, const int64_t *cum_seq_lens_q, size_t num_seqs) {
size_t low = 0, high = num_seqs - 1;
while (low <= high) {
size_t mid = (low + high) >> 1;
if (token_idx >= (size_t)cum_seq_lens_q[mid] && token_idx < (size_t)cum_seq_lens_q[mid + 1]) {
return mid;
} else if (token_idx < (size_t)cum_seq_lens_q[mid]) {
high = mid - 1;
} else {
low = mid + 1;
}
}
return 0;
}
// Warp-level sum reduction with an explicit active mask (safe for partial warps).
__device__ __forceinline__ float warpReduceSum(float val, unsigned mask) {
for (int offset = 16; offset > 0; offset >>= 1) {
val += __shfl_down_sync(mask, val, offset);
}
return val;
}
// Block-level sum reduction. Returns the sum to all threads in the block.
// Supports blockDim.x up to 1024.
__device__ __forceinline__ float blockReduceSum(float val) {
__shared__ float shared[32]; // max 32 warps per block
const int lane = threadIdx.x & 31;
const int wid = threadIdx.x >> 5;
const unsigned mask = __activemask();
val = warpReduceSum(val, mask);
if (lane == 0) {
shared[wid] = val;
}
__syncthreads();
const int num_warps = (blockDim.x + 31) >> 5;
float sum = 0.0f;
if (wid == 0) {
sum = (lane < num_warps) ? shared[lane] : 0.0f;
const unsigned mask0 = (num_warps >= 32) ? 0xffffffffu : ((1u << num_warps) - 1u);
sum = warpReduceSum(sum, mask0);
if (lane == 0) {
shared[0] = sum;
}
}
__syncthreads();
return shared[0];
}
template <typename Tdata, typename Tcompute>
__global__ void pagedAttentionPrefillKernel(
Tdata *out_, const Tdata *q_, const Tdata *k_cache_, const Tdata *v_cache_,
const int64_t *block_tables_,
const int64_t *total_kv_lens_,
const int64_t *cum_seq_lens_q_,
const float *alibi_slopes_,
const size_t num_heads, const size_t num_kv_heads, const float scale,
const size_t max_num_blocks_per_seq, const size_t block_size,
const ptrdiff_t kv_block_stride, const ptrdiff_t kv_head_stride,
const ptrdiff_t q_stride, const ptrdiff_t q_head_stride,
const size_t head_size,
const size_t num_seqs) {
// Grid : x -> token, y -> head
const size_t global_token_idx = blockIdx.x;
const size_t head_idx = blockIdx.y;
const size_t dim_idx = threadIdx.x;
if (dim_idx >= head_size) {
return;
}
__shared__ size_t sh_seq_idx;
__shared__ size_t sh_causal_limit;
__shared__ size_t sh_kv_head_idx;
__shared__ float sh_scale_acc;
__shared__ float sh_w;
__shared__ float sh_inv_l;
if (dim_idx == 0) {
sh_seq_idx = find_seq_id(global_token_idx, cum_seq_lens_q_, num_seqs);
const size_t q_token_idx = global_token_idx - static_cast<size_t>(cum_seq_lens_q_[sh_seq_idx]);
const size_t total_kv_len = static_cast<size_t>(total_kv_lens_[sh_seq_idx]);
const size_t q_len = static_cast<size_t>(cum_seq_lens_q_[sh_seq_idx + 1] - cum_seq_lens_q_[sh_seq_idx]);
const size_t history_len = total_kv_len - q_len;
sh_causal_limit = history_len + q_token_idx;
const size_t num_queries_per_kv = num_heads / num_kv_heads;
sh_kv_head_idx = head_idx / num_queries_per_kv;
}
__syncthreads();
const size_t seq_idx = sh_seq_idx;
const size_t causal_limit = sh_causal_limit;
const size_t kv_head_idx = sh_kv_head_idx;
const Tdata *q_vec = q_ + global_token_idx * q_stride + head_idx * q_head_stride;
Tdata *out_ptr = out_ + global_token_idx * num_heads * head_size + head_idx * head_size;
const int64_t *block_table = block_tables_ + seq_idx * max_num_blocks_per_seq;
const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx];
const float qv = static_cast<float>(q_vec[dim_idx]);
Tcompute acc = 0.0f;
float m = -FLT_MAX;
float l = 0.0f;
for (size_t t = 0; t <= causal_limit; ++t) {
const size_t b_idx = t / block_size;
const size_t t_off = t % block_size;
const ptrdiff_t physical_block_id = block_table[b_idx];
const Tdata *k_vec = k_cache_ + physical_block_id * kv_block_stride + kv_head_idx * kv_head_stride + t_off * head_size;
const float dot = blockReduceSum(qv * static_cast<float>(k_vec[dim_idx]));
if (dim_idx == 0) {
float score = dot * static_cast<float>(scale);
if (alibi_slope != 0.0f) {
score += alibi_slope * static_cast<float>(t - causal_limit);
}
const float m_new = fmaxf(m, score);
const float scale_acc = expf(m - m_new);
const float w = expf(score - m_new);
l = l * scale_acc + w;
m = m_new;
sh_scale_acc = scale_acc;
sh_w = w;
}
__syncthreads();
const float scale_acc = sh_scale_acc;
const float w = sh_w;
const Tdata *v_vec = v_cache_ + physical_block_id * kv_block_stride + kv_head_idx * kv_head_stride + t_off * head_size;
acc = acc * static_cast<Tcompute>(scale_acc) + static_cast<Tcompute>(w) * static_cast<Tcompute>(v_vec[dim_idx]);
__syncthreads();
}
if (dim_idx == 0) {
sh_inv_l = 1.0f / (l + 1e-6f);
}
__syncthreads();
out_ptr[dim_idx] = static_cast<Tdata>(acc * static_cast<Tcompute>(sh_inv_l));
}
} // namespace op::paged_attention_prefill::cuda
#endif
#ifndef __PAGED_ATTENTION_PREFILL_MOORE_H__
#define __PAGED_ATTENTION_PREFILL_MOORE_H__
#include "../paged_attention_prefill.h"
DESCRIPTOR(moore)
#endif // __PAGED_ATTENTION_PREFILL_MOORE_H__
#include <musa_fp16.h>
#include <float.h>
#include <math.h>
#include <stdint.h>
#include "../../../devices/moore/moore_common.h"
#include "../../../devices/moore/moore_kernel_common.h"
#include "paged_attention_prefill_kernel.h"
#include "paged_attention_prefill_moore.h"
template <typename Tdata, typename Tcompute>
infiniStatus_t launchPagedAttentionPrefill(
Tdata *out, const Tdata *q, const Tdata *k_cache, const Tdata *v_cache,
const int64_t *block_tables,
const int64_t *seq_lens,
const int64_t *cum_seq_lens_q,
const float *alibi_slopes,
const size_t num_heads,
const size_t num_seqs,
const size_t num_kv_heads,
const float scale,
const size_t max_num_blocks_per_seq,
const size_t page_block_size,
const size_t total_q_tokens,
const size_t head_size,
const ptrdiff_t k_batch_stride,
const ptrdiff_t k_head_stride,
const ptrdiff_t q_stride,
const ptrdiff_t q_head_stride,
musaStream_t stream) {
if (total_q_tokens == 0 || num_heads == 0) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
dim3 grid(total_q_tokens, num_heads);
dim3 block(head_size);
op::paged_attention_prefill::cuda::pagedAttentionPrefillKernel<Tdata, Tcompute>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache,
block_tables, seq_lens, cum_seq_lens_q, alibi_slopes,
num_heads, num_kv_heads, scale,
max_num_blocks_per_seq, page_block_size,
k_batch_stride, k_head_stride,
q_stride, q_head_stride,
head_size,
num_seqs);
return INFINI_STATUS_SUCCESS;
}
namespace op::paged_attention_prefill::moore {
struct Descriptor::Opaque {
std::shared_ptr<device::moore::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t seq_lens_desc,
infiniopTensorDescriptor_t cum_seq_lens_q_desc,
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc,
float scale) {
auto info = PagedAttentionPrefillInfo::create(
out_desc, q_desc, k_cache_desc, v_cache_desc,
block_tables_desc, seq_lens_desc,
cum_seq_lens_q_desc,
alibi_slopes_desc, scale);
CHECK_RESULT(info);
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::moore::Handle *>(handle)->internal()},
info.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
const void *block_tables,
const void *seq_lens,
const void *cum_seq_lens_q,
const void *alibi_slopes,
void *stream_) const {
musaStream_t stream = (musaStream_t)stream_;
#define LAUNCH_KERNEL(Tdata, Tcompute) \
launchPagedAttentionPrefill<Tdata, Tcompute>( \
(Tdata *)out, (const Tdata *)q, (const Tdata *)k_cache, (const Tdata *)v_cache, \
(const int64_t *)block_tables, (const int64_t *)seq_lens, (const int64_t *)cum_seq_lens_q, \
(const float *)alibi_slopes, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, \
_info.scale, _info.max_num_blocks_per_seq, \
_info.page_block_size, _info.total_q_tokens, \
_info.head_size, \
_info.k_batch_stride, _info.k_head_stride, \
_info.q_stride, _info.q_head_stride, \
stream)
if (_info.dtype == INFINI_DTYPE_F16) {
return LAUNCH_KERNEL(half, float);
} else if (_info.dtype == INFINI_DTYPE_BF16) {
return LAUNCH_KERNEL(__mt_bfloat16, float);
} else if (_info.dtype == INFINI_DTYPE_F32) {
return LAUNCH_KERNEL(float, float);
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} // namespace op::paged_attention_prefill::moore
#include <cuda_fp16.h>
#include <float.h>
#include <math.h>
#include <stdint.h>
#include <cuda_runtime.h>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include "../../../devices/nvidia/nvidia_common.cuh"
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
#include "../cuda/kernel.cuh"
// #include "paged_attention_prefill_fa2.cuh"
#include "paged_attention_prefill_nvidia.cuh"
template <typename Tdata, typename Tcompute>
infiniStatus_t launchPagedAttentionPrefill(
Tdata *out, const Tdata *q, const Tdata *k_cache, const Tdata *v_cache,
const int64_t *block_tables,
const int64_t *seq_lens,
const int64_t *cum_seq_lens_q,
const float *alibi_slopes,
const size_t num_heads,
const size_t num_seqs,
const size_t num_kv_heads,
const float scale,
const size_t max_num_blocks_per_seq,
const size_t block_size,
const size_t total_q_tokens,
const size_t head_size,
const ptrdiff_t kv_block_stride,
const ptrdiff_t kv_head_stride,
const ptrdiff_t q_stride,
const ptrdiff_t q_head_stride,
#include "../cuda/kernel_v2.cuh"
namespace op::paged_attention_prefill::nvidia {
namespace {
constexpr size_t ceilDiv(size_t a, size_t b) {
return (a + b - 1) / b;
}
inline const char *default_prefill_kernel(const PagedAttentionPrefillInfo &info) {
// Iluvatar: use warp (stable). Users can override via INFINIOP_FLASH_PREFILL_KERNEL.
#ifdef ENABLE_ILUVATAR_API
(void)info;
return "warp";
#endif
// Heuristic auto-dispatch (v0.4):
// - Prefer the pipelined + tile-wise softmax kernel on FA2-compatible block_size=256.
// - Keep a conservative fallback for other shapes / older GPUs (cp.async is a no-op below SM80).
//
// Users can always override via INFINIOP_FLASH_PREFILL_KERNEL.
if (info.page_block_size == 256 && (info.dtype == INFINI_DTYPE_F16 || info.dtype == INFINI_DTYPE_BF16)) {
if (info.head_size == 128) {
return "warpcta8pipe";
}
// For head_size=64 we keep the previous default until we have broader perf coverage.
}
return "warpcta8";
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128Warp(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// Legacy per-seq launch (kept only as a wrapper; current "warp" impl uses a global-token kernel).
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpKernel<Tindex, Tdata, 128>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, block_table_batch_stride,
q_stride, q_head_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd64Warp(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// Legacy per-seq launch (kept only as a wrapper; current "warp" impl uses a global-token kernel).
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpKernel<Tindex, Tdata, 64>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, block_table_batch_stride,
q_stride, q_head_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128WarpCta(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 4 warps per CTA, one warp per query token.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel<Tindex, Tdata, 128, 4, 64>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd64WarpCta(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 4 warps per CTA, one warp per query token.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel<Tindex, Tdata, 64, 4, 128>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128WarpCta8(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 8 warps per CTA, one warp per query token.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel<Tindex, Tdata, 128, 8, 64>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128WarpCta8N128(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 8 warps per CTA, one warp per query token, tile_n=128 for fewer K stages.
// Note: we keep K in shared memory but load V from global to stay within the per-block shared limit.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernelKOnly<Tindex, Tdata, 128, 8, 128>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd64WarpCta8(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 8 warps per CTA, one warp per query token.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel<Tindex, Tdata, 64, 8, 128>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128WarpCta8Pipe(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 8 warps per CTA, one warp per query token, with cp.async pipelining.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernelPipelined<Tindex, Tdata, 128, 8, 32, 2>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex>
INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128WarpCta8Mma(
half *out,
const half *q,
const half *k_cache,
const half *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCta8MmaHd128Kernel<Tindex>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd64WarpCta8Pipe(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 8 warps per CTA, one warp per query token, with cp.async pipelining.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernelPipelined<Tindex, Tdata, 64, 8, 32, 2>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128WarpCta8PipeSplitKv(
float *partial_acc,
float *partial_m,
float *partial_l,
int num_splits,
size_t total_q_tokens,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride) {
// Encode (split_idx, m_block) into blockIdx.z to allow a single kernel launch:
// blockIdx.z in [0, num_splits * num_m_blocks).
const int num_m_blocks = static_cast<int>((total_q_tokens + 8 - 1) / 8);
const int bz = static_cast<int>(blockIdx.z);
const int split_idx = bz / num_m_blocks;
const int m_block = bz - split_idx * num_m_blocks;
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernelPipelinedSplitKv<Tindex, Tdata, 128, 8, 32, 2>(
partial_acc, partial_m, partial_l, split_idx, num_splits, m_block, total_q_tokens,
q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd64WarpCta8PipeSplitKv(
float *partial_acc,
float *partial_m,
float *partial_l,
int num_splits,
size_t total_q_tokens,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride) {
const int num_m_blocks = static_cast<int>((total_q_tokens + 8 - 1) / 8);
const int bz = static_cast<int>(blockIdx.z);
const int split_idx = bz / num_m_blocks;
const int m_block = bz - split_idx * num_m_blocks;
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernelPipelinedSplitKv<Tindex, Tdata, 64, 8, 32, 2>(
partial_acc, partial_m, partial_l, split_idx, num_splits, m_block, total_q_tokens,
q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride);
}
template <typename Tdata>
INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128SplitKvCombine(
Tdata *out,
const float *partial_acc,
const float *partial_m,
const float *partial_l,
int num_splits,
size_t total_q_tokens,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
op::paged_attention_prefill::cuda::PagedAttentionPrefillSplitKvCombineWarpKernel<Tdata, 128>(
out, partial_acc, partial_m, partial_l, num_splits, total_q_tokens, o_stride, o_head_stride);
}
template <typename Tdata>
INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd64SplitKvCombine(
Tdata *out,
const float *partial_acc,
const float *partial_m,
const float *partial_l,
int num_splits,
size_t total_q_tokens,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
op::paged_attention_prefill::cuda::PagedAttentionPrefillSplitKvCombineWarpKernel<Tdata, 64>(
out, partial_acc, partial_m, partial_l, num_splits, total_q_tokens, o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128WarpCta16(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 16 warps per CTA, one warp per query token.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel<Tindex, Tdata, 128, 16, 64>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd64WarpCta16(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 16 warps per CTA, one warp per query token.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel<Tindex, Tdata, 64, 16, 128>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata, typename Tcompute>
infiniStatus_t launch_prefill_ref(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
cudaStream_t stream) {
const dim3 grid(static_cast<uint32_t>(total_q_tokens), static_cast<uint32_t>(num_heads), 1);
const dim3 block(static_cast<uint32_t>(head_size), 1, 1);
if (head_size == 64) {
op::paged_attention_prefill::cuda::PagedAttentionPrefillReferenceKernel<Tindex, Tdata, Tcompute, 64>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_heads, num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride, q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride, num_seqs);
return INFINI_STATUS_SUCCESS;
}
if (head_size == 128) {
op::paged_attention_prefill::cuda::PagedAttentionPrefillReferenceKernel<Tindex, Tdata, Tcompute, 128>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_heads, num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride, q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride, num_seqs);
return INFINI_STATUS_SUCCESS;
}
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill_warp(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
cudaStream_t stream) {
const dim3 block(32, 1, 1);
// Global-token launch:
// - dramatically reduces grid size vs the legacy (num_seqs * total_q_tokens) launch
// - matches PagedAttention varlen (cu_seqlens) mental model better
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(total_q_tokens),
1);
switch (head_size) {
case 64:
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpGlobalKernel<Tindex, Tdata, 64>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_heads, num_seqs, num_kv_heads, total_q_tokens, scale, max_num_blocks_per_seq,
page_block_size, block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
case 128:
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpGlobalKernel<Tindex, Tdata, 128>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_heads, num_seqs, num_kv_heads, total_q_tokens, scale, max_num_blocks_per_seq,
page_block_size, block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
cudaStream_t stream) {
constexpr int kWarps = 4;
constexpr int kThreads = kWarps * 32;
const dim3 block(kThreads);
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(num_seqs),
static_cast<uint32_t>(ceilDiv(total_q_tokens, static_cast<size_t>(kWarps))));
switch (head_size) {
case 64:
PagedAttentionPrefillHd64WarpCta<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
case 128:
PagedAttentionPrefillHd128WarpCta<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill_warpcta8(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
cudaStream_t stream) {
constexpr int kWarps = 8;
constexpr int kThreads = kWarps * 32;
const dim3 block(kThreads);
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(num_seqs),
static_cast<uint32_t>(ceilDiv(total_q_tokens, static_cast<size_t>(kWarps))));
switch (head_size) {
case 64:
PagedAttentionPrefillHd64WarpCta8<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
case 128:
PagedAttentionPrefillHd128WarpCta8<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill_warpcta8pipe(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
cudaStream_t stream) {
if (total_q_tokens == 0 || num_heads == 0) {
constexpr int kWarps = 8;
constexpr int kThreads = kWarps * 32;
const dim3 block(kThreads);
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(num_seqs),
static_cast<uint32_t>(ceilDiv(total_q_tokens, static_cast<size_t>(kWarps))));
switch (head_size) {
case 64:
PagedAttentionPrefillHd64WarpCta8Pipe<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
case 128:
PagedAttentionPrefillHd128WarpCta8Pipe<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
dim3 grid(total_q_tokens, num_heads);
dim3 block(head_size);
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill_warpcta8mma(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
cudaStream_t stream) {
// Current WMMA kernel only supports fp16 + head_dim=128.
if constexpr (!std::is_same_v<Tdata, half>) {
return launch_prefill_warpcta8pipe<Tindex, Tdata>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_heads, num_seqs, num_kv_heads, total_q_tokens, head_size, scale,
max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride, stream);
}
op::paged_attention_prefill::cuda::pagedAttentionPrefillKernel<Tdata, Tcompute>
if (head_size != 128) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
// Guardrail: the current WMMA-score kernel is correctness-first and can be extremely slow on long prompts.
// Allow power users to force it via INFINIOP_FLASH_PREFILL_MMA_FORCE=1.
const char *force_env = std::getenv("INFINIOP_FLASH_PREFILL_MMA_FORCE");
const bool force_mma = (force_env != nullptr) && (std::strcmp(force_env, "1") == 0);
const size_t seqlen_k_est = max_num_blocks_per_seq * page_block_size;
if (!force_mma && seqlen_k_est > 4096) {
static bool warned = false;
if (!warned) {
std::fprintf(stderr,
"[infiniop][paged_attention_prefill] warpcta8mma is experimental and very slow for long seqlen_k (est=%zu). "
"Falling back to warpcta8pipe. Set INFINIOP_FLASH_PREFILL_MMA_FORCE=1 to override.\n",
seqlen_k_est);
warned = true;
}
return launch_prefill_warpcta8pipe<Tindex, Tdata>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_heads, num_seqs, num_kv_heads, total_q_tokens, head_size, scale,
max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride, stream);
}
// WMMA requires SM70+. If not supported (or if we can't query), fall back to the pipelined SIMT kernel.
int device = 0;
cudaDeviceProp prop{};
if (cudaGetDevice(&device) == cudaSuccess && cudaGetDeviceProperties(&prop, device) == cudaSuccess) {
if (prop.major < 7) {
return launch_prefill_warpcta8pipe<Tindex, Tdata>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_heads, num_seqs, num_kv_heads, total_q_tokens, head_size, scale,
max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride, stream);
}
}
constexpr int kWarps = 8;
constexpr int kThreads = kWarps * 32;
const dim3 block(kThreads);
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(num_seqs),
static_cast<uint32_t>(ceilDiv(total_q_tokens, static_cast<size_t>(16))));
PagedAttentionPrefillHd128WarpCta8Mma<Tindex>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache,
block_tables, seq_lens, cum_seq_lens_q, alibi_slopes,
num_heads, num_kv_heads, scale,
max_num_blocks_per_seq, block_size,
kv_block_stride, kv_head_stride,
static_cast<half *>(out),
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
head_size,
num_seqs);
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill_warpcta8pipe_splitkv(
float *partial_acc,
float *partial_m,
float *partial_l,
int num_splits,
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
cudaStream_t stream) {
constexpr int kMaxSplits = 8;
if (num_splits < 1) {
num_splits = 1;
}
if (num_splits > kMaxSplits) {
num_splits = kMaxSplits;
}
constexpr int kWarps = 8;
constexpr int kThreads = kWarps * 32;
const dim3 block(kThreads);
const size_t num_m_blocks = ceilDiv(total_q_tokens, static_cast<size_t>(kWarps));
// Single kernel launch with split_idx encoded in grid.z:
// blockIdx.z in [0, num_splits * num_m_blocks).
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(num_seqs),
static_cast<uint32_t>(num_m_blocks * static_cast<size_t>(num_splits)));
switch (head_size) {
case 64:
PagedAttentionPrefillHd64WarpCta8PipeSplitKv<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
partial_acc, partial_m, partial_l, num_splits, total_q_tokens,
q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride);
break;
case 128:
PagedAttentionPrefillHd128WarpCta8PipeSplitKv<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
partial_acc, partial_m, partial_l, num_splits, total_q_tokens,
q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride);
break;
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
// Combine: one warp per (token, head).
const dim3 block2(32);
const dim3 grid2(static_cast<uint32_t>(num_heads), static_cast<uint32_t>(total_q_tokens), 1);
switch (head_size) {
case 64:
PagedAttentionPrefillHd64SplitKvCombine<Tdata>
<<<grid2, block2, 0, stream>>>(
out, partial_acc, partial_m, partial_l, num_splits, total_q_tokens, o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
case 128:
PagedAttentionPrefillHd128SplitKvCombine<Tdata>
<<<grid2, block2, 0, stream>>>(
out, partial_acc, partial_m, partial_l, num_splits, total_q_tokens, o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill_warpcta8n128(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
cudaStream_t stream) {
constexpr int kWarps = 8;
constexpr int kThreads = kWarps * 32;
const dim3 block(kThreads);
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(num_seqs),
static_cast<uint32_t>(ceilDiv(total_q_tokens, static_cast<size_t>(kWarps))));
// Only meaningful for head_dim=128.
if (head_size != 128) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
PagedAttentionPrefillHd128WarpCta8N128<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
}
namespace op::paged_attention_prefill::nvidia {
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill_warpcta16(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
cudaStream_t stream) {
constexpr int kWarps = 16;
constexpr int kThreads = kWarps * 32;
const dim3 block(kThreads);
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(num_seqs),
static_cast<uint32_t>(ceilDiv(total_q_tokens, static_cast<size_t>(kWarps))));
switch (head_size) {
case 64:
PagedAttentionPrefillHd64WarpCta16<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
case 128:
PagedAttentionPrefillHd128WarpCta16<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
} // namespace
struct Descriptor::Opaque {
std::shared_ptr<device::nvidia::Handle::Internal> internal;
......@@ -68,22 +1254,48 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t seq_lens_desc,
infiniopTensorDescriptor_t cum_seq_lens_q_desc,
infiniopTensorDescriptor_t total_kv_lens_desc,
infiniopTensorDescriptor_t cum_seqlens_q_desc,
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc,
float scale) {
auto info = PagedAttentionPrefillInfo::create(
out_desc, q_desc, k_cache_desc, v_cache_desc,
block_tables_desc, seq_lens_desc,
cum_seq_lens_q_desc,
block_tables_desc, total_kv_lens_desc, cum_seqlens_q_desc,
alibi_slopes_desc, scale);
CHECK_RESULT(info);
// Optional split-kv prefill requires workspace for partial (m, l, acc).
// IMPORTANT: Unlike decode, prefill's total_q_tokens can be very large, so we must NOT reserve
// a huge workspace unless the user explicitly enables split-kv.
bool use_splitkv = false;
if (const char *env = std::getenv("INFINIOP_FLASH_PREFILL_SPLITKV")) {
use_splitkv = (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0);
}
int num_splits = 1;
if (use_splitkv) {
if (const char *env = std::getenv("INFINIOP_FLASH_PREFILL_NUM_SPLITS")) {
const int v = std::atoi(env);
if (v > 0) {
num_splits = v;
}
} else {
num_splits = 4;
}
constexpr int kMaxSplits = 8;
if (num_splits > kMaxSplits) {
num_splits = kMaxSplits;
}
}
const size_t n = info->total_q_tokens * info->num_heads;
const size_t splitkv_workspace_bytes = use_splitkv ? (static_cast<size_t>(num_splits) * n * (info->head_size + 2) * sizeof(float)) : 0;
const size_t workspace_bytes = splitkv_workspace_bytes;
// const size_t workspace_bytes = splitkv_workspace_bytes + fa2_workspace_bytes;
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::nvidia::Handle *>(handle)->internal()},
info.take(), 0, handle->device, handle->device_id);
info.take(), workspace_bytes, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
......@@ -92,32 +1304,249 @@ infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
const void *block_tables,
const void *seq_lens,
const void *cum_seq_lens_q,
const void *total_kv_lens,
const void *cum_seqlens_q,
const void *alibi_slopes,
void *stream_) const {
auto stream = static_cast<cudaStream_t>(stream_);
const float *alibi_ptr = (alibi_slopes == nullptr) ? nullptr : static_cast<const float *>(alibi_slopes);
const auto *total_kv_lens_i64 = static_cast<const int64_t *>(total_kv_lens);
const auto *cu_seqlens_q_i64 = static_cast<const int64_t *>(cum_seqlens_q);
bool use_splitkv = false;
if (const char *env = std::getenv("INFINIOP_FLASH_PREFILL_SPLITKV")) {
use_splitkv = (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0);
}
int num_splits = 1;
if (use_splitkv) {
if (const char *env = std::getenv("INFINIOP_FLASH_PREFILL_NUM_SPLITS")) {
const int v = std::atoi(env);
if (v > 0) {
num_splits = v;
}
} else {
// Conservative default; users can override.
num_splits = 4;
}
constexpr int kMaxSplits = 8;
if (num_splits > kMaxSplits) {
num_splits = kMaxSplits;
}
const size_t n = _info.total_q_tokens * _info.num_heads;
const size_t required = static_cast<size_t>(num_splits) * n * (_info.head_size + 2) * sizeof(float);
if (workspace_size < required) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
}
if (use_splitkv) {
const size_t n = _info.total_q_tokens * _info.num_heads;
float *partial_acc = static_cast<float *>(workspace);
float *partial_m = partial_acc + static_cast<size_t>(num_splits) * n * _info.head_size;
float *partial_l = partial_m + static_cast<size_t>(num_splits) * n;
// Dispatch by (Tdata, Tindex). total_kv_lens + cu_seqlens_q are currently always int64.
#define DISPATCH_SPLITKV(Tindex, Tdata, BT_PTR) \
return launch_prefill_warpcta8pipe_splitkv<Tindex, Tdata>( \
partial_acc, partial_m, partial_l, num_splits, \
static_cast<Tdata *>(out), \
static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), \
static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(BT_PTR), \
total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream)
if (_info.dtype == INFINI_DTYPE_F16) {
if (_info.index_dtype == INFINI_DTYPE_I64) {
DISPATCH_SPLITKV(int64_t, half, block_tables);
}
if (_info.index_dtype == INFINI_DTYPE_I32) {
DISPATCH_SPLITKV(int32_t, half, block_tables);
}
if (_info.index_dtype == INFINI_DTYPE_U32) {
DISPATCH_SPLITKV(uint32_t, half, block_tables);
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (_info.dtype == INFINI_DTYPE_BF16) {
if (_info.index_dtype == INFINI_DTYPE_I64) {
DISPATCH_SPLITKV(int64_t, __nv_bfloat16, block_tables);
}
if (_info.index_dtype == INFINI_DTYPE_I32) {
DISPATCH_SPLITKV(int32_t, __nv_bfloat16, block_tables);
}
if (_info.index_dtype == INFINI_DTYPE_U32) {
DISPATCH_SPLITKV(uint32_t, __nv_bfloat16, block_tables);
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
#undef DISPATCH_SPLITKV
}
// Default to the fastest validated kernel for supported shapes.
// "ref" is still available for debugging/correctness bisecting.
#define DISPATCH_KERNEL(Tindex, Tdata, Tcompute) \
do { \
const char *k_env = std::getenv("INFINIOP_FLASH_PREFILL_KERNEL"); \
const char *k = (k_env == nullptr) ? default_prefill_kernel(_info) : k_env; \
if (k_env != nullptr) { \
const bool known = (std::strcmp(k, "warp") == 0) || (std::strcmp(k, "warpcta") == 0) || (std::strcmp(k, "warpcta8") == 0) || (std::strcmp(k, "warpcta8pipe") == 0) || (std::strcmp(k, "warpcta8mma") == 0) || (std::strcmp(k, "warpcta8n128") == 0) || (std::strcmp(k, "warpcta16") == 0) || (std::strcmp(k, "ref") == 0); \
if (!known) { \
const char *fallback = default_prefill_kernel(_info); \
std::fprintf(stderr, \
"[infiniop][paged_attention_prefill] WARNING: unknown kernel '%s', falling back to '%s'\n", \
k, fallback); \
k = fallback; \
} \
} \
const char *dbg = std::getenv("INFINIOP_DEBUG_PREFILL_DISPATCH"); \
static bool printed_dispatch = false; \
if (!printed_dispatch && dbg != nullptr && std::strcmp(dbg, "1") == 0) { \
std::fprintf(stderr, \
"[infiniop][paged_attention_prefill] kernel=%s (override=%s head_size=%zu block=%zu dtype=%zu)\n", \
k, \
(k_env == nullptr ? "auto" : "env"), \
static_cast<size_t>(_info.head_size), \
static_cast<size_t>(_info.page_block_size), \
static_cast<size_t>(_info.dtype)); \
printed_dispatch = true; \
} \
if (std::strcmp(k, "warp") == 0) { \
return launch_prefill_warp<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if (std::strcmp(k, "warpcta") == 0) { \
return launch_prefill<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if (std::strcmp(k, "warpcta8") == 0) { \
return launch_prefill_warpcta8<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if (std::strcmp(k, "warpcta8pipe") == 0) { \
return launch_prefill_warpcta8pipe<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if constexpr (std::is_same_v<Tdata, half>) { \
if (std::strcmp(k, "warpcta8mma") == 0) { \
return launch_prefill_warpcta8mma<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
} \
if (std::strcmp(k, "warpcta8n128") == 0) { \
return launch_prefill_warpcta8n128<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if (std::strcmp(k, "warpcta16") == 0) { \
return launch_prefill_warpcta16<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if (std::strcmp(k, "ref") == 0) { \
return launch_prefill_ref<Tindex, Tdata, Tcompute>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
return INFINI_STATUS_BAD_PARAM; \
} while (false)
#define DISPATCH_INDEX(Tindex) \
do { \
if (_info.dtype == INFINI_DTYPE_F16) { \
DISPATCH_KERNEL(Tindex, half, float); \
} \
if (_info.dtype == INFINI_DTYPE_BF16) { \
DISPATCH_KERNEL(Tindex, __nv_bfloat16, float); \
} \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
} while (false)
cudaStream_t stream = (cudaStream_t)stream_;
#define LAUNCH_KERNEL(Tdata, Tcompute) \
launchPagedAttentionPrefill<Tdata, Tcompute>( \
(Tdata *)out, (const Tdata *)q, (const Tdata *)k_cache, (const Tdata *)v_cache, \
(const int64_t *)block_tables, (const int64_t *)seq_lens, (const int64_t *)cum_seq_lens_q, \
(const float *)alibi_slopes, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, \
_info.scale, _info.max_num_blocks_per_seq, \
_info.block_size, _info.total_q_tokens, \
_info.head_size, \
_info.kv_block_stride, _info.kv_head_stride, \
_info.q_stride, _info.q_head_stride, \
stream)
if (_info.dtype == INFINI_DTYPE_F16) {
return LAUNCH_KERNEL(half, float);
} else if (_info.dtype == INFINI_DTYPE_BF16) {
return LAUNCH_KERNEL(__nv_bfloat16, float);
} else if (_info.dtype == INFINI_DTYPE_F32) {
return LAUNCH_KERNEL(float, float);
if (_info.index_dtype == INFINI_DTYPE_I64) {
DISPATCH_INDEX(int64_t);
} else if (_info.index_dtype == INFINI_DTYPE_I32) {
DISPATCH_INDEX(int32_t);
} else if (_info.index_dtype == INFINI_DTYPE_U32) {
DISPATCH_INDEX(uint32_t);
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
......
......@@ -2,9 +2,15 @@
#include "../../handle.h"
#include "infiniop/ops/paged_attention_prefill.h"
#ifdef ENABLE_NVIDIA_API
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API)
#include "nvidia/paged_attention_prefill_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
#include "metax/paged_attention_prefill_metax.h"
#endif
#ifdef ENABLE_MOORE_API
#include "moore/paged_attention_prefill_moore.h"
#endif
__C infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor(
infiniopHandle_t handle,
......@@ -32,6 +38,18 @@ __C infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor(
switch (handle->device) {
#ifdef ENABLE_NVIDIA_API
CREATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax)
#endif
#ifdef ENABLE_ALI_API
CREATE(INFINI_DEVICE_ALI, nvidia)
#endif
#ifdef ENABLE_ILUVATAR_API
CREATE(INFINI_DEVICE_ILUVATAR, nvidia)
#endif
#ifdef ENABLE_MOORE_API
CREATE(INFINI_DEVICE_MOORE, moore)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -50,6 +68,18 @@ __C infiniStatus_t infiniopGetPagedAttentionPrefillWorkspaceSize(
switch (desc->device_type) {
#ifdef ENABLE_NVIDIA_API
GET(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax)
#endif
#ifdef ENABLE_ALI_API
GET(INFINI_DEVICE_ALI, nvidia)
#endif
#ifdef ENABLE_ILUVATAR_API
GET(INFINI_DEVICE_ILUVATAR, nvidia)
#endif
#ifdef ENABLE_MOORE_API
GET(INFINI_DEVICE_MOORE, moore)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -75,6 +105,18 @@ __C infiniStatus_t infiniopPagedAttentionPrefill(
switch (desc->device_type) {
#ifdef ENABLE_NVIDIA_API
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax)
#endif
#ifdef ENABLE_ALI_API
CALCULATE(INFINI_DEVICE_ALI, nvidia)
#endif
#ifdef ENABLE_ILUVATAR_API
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia)
#endif
#ifdef ENABLE_MOORE_API
CALCULATE(INFINI_DEVICE_MOORE, moore)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -92,6 +134,18 @@ __C infiniStatus_t infiniopDestroyPagedAttentionPrefillDescriptor(
switch (desc->device_type) {
#ifdef ENABLE_NVIDIA_API
DESTROY(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_METAX_API
DESTROY(INFINI_DEVICE_METAX, metax)
#endif
#ifdef ENABLE_ALI_API
DESTROY(INFINI_DEVICE_ALI, nvidia)
#endif
#ifdef ENABLE_ILUVATAR_API
DESTROY(INFINI_DEVICE_ILUVATAR, nvidia)
#endif
#ifdef ENABLE_MOORE_API
DESTROY(INFINI_DEVICE_MOORE, moore)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......
#ifndef __PAGED_CACHING_METAX_H__
#define __PAGED_CACHING_METAX_H__
#include "../paged_caching.h"
DESCRIPTOR(metax)
#endif // __PAGED_CACHING_METAX_H__
#include "../../../devices/metax/metax_common.h"
#include "../../../devices/metax/metax_kernel_common.h"
#include "../cuda/kernel.cuh"
#include "paged_caching_metax.h"
template <typename Tdata, int NUM_THREADS>
INFINIOP_METAX_KERNEL pagedCaching(
Tdata *k_cache, Tdata *v_cache,
const Tdata *k, const Tdata *v,
const int64_t *slot_mapping,
const size_t head_size, const size_t block_size,
const ptrdiff_t k_src_stride, const ptrdiff_t v_src_stride,
const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride) {
op::paged_caching::cuda::pagedCachingKernel<Tdata, NUM_THREADS>(
k_cache, v_cache, k, v, slot_mapping, head_size,
block_size, k_src_stride, v_src_stride, k_cache_block_stride, v_cache_block_stride);
}
namespace op::paged_caching::metax {
// PIMPL struct definition
struct Descriptor::Opaque {
std::shared_ptr<device::metax::Handle::Internal> internal;
};
// Destructor implementation
Descriptor::~Descriptor() {
delete _opaque;
}
// Static factory method implementation
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t slot_mapping_desc) {
auto info = PagedCachingInfo::create(k_cache_desc, v_cache_desc, k_desc, v_desc, slot_mapping_desc);
CHECK_RESULT(info);
// Create and return the Descriptor instance.
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
info.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
// The launchKernel function is a templated helper to encapsulate the kernel launch.
// It sets up grid/block dimensions and calls the device-side kernel.
template <int NUM_THREADS>
infiniStatus_t launchKernel(const PagedCachingInfo &info,
void *k_cache, void *v_cache,
infiniDtype_t dtype,
const void *k, const void *v,
const void *slot_mapping,
size_t num_tokens, size_t num_kv_heads, size_t head_size, size_t block_size,
ptrdiff_t k_src_stride, ptrdiff_t v_src_stride,
ptrdiff_t k_cache_block_stride, ptrdiff_t v_cache_block_stride,
hcStream_t stream) {
// Grid dimension is 1D, with one block per token, as we decided.
dim3 grid(uint64_t(num_kv_heads), uint64_t(num_tokens), 1);
// Block dimension is 1D, using the number of threads specified at compile time.
dim3 block(NUM_THREADS);
// This kernel does not require dynamic shared memory.
size_t shared_mem_size = 0;
// Launch the device-side kernel.
if (dtype == INFINI_DTYPE_F16) {
pagedCaching<half, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
(half *)k_cache,
(half *)v_cache,
(const half *)k,
(const half *)v,
(const int64_t *)slot_mapping,
head_size,
block_size,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
} else if (dtype == INFINI_DTYPE_BF16) {
pagedCaching<cuda_bfloat16, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
(cuda_bfloat16 *)k_cache,
(cuda_bfloat16 *)v_cache,
(const cuda_bfloat16 *)k,
(const cuda_bfloat16 *)v,
(const int64_t *)slot_mapping,
head_size,
block_size,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
} else if (dtype == INFINI_DTYPE_F32) {
pagedCaching<float, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
(float *)k_cache,
(float *)v_cache,
(const float *)k,
(const float *)v,
(const int64_t *)slot_mapping,
head_size,
block_size,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
// Execution method implementation
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *k_cache, void *v_cache,
const void *k, const void *v,
const void *slot_mapping,
void *stream_) const {
hcStream_t stream = (hcStream_t)stream_;
// Dispatch logic based on the device's maximum threads per block.
// This allows selecting the largest, most efficient block size the hardware supports.
int max_threads = _opaque->internal->maxThreadsPerBlock();
if (max_threads >= METAX_BLOCK_SIZE_1024) {
// Dispatch based on data type for a 1024-thread block.
launchKernel<METAX_BLOCK_SIZE_1024>(
_info, k_cache, v_cache, _info.dtype, k, v, slot_mapping,
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride,
stream);
} else if (max_threads >= METAX_BLOCK_SIZE_512) {
launchKernel<METAX_BLOCK_SIZE_512>(
_info, k_cache, v_cache, _info.dtype, k, v, slot_mapping,
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride,
stream);
} else {
// If the device supports fewer threads, return an error.
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::paged_caching::metax
#ifndef __PAGED_CACHING_MOORE_H__
#define __PAGED_CACHING_MOORE_H__
#include "../paged_caching.h"
DESCRIPTOR(moore)
#endif // __PAGED_CACHING_MOORE_H__
#include "../../../devices/moore/moore_common.h"
#include "../../../devices/moore/moore_kernel_common.h"
#include "../cuda/kernel.cuh"
#include "paged_caching_moore.h"
template <typename Tdata, int NUM_THREADS>
INFINIOP_MOORE_KERNEL pagedCaching(
Tdata *k_cache, Tdata *v_cache,
const Tdata *k, const Tdata *v,
const int64_t *slot_mapping,
const size_t head_size, const size_t block_size,
const ptrdiff_t k_src_stride, const ptrdiff_t v_src_stride,
const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride) {
op::paged_caching::cuda::pagedCachingKernel<Tdata, NUM_THREADS>(
k_cache, v_cache, k, v, slot_mapping, head_size,
block_size, k_src_stride, v_src_stride, k_cache_block_stride, v_cache_block_stride);
}
namespace op::paged_caching::moore {
// PIMPL struct definition
struct Descriptor::Opaque {
std::shared_ptr<device::moore::Handle::Internal> internal;
};
// Destructor implementation
Descriptor::~Descriptor() {
delete _opaque;
}
// Static factory method implementation
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t slot_mapping_desc) {
auto info = PagedCachingInfo::create(k_cache_desc, v_cache_desc, k_desc, v_desc, slot_mapping_desc);
CHECK_RESULT(info);
// Create and return the Descriptor instance.
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::moore::Handle *>(handle)->internal()},
info.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
// The launchKernel function is a templated helper to encapsulate the MUSA kernel launch.
// It sets up grid/block dimensions and calls the device-side kernel.
template <int NUM_THREADS>
infiniStatus_t launchKernel(const PagedCachingInfo &info,
void *k_cache, void *v_cache,
infiniDtype_t dtype,
const void *k, const void *v,
const void *slot_mapping,
size_t num_tokens, size_t num_kv_heads, size_t head_size, size_t block_size,
ptrdiff_t k_src_stride, ptrdiff_t v_src_stride,
ptrdiff_t k_cache_block_stride, ptrdiff_t v_cache_block_stride,
musaStream_t stream) {
// Grid dimension is 1D, with one block per token, as we decided.
dim3 grid(uint64_t(num_kv_heads), uint64_t(num_tokens), 1);
// Block dimension is 1D, using the number of threads specified at compile time.
dim3 block(NUM_THREADS);
// This kernel does not require dynamic shared memory.
size_t shared_mem_size = 0;
// Launch the device-side MUSA kernel.
if (dtype == INFINI_DTYPE_F16) {
pagedCaching<half, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
(half *)k_cache,
(half *)v_cache,
(const half *)k,
(const half *)v,
(const int64_t *)slot_mapping,
head_size,
block_size,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
} else if (dtype == INFINI_DTYPE_BF16) {
pagedCaching<__mt_bfloat16, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
(__mt_bfloat16 *)k_cache,
(__mt_bfloat16 *)v_cache,
(const __mt_bfloat16 *)k,
(const __mt_bfloat16 *)v,
(const int64_t *)slot_mapping,
head_size,
block_size,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
} else if (dtype == INFINI_DTYPE_F32) {
pagedCaching<float, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
(float *)k_cache,
(float *)v_cache,
(const float *)k,
(const float *)v,
(const int64_t *)slot_mapping,
head_size,
block_size,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
// Execution method implementation
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *k_cache, void *v_cache,
const void *k, const void *v,
const void *slot_mapping,
void *stream_) const {
musaStream_t stream = (musaStream_t)stream_;
// Dispatch logic based on the GPU's maximum threads per block.
// This allows selecting the largest, most efficient block size the hardware supports.
if (_opaque->internal->maxThreadsPerBlock() >= MOORE_BLOCK_SIZE_1024) {
// Dispatch based on data type for a 1024-thread block.
launchKernel<MOORE_BLOCK_SIZE_1024>(
_info, k_cache, v_cache, _info.dtype, k, v, slot_mapping,
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride,
stream);
} else if (_opaque->internal->maxThreadsPerBlock() >= MOORE_BLOCK_SIZE_512) {
launchKernel<MOORE_BLOCK_SIZE_512>(
_info, k_cache, v_cache, _info.dtype, k, v, slot_mapping,
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride,
stream);
} else {
// If the GPU is older and supports fewer threads, return an error.
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::paged_caching::moore
......@@ -2,12 +2,15 @@
#include "../../handle.h"
#include "infiniop/ops/paged_caching.h"
#ifdef ENABLE_NVIDIA_API
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API)
#include "nvidia/paged_caching_nvidia.cuh"
#endif
// #ifdef ENABLE_METAX_API
// #include "metax/paged_caching_metax.h"
// #endif
#ifdef ENABLE_METAX_API
#include "metax/paged_caching_metax.h"
#endif
#ifdef ENABLE_MOORE_API
#include "moore/paged_caching_moore.h"
#endif
__C infiniStatus_t infiniopCreatePagedCachingDescriptor(
infiniopHandle_t handle,
......@@ -29,9 +32,18 @@ __C infiniStatus_t infiniopCreatePagedCachingDescriptor(
#ifdef ENABLE_NVIDIA_API
CREATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif
// #ifdef ENABLE_METAX_API
// CREATE(INFINI_DEVICE_METAX, metax)
// #endif
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax)
#endif
#ifdef ENABLE_ALI_API
CREATE(INFINI_DEVICE_ALI, nvidia)
#endif
#ifdef ENABLE_ILUVATAR_API
CREATE(INFINI_DEVICE_ILUVATAR, nvidia)
#endif
#ifdef ENABLE_MOORE_API
CREATE(INFINI_DEVICE_MOORE, moore)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
......@@ -50,9 +62,18 @@ __C infiniStatus_t infiniopGetPagedCachingWorkspaceSize(
#ifdef ENABLE_NVIDIA_API
GET(INFINI_DEVICE_NVIDIA, nvidia)
#endif
// #ifdef ENABLE_METAX_API
// GET(INFINI_DEVICE_METAX, metax)
// #endif
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax)
#endif
#ifdef ENABLE_ALI_API
GET(INFINI_DEVICE_ALI, nvidia)
#endif
#ifdef ENABLE_ILUVATAR_API
GET(INFINI_DEVICE_ILUVATAR, nvidia)
#endif
#ifdef ENABLE_MOORE_API
GET(INFINI_DEVICE_MOORE, moore)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
......@@ -75,9 +96,18 @@ __C infiniStatus_t infiniopPagedCaching(
#ifdef ENABLE_NVIDIA_API
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif
// #ifdef ENABLE_METAX_API
// CALCULATE(INFINI_DEVICE_METAX, metax)
// #endif
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax)
#endif
#ifdef ENABLE_ALI_API
CALCULATE(INFINI_DEVICE_ALI, nvidia)
#endif
#ifdef ENABLE_ILUVATAR_API
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia)
#endif
#ifdef ENABLE_MOORE_API
CALCULATE(INFINI_DEVICE_MOORE, moore)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
......@@ -95,9 +125,18 @@ __C infiniStatus_t infiniopDestroyPagedCachingDescriptor(
#ifdef ENABLE_NVIDIA_API
DESTROY(INFINI_DEVICE_NVIDIA, nvidia)
#endif
// #ifdef ENABLE_METAX_API
// DESTROY(INFINI_DEVICE_METAX, metax)
// #endif
#ifdef ENABLE_METAX_API
DESTROY(INFINI_DEVICE_METAX, metax)
#endif
#ifdef ENABLE_ALI_API
DESTROY(INFINI_DEVICE_ALI, nvidia)
#endif
#ifdef ENABLE_ILUVATAR_API
DESTROY(INFINI_DEVICE_ILUVATAR, nvidia)
#endif
#ifdef ENABLE_MOORE_API
DESTROY(INFINI_DEVICE_MOORE, moore)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
......
#ifndef __PERCHANNEL_QUANTINT8_KERNEL_CUH__
#define __PERCHANNEL_QUANTINT8_KERNEL_CUH__
#include <cub/block/block_reduce.cuh>
__device__ inline int round_half_away_from_zero(float x) {
float ax = fabsf(x);
float r = floorf(ax + 0.5f);
return (x >= 0.0f) ? (int)r : -(int)r;
}
template <typename Tdata, unsigned int BLOCK_SIZE>
__device__ void blockPerChannelQuantI8Kernel(
int8_t *x_packed, float *x_scale, float *x_zero, const Tdata *x,
int M, int K) {
int row = blockIdx.x;
int tid = row * K;
// ---- 1. reduce max ----
float local_max = op::common_cuda::reduce_op::max<BLOCK_SIZE, Tdata>(
x + tid, K);
__shared__ float global_max_f;
if (threadIdx.x == 0) {
global_max_f = local_max;
}
__syncthreads();
typedef cub::BlockReduce<float, BLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
// ---- 2. reduce min ----
float thread_min = __FLT_MAX__;
for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE) {
thread_min = fminf(thread_min, (float)x[tid + ind]);
}
#if CUDART_VERSION >= 12090
float local_min = BlockReduce(temp_storage).Reduce(thread_min, ::cuda::minimum());
#else
float local_min = BlockReduce(temp_storage).Reduce(thread_min, cub::Min());
#endif
__shared__ float global_min_f;
if (threadIdx.x == 0) {
global_min_f = local_min;
}
__syncthreads();
float global_max = global_max_f;
float global_min = global_min_f;
float scale = (global_max - global_min) / 255.0f;
if (scale < 1e-8f) {
scale = 1e-8f;
}
float inv_scale = 1.0f / scale;
float zero = -global_min * inv_scale - 128.0f;
x_scale[row] = (Tdata)scale;
x_zero[row] = (Tdata)zero;
for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE) {
float v = (float)x[tid + ind];
float qf = v * inv_scale + zero;
int q = round_half_away_from_zero(qf);
if (q > 127) {
q = 127;
}
if (q < -128) {
q = -128;
}
x_packed[tid + ind] = (int8_t)q;
}
}
template <typename Tdata, unsigned int BLOCK_SIZE>
__device__ void blockPerChannelQuantI8SymKernel(
int8_t *x_packed, float *x_scale, const Tdata *x,
int M, int K) {
int row = blockIdx.x;
int tid = row * K;
typedef cub::BlockReduce<float, BLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
// ---- 2. reduce min ----
float thread_max = -__FLT_MAX__;
for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE) {
thread_max = fmaxf(thread_max, fabs((float)x[tid + ind]));
}
#if CUDART_VERSION >= 12090
float local_max = BlockReduce(temp_storage).Reduce(thread_max, ::cuda::maximum());
#else
float local_max = BlockReduce(temp_storage).Reduce(thread_max, cub::Max());
#endif
__shared__ float global_max_f;
if (threadIdx.x == 0) {
global_max_f = local_max;
}
__syncthreads();
float global_max = global_max_f;
float scale = global_max / 127.0f;
if (scale < 1e-8f) {
scale = 1e-8f;
}
float inv_scale = 1.0f / scale;
x_scale[row] = (Tdata)scale;
for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE) {
float v = (float)x[tid + ind];
float qf = v * inv_scale;
int q = round_half_away_from_zero(qf);
if (q > 127) {
q = 127;
}
if (q < -127) {
q = -127;
}
x_packed[tid + ind] = (int8_t)q;
}
}
template <typename T>
struct MaxOp {
__device__ __forceinline__ T operator()(const T &a, const T &b) const {
return max(a, b);
}
};
template <typename T>
struct MinOp {
__device__ __forceinline__ T operator()(const T &a, const T &b) const {
return min(a, b);
}
};
template <template <typename> class ReductionOp, typename T,
int thread_group_width>
__inline__ __device__ T WarpAllReduce(T val) {
for (int mask = thread_group_width / 2; mask > 0; mask /= 2) {
val = ReductionOp<T>()(val, __shfl_xor_sync(0xffffffff, val, mask));
}
return val;
}
template <typename Tdata, unsigned int BLOCK_SIZE_x, unsigned int BLOCK_SIZE_y>
__device__ void warpPerChannelQuantI8Kernel(
int8_t *x_packed, float *x_scale, float *x_zero, const Tdata *x,
int M, int K) {
int otherIdx = blockIdx.x * blockDim.y + threadIdx.y;
int tid = otherIdx * K;
if (otherIdx < M) {
__shared__ float max_total[BLOCK_SIZE_y];
__shared__ float min_total[BLOCK_SIZE_y];
float max_data = -__FLT_MAX__;
float min_data = __FLT_MAX__;
// ---- reduce max/min ----
for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE_x) {
float v = (float)x[tid + ind];
max_data = fmaxf(max_data, v);
min_data = fminf(min_data, v);
}
max_data = WarpAllReduce<MaxOp, float, BLOCK_SIZE_x>(max_data);
min_data = WarpAllReduce<MinOp, float, BLOCK_SIZE_x>(min_data);
if (threadIdx.x == 0) {
max_total[threadIdx.y] = max_data;
min_total[threadIdx.y] = min_data;
}
__syncthreads();
float max_f = max_total[threadIdx.y];
float min_f = min_total[threadIdx.y];
float scale = (max_f - min_f) / 255.0f;
if (scale < 1e-8f) {
scale = 1e-8f;
}
float inv_scale = 1.0f / scale;
float zero = -min_f * inv_scale - 128.0f;
x_scale[otherIdx] = scale;
x_zero[otherIdx] = zero;
for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE_x) {
float v = (float)x[tid + ind];
float qf = v * inv_scale + zero;
int q = round_half_away_from_zero(qf);
if (q > 127) {
q = 127;
}
if (q < -128) {
q = -128;
}
x_packed[tid + ind] = (int8_t)q;
}
}
}
template <typename Tdata, unsigned int BLOCK_SIZE_x, unsigned int BLOCK_SIZE_y>
__device__ void warpPerChannelQuantI8SymKernel(
int8_t *x_packed, float *x_scale, const Tdata *x,
int M, int K) {
int otherIdx = blockIdx.x * blockDim.y + threadIdx.y;
int tid = otherIdx * K;
if (otherIdx < M) {
__shared__ float max_total[BLOCK_SIZE_y];
float max_data = -__FLT_MAX__;
// ---- reduce max/min ----
for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE_x) {
float v = fabs((float)x[tid + ind]);
max_data = fmaxf(max_data, v);
}
max_data = WarpAllReduce<MaxOp, float, BLOCK_SIZE_x>(max_data);
if (threadIdx.x == 0) {
max_total[threadIdx.y] = max_data;
}
__syncthreads();
float max_f = max_total[threadIdx.y];
float scale = max_f / 127.0f;
if (scale < 1e-8f) {
scale = 1e-8f;
}
float inv_scale = 1.0f / scale;
x_scale[otherIdx] = scale;
for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE_x) {
float v = (float)x[tid + ind];
float qf = v * inv_scale;
int q = round_half_away_from_zero(qf);
if (q > 127) {
q = 127;
}
if (q < -127) {
q = -127;
}
x_packed[tid + ind] = (int8_t)q;
}
}
}
#endif // __PERCHANNEL_QUANTINT8_KERNEL_CUH__
#ifndef __PER_CHANNEL_QUANT_INT8_INFO_H__
#define __PER_CHANNEL_QUANT_INT8_INFO_H__
#include "../../../../utils.h"
#include "../../../operator.h"
#include "../../../tensor.h"
namespace op::per_channel_quant_int8 {
class PerChannelQuantI8Info {
private:
PerChannelQuantI8Info() = default;
public:
infiniDtype_t dtype, packed_type;
size_t M, K;
static utils::Result<PerChannelQuantI8Info> createPerChannelQuantI8Info(
infiniopTensorDescriptor_t x_packed_desc,
infiniopTensorDescriptor_t x_scale_desc,
infiniopTensorDescriptor_t x_zero_desc,
infiniopTensorDescriptor_t x_desc) {
CHECK_OR_RETURN(
x_packed_desc != nullptr && x_scale_desc != nullptr && x_desc != nullptr,
INFINI_STATUS_NULL_POINTER);
const infiniDtype_t dtype = x_desc->dtype();
const infiniDtype_t packed_type = x_packed_desc->dtype();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32);
CHECK_DTYPE(packed_type, INFINI_DTYPE_I8);
CHECK_OR_RETURN(x_desc->ndim() == 2
&& x_packed_desc->ndim() == 2
&& x_scale_desc->ndim() == 2,
INFINI_STATUS_BAD_TENSOR_SHAPE);
size_t M = x_desc->dim(0);
size_t K = x_desc->dim(1);
CHECK_OR_RETURN(M == x_packed_desc->dim(0)
|| K == x_packed_desc->dim(1)
|| M == x_scale_desc->dim(0)
|| 1 == x_scale_desc->dim(1),
INFINI_STATUS_BAD_TENSOR_SHAPE);
return utils::Result<PerChannelQuantI8Info>(PerChannelQuantI8Info{
dtype,
packed_type,
M,
K,
});
}
};
} // namespace op::per_channel_quant_int8
#endif // __PER_CHANNEL_QUANT_INT8_INFO_H__
#ifndef __PER_CHANNEL_QUANT_INT8_MOORE_API_H__
#define __PER_CHANNEL_QUANT_INT8_MOORE_API_H__
#include "../per_channel_quant_int8.h"
DESCRIPTOR(moore)
#endif // __PER_CHANNEL_QUANT_INT8_MOORE_API_H__
#include "../../../../devices/moore/moore_common.h"
#include "per_channel_quant_int8_moore.h"
#include "../../../../devices/moore/moore_kernel_common.h"
#include "../../../../reduce/cuda/reduce.cuh"
#include <cub/block/block_reduce.cuh>
#include "../cuda/kernel.cuh"
template <typename Tdata, unsigned int BLOCK_SIZE>
INFINIOP_MOORE_KERNEL blockPerChannelQuantI8(
int8_t *x_packed, float *x_scale, float *x_zero, const Tdata *x, int M, int K) {
blockPerChannelQuantI8Kernel<Tdata, BLOCK_SIZE>(x_packed, x_scale, x_zero, x, M, K);
}
template <typename Tdata, unsigned int BLOCK_SIZE>
INFINIOP_MOORE_KERNEL blockPerChannelQuantI8Sym(
int8_t *x_packed, float *x_scale, const Tdata *x, int M, int K) {
blockPerChannelQuantI8SymKernel<Tdata, BLOCK_SIZE>(x_packed, x_scale, x, M, K);
}
template <typename Tdata, unsigned int BLOCK_SIZE_x, unsigned int BLOCK_SIZE_y>
INFINIOP_MOORE_KERNEL warpPerChannelQuantI8(
int8_t *x_packed, float *x_scale, float *x_zero, const Tdata *x, int M, int K) {
warpPerChannelQuantI8Kernel<Tdata, BLOCK_SIZE_x, BLOCK_SIZE_y>(x_packed, x_scale, x_zero, x, M, K);
}
template <typename Tdata, unsigned int BLOCK_SIZE_x, unsigned int BLOCK_SIZE_y>
INFINIOP_MOORE_KERNEL warpPerChannelQuantI8Sym(
int8_t *x_packed, float *x_scale, const Tdata *x, int M, int K) {
warpPerChannelQuantI8SymKernel<Tdata, BLOCK_SIZE_x, BLOCK_SIZE_y>(x_packed, x_scale, x, M, K);
}
namespace op::per_channel_quant_int8::moore {
struct Descriptor::Opaque {
std::shared_ptr<device::moore::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle, Descriptor **desc_ptr,
infiniopTensorDescriptor_t x_packed_desc,
infiniopTensorDescriptor_t x_scale_desc,
infiniopTensorDescriptor_t x_zero_desc,
infiniopTensorDescriptor_t x_desc) {
auto info = PerChannelQuantI8Info::createPerChannelQuantI8Info(x_packed_desc, x_scale_desc, x_zero_desc, x_desc);
CHECK_RESULT(info);
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::moore::Handle *>(handle)->internal()},
info.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <unsigned int BLOCK_SIZE, typename Tdata>
infiniStatus_t per_channel_quant_int8Kernel(const PerChannelQuantI8Info &info, int8_t *x_packed, float *x_scale, float *x_zero, const Tdata *x, musaStream_t stream) {
int M = (int)info.M;
int K = (int)info.K;
if (K >= 1024) {
if (x_zero == nullptr) {
blockPerChannelQuantI8Sym<Tdata, BLOCK_SIZE>
<<<M, BLOCK_SIZE, 0, stream>>>(x_packed, x_scale, x, M, K);
} else {
blockPerChannelQuantI8<Tdata, BLOCK_SIZE>
<<<M, BLOCK_SIZE, 0, stream>>>(x_packed, x_scale, x_zero, x, M, K);
}
} else {
constexpr unsigned int BLOCK_SIZE_x = 32;
constexpr unsigned int BLOCK_SIZE_y = 32;
int num_block_x = (M + BLOCK_SIZE_y - 1) / BLOCK_SIZE_y;
dim3 block_dim(BLOCK_SIZE_x, BLOCK_SIZE_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
if (x_zero == nullptr) {
warpPerChannelQuantI8Sym<Tdata, BLOCK_SIZE_x, BLOCK_SIZE_y>
<<<grid_dim, block_dim, 0, stream>>>(x_packed, x_scale, x, M, K);
} else {
warpPerChannelQuantI8<Tdata, BLOCK_SIZE_x, BLOCK_SIZE_y>
<<<grid_dim, block_dim, 0, stream>>>(x_packed, x_scale, x_zero, x, M, K);
}
}
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
void *x_packed, void *x_scale, void *x_zero, const void *x,
void *stream_) const {
musaStream_t stream = (musaStream_t)stream_;
#define QUANT(BLOCK_SIZE, TDATA) \
per_channel_quant_int8Kernel<BLOCK_SIZE, TDATA>(_info, (int8_t *)x_packed, (float *)x_scale, (float *)x_zero, (const TDATA *)x, stream)
#define QUANT_WITH_BLOCK_SIZE(BLOCK_SIZE) \
{ \
if (_info.dtype == INFINI_DTYPE_F16) \
return QUANT(BLOCK_SIZE, half); \
else if (_info.dtype == INFINI_DTYPE_F32) \
return QUANT(BLOCK_SIZE, float); \
else if (_info.dtype == INFINI_DTYPE_BF16) \
return QUANT(BLOCK_SIZE, __mt_bfloat16); \
else \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
}
if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_1024) {
QUANT_WITH_BLOCK_SIZE(MOORE_BLOCK_SIZE_1024)
} else if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_512) {
QUANT_WITH_BLOCK_SIZE(MOORE_BLOCK_SIZE_512)
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::per_channel_quant_int8::moore
#include "../../../../devices/nvidia/nvidia_common.cuh"
#include "per_channel_quant_int8_nvidia.cuh"
#include "../../../../devices/nvidia/nvidia_kernel_common.cuh"
#include "../../../../reduce/cuda/reduce.cuh"
#include <cub/block/block_reduce.cuh>
#include "../cuda/kernel.cuh"
template <typename Tdata, unsigned int BLOCK_SIZE>
INFINIOP_CUDA_KERNEL blockPerChannelQuantI8(
int8_t *x_packed, float *x_scale, float *x_zero, const Tdata *x, int M, int K) {
blockPerChannelQuantI8Kernel<Tdata, BLOCK_SIZE>(x_packed, x_scale, x_zero, x, M, K);
}
template <typename Tdata, unsigned int BLOCK_SIZE>
INFINIOP_CUDA_KERNEL blockPerChannelQuantI8Sym(
int8_t *x_packed, float *x_scale, const Tdata *x, int M, int K) {
blockPerChannelQuantI8SymKernel<Tdata, BLOCK_SIZE>(x_packed, x_scale, x, M, K);
}
template <typename Tdata, unsigned int BLOCK_SIZE_x, unsigned int BLOCK_SIZE_y>
INFINIOP_CUDA_KERNEL warpPerChannelQuantI8(
int8_t *x_packed, float *x_scale, float *x_zero, const Tdata *x, int M, int K) {
warpPerChannelQuantI8Kernel<Tdata, BLOCK_SIZE_x, BLOCK_SIZE_y>(x_packed, x_scale, x_zero, x, M, K);
}
template <typename Tdata, unsigned int BLOCK_SIZE_x, unsigned int BLOCK_SIZE_y>
INFINIOP_CUDA_KERNEL warpPerChannelQuantI8Sym(
int8_t *x_packed, float *x_scale, const Tdata *x, int M, int K) {
warpPerChannelQuantI8SymKernel<Tdata, BLOCK_SIZE_x, BLOCK_SIZE_y>(x_packed, x_scale, x, M, K);
}
namespace op::per_channel_quant_int8::nvidia {
struct Descriptor::Opaque {
std::shared_ptr<device::nvidia::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle, Descriptor **desc_ptr,
infiniopTensorDescriptor_t x_packed_desc,
infiniopTensorDescriptor_t x_scale_desc,
infiniopTensorDescriptor_t x_zero_desc,
infiniopTensorDescriptor_t x_desc) {
auto info = PerChannelQuantI8Info::createPerChannelQuantI8Info(x_packed_desc, x_scale_desc, x_zero_desc, x_desc);
CHECK_RESULT(info);
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::nvidia::Handle *>(handle)->internal()},
info.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <unsigned int BLOCK_SIZE, typename Tdata>
infiniStatus_t per_channel_quant_int8Kernel(const PerChannelQuantI8Info &info, int8_t *x_packed, float *x_scale, float *x_zero, const Tdata *x, cudaStream_t stream) {
int M = (int)info.M;
int K = (int)info.K;
if (K >= 1024) {
if (x_zero == nullptr) {
blockPerChannelQuantI8Sym<Tdata, BLOCK_SIZE>
<<<M, BLOCK_SIZE, 0, stream>>>(x_packed, x_scale, x, M, K);
} else {
blockPerChannelQuantI8<Tdata, BLOCK_SIZE>
<<<M, BLOCK_SIZE, 0, stream>>>(x_packed, x_scale, x_zero, x, M, K);
}
} else {
constexpr unsigned int BLOCK_SIZE_x = 32;
constexpr unsigned int BLOCK_SIZE_y = 32;
int num_block_x = (M + BLOCK_SIZE_y - 1) / BLOCK_SIZE_y;
dim3 block_dim(BLOCK_SIZE_x, BLOCK_SIZE_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
if (x_zero == nullptr) {
warpPerChannelQuantI8Sym<Tdata, BLOCK_SIZE_x, BLOCK_SIZE_y>
<<<grid_dim, block_dim, 0, stream>>>(x_packed, x_scale, x, M, K);
} else {
warpPerChannelQuantI8<Tdata, BLOCK_SIZE_x, BLOCK_SIZE_y>
<<<grid_dim, block_dim, 0, stream>>>(x_packed, x_scale, x_zero, x, M, K);
}
}
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
void *x_packed, void *x_scale, void *x_zero, const void *x,
void *stream_) const {
cudaStream_t stream = (cudaStream_t)stream_;
#define QUANT(BLOCK_SIZE, TDATA) \
per_channel_quant_int8Kernel<BLOCK_SIZE, TDATA>(_info, (int8_t *)x_packed, (float *)x_scale, (float *)x_zero, (const TDATA *)x, stream)
#define QUANT_WITH_BLOCK_SIZE(BLOCK_SIZE) \
{ \
if (_info.dtype == INFINI_DTYPE_F16) \
return QUANT(BLOCK_SIZE, half); \
else if (_info.dtype == INFINI_DTYPE_F32) \
return QUANT(BLOCK_SIZE, float); \
else if (_info.dtype == INFINI_DTYPE_BF16) \
return QUANT(BLOCK_SIZE, __nv_bfloat16); \
else \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
}
if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) {
QUANT_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_1024)
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_512) {
QUANT_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_512)
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) {
QUANT_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_4096)
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::per_channel_quant_int8::nvidia
#ifndef __PER_CHANNEL_QUANT_INT8_NVIDIA_API_H__
#define __PER_CHANNEL_QUANT_INT8_NVIDIA_API_H__
#include "../per_channel_quant_int8.h"
DESCRIPTOR(nvidia)
#endif // __PER_CHANNEL_QUANT_INT8_NVIDIA_API_H__
#include "../../../operator.h"
#include "../../../handle.h"
#include "infiniop/ops/quant/per_channel_quant_int8.h"
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API)
#include "nvidia/per_channel_quant_int8_nvidia.cuh"
#endif
#if defined(ENABLE_MOORE_API)
#include "moore/per_channel_quant_int8_moore.h"
#endif
__C infiniStatus_t infiniopCreatePerChannelQuantI8Descriptor(infiniopHandle_t handle,
infiniopPerChannelQuantI8Descriptor_t *desc_ptr,
infiniopTensorDescriptor_t x_packed_desc,
infiniopTensorDescriptor_t x_scale_desc,
infiniopTensorDescriptor_t x_zero_desc,
infiniopTensorDescriptor_t x_desc) {
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::per_channel_quant_int8::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::per_channel_quant_int8::NAMESPACE::Descriptor **>(desc_ptr), \
x_packed_desc, \
x_scale_desc, \
x_zero_desc, \
x_desc);
switch (handle->device) {
#ifdef ENABLE_NVIDIA_API
CREATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia)
#endif
#ifdef ENABLE_MOORE_API
CREATE(INFINI_DEVICE_MOORE, moore)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CREATE
}
__C infiniStatus_t infiniopGetPerChannelQuantI8WorkspaceSize(infiniopPerChannelQuantI8Descriptor_t desc, size_t *size) {
switch (desc->device_type) {
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<op::per_channel_quant_int8::NAMESPACE::Descriptor *>(desc)->minWorkspaceSize(); \
return INFINI_STATUS_SUCCESS;
#ifdef ENABLE_NVIDIA_API
GET(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_QY_API
GET(INFINI_DEVICE_QY, nvidia)
#endif
#ifdef ENABLE_MOORE_API
GET(INFINI_DEVICE_MOORE, moore)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef GET
}
__C infiniStatus_t infiniopPerChannelQuantI8(infiniopPerChannelQuantI8Descriptor_t desc,
void *workspace,
size_t workspace_size,
void *x_packed,
void *x_scale,
void *x_zero,
const void *x,
void *stream) {
#define QUANT(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<op::per_channel_quant_int8::NAMESPACE::Descriptor *>(desc)->calculate( \
workspace, workspace_size, x_packed, x_scale, x_zero, x, stream);
switch (desc->device_type) {
#ifdef ENABLE_NVIDIA_API
QUANT(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_QY_API
QUANT(INFINI_DEVICE_QY, nvidia)
#endif
#ifdef ENABLE_MOORE_API
QUANT(INFINI_DEVICE_MOORE, moore)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef QUANT
}
__C infiniStatus_t infiniopDestroyPerChannelQuantI8Descriptor(infiniopPerChannelQuantI8Descriptor_t desc) {
#define DESTROY(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<op::per_channel_quant_int8::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) {
#ifdef ENABLE_NVIDIA_API
DESTROY(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_QY_API
DESTROY(INFINI_DEVICE_QY, nvidia)
#endif
#ifdef ENABLE_MOORE_API
DESTROY(INFINI_DEVICE_MOORE, moore)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef DESTROY
}
#ifndef __QUANT_H__
#define __QUANT_H__
#include "../../../operator.h"
#include "info.h"
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::per_channel_quant_int8::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
PerChannelQuantI8Info _info; \
size_t _workspace_size; \
\
Descriptor(Opaque *opaque, PerChannelQuantI8Info info, \
size_t workspace_size, \
infiniDevice_t device_type, int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), _info(info), _workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t minWorkspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, Descriptor **desc_ptr, \
infiniopTensorDescriptor_t x_packed_desc, \
infiniopTensorDescriptor_t x_scale_desc, \
infiniopTensorDescriptor_t x_zero_desc, \
infiniopTensorDescriptor_t x_desc); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *x_packed, void *x_scale, void *x_zero, const void *x, void *stream) const; \
}; \
}
#endif // __QUANT_H__
......@@ -534,13 +534,13 @@ struct Algo {
if constexpr (std::is_same<Tval_, float>::value) {
auto logits = reinterpret_cast<const float *>(probs);
argMax<<<dim, CNRT_FUNC_TYPE_BLOCK, queue>>>(logits, result, gdram_indices, voc);
argMax<<<dim, cnrtFuncTypeBlock, queue>>>(logits, result, gdram_indices, voc);
} else if constexpr (std::is_same<Tval_, CustomFloat16>::value) {
auto logits = reinterpret_cast<const half *>(probs);
argMax<<<dim, CNRT_FUNC_TYPE_BLOCK, queue>>>(logits, result, gdram_indices, voc);
argMax<<<dim, cnrtFuncTypeBlock, queue>>>(logits, result, gdram_indices, voc);
} else if constexpr (std::is_same<Tval_, CustomBFloat16>::value) {
auto logits = reinterpret_cast<const bfloat16_t *>(probs);
argMax<<<dim, CNRT_FUNC_TYPE_BLOCK, queue>>>(logits, result, gdram_indices, voc);
argMax<<<dim, cnrtFuncTypeBlock, queue>>>(logits, result, gdram_indices, voc);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
......@@ -575,10 +575,10 @@ struct Algo {
const int max_num = SRC_MAX_SIZE / sizeof(float);
if (voc >= task_num * max_num) {
randomSampleKernelLarge<<<dim, CNRT_FUNC_TYPE_UNION1, queue>>>(
randomSampleKernelLarge<<<dim, cnrtFuncTypeUnion1, queue>>>(
logits, result, gdram_indices, global_top_k, global_sum, voc, random_val, topp, topk, temperature);
} else {
randomSampleKernel<<<dim, CNRT_FUNC_TYPE_UNION1, queue>>>(
randomSampleKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
logits, result, gdram_indices, global_top_k, global_sum, voc, random_val, topp, topk, temperature);
}
} else if constexpr (std::is_same<Tval_, CustomFloat16>::value) {
......@@ -592,10 +592,10 @@ struct Algo {
const int max_num = SRC_MAX_SIZE / sizeof(half);
if (voc >= task_num * max_num) {
randomSampleKernelLarge<<<dim, CNRT_FUNC_TYPE_UNION1, queue>>>(
randomSampleKernelLarge<<<dim, cnrtFuncTypeUnion1, queue>>>(
logits, result, gdram_indices, global_top_k, global_sum, voc, random_val, topp, topk, temperature);
} else {
randomSampleKernel<<<dim, CNRT_FUNC_TYPE_UNION1, queue>>>(
randomSampleKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
logits, result, gdram_indices, global_top_k, global_sum, voc, random_val, topp, topk, temperature);
}
} else if constexpr (std::is_same<Tval_, CustomBFloat16>::value) {
......@@ -609,10 +609,10 @@ struct Algo {
const int max_num = SRC_MAX_SIZE / sizeof(bfloat16_t);
if (voc >= task_num * max_num) {
randomSampleKernelLarge<<<dim, CNRT_FUNC_TYPE_UNION1, queue>>>(
randomSampleKernelLarge<<<dim, cnrtFuncTypeUnion1, queue>>>(
logits, result, gdram_indices, global_top_k, global_sum, voc, random_val, topp, topk, temperature);
} else {
randomSampleKernel<<<dim, CNRT_FUNC_TYPE_UNION1, queue>>>(
randomSampleKernel<<<dim, cnrtFuncTypeUnion1, queue>>>(
logits, result, gdram_indices, global_top_k, global_sum, voc, random_val, topp, topk, temperature);
}
} else {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment