Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
8d09630a
Unverified
Commit
8d09630a
authored
Feb 11, 2026
by
gongchensu
Committed by
GitHub
Feb 11, 2026
Browse files
Merge branch 'demo131' into Issue/862
parents
ab52dead
012df56c
Changes
387
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4782 additions
and
0 deletions
+4782
-0
src/infiniop/ops/paged_attention_prefill/metax/paged_attention_prefill_metax.maca
...ttention_prefill/metax/paged_attention_prefill_metax.maca
+1554
-0
src/infiniop/ops/paged_attention_prefill/moore/paged_attention_prefill_kernel.h
..._attention_prefill/moore/paged_attention_prefill_kernel.h
+132
-0
src/infiniop/ops/paged_attention_prefill/moore/paged_attention_prefill_moore.h
...d_attention_prefill/moore/paged_attention_prefill_moore.h
+8
-0
src/infiniop/ops/paged_attention_prefill/moore/paged_attention_prefill_moore.mu
..._attention_prefill/moore/paged_attention_prefill_moore.mu
+126
-0
src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu
...ttention_prefill/nvidia/paged_attention_prefill_nvidia.cu
+1550
-0
src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cuh
...tention_prefill/nvidia/paged_attention_prefill_nvidia.cuh
+8
-0
src/infiniop/ops/paged_attention_prefill/operator.cc
src/infiniop/ops/paged_attention_prefill/operator.cc
+153
-0
src/infiniop/ops/paged_attention_prefill/paged_attention_prefill.h
...iop/ops/paged_attention_prefill/paged_attention_prefill.h
+56
-0
src/infiniop/ops/paged_caching/cuda/kernel.cuh
src/infiniop/ops/paged_caching/cuda/kernel.cuh
+88
-0
src/infiniop/ops/paged_caching/info.h
src/infiniop/ops/paged_caching/info.h
+82
-0
src/infiniop/ops/paged_caching/metax/paged_caching_metax.h
src/infiniop/ops/paged_caching/metax/paged_caching_metax.h
+8
-0
src/infiniop/ops/paged_caching/metax/paged_caching_metax.maca
...infiniop/ops/paged_caching/metax/paged_caching_metax.maca
+157
-0
src/infiniop/ops/paged_caching/moore/paged_caching_moore.h
src/infiniop/ops/paged_caching/moore/paged_caching_moore.h
+8
-0
src/infiniop/ops/paged_caching/moore/paged_caching_moore.mu
src/infiniop/ops/paged_caching/moore/paged_caching_moore.mu
+156
-0
src/infiniop/ops/paged_caching/nvidia/paged_caching_nvidia.cu
...infiniop/ops/paged_caching/nvidia/paged_caching_nvidia.cu
+163
-0
src/infiniop/ops/paged_caching/nvidia/paged_caching_nvidia.cuh
...nfiniop/ops/paged_caching/nvidia/paged_caching_nvidia.cuh
+8
-0
src/infiniop/ops/paged_caching/operator.cc
src/infiniop/ops/paged_caching/operator.cc
+143
-0
src/infiniop/ops/paged_caching/paged_caching.h
src/infiniop/ops/paged_caching/paged_caching.h
+50
-0
src/infiniop/ops/quant/per_channel_quant_int8/cuda/kernel.cuh
...infiniop/ops/quant/per_channel_quant_int8/cuda/kernel.cuh
+273
-0
src/infiniop/ops/quant/per_channel_quant_int8/info.h
src/infiniop/ops/quant/per_channel_quant_int8/info.h
+59
-0
No files found.
src/infiniop/ops/paged_attention_prefill/metax/paged_attention_prefill_metax.maca
0 → 100644
View file @
8d09630a
#ifdef ENABLE_METAX_MC_API
#include <mc_runtime.h>
#else
#include <hc_runtime.h>
#endif
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include "../../../devices/metax/metax_common.h"
#include "../../../devices/metax/metax_kernel_common.h"
// #include "paged_attention_prefill_fa2.cuh"
#include "paged_attention_prefill_metax.h"
#include "../cuda/kernel_v2.cuh"
namespace op::paged_attention_prefill::metax {
namespace {
constexpr size_t ceilDiv(size_t a, size_t b) {
return (a + b - 1) / b;
}
inline const char *default_prefill_kernel(const PagedAttentionPrefillInfo &info) {
// Heuristic auto-dispatch (v0.4):
// - Prefer the pipelined + tile-wise softmax kernel on FA2-compatible block_size=256.
// - Keep a conservative fallback for other shapes / older GPUs (cp.async is a no-op below SM80).
//
// Users can always override via INFINIOP_FLASH_PREFILL_KERNEL.
if (info.page_block_size == 256 && (info.dtype == INFINI_DTYPE_F16 || info.dtype == INFINI_DTYPE_BF16)) {
if (info.head_size == 128) {
return "warpcta8pipe";
}
// For head_size=64 we keep the previous default until we have broader perf coverage.
}
return "warpcta8";
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128Warp(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// Legacy per-seq launch (kept only as a wrapper; current "warp" impl uses a global-token kernel).
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpKernel<Tindex, Tdata, 128>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, block_table_batch_stride,
q_stride, q_head_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64Warp(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// Legacy per-seq launch (kept only as a wrapper; current "warp" impl uses a global-token kernel).
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpKernel<Tindex, Tdata, 64>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, block_table_batch_stride,
q_stride, q_head_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 4 warps per CTA, one warp per query token.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel<Tindex, Tdata, 128, 4, 64>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64WarpCta(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 4 warps per CTA, one warp per query token.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel<Tindex, Tdata, 64, 4, 128>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta8(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 8 warps per CTA, one warp per query token.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel<Tindex, Tdata, 128, 8, 64>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta8N128(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 8 warps per CTA, one warp per query token, tile_n=128 for fewer K stages.
// Note: we keep K in shared memory but load V from global to stay within the per-block shared limit.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernelKOnly<Tindex, Tdata, 128, 8, 128>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64WarpCta8(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 8 warps per CTA, one warp per query token.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel<Tindex, Tdata, 64, 8, 128>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta8Pipe(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 8 warps per CTA, one warp per query token, with cp.async pipelining.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernelPipelined<Tindex, Tdata, 128, 8, 32, 2>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta8Mma(
half *out,
const half *q,
const half *k_cache,
const half *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCta8MmaHd128Kernel<Tindex>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64WarpCta8Pipe(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 8 warps per CTA, one warp per query token, with cp.async pipelining.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernelPipelined<Tindex, Tdata, 64, 8, 32, 2>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta8PipeSplitKv(
float *partial_acc,
float *partial_m,
float *partial_l,
int num_splits,
size_t total_q_tokens,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride) {
// Encode (split_idx, m_block) into blockIdx.z to allow a single kernel launch:
// blockIdx.z in [0, num_splits * num_m_blocks).
const int num_m_blocks = static_cast<int>((total_q_tokens + 8 - 1) / 8);
const int bz = static_cast<int>(blockIdx.z);
const int split_idx = bz / num_m_blocks;
const int m_block = bz - split_idx * num_m_blocks;
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernelPipelinedSplitKv<Tindex, Tdata, 128, 8, 32, 2>(
partial_acc, partial_m, partial_l, split_idx, num_splits, m_block, total_q_tokens,
q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64WarpCta8PipeSplitKv(
float *partial_acc,
float *partial_m,
float *partial_l,
int num_splits,
size_t total_q_tokens,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride) {
const int num_m_blocks = static_cast<int>((total_q_tokens + 8 - 1) / 8);
const int bz = static_cast<int>(blockIdx.z);
const int split_idx = bz / num_m_blocks;
const int m_block = bz - split_idx * num_m_blocks;
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernelPipelinedSplitKv<Tindex, Tdata, 64, 8, 32, 2>(
partial_acc, partial_m, partial_l, split_idx, num_splits, m_block, total_q_tokens,
q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride);
}
template <typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128SplitKvCombine(
Tdata *out,
const float *partial_acc,
const float *partial_m,
const float *partial_l,
int num_splits,
size_t total_q_tokens,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
op::paged_attention_prefill::cuda::PagedAttentionPrefillSplitKvCombineWarpKernel<Tdata, 128>(
out, partial_acc, partial_m, partial_l, num_splits, total_q_tokens, o_stride, o_head_stride);
}
template <typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64SplitKvCombine(
Tdata *out,
const float *partial_acc,
const float *partial_m,
const float *partial_l,
int num_splits,
size_t total_q_tokens,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
op::paged_attention_prefill::cuda::PagedAttentionPrefillSplitKvCombineWarpKernel<Tdata, 64>(
out, partial_acc, partial_m, partial_l, num_splits, total_q_tokens, o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd128WarpCta16(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 16 warps per CTA, one warp per query token.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel<Tindex, Tdata, 128, 16, 64>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_METAX_KERNEL PagedAttentionPrefillHd64WarpCta16(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 16 warps per CTA, one warp per query token.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel<Tindex, Tdata, 64, 16, 128>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata, typename Tcompute>
infiniStatus_t launch_prefill_ref(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
hcStream_t stream) {
const dim3 grid(static_cast<uint32_t>(total_q_tokens), static_cast<uint32_t>(num_heads), 1);
const dim3 block(static_cast<uint32_t>(head_size), 1, 1);
if (head_size == 64) {
op::paged_attention_prefill::cuda::PagedAttentionPrefillReferenceKernel<Tindex, Tdata, Tcompute, 64>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_heads, num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride, q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride, num_seqs);
return INFINI_STATUS_SUCCESS;
}
if (head_size == 128) {
op::paged_attention_prefill::cuda::PagedAttentionPrefillReferenceKernel<Tindex, Tdata, Tcompute, 128>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_heads, num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride, q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride, num_seqs);
return INFINI_STATUS_SUCCESS;
}
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill_warp(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
hcStream_t stream) {
const dim3 block(32, 1, 1);
// Global-token launch:
// - dramatically reduces grid size vs the legacy (num_seqs * total_q_tokens) launch
// - matches PagedAttention varlen (cu_seqlens) mental model better
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(total_q_tokens),
1);
switch (head_size) {
case 64:
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpGlobalKernel<Tindex, Tdata, 64>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_heads, num_seqs, num_kv_heads, total_q_tokens, scale, max_num_blocks_per_seq,
page_block_size, block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
case 128:
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpGlobalKernel<Tindex, Tdata, 128>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_heads, num_seqs, num_kv_heads, total_q_tokens, scale, max_num_blocks_per_seq,
page_block_size, block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
hcStream_t stream) {
constexpr int kWarps = 4;
constexpr int kThreads = kWarps * 32;
const dim3 block(kThreads);
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(num_seqs),
static_cast<uint32_t>(ceilDiv(total_q_tokens, static_cast<size_t>(kWarps))));
switch (head_size) {
case 64:
PagedAttentionPrefillHd64WarpCta<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
case 128:
PagedAttentionPrefillHd128WarpCta<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill_warpcta8(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
hcStream_t stream) {
constexpr int kWarps = 8;
constexpr int kThreads = kWarps * 32;
const dim3 block(kThreads);
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(num_seqs),
static_cast<uint32_t>(ceilDiv(total_q_tokens, static_cast<size_t>(kWarps))));
switch (head_size) {
case 64:
PagedAttentionPrefillHd64WarpCta8<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
case 128:
PagedAttentionPrefillHd128WarpCta8<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill_warpcta8pipe(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
hcStream_t stream) {
constexpr int kWarps = 8;
constexpr int kThreads = kWarps * 32;
const dim3 block(kThreads);
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(num_seqs),
static_cast<uint32_t>(ceilDiv(total_q_tokens, static_cast<size_t>(kWarps))));
switch (head_size) {
case 64:
PagedAttentionPrefillHd64WarpCta8Pipe<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
case 128:
PagedAttentionPrefillHd128WarpCta8Pipe<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill_warpcta8mma(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
hcStream_t stream) {
// Current WMMA kernel only supports fp16 + head_dim=128.
if constexpr (!std::is_same_v<Tdata, half>) {
return launch_prefill_warpcta8pipe<Tindex, Tdata>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_heads, num_seqs, num_kv_heads, total_q_tokens, head_size, scale,
max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride, stream);
}
if (head_size != 128) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
// Guardrail: the current WMMA-score kernel is correctness-first and can be extremely slow on long prompts.
// Allow power users to force it via INFINIOP_FLASH_PREFILL_MMA_FORCE=1.
const char *force_env = std::getenv("INFINIOP_FLASH_PREFILL_MMA_FORCE");
const bool force_mma = (force_env != nullptr) && (std::strcmp(force_env, "1") == 0);
const size_t seqlen_k_est = max_num_blocks_per_seq * page_block_size;
if (!force_mma && seqlen_k_est > 4096) {
static bool warned = false;
if (!warned) {
std::fprintf(stderr,
"[infiniop][paged_attention_prefill] warpcta8mma is experimental and very slow for long seqlen_k (est=%zu). "
"Falling back to warpcta8pipe. Set INFINIOP_FLASH_PREFILL_MMA_FORCE=1 to override.\n",
seqlen_k_est);
warned = true;
}
return launch_prefill_warpcta8pipe<Tindex, Tdata>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_heads, num_seqs, num_kv_heads, total_q_tokens, head_size, scale,
max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride, stream);
}
// WMMA requires SM70+. If not supported (or if we can't query), fall back to the pipelined SIMT kernel.
int device = 0;
hcDeviceProp_t prop{};
if (hcGetDevice(&device) == hcSuccess && hcGetDeviceProperties(&prop, device) == hcSuccess) {
if (prop.major < 7) {
return launch_prefill_warpcta8pipe<Tindex, Tdata>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_heads, num_seqs, num_kv_heads, total_q_tokens, head_size, scale,
max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride, stream);
}
}
constexpr int kWarps = 8;
constexpr int kThreads = kWarps * 32;
const dim3 block(kThreads);
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(num_seqs),
static_cast<uint32_t>(ceilDiv(total_q_tokens, static_cast<size_t>(16))));
PagedAttentionPrefillHd128WarpCta8Mma<Tindex>
<<<grid, block, 0, stream>>>(
static_cast<half *>(out),
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill_warpcta8pipe_splitkv(
float *partial_acc,
float *partial_m,
float *partial_l,
int num_splits,
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
hcStream_t stream) {
constexpr int kMaxSplits = 8;
if (num_splits < 1) {
num_splits = 1;
}
if (num_splits > kMaxSplits) {
num_splits = kMaxSplits;
}
constexpr int kWarps = 8;
constexpr int kThreads = kWarps * 32;
const dim3 block(kThreads);
const size_t num_m_blocks = ceilDiv(total_q_tokens, static_cast<size_t>(kWarps));
// Single kernel launch with split_idx encoded in grid.z:
// blockIdx.z in [0, num_splits * num_m_blocks).
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(num_seqs),
static_cast<uint32_t>(num_m_blocks * static_cast<size_t>(num_splits)));
switch (head_size) {
case 64:
PagedAttentionPrefillHd64WarpCta8PipeSplitKv<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
partial_acc, partial_m, partial_l, num_splits, total_q_tokens,
q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride);
break;
case 128:
PagedAttentionPrefillHd128WarpCta8PipeSplitKv<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
partial_acc, partial_m, partial_l, num_splits, total_q_tokens,
q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride);
break;
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
// Combine: one warp per (token, head).
const dim3 block2(32);
const dim3 grid2(static_cast<uint32_t>(num_heads), static_cast<uint32_t>(total_q_tokens), 1);
switch (head_size) {
case 64:
PagedAttentionPrefillHd64SplitKvCombine<Tdata>
<<<grid2, block2, 0, stream>>>(
out, partial_acc, partial_m, partial_l, num_splits, total_q_tokens, o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
case 128:
PagedAttentionPrefillHd128SplitKvCombine<Tdata>
<<<grid2, block2, 0, stream>>>(
out, partial_acc, partial_m, partial_l, num_splits, total_q_tokens, o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill_warpcta8n128(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
hcStream_t stream) {
constexpr int kWarps = 8;
constexpr int kThreads = kWarps * 32;
const dim3 block(kThreads);
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(num_seqs),
static_cast<uint32_t>(ceilDiv(total_q_tokens, static_cast<size_t>(kWarps))));
// Only meaningful for head_dim=128.
if (head_size != 128) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
PagedAttentionPrefillHd128WarpCta8N128<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill_warpcta16(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
hcStream_t stream) {
constexpr int kWarps = 16;
constexpr int kThreads = kWarps * 32;
const dim3 block(kThreads);
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(num_seqs),
static_cast<uint32_t>(ceilDiv(total_q_tokens, static_cast<size_t>(kWarps))));
switch (head_size) {
case 64:
PagedAttentionPrefillHd64WarpCta16<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
case 128:
PagedAttentionPrefillHd128WarpCta16<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
} // namespace
struct Descriptor::Opaque {
std::shared_ptr<device::metax::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t total_kv_lens_desc,
infiniopTensorDescriptor_t cum_seqlens_q_desc,
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc,
float scale) {
auto info = PagedAttentionPrefillInfo::create(
out_desc, q_desc, k_cache_desc, v_cache_desc,
block_tables_desc, total_kv_lens_desc, cum_seqlens_q_desc,
alibi_slopes_desc, scale);
CHECK_RESULT(info);
// Optional split-kv prefill requires workspace for partial (m, l, acc).
// IMPORTANT: Unlike decode, prefill's total_q_tokens can be very large, so we must NOT reserve
// a huge workspace unless the user explicitly enables split-kv.
bool use_splitkv = false;
if (const char *env = std::getenv("INFINIOP_FLASH_PREFILL_SPLITKV")) {
use_splitkv = (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0);
}
int num_splits = 1;
if (use_splitkv) {
if (const char *env = std::getenv("INFINIOP_FLASH_PREFILL_NUM_SPLITS")) {
const int v = std::atoi(env);
if (v > 0) {
num_splits = v;
}
} else {
num_splits = 4;
}
constexpr int kMaxSplits = 8;
if (num_splits > kMaxSplits) {
num_splits = kMaxSplits;
}
}
const size_t n = info->total_q_tokens * info->num_heads;
const size_t splitkv_workspace_bytes = use_splitkv ? (static_cast<size_t>(num_splits) * n * (info->head_size + 2) * sizeof(float)) : 0;
const size_t workspace_bytes = splitkv_workspace_bytes;
// const size_t workspace_bytes = splitkv_workspace_bytes + fa2_workspace_bytes;
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
info.take(), workspace_bytes, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
const void *block_tables,
const void *total_kv_lens,
const void *cum_seqlens_q,
const void *alibi_slopes,
void *stream_) const {
auto stream = static_cast<hcStream_t>(stream_);
const float *alibi_ptr = (alibi_slopes == nullptr) ? nullptr : static_cast<const float *>(alibi_slopes);
const auto *total_kv_lens_i64 = static_cast<const int64_t *>(total_kv_lens);
const auto *cu_seqlens_q_i64 = static_cast<const int64_t *>(cum_seqlens_q);
bool use_splitkv = false;
if (const char *env = std::getenv("INFINIOP_FLASH_PREFILL_SPLITKV")) {
use_splitkv = (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0);
}
int num_splits = 1;
if (use_splitkv) {
if (const char *env = std::getenv("INFINIOP_FLASH_PREFILL_NUM_SPLITS")) {
const int v = std::atoi(env);
if (v > 0) {
num_splits = v;
}
} else {
// Conservative default; users can override.
num_splits = 4;
}
constexpr int kMaxSplits = 8;
if (num_splits > kMaxSplits) {
num_splits = kMaxSplits;
}
const size_t n = _info.total_q_tokens * _info.num_heads;
const size_t required = static_cast<size_t>(num_splits) * n * (_info.head_size + 2) * sizeof(float);
if (workspace_size < required) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
}
if (use_splitkv) {
const size_t n = _info.total_q_tokens * _info.num_heads;
float *partial_acc = static_cast<float *>(workspace);
float *partial_m = partial_acc + static_cast<size_t>(num_splits) * n * _info.head_size;
float *partial_l = partial_m + static_cast<size_t>(num_splits) * n;
// Dispatch by (Tdata, Tindex). total_kv_lens + cu_seqlens_q are currently always int64.
#define DISPATCH_SPLITKV(Tindex, Tdata, BT_PTR) \
return launch_prefill_warpcta8pipe_splitkv<Tindex, Tdata>( \
partial_acc, partial_m, partial_l, num_splits, \
static_cast<Tdata *>(out), \
static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), \
static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(BT_PTR), \
total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream)
if (_info.dtype == INFINI_DTYPE_F16) {
if (_info.index_dtype == INFINI_DTYPE_I64) {
DISPATCH_SPLITKV(int64_t, half, block_tables);
}
if (_info.index_dtype == INFINI_DTYPE_I32) {
DISPATCH_SPLITKV(int32_t, half, block_tables);
}
if (_info.index_dtype == INFINI_DTYPE_U32) {
DISPATCH_SPLITKV(uint32_t, half, block_tables);
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (_info.dtype == INFINI_DTYPE_BF16) {
if (_info.index_dtype == INFINI_DTYPE_I64) {
DISPATCH_SPLITKV(int64_t, __nv_bfloat16, block_tables);
}
if (_info.index_dtype == INFINI_DTYPE_I32) {
DISPATCH_SPLITKV(int32_t, __nv_bfloat16, block_tables);
}
if (_info.index_dtype == INFINI_DTYPE_U32) {
DISPATCH_SPLITKV(uint32_t, __nv_bfloat16, block_tables);
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
#undef DISPATCH_SPLITKV
}
// Default to the fastest validated kernel for supported shapes.
// "ref" is still available for debugging/correctness bisecting.
#define DISPATCH_KERNEL(Tindex, Tdata, Tcompute) \
do { \
const char *k_env = std::getenv("INFINIOP_FLASH_PREFILL_KERNEL"); \
const char *k = (k_env == nullptr) ? default_prefill_kernel(_info) : k_env; \
if (k_env != nullptr) { \
const bool known = (std::strcmp(k, "warp") == 0) || (std::strcmp(k, "warpcta") == 0) || (std::strcmp(k, "warpcta8") == 0) || (std::strcmp(k, "warpcta8pipe") == 0) || (std::strcmp(k, "warpcta8mma") == 0) || (std::strcmp(k, "warpcta8n128") == 0) || (std::strcmp(k, "warpcta16") == 0) || (std::strcmp(k, "ref") == 0); \
if (!known) { \
const char *fallback = default_prefill_kernel(_info); \
std::fprintf(stderr, \
"[infiniop][paged_attention_prefill] WARNING: unknown kernel '%s', falling back to '%s'\n", \
k, fallback); \
k = fallback; \
} \
} \
const char *dbg = std::getenv("INFINIOP_DEBUG_PREFILL_DISPATCH"); \
static bool printed_dispatch = false; \
if (!printed_dispatch && dbg != nullptr && std::strcmp(dbg, "1") == 0) { \
std::fprintf(stderr, \
"[infiniop][paged_attention_prefill] kernel=%s (override=%s head_size=%zu block=%zu dtype=%zu)\n", \
k, \
(k_env == nullptr ? "auto" : "env"), \
static_cast<size_t>(_info.head_size), \
static_cast<size_t>(_info.page_block_size), \
static_cast<size_t>(_info.dtype)); \
printed_dispatch = true; \
} \
if (std::strcmp(k, "warp") == 0) { \
return launch_prefill_warp<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if (std::strcmp(k, "warpcta") == 0) { \
return launch_prefill<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if (std::strcmp(k, "warpcta8") == 0) { \
return launch_prefill_warpcta8<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if (std::strcmp(k, "warpcta8pipe") == 0) { \
return launch_prefill_warpcta8pipe<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if constexpr (std::is_same_v<Tdata, half>) { \
if (std::strcmp(k, "warpcta8mma") == 0) { \
return launch_prefill_warpcta8mma<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
} \
if (std::strcmp(k, "warpcta8n128") == 0) { \
return launch_prefill_warpcta8n128<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if (std::strcmp(k, "warpcta16") == 0) { \
return launch_prefill_warpcta16<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if (std::strcmp(k, "ref") == 0) { \
return launch_prefill_ref<Tindex, Tdata, Tcompute>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
return INFINI_STATUS_BAD_PARAM; \
} while (false)
#define DISPATCH_INDEX(Tindex) \
do { \
if (_info.dtype == INFINI_DTYPE_F16) { \
DISPATCH_KERNEL(Tindex, half, float); \
} \
if (_info.dtype == INFINI_DTYPE_BF16) { \
DISPATCH_KERNEL(Tindex, __nv_bfloat16, float); \
} \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
} while (false)
if (_info.index_dtype == INFINI_DTYPE_I64) {
DISPATCH_INDEX(int64_t);
} else if (_info.index_dtype == INFINI_DTYPE_I32) {
DISPATCH_INDEX(int32_t);
} else if (_info.index_dtype == INFINI_DTYPE_U32) {
DISPATCH_INDEX(uint32_t);
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} // namespace op::paged_attention_prefill::nvidia
src/infiniop/ops/paged_attention_prefill/moore/paged_attention_prefill_kernel.h
0 → 100644
View file @
8d09630a
#ifndef __PAGED_ATTENTION_PREFILL_KERNEL_CUH__
#define __PAGED_ATTENTION_PREFILL_KERNEL_CUH__
namespace
op
::
paged_attention_prefill
::
cuda
{
__device__
__forceinline__
size_t
find_seq_id
(
size_t
token_idx
,
const
int64_t
*
cum_seq_lens_q
,
size_t
num_seqs
)
{
size_t
low
=
0
,
high
=
num_seqs
-
1
;
while
(
low
<=
high
)
{
size_t
mid
=
(
low
+
high
)
>>
1
;
if
(
token_idx
>=
(
size_t
)
cum_seq_lens_q
[
mid
]
&&
token_idx
<
(
size_t
)
cum_seq_lens_q
[
mid
+
1
])
{
return
mid
;
}
else
if
(
token_idx
<
(
size_t
)
cum_seq_lens_q
[
mid
])
{
high
=
mid
-
1
;
}
else
{
low
=
mid
+
1
;
}
}
return
0
;
}
// Warp-level sum reduction with an explicit active mask (safe for partial warps).
__device__
__forceinline__
float
warpReduceSum
(
float
val
,
unsigned
mask
)
{
for
(
int
offset
=
16
;
offset
>
0
;
offset
>>=
1
)
{
val
+=
__shfl_down_sync
(
mask
,
val
,
offset
);
}
return
val
;
}
// Block-level sum reduction. Returns the sum to all threads in the block.
// Supports blockDim.x up to 1024.
__device__
__forceinline__
float
blockReduceSum
(
float
val
)
{
__shared__
float
shared
[
32
];
// max 32 warps per block
const
int
lane
=
threadIdx
.
x
&
31
;
const
int
wid
=
threadIdx
.
x
>>
5
;
const
unsigned
mask
=
__activemask
();
val
=
warpReduceSum
(
val
,
mask
);
if
(
lane
==
0
)
{
shared
[
wid
]
=
val
;
}
__syncthreads
();
const
int
num_warps
=
(
blockDim
.
x
+
31
)
>>
5
;
float
sum
=
0.0
f
;
if
(
wid
==
0
)
{
sum
=
(
lane
<
num_warps
)
?
shared
[
lane
]
:
0.0
f
;
const
unsigned
mask0
=
(
num_warps
>=
32
)
?
0xffffffffu
:
((
1u
<<
num_warps
)
-
1u
);
sum
=
warpReduceSum
(
sum
,
mask0
);
if
(
lane
==
0
)
{
shared
[
0
]
=
sum
;
}
}
__syncthreads
();
return
shared
[
0
];
}
template
<
typename
Tdata
,
typename
Tcompute
>
__global__
void
pagedAttentionPrefillKernel
(
Tdata
*
out_
,
const
Tdata
*
q_
,
const
Tdata
*
k_cache_
,
const
Tdata
*
v_cache_
,
const
int64_t
*
block_tables_
,
const
int64_t
*
total_kv_lens_
,
const
int64_t
*
cum_seq_lens_q_
,
const
float
*
alibi_slopes_
,
const
size_t
num_heads
,
const
size_t
num_kv_heads
,
const
float
scale
,
const
size_t
max_num_blocks_per_seq
,
const
size_t
block_size
,
const
ptrdiff_t
kv_block_stride
,
const
ptrdiff_t
kv_head_stride
,
const
ptrdiff_t
q_stride
,
const
ptrdiff_t
q_head_stride
,
const
size_t
head_size
,
const
size_t
num_seqs
)
{
// Grid : x -> token, y -> head
const
size_t
global_token_idx
=
blockIdx
.
x
;
const
size_t
head_idx
=
blockIdx
.
y
;
const
size_t
dim_idx
=
threadIdx
.
x
;
if
(
dim_idx
>=
head_size
)
{
return
;
}
__shared__
size_t
sh_seq_idx
;
__shared__
size_t
sh_causal_limit
;
__shared__
size_t
sh_kv_head_idx
;
__shared__
float
sh_scale_acc
;
__shared__
float
sh_w
;
__shared__
float
sh_inv_l
;
if
(
dim_idx
==
0
)
{
sh_seq_idx
=
find_seq_id
(
global_token_idx
,
cum_seq_lens_q_
,
num_seqs
);
const
size_t
q_token_idx
=
global_token_idx
-
static_cast
<
size_t
>
(
cum_seq_lens_q_
[
sh_seq_idx
]);
const
size_t
total_kv_len
=
static_cast
<
size_t
>
(
total_kv_lens_
[
sh_seq_idx
]);
const
size_t
q_len
=
static_cast
<
size_t
>
(
cum_seq_lens_q_
[
sh_seq_idx
+
1
]
-
cum_seq_lens_q_
[
sh_seq_idx
]);
const
size_t
history_len
=
total_kv_len
-
q_len
;
sh_causal_limit
=
history_len
+
q_token_idx
;
const
size_t
num_queries_per_kv
=
num_heads
/
num_kv_heads
;
sh_kv_head_idx
=
head_idx
/
num_queries_per_kv
;
}
__syncthreads
();
const
size_t
seq_idx
=
sh_seq_idx
;
const
size_t
causal_limit
=
sh_causal_limit
;
const
size_t
kv_head_idx
=
sh_kv_head_idx
;
const
Tdata
*
q_vec
=
q_
+
global_token_idx
*
q_stride
+
head_idx
*
q_head_stride
;
Tdata
*
out_ptr
=
out_
+
global_token_idx
*
num_heads
*
head_size
+
head_idx
*
head_size
;
const
int64_t
*
block_table
=
block_tables_
+
seq_idx
*
max_num_blocks_per_seq
;
const
float
alibi_slope
=
(
alibi_slopes_
==
nullptr
)
?
0.0
f
:
alibi_slopes_
[
head_idx
];
const
float
qv
=
static_cast
<
float
>
(
q_vec
[
dim_idx
]);
Tcompute
acc
=
0.0
f
;
float
m
=
-
FLT_MAX
;
float
l
=
0.0
f
;
for
(
size_t
t
=
0
;
t
<=
causal_limit
;
++
t
)
{
const
size_t
b_idx
=
t
/
block_size
;
const
size_t
t_off
=
t
%
block_size
;
const
ptrdiff_t
physical_block_id
=
block_table
[
b_idx
];
const
Tdata
*
k_vec
=
k_cache_
+
physical_block_id
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
+
t_off
*
head_size
;
const
float
dot
=
blockReduceSum
(
qv
*
static_cast
<
float
>
(
k_vec
[
dim_idx
]));
if
(
dim_idx
==
0
)
{
float
score
=
dot
*
static_cast
<
float
>
(
scale
);
if
(
alibi_slope
!=
0.0
f
)
{
score
+=
alibi_slope
*
static_cast
<
float
>
(
t
-
causal_limit
);
}
const
float
m_new
=
fmaxf
(
m
,
score
);
const
float
scale_acc
=
expf
(
m
-
m_new
);
const
float
w
=
expf
(
score
-
m_new
);
l
=
l
*
scale_acc
+
w
;
m
=
m_new
;
sh_scale_acc
=
scale_acc
;
sh_w
=
w
;
}
__syncthreads
();
const
float
scale_acc
=
sh_scale_acc
;
const
float
w
=
sh_w
;
const
Tdata
*
v_vec
=
v_cache_
+
physical_block_id
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
+
t_off
*
head_size
;
acc
=
acc
*
static_cast
<
Tcompute
>
(
scale_acc
)
+
static_cast
<
Tcompute
>
(
w
)
*
static_cast
<
Tcompute
>
(
v_vec
[
dim_idx
]);
__syncthreads
();
}
if
(
dim_idx
==
0
)
{
sh_inv_l
=
1.0
f
/
(
l
+
1e-6
f
);
}
__syncthreads
();
out_ptr
[
dim_idx
]
=
static_cast
<
Tdata
>
(
acc
*
static_cast
<
Tcompute
>
(
sh_inv_l
));
}
}
// namespace op::paged_attention_prefill::cuda
#endif
src/infiniop/ops/paged_attention_prefill/moore/paged_attention_prefill_moore.h
0 → 100644
View file @
8d09630a
#ifndef __PAGED_ATTENTION_PREFILL_MOORE_H__
#define __PAGED_ATTENTION_PREFILL_MOORE_H__
#include "../paged_attention_prefill.h"
DESCRIPTOR
(
moore
)
#endif // __PAGED_ATTENTION_PREFILL_MOORE_H__
src/infiniop/ops/paged_attention_prefill/moore/paged_attention_prefill_moore.mu
0 → 100644
View file @
8d09630a
#include <musa_fp16.h>
#include <float.h>
#include <math.h>
#include <stdint.h>
#include "../../../devices/moore/moore_common.h"
#include "../../../devices/moore/moore_kernel_common.h"
#include "paged_attention_prefill_kernel.h"
#include "paged_attention_prefill_moore.h"
template <typename Tdata, typename Tcompute>
infiniStatus_t launchPagedAttentionPrefill(
Tdata *out, const Tdata *q, const Tdata *k_cache, const Tdata *v_cache,
const int64_t *block_tables,
const int64_t *seq_lens,
const int64_t *cum_seq_lens_q,
const float *alibi_slopes,
const size_t num_heads,
const size_t num_seqs,
const size_t num_kv_heads,
const float scale,
const size_t max_num_blocks_per_seq,
const size_t page_block_size,
const size_t total_q_tokens,
const size_t head_size,
const ptrdiff_t k_batch_stride,
const ptrdiff_t k_head_stride,
const ptrdiff_t q_stride,
const ptrdiff_t q_head_stride,
musaStream_t stream) {
if (total_q_tokens == 0 || num_heads == 0) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
dim3 grid(total_q_tokens, num_heads);
dim3 block(head_size);
op::paged_attention_prefill::cuda::pagedAttentionPrefillKernel<Tdata, Tcompute>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache,
block_tables, seq_lens, cum_seq_lens_q, alibi_slopes,
num_heads, num_kv_heads, scale,
max_num_blocks_per_seq, page_block_size,
k_batch_stride, k_head_stride,
q_stride, q_head_stride,
head_size,
num_seqs);
return INFINI_STATUS_SUCCESS;
}
namespace op::paged_attention_prefill::moore {
struct Descriptor::Opaque {
std::shared_ptr<device::moore::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t seq_lens_desc,
infiniopTensorDescriptor_t cum_seq_lens_q_desc,
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc,
float scale) {
auto info = PagedAttentionPrefillInfo::create(
out_desc, q_desc, k_cache_desc, v_cache_desc,
block_tables_desc, seq_lens_desc,
cum_seq_lens_q_desc,
alibi_slopes_desc, scale);
CHECK_RESULT(info);
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::moore::Handle *>(handle)->internal()},
info.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
const void *block_tables,
const void *seq_lens,
const void *cum_seq_lens_q,
const void *alibi_slopes,
void *stream_) const {
musaStream_t stream = (musaStream_t)stream_;
#define LAUNCH_KERNEL(Tdata, Tcompute) \
launchPagedAttentionPrefill<Tdata, Tcompute>( \
(Tdata *)out, (const Tdata *)q, (const Tdata *)k_cache, (const Tdata *)v_cache, \
(const int64_t *)block_tables, (const int64_t *)seq_lens, (const int64_t *)cum_seq_lens_q, \
(const float *)alibi_slopes, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, \
_info.scale, _info.max_num_blocks_per_seq, \
_info.page_block_size, _info.total_q_tokens, \
_info.head_size, \
_info.k_batch_stride, _info.k_head_stride, \
_info.q_stride, _info.q_head_stride, \
stream)
if (_info.dtype == INFINI_DTYPE_F16) {
return LAUNCH_KERNEL(half, float);
} else if (_info.dtype == INFINI_DTYPE_BF16) {
return LAUNCH_KERNEL(__mt_bfloat16, float);
} else if (_info.dtype == INFINI_DTYPE_F32) {
return LAUNCH_KERNEL(float, float);
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} // namespace op::paged_attention_prefill::moore
src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu
0 → 100644
View file @
8d09630a
#include <cuda_runtime.h>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include "../../../devices/nvidia/nvidia_common.cuh"
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
// #include "paged_attention_prefill_fa2.cuh"
#include "paged_attention_prefill_nvidia.cuh"
#include "../cuda/kernel_v2.cuh"
namespace
op
::
paged_attention_prefill
::
nvidia
{
namespace
{
constexpr
size_t
ceilDiv
(
size_t
a
,
size_t
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
inline
const
char
*
default_prefill_kernel
(
const
PagedAttentionPrefillInfo
&
info
)
{
// Heuristic auto-dispatch (v0.4):
// - Prefer the pipelined + tile-wise softmax kernel on FA2-compatible block_size=256.
// - Keep a conservative fallback for other shapes / older GPUs (cp.async is a no-op below SM80).
//
// Users can always override via INFINIOP_FLASH_PREFILL_KERNEL.
if
(
info
.
page_block_size
==
256
&&
(
info
.
dtype
==
INFINI_DTYPE_F16
||
info
.
dtype
==
INFINI_DTYPE_BF16
))
{
if
(
info
.
head_size
==
128
)
{
return
"warpcta8pipe"
;
}
// For head_size=64 we keep the previous default until we have broader perf coverage.
}
return
"warpcta8"
;
}
template
<
typename
Tindex
,
typename
Tdata
>
INFINIOP_CUDA_KERNEL
PagedAttentionPrefillHd128Warp
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
int64_t
*
total_kv_lens
,
const
int64_t
*
cu_seqlens_q
,
const
float
*
alibi_slopes
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
)
{
// Legacy per-seq launch (kept only as a wrapper; current "warp" impl uses a global-token kernel).
op
::
paged_attention_prefill
::
cuda
::
PagedAttentionPrefillWarpKernel
<
Tindex
,
Tdata
,
128
>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
);
}
template
<
typename
Tindex
,
typename
Tdata
>
INFINIOP_CUDA_KERNEL
PagedAttentionPrefillHd64Warp
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
int64_t
*
total_kv_lens
,
const
int64_t
*
cu_seqlens_q
,
const
float
*
alibi_slopes
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
)
{
// Legacy per-seq launch (kept only as a wrapper; current "warp" impl uses a global-token kernel).
op
::
paged_attention_prefill
::
cuda
::
PagedAttentionPrefillWarpKernel
<
Tindex
,
Tdata
,
64
>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
);
}
template
<
typename
Tindex
,
typename
Tdata
>
INFINIOP_CUDA_KERNEL
PagedAttentionPrefillHd128WarpCta
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
int64_t
*
total_kv_lens
,
const
int64_t
*
cu_seqlens_q
,
const
float
*
alibi_slopes
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
)
{
// 4 warps per CTA, one warp per query token.
op
::
paged_attention_prefill
::
cuda
::
PagedAttentionPrefillWarpCtaKernel
<
Tindex
,
Tdata
,
128
,
4
,
64
>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
);
}
template
<
typename
Tindex
,
typename
Tdata
>
INFINIOP_CUDA_KERNEL
PagedAttentionPrefillHd64WarpCta
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
int64_t
*
total_kv_lens
,
const
int64_t
*
cu_seqlens_q
,
const
float
*
alibi_slopes
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
)
{
// 4 warps per CTA, one warp per query token.
op
::
paged_attention_prefill
::
cuda
::
PagedAttentionPrefillWarpCtaKernel
<
Tindex
,
Tdata
,
64
,
4
,
128
>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
);
}
template
<
typename
Tindex
,
typename
Tdata
>
INFINIOP_CUDA_KERNEL
PagedAttentionPrefillHd128WarpCta8
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
int64_t
*
total_kv_lens
,
const
int64_t
*
cu_seqlens_q
,
const
float
*
alibi_slopes
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
)
{
// 8 warps per CTA, one warp per query token.
op
::
paged_attention_prefill
::
cuda
::
PagedAttentionPrefillWarpCtaKernel
<
Tindex
,
Tdata
,
128
,
8
,
64
>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
);
}
template
<
typename
Tindex
,
typename
Tdata
>
INFINIOP_CUDA_KERNEL
PagedAttentionPrefillHd128WarpCta8N128
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
int64_t
*
total_kv_lens
,
const
int64_t
*
cu_seqlens_q
,
const
float
*
alibi_slopes
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
)
{
// 8 warps per CTA, one warp per query token, tile_n=128 for fewer K stages.
// Note: we keep K in shared memory but load V from global to stay within the per-block shared limit.
op
::
paged_attention_prefill
::
cuda
::
PagedAttentionPrefillWarpCtaKernelKOnly
<
Tindex
,
Tdata
,
128
,
8
,
128
>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
);
}
template
<
typename
Tindex
,
typename
Tdata
>
INFINIOP_CUDA_KERNEL
PagedAttentionPrefillHd64WarpCta8
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
int64_t
*
total_kv_lens
,
const
int64_t
*
cu_seqlens_q
,
const
float
*
alibi_slopes
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
)
{
// 8 warps per CTA, one warp per query token.
op
::
paged_attention_prefill
::
cuda
::
PagedAttentionPrefillWarpCtaKernel
<
Tindex
,
Tdata
,
64
,
8
,
128
>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
);
}
template
<
typename
Tindex
,
typename
Tdata
>
INFINIOP_CUDA_KERNEL
PagedAttentionPrefillHd128WarpCta8Pipe
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
int64_t
*
total_kv_lens
,
const
int64_t
*
cu_seqlens_q
,
const
float
*
alibi_slopes
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
)
{
// 8 warps per CTA, one warp per query token, with cp.async pipelining.
op
::
paged_attention_prefill
::
cuda
::
PagedAttentionPrefillWarpCtaKernelPipelined
<
Tindex
,
Tdata
,
128
,
8
,
32
,
2
>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
);
}
template
<
typename
Tindex
>
INFINIOP_CUDA_KERNEL
PagedAttentionPrefillHd128WarpCta8Mma
(
half
*
out
,
const
half
*
q
,
const
half
*
k_cache
,
const
half
*
v_cache
,
const
Tindex
*
block_tables
,
const
int64_t
*
total_kv_lens
,
const
int64_t
*
cu_seqlens_q
,
const
float
*
alibi_slopes
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
)
{
op
::
paged_attention_prefill
::
cuda
::
PagedAttentionPrefillWarpCta8MmaHd128Kernel
<
Tindex
>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
);
}
template
<
typename
Tindex
,
typename
Tdata
>
INFINIOP_CUDA_KERNEL
PagedAttentionPrefillHd64WarpCta8Pipe
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
int64_t
*
total_kv_lens
,
const
int64_t
*
cu_seqlens_q
,
const
float
*
alibi_slopes
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
)
{
// 8 warps per CTA, one warp per query token, with cp.async pipelining.
op
::
paged_attention_prefill
::
cuda
::
PagedAttentionPrefillWarpCtaKernelPipelined
<
Tindex
,
Tdata
,
64
,
8
,
32
,
2
>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
);
}
template
<
typename
Tindex
,
typename
Tdata
>
INFINIOP_CUDA_KERNEL
PagedAttentionPrefillHd128WarpCta8PipeSplitKv
(
float
*
partial_acc
,
float
*
partial_m
,
float
*
partial_l
,
int
num_splits
,
size_t
total_q_tokens
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
int64_t
*
total_kv_lens
,
const
int64_t
*
cu_seqlens_q
,
const
float
*
alibi_slopes
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
)
{
// Encode (split_idx, m_block) into blockIdx.z to allow a single kernel launch:
// blockIdx.z in [0, num_splits * num_m_blocks).
const
int
num_m_blocks
=
static_cast
<
int
>
((
total_q_tokens
+
8
-
1
)
/
8
);
const
int
bz
=
static_cast
<
int
>
(
blockIdx
.
z
);
const
int
split_idx
=
bz
/
num_m_blocks
;
const
int
m_block
=
bz
-
split_idx
*
num_m_blocks
;
op
::
paged_attention_prefill
::
cuda
::
PagedAttentionPrefillWarpCtaKernelPipelinedSplitKv
<
Tindex
,
Tdata
,
128
,
8
,
32
,
2
>
(
partial_acc
,
partial_m
,
partial_l
,
split_idx
,
num_splits
,
m_block
,
total_q_tokens
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
);
}
template
<
typename
Tindex
,
typename
Tdata
>
INFINIOP_CUDA_KERNEL
PagedAttentionPrefillHd64WarpCta8PipeSplitKv
(
float
*
partial_acc
,
float
*
partial_m
,
float
*
partial_l
,
int
num_splits
,
size_t
total_q_tokens
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
int64_t
*
total_kv_lens
,
const
int64_t
*
cu_seqlens_q
,
const
float
*
alibi_slopes
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
)
{
const
int
num_m_blocks
=
static_cast
<
int
>
((
total_q_tokens
+
8
-
1
)
/
8
);
const
int
bz
=
static_cast
<
int
>
(
blockIdx
.
z
);
const
int
split_idx
=
bz
/
num_m_blocks
;
const
int
m_block
=
bz
-
split_idx
*
num_m_blocks
;
op
::
paged_attention_prefill
::
cuda
::
PagedAttentionPrefillWarpCtaKernelPipelinedSplitKv
<
Tindex
,
Tdata
,
64
,
8
,
32
,
2
>
(
partial_acc
,
partial_m
,
partial_l
,
split_idx
,
num_splits
,
m_block
,
total_q_tokens
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
);
}
template
<
typename
Tdata
>
INFINIOP_CUDA_KERNEL
PagedAttentionPrefillHd128SplitKvCombine
(
Tdata
*
out
,
const
float
*
partial_acc
,
const
float
*
partial_m
,
const
float
*
partial_l
,
int
num_splits
,
size_t
total_q_tokens
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
)
{
op
::
paged_attention_prefill
::
cuda
::
PagedAttentionPrefillSplitKvCombineWarpKernel
<
Tdata
,
128
>
(
out
,
partial_acc
,
partial_m
,
partial_l
,
num_splits
,
total_q_tokens
,
o_stride
,
o_head_stride
);
}
template
<
typename
Tdata
>
INFINIOP_CUDA_KERNEL
PagedAttentionPrefillHd64SplitKvCombine
(
Tdata
*
out
,
const
float
*
partial_acc
,
const
float
*
partial_m
,
const
float
*
partial_l
,
int
num_splits
,
size_t
total_q_tokens
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
)
{
op
::
paged_attention_prefill
::
cuda
::
PagedAttentionPrefillSplitKvCombineWarpKernel
<
Tdata
,
64
>
(
out
,
partial_acc
,
partial_m
,
partial_l
,
num_splits
,
total_q_tokens
,
o_stride
,
o_head_stride
);
}
template
<
typename
Tindex
,
typename
Tdata
>
INFINIOP_CUDA_KERNEL
PagedAttentionPrefillHd128WarpCta16
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
int64_t
*
total_kv_lens
,
const
int64_t
*
cu_seqlens_q
,
const
float
*
alibi_slopes
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
)
{
// 16 warps per CTA, one warp per query token.
op
::
paged_attention_prefill
::
cuda
::
PagedAttentionPrefillWarpCtaKernel
<
Tindex
,
Tdata
,
128
,
16
,
64
>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
);
}
template
<
typename
Tindex
,
typename
Tdata
>
INFINIOP_CUDA_KERNEL
PagedAttentionPrefillHd64WarpCta16
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
int64_t
*
total_kv_lens
,
const
int64_t
*
cu_seqlens_q
,
const
float
*
alibi_slopes
,
size_t
num_kv_heads
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
)
{
// 16 warps per CTA, one warp per query token.
op
::
paged_attention_prefill
::
cuda
::
PagedAttentionPrefillWarpCtaKernel
<
Tindex
,
Tdata
,
64
,
16
,
128
>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
);
}
template
<
typename
Tindex
,
typename
Tdata
,
typename
Tcompute
>
infiniStatus_t
launch_prefill_ref
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
int64_t
*
total_kv_lens
,
const
int64_t
*
cu_seqlens_q
,
const
float
*
alibi_slopes
,
size_t
num_heads
,
size_t
num_seqs
,
size_t
num_kv_heads
,
size_t
total_q_tokens
,
size_t
head_size
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
,
cudaStream_t
stream
)
{
const
dim3
grid
(
static_cast
<
uint32_t
>
(
total_q_tokens
),
static_cast
<
uint32_t
>
(
num_heads
),
1
);
const
dim3
block
(
static_cast
<
uint32_t
>
(
head_size
),
1
,
1
);
if
(
head_size
==
64
)
{
op
::
paged_attention_prefill
::
cuda
::
PagedAttentionPrefillReferenceKernel
<
Tindex
,
Tdata
,
Tcompute
,
64
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_heads
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
,
num_seqs
);
return
INFINI_STATUS_SUCCESS
;
}
if
(
head_size
==
128
)
{
op
::
paged_attention_prefill
::
cuda
::
PagedAttentionPrefillReferenceKernel
<
Tindex
,
Tdata
,
Tcompute
,
128
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_heads
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
,
num_seqs
);
return
INFINI_STATUS_SUCCESS
;
}
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
template
<
typename
Tindex
,
typename
Tdata
>
infiniStatus_t
launch_prefill_warp
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
int64_t
*
total_kv_lens
,
const
int64_t
*
cu_seqlens_q
,
const
float
*
alibi_slopes
,
size_t
num_heads
,
size_t
num_seqs
,
size_t
num_kv_heads
,
size_t
total_q_tokens
,
size_t
head_size
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
,
cudaStream_t
stream
)
{
const
dim3
block
(
32
,
1
,
1
);
// Global-token launch:
// - dramatically reduces grid size vs the legacy (num_seqs * total_q_tokens) launch
// - matches PagedAttention varlen (cu_seqlens) mental model better
const
dim3
grid
(
static_cast
<
uint32_t
>
(
num_heads
),
static_cast
<
uint32_t
>
(
total_q_tokens
),
1
);
switch
(
head_size
)
{
case
64
:
op
::
paged_attention_prefill
::
cuda
::
PagedAttentionPrefillWarpGlobalKernel
<
Tindex
,
Tdata
,
64
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_heads
,
num_seqs
,
num_kv_heads
,
total_q_tokens
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
);
return
INFINI_STATUS_SUCCESS
;
case
128
:
op
::
paged_attention_prefill
::
cuda
::
PagedAttentionPrefillWarpGlobalKernel
<
Tindex
,
Tdata
,
128
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_heads
,
num_seqs
,
num_kv_heads
,
total_q_tokens
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
);
return
INFINI_STATUS_SUCCESS
;
default:
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
}
template
<
typename
Tindex
,
typename
Tdata
>
infiniStatus_t
launch_prefill
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
int64_t
*
total_kv_lens
,
const
int64_t
*
cu_seqlens_q
,
const
float
*
alibi_slopes
,
size_t
num_heads
,
size_t
num_seqs
,
size_t
num_kv_heads
,
size_t
total_q_tokens
,
size_t
head_size
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
,
cudaStream_t
stream
)
{
constexpr
int
kWarps
=
4
;
constexpr
int
kThreads
=
kWarps
*
32
;
const
dim3
block
(
kThreads
);
const
dim3
grid
(
static_cast
<
uint32_t
>
(
num_heads
),
static_cast
<
uint32_t
>
(
num_seqs
),
static_cast
<
uint32_t
>
(
ceilDiv
(
total_q_tokens
,
static_cast
<
size_t
>
(
kWarps
))));
switch
(
head_size
)
{
case
64
:
PagedAttentionPrefillHd64WarpCta
<
Tindex
,
Tdata
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
);
return
INFINI_STATUS_SUCCESS
;
case
128
:
PagedAttentionPrefillHd128WarpCta
<
Tindex
,
Tdata
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
);
return
INFINI_STATUS_SUCCESS
;
default:
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
}
template
<
typename
Tindex
,
typename
Tdata
>
infiniStatus_t
launch_prefill_warpcta8
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
int64_t
*
total_kv_lens
,
const
int64_t
*
cu_seqlens_q
,
const
float
*
alibi_slopes
,
size_t
num_heads
,
size_t
num_seqs
,
size_t
num_kv_heads
,
size_t
total_q_tokens
,
size_t
head_size
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
,
cudaStream_t
stream
)
{
constexpr
int
kWarps
=
8
;
constexpr
int
kThreads
=
kWarps
*
32
;
const
dim3
block
(
kThreads
);
const
dim3
grid
(
static_cast
<
uint32_t
>
(
num_heads
),
static_cast
<
uint32_t
>
(
num_seqs
),
static_cast
<
uint32_t
>
(
ceilDiv
(
total_q_tokens
,
static_cast
<
size_t
>
(
kWarps
))));
switch
(
head_size
)
{
case
64
:
PagedAttentionPrefillHd64WarpCta8
<
Tindex
,
Tdata
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
);
return
INFINI_STATUS_SUCCESS
;
case
128
:
PagedAttentionPrefillHd128WarpCta8
<
Tindex
,
Tdata
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
);
return
INFINI_STATUS_SUCCESS
;
default:
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
}
template
<
typename
Tindex
,
typename
Tdata
>
infiniStatus_t
launch_prefill_warpcta8pipe
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
int64_t
*
total_kv_lens
,
const
int64_t
*
cu_seqlens_q
,
const
float
*
alibi_slopes
,
size_t
num_heads
,
size_t
num_seqs
,
size_t
num_kv_heads
,
size_t
total_q_tokens
,
size_t
head_size
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
,
cudaStream_t
stream
)
{
constexpr
int
kWarps
=
8
;
constexpr
int
kThreads
=
kWarps
*
32
;
const
dim3
block
(
kThreads
);
const
dim3
grid
(
static_cast
<
uint32_t
>
(
num_heads
),
static_cast
<
uint32_t
>
(
num_seqs
),
static_cast
<
uint32_t
>
(
ceilDiv
(
total_q_tokens
,
static_cast
<
size_t
>
(
kWarps
))));
switch
(
head_size
)
{
case
64
:
PagedAttentionPrefillHd64WarpCta8Pipe
<
Tindex
,
Tdata
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
);
return
INFINI_STATUS_SUCCESS
;
case
128
:
PagedAttentionPrefillHd128WarpCta8Pipe
<
Tindex
,
Tdata
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
);
return
INFINI_STATUS_SUCCESS
;
default:
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
}
template
<
typename
Tindex
,
typename
Tdata
>
infiniStatus_t
launch_prefill_warpcta8mma
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
int64_t
*
total_kv_lens
,
const
int64_t
*
cu_seqlens_q
,
const
float
*
alibi_slopes
,
size_t
num_heads
,
size_t
num_seqs
,
size_t
num_kv_heads
,
size_t
total_q_tokens
,
size_t
head_size
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
,
cudaStream_t
stream
)
{
// Current WMMA kernel only supports fp16 + head_dim=128.
if
constexpr
(
!
std
::
is_same_v
<
Tdata
,
half
>
)
{
return
launch_prefill_warpcta8pipe
<
Tindex
,
Tdata
>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_heads
,
num_seqs
,
num_kv_heads
,
total_q_tokens
,
head_size
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
,
stream
);
}
if
(
head_size
!=
128
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
// Guardrail: the current WMMA-score kernel is correctness-first and can be extremely slow on long prompts.
// Allow power users to force it via INFINIOP_FLASH_PREFILL_MMA_FORCE=1.
const
char
*
force_env
=
std
::
getenv
(
"INFINIOP_FLASH_PREFILL_MMA_FORCE"
);
const
bool
force_mma
=
(
force_env
!=
nullptr
)
&&
(
std
::
strcmp
(
force_env
,
"1"
)
==
0
);
const
size_t
seqlen_k_est
=
max_num_blocks_per_seq
*
page_block_size
;
if
(
!
force_mma
&&
seqlen_k_est
>
4096
)
{
static
bool
warned
=
false
;
if
(
!
warned
)
{
std
::
fprintf
(
stderr
,
"[infiniop][paged_attention_prefill] warpcta8mma is experimental and very slow for long seqlen_k (est=%zu). "
"Falling back to warpcta8pipe. Set INFINIOP_FLASH_PREFILL_MMA_FORCE=1 to override.
\n
"
,
seqlen_k_est
);
warned
=
true
;
}
return
launch_prefill_warpcta8pipe
<
Tindex
,
Tdata
>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_heads
,
num_seqs
,
num_kv_heads
,
total_q_tokens
,
head_size
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
,
stream
);
}
// WMMA requires SM70+. If not supported (or if we can't query), fall back to the pipelined SIMT kernel.
int
device
=
0
;
cudaDeviceProp
prop
{};
if
(
cudaGetDevice
(
&
device
)
==
cudaSuccess
&&
cudaGetDeviceProperties
(
&
prop
,
device
)
==
cudaSuccess
)
{
if
(
prop
.
major
<
7
)
{
return
launch_prefill_warpcta8pipe
<
Tindex
,
Tdata
>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_heads
,
num_seqs
,
num_kv_heads
,
total_q_tokens
,
head_size
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
,
stream
);
}
}
constexpr
int
kWarps
=
8
;
constexpr
int
kThreads
=
kWarps
*
32
;
const
dim3
block
(
kThreads
);
const
dim3
grid
(
static_cast
<
uint32_t
>
(
num_heads
),
static_cast
<
uint32_t
>
(
num_seqs
),
static_cast
<
uint32_t
>
(
ceilDiv
(
total_q_tokens
,
static_cast
<
size_t
>
(
16
))));
PagedAttentionPrefillHd128WarpCta8Mma
<
Tindex
>
<<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
half
*>
(
out
),
static_cast
<
const
half
*>
(
q
),
static_cast
<
const
half
*>
(
k_cache
),
static_cast
<
const
half
*>
(
v_cache
),
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
);
return
INFINI_STATUS_SUCCESS
;
}
template
<
typename
Tindex
,
typename
Tdata
>
infiniStatus_t
launch_prefill_warpcta8pipe_splitkv
(
float
*
partial_acc
,
float
*
partial_m
,
float
*
partial_l
,
int
num_splits
,
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
int64_t
*
total_kv_lens
,
const
int64_t
*
cu_seqlens_q
,
const
float
*
alibi_slopes
,
size_t
num_heads
,
size_t
num_seqs
,
size_t
num_kv_heads
,
size_t
total_q_tokens
,
size_t
head_size
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
,
cudaStream_t
stream
)
{
constexpr
int
kMaxSplits
=
8
;
if
(
num_splits
<
1
)
{
num_splits
=
1
;
}
if
(
num_splits
>
kMaxSplits
)
{
num_splits
=
kMaxSplits
;
}
constexpr
int
kWarps
=
8
;
constexpr
int
kThreads
=
kWarps
*
32
;
const
dim3
block
(
kThreads
);
const
size_t
num_m_blocks
=
ceilDiv
(
total_q_tokens
,
static_cast
<
size_t
>
(
kWarps
));
// Single kernel launch with split_idx encoded in grid.z:
// blockIdx.z in [0, num_splits * num_m_blocks).
const
dim3
grid
(
static_cast
<
uint32_t
>
(
num_heads
),
static_cast
<
uint32_t
>
(
num_seqs
),
static_cast
<
uint32_t
>
(
num_m_blocks
*
static_cast
<
size_t
>
(
num_splits
)));
switch
(
head_size
)
{
case
64
:
PagedAttentionPrefillHd64WarpCta8PipeSplitKv
<
Tindex
,
Tdata
>
<<<
grid
,
block
,
0
,
stream
>>>
(
partial_acc
,
partial_m
,
partial_l
,
num_splits
,
total_q_tokens
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
);
break
;
case
128
:
PagedAttentionPrefillHd128WarpCta8PipeSplitKv
<
Tindex
,
Tdata
>
<<<
grid
,
block
,
0
,
stream
>>>
(
partial_acc
,
partial_m
,
partial_l
,
num_splits
,
total_q_tokens
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
);
break
;
default:
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
// Combine: one warp per (token, head).
const
dim3
block2
(
32
);
const
dim3
grid2
(
static_cast
<
uint32_t
>
(
num_heads
),
static_cast
<
uint32_t
>
(
total_q_tokens
),
1
);
switch
(
head_size
)
{
case
64
:
PagedAttentionPrefillHd64SplitKvCombine
<
Tdata
>
<<<
grid2
,
block2
,
0
,
stream
>>>
(
out
,
partial_acc
,
partial_m
,
partial_l
,
num_splits
,
total_q_tokens
,
o_stride
,
o_head_stride
);
return
INFINI_STATUS_SUCCESS
;
case
128
:
PagedAttentionPrefillHd128SplitKvCombine
<
Tdata
>
<<<
grid2
,
block2
,
0
,
stream
>>>
(
out
,
partial_acc
,
partial_m
,
partial_l
,
num_splits
,
total_q_tokens
,
o_stride
,
o_head_stride
);
return
INFINI_STATUS_SUCCESS
;
default:
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
}
template
<
typename
Tindex
,
typename
Tdata
>
infiniStatus_t
launch_prefill_warpcta8n128
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
int64_t
*
total_kv_lens
,
const
int64_t
*
cu_seqlens_q
,
const
float
*
alibi_slopes
,
size_t
num_heads
,
size_t
num_seqs
,
size_t
num_kv_heads
,
size_t
total_q_tokens
,
size_t
head_size
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
,
cudaStream_t
stream
)
{
constexpr
int
kWarps
=
8
;
constexpr
int
kThreads
=
kWarps
*
32
;
const
dim3
block
(
kThreads
);
const
dim3
grid
(
static_cast
<
uint32_t
>
(
num_heads
),
static_cast
<
uint32_t
>
(
num_seqs
),
static_cast
<
uint32_t
>
(
ceilDiv
(
total_q_tokens
,
static_cast
<
size_t
>
(
kWarps
))));
// Only meaningful for head_dim=128.
if
(
head_size
!=
128
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
PagedAttentionPrefillHd128WarpCta8N128
<
Tindex
,
Tdata
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
);
return
INFINI_STATUS_SUCCESS
;
}
template
<
typename
Tindex
,
typename
Tdata
>
infiniStatus_t
launch_prefill_warpcta16
(
Tdata
*
out
,
const
Tdata
*
q
,
const
Tdata
*
k_cache
,
const
Tdata
*
v_cache
,
const
Tindex
*
block_tables
,
const
int64_t
*
total_kv_lens
,
const
int64_t
*
cu_seqlens_q
,
const
float
*
alibi_slopes
,
size_t
num_heads
,
size_t
num_seqs
,
size_t
num_kv_heads
,
size_t
total_q_tokens
,
size_t
head_size
,
float
scale
,
size_t
max_num_blocks_per_seq
,
size_t
page_block_size
,
ptrdiff_t
block_table_batch_stride
,
ptrdiff_t
q_stride
,
ptrdiff_t
q_head_stride
,
ptrdiff_t
k_batch_stride
,
ptrdiff_t
k_row_stride
,
ptrdiff_t
k_head_stride
,
ptrdiff_t
v_batch_stride
,
ptrdiff_t
v_row_stride
,
ptrdiff_t
v_head_stride
,
ptrdiff_t
o_stride
,
ptrdiff_t
o_head_stride
,
cudaStream_t
stream
)
{
constexpr
int
kWarps
=
16
;
constexpr
int
kThreads
=
kWarps
*
32
;
const
dim3
block
(
kThreads
);
const
dim3
grid
(
static_cast
<
uint32_t
>
(
num_heads
),
static_cast
<
uint32_t
>
(
num_seqs
),
static_cast
<
uint32_t
>
(
ceilDiv
(
total_q_tokens
,
static_cast
<
size_t
>
(
kWarps
))));
switch
(
head_size
)
{
case
64
:
PagedAttentionPrefillHd64WarpCta16
<
Tindex
,
Tdata
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
);
return
INFINI_STATUS_SUCCESS
;
case
128
:
PagedAttentionPrefillHd128WarpCta16
<
Tindex
,
Tdata
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
,
q
,
k_cache
,
v_cache
,
block_tables
,
total_kv_lens
,
cu_seqlens_q
,
alibi_slopes
,
num_kv_heads
,
scale
,
max_num_blocks_per_seq
,
page_block_size
,
block_table_batch_stride
,
q_stride
,
q_head_stride
,
k_batch_stride
,
k_row_stride
,
k_head_stride
,
v_batch_stride
,
v_row_stride
,
v_head_stride
,
o_stride
,
o_head_stride
);
return
INFINI_STATUS_SUCCESS
;
default:
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
}
}
// namespace
struct
Descriptor
::
Opaque
{
std
::
shared_ptr
<
device
::
nvidia
::
Handle
::
Internal
>
internal
;
};
Descriptor
::~
Descriptor
()
{
delete
_opaque
;
}
infiniStatus_t
Descriptor
::
create
(
infiniopHandle_t
handle
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
out_desc
,
infiniopTensorDescriptor_t
q_desc
,
infiniopTensorDescriptor_t
k_cache_desc
,
infiniopTensorDescriptor_t
v_cache_desc
,
infiniopTensorDescriptor_t
block_tables_desc
,
infiniopTensorDescriptor_t
total_kv_lens_desc
,
infiniopTensorDescriptor_t
cum_seqlens_q_desc
,
const
std
::
optional
<
infiniopTensorDescriptor_t
>
&
alibi_slopes_desc
,
float
scale
)
{
auto
info
=
PagedAttentionPrefillInfo
::
create
(
out_desc
,
q_desc
,
k_cache_desc
,
v_cache_desc
,
block_tables_desc
,
total_kv_lens_desc
,
cum_seqlens_q_desc
,
alibi_slopes_desc
,
scale
);
CHECK_RESULT
(
info
);
// Optional split-kv prefill requires workspace for partial (m, l, acc).
// IMPORTANT: Unlike decode, prefill's total_q_tokens can be very large, so we must NOT reserve
// a huge workspace unless the user explicitly enables split-kv.
bool
use_splitkv
=
false
;
if
(
const
char
*
env
=
std
::
getenv
(
"INFINIOP_FLASH_PREFILL_SPLITKV"
))
{
use_splitkv
=
(
std
::
strcmp
(
env
,
"1"
)
==
0
)
||
(
std
::
strcmp
(
env
,
"true"
)
==
0
);
}
int
num_splits
=
1
;
if
(
use_splitkv
)
{
if
(
const
char
*
env
=
std
::
getenv
(
"INFINIOP_FLASH_PREFILL_NUM_SPLITS"
))
{
const
int
v
=
std
::
atoi
(
env
);
if
(
v
>
0
)
{
num_splits
=
v
;
}
}
else
{
num_splits
=
4
;
}
constexpr
int
kMaxSplits
=
8
;
if
(
num_splits
>
kMaxSplits
)
{
num_splits
=
kMaxSplits
;
}
}
const
size_t
n
=
info
->
total_q_tokens
*
info
->
num_heads
;
const
size_t
splitkv_workspace_bytes
=
use_splitkv
?
(
static_cast
<
size_t
>
(
num_splits
)
*
n
*
(
info
->
head_size
+
2
)
*
sizeof
(
float
))
:
0
;
const
size_t
workspace_bytes
=
splitkv_workspace_bytes
;
// const size_t workspace_bytes = splitkv_workspace_bytes + fa2_workspace_bytes;
*
desc_ptr
=
new
Descriptor
(
new
Opaque
{
reinterpret_cast
<
device
::
nvidia
::
Handle
*>
(
handle
)
->
internal
()},
info
.
take
(),
workspace_bytes
,
handle
->
device
,
handle
->
device_id
);
return
INFINI_STATUS_SUCCESS
;
}
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
out
,
const
void
*
q
,
const
void
*
k_cache
,
const
void
*
v_cache
,
const
void
*
block_tables
,
const
void
*
total_kv_lens
,
const
void
*
cum_seqlens_q
,
const
void
*
alibi_slopes
,
void
*
stream_
)
const
{
auto
stream
=
static_cast
<
cudaStream_t
>
(
stream_
);
const
float
*
alibi_ptr
=
(
alibi_slopes
==
nullptr
)
?
nullptr
:
static_cast
<
const
float
*>
(
alibi_slopes
);
const
auto
*
total_kv_lens_i64
=
static_cast
<
const
int64_t
*>
(
total_kv_lens
);
const
auto
*
cu_seqlens_q_i64
=
static_cast
<
const
int64_t
*>
(
cum_seqlens_q
);
bool
use_splitkv
=
false
;
if
(
const
char
*
env
=
std
::
getenv
(
"INFINIOP_FLASH_PREFILL_SPLITKV"
))
{
use_splitkv
=
(
std
::
strcmp
(
env
,
"1"
)
==
0
)
||
(
std
::
strcmp
(
env
,
"true"
)
==
0
);
}
int
num_splits
=
1
;
if
(
use_splitkv
)
{
if
(
const
char
*
env
=
std
::
getenv
(
"INFINIOP_FLASH_PREFILL_NUM_SPLITS"
))
{
const
int
v
=
std
::
atoi
(
env
);
if
(
v
>
0
)
{
num_splits
=
v
;
}
}
else
{
// Conservative default; users can override.
num_splits
=
4
;
}
constexpr
int
kMaxSplits
=
8
;
if
(
num_splits
>
kMaxSplits
)
{
num_splits
=
kMaxSplits
;
}
const
size_t
n
=
_info
.
total_q_tokens
*
_info
.
num_heads
;
const
size_t
required
=
static_cast
<
size_t
>
(
num_splits
)
*
n
*
(
_info
.
head_size
+
2
)
*
sizeof
(
float
);
if
(
workspace_size
<
required
)
{
return
INFINI_STATUS_INSUFFICIENT_WORKSPACE
;
}
}
if
(
use_splitkv
)
{
const
size_t
n
=
_info
.
total_q_tokens
*
_info
.
num_heads
;
float
*
partial_acc
=
static_cast
<
float
*>
(
workspace
);
float
*
partial_m
=
partial_acc
+
static_cast
<
size_t
>
(
num_splits
)
*
n
*
_info
.
head_size
;
float
*
partial_l
=
partial_m
+
static_cast
<
size_t
>
(
num_splits
)
*
n
;
// Dispatch by (Tdata, Tindex). total_kv_lens + cu_seqlens_q are currently always int64.
#define DISPATCH_SPLITKV(Tindex, Tdata, BT_PTR) \
return launch_prefill_warpcta8pipe_splitkv<Tindex, Tdata>( \
partial_acc, partial_m, partial_l, num_splits, \
static_cast<Tdata *>(out), \
static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), \
static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(BT_PTR), \
total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream)
if
(
_info
.
dtype
==
INFINI_DTYPE_F16
)
{
if
(
_info
.
index_dtype
==
INFINI_DTYPE_I64
)
{
DISPATCH_SPLITKV
(
int64_t
,
half
,
block_tables
);
}
if
(
_info
.
index_dtype
==
INFINI_DTYPE_I32
)
{
DISPATCH_SPLITKV
(
int32_t
,
half
,
block_tables
);
}
if
(
_info
.
index_dtype
==
INFINI_DTYPE_U32
)
{
DISPATCH_SPLITKV
(
uint32_t
,
half
,
block_tables
);
}
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
if
(
_info
.
dtype
==
INFINI_DTYPE_BF16
)
{
if
(
_info
.
index_dtype
==
INFINI_DTYPE_I64
)
{
DISPATCH_SPLITKV
(
int64_t
,
__nv_bfloat16
,
block_tables
);
}
if
(
_info
.
index_dtype
==
INFINI_DTYPE_I32
)
{
DISPATCH_SPLITKV
(
int32_t
,
__nv_bfloat16
,
block_tables
);
}
if
(
_info
.
index_dtype
==
INFINI_DTYPE_U32
)
{
DISPATCH_SPLITKV
(
uint32_t
,
__nv_bfloat16
,
block_tables
);
}
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
#undef DISPATCH_SPLITKV
}
// Default to the fastest validated kernel for supported shapes.
// "ref" is still available for debugging/correctness bisecting.
#define DISPATCH_KERNEL(Tindex, Tdata, Tcompute) \
do { \
const char *k_env = std::getenv("INFINIOP_FLASH_PREFILL_KERNEL"); \
const char *k = (k_env == nullptr) ? default_prefill_kernel(_info) : k_env; \
if (k_env != nullptr) { \
const bool known = (std::strcmp(k, "warp") == 0) || (std::strcmp(k, "warpcta") == 0) || (std::strcmp(k, "warpcta8") == 0) || (std::strcmp(k, "warpcta8pipe") == 0) || (std::strcmp(k, "warpcta8mma") == 0) || (std::strcmp(k, "warpcta8n128") == 0) || (std::strcmp(k, "warpcta16") == 0) || (std::strcmp(k, "ref") == 0); \
if (!known) { \
const char *fallback = default_prefill_kernel(_info); \
std::fprintf(stderr, \
"[infiniop][paged_attention_prefill] WARNING: unknown kernel '%s', falling back to '%s'\n", \
k, fallback); \
k = fallback; \
} \
} \
const char *dbg = std::getenv("INFINIOP_DEBUG_PREFILL_DISPATCH"); \
static bool printed_dispatch = false; \
if (!printed_dispatch && dbg != nullptr && std::strcmp(dbg, "1") == 0) { \
std::fprintf(stderr, \
"[infiniop][paged_attention_prefill] kernel=%s (override=%s head_size=%zu block=%zu dtype=%zu)\n", \
k, \
(k_env == nullptr ? "auto" : "env"), \
static_cast<size_t>(_info.head_size), \
static_cast<size_t>(_info.page_block_size), \
static_cast<size_t>(_info.dtype)); \
printed_dispatch = true; \
} \
if (std::strcmp(k, "warp") == 0) { \
return launch_prefill_warp<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if (std::strcmp(k, "warpcta") == 0) { \
return launch_prefill<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if (std::strcmp(k, "warpcta8") == 0) { \
return launch_prefill_warpcta8<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if (std::strcmp(k, "warpcta8pipe") == 0) { \
return launch_prefill_warpcta8pipe<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if constexpr (std::is_same_v<Tdata, half>) { \
if (std::strcmp(k, "warpcta8mma") == 0) { \
return launch_prefill_warpcta8mma<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
} \
if (std::strcmp(k, "warpcta8n128") == 0) { \
return launch_prefill_warpcta8n128<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if (std::strcmp(k, "warpcta16") == 0) { \
return launch_prefill_warpcta16<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if (std::strcmp(k, "ref") == 0) { \
return launch_prefill_ref<Tindex, Tdata, Tcompute>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
return INFINI_STATUS_BAD_PARAM; \
} while (false)
#define DISPATCH_INDEX(Tindex) \
do { \
if (_info.dtype == INFINI_DTYPE_F16) { \
DISPATCH_KERNEL(Tindex, half, float); \
} \
if (_info.dtype == INFINI_DTYPE_BF16) { \
DISPATCH_KERNEL(Tindex, __nv_bfloat16, float); \
} \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
} while (false)
if
(
_info
.
index_dtype
==
INFINI_DTYPE_I64
)
{
DISPATCH_INDEX
(
int64_t
);
}
else
if
(
_info
.
index_dtype
==
INFINI_DTYPE_I32
)
{
DISPATCH_INDEX
(
int32_t
);
}
else
if
(
_info
.
index_dtype
==
INFINI_DTYPE_U32
)
{
DISPATCH_INDEX
(
uint32_t
);
}
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
}
// namespace op::paged_attention_prefill::nvidia
src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cuh
0 → 100644
View file @
8d09630a
#ifndef __PAGED_ATTENTION_PREFILL_NVIDIA_H__
#define __PAGED_ATTENTION_PREFILL_NVIDIA_H__
#include "../paged_attention_prefill.h"
DESCRIPTOR
(
nvidia
)
#endif // __PAGED_ATTENTION_PREFILL_NVIDIA_H__
src/infiniop/ops/paged_attention_prefill/operator.cc
0 → 100644
View file @
8d09630a
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/paged_attention_prefill.h"
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API)
#include "nvidia/paged_attention_prefill_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
#include "metax/paged_attention_prefill_metax.h"
#endif
#ifdef ENABLE_MOORE_API
#include "moore/paged_attention_prefill_moore.h"
#endif
__C
infiniStatus_t
infiniopCreatePagedAttentionPrefillDescriptor
(
infiniopHandle_t
handle
,
infiniopPagedAttentionPrefillDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
out_desc
,
infiniopTensorDescriptor_t
q_desc
,
infiniopTensorDescriptor_t
k_cache_desc
,
infiniopTensorDescriptor_t
v_cache_desc
,
infiniopTensorDescriptor_t
block_tables_desc
,
infiniopTensorDescriptor_t
seq_lens_desc
,
infiniopTensorDescriptor_t
cum_seq_lens_q_desc
,
infiniopTensorDescriptor_t
alibi_slopes_desc
,
float
scale
)
{
infiniopTensorDescriptor_t
alibi_opt
=
(
alibi_slopes_desc
==
nullptr
)
?
nullptr
:
alibi_slopes_desc
;
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::paged_attention_prefill::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::paged_attention_prefill::NAMESPACE::Descriptor **>(desc_ptr), \
out_desc, q_desc, k_cache_desc, v_cache_desc, block_tables_desc, \
seq_lens_desc, cum_seq_lens_q_desc, alibi_opt, scale);
switch
(
handle
->
device
)
{
#ifdef ENABLE_NVIDIA_API
CREATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
#ifdef ENABLE_METAX_API
CREATE
(
INFINI_DEVICE_METAX
,
metax
)
#endif
#ifdef ENABLE_ALI_API
CREATE
(
INFINI_DEVICE_ALI
,
nvidia
)
#endif
#ifdef ENABLE_ILUVATAR_API
CREATE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
)
#endif
#ifdef ENABLE_MOORE_API
CREATE
(
INFINI_DEVICE_MOORE
,
moore
)
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
}
__C
infiniStatus_t
infiniopGetPagedAttentionPrefillWorkspaceSize
(
infiniopPagedAttentionPrefillDescriptor_t
desc
,
size_t
*
size
)
{
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<op::paged_attention_prefill::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS;
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_NVIDIA_API
GET
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
#ifdef ENABLE_METAX_API
GET
(
INFINI_DEVICE_METAX
,
metax
)
#endif
#ifdef ENABLE_ALI_API
GET
(
INFINI_DEVICE_ALI
,
nvidia
)
#endif
#ifdef ENABLE_ILUVATAR_API
GET
(
INFINI_DEVICE_ILUVATAR
,
nvidia
)
#endif
#ifdef ENABLE_MOORE_API
GET
(
INFINI_DEVICE_MOORE
,
moore
)
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
}
__C
infiniStatus_t
infiniopPagedAttentionPrefill
(
infiniopPagedAttentionPrefillDescriptor_t
desc
,
void
*
workspace
,
size_t
workspace_size
,
void
*
out
,
const
void
*
q
,
const
void
*
k_cache
,
const
void
*
v_cache
,
const
void
*
block_tables
,
const
void
*
seq_lens
,
const
void
*
cum_seq_lens_q
,
const
void
*
alibi_slopes
,
void
*
stream
)
{
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<op::paged_attention_prefill::NAMESPACE::Descriptor *>(desc)->calculate( \
workspace, workspace_size, out, q, k_cache, v_cache, block_tables, \
seq_lens, cum_seq_lens_q, alibi_slopes, stream);
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_NVIDIA_API
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
#ifdef ENABLE_METAX_API
CALCULATE
(
INFINI_DEVICE_METAX
,
metax
)
#endif
#ifdef ENABLE_ALI_API
CALCULATE
(
INFINI_DEVICE_ALI
,
nvidia
)
#endif
#ifdef ENABLE_ILUVATAR_API
CALCULATE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
)
#endif
#ifdef ENABLE_MOORE_API
CALCULATE
(
INFINI_DEVICE_MOORE
,
moore
)
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
}
__C
infiniStatus_t
infiniopDestroyPagedAttentionPrefillDescriptor
(
infiniopPagedAttentionPrefillDescriptor_t
desc
)
{
#define DESTROY(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<op::paged_attention_prefill::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_NVIDIA_API
DESTROY
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
#ifdef ENABLE_METAX_API
DESTROY
(
INFINI_DEVICE_METAX
,
metax
)
#endif
#ifdef ENABLE_ALI_API
DESTROY
(
INFINI_DEVICE_ALI
,
nvidia
)
#endif
#ifdef ENABLE_ILUVATAR_API
DESTROY
(
INFINI_DEVICE_ILUVATAR
,
nvidia
)
#endif
#ifdef ENABLE_MOORE_API
DESTROY
(
INFINI_DEVICE_MOORE
,
moore
)
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
}
src/infiniop/ops/paged_attention_prefill/paged_attention_prefill.h
0 → 100644
View file @
8d09630a
#ifndef PAGED_ATTENTION_PREFILL_H
#define PAGED_ATTENTION_PREFILL_H
#include "../../operator.h"
#include "info.h"
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::paged_attention_prefill::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
PagedAttentionPrefillInfo _info; \
size_t _workspace_size; \
\
Descriptor( \
Opaque *opaque, \
PagedAttentionPrefillInfo info, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t out_desc, \
infiniopTensorDescriptor_t q_desc, \
infiniopTensorDescriptor_t k_cache_desc, \
infiniopTensorDescriptor_t v_cache_desc, \
infiniopTensorDescriptor_t block_tables_desc, \
infiniopTensorDescriptor_t seq_lens_desc, \
infiniopTensorDescriptor_t cum_seq_lens_q_desc, \
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc, \
float scale); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *out, const void *q, const void *k_cache, const void *v_cache, \
const void *block_tables, \
const void *seq_lens, \
const void *cum_seq_lens_q, \
const void *alibi_slopes, \
void *stream) const; \
}; \
}
#endif // PAGED_ATTENTION_PREFILL_H
src/infiniop/ops/paged_caching/cuda/kernel.cuh
0 → 100644
View file @
8d09630a
#ifndef __PAGED_CACHING_KERNEL_CUH__
#define __PAGED_CACHING_KERNEL_CUH__
//================================================================================
// Paged Caching Operator CUDA Kernel
//
// This kernel implements the "paged_caching" operation, which copies Key and Value
// vectors from a contiguous source tensor into a paged, non-contiguous KV Cache.
//
// Design Principles:
// 1. Token-Centric Parallelism: A 1D grid of `num_tokens` is launched. Each CUDA
// block is responsible for caching one full token (all its heads).
// 2. Coalesced Memory Access: This grid strategy ensures that threads within a
// block read a large, contiguous chunk of memory from the source tensors,
// maximizing memory bandwidth utilization.
// 3. Vectorization: The copy operation is vectorized to further enhance memory
// throughput, processing multiple data elements in a single instruction.
//================================================================================
namespace
op
::
paged_caching
::
cuda
{
template
<
typename
Tdata
,
// Data type of the tensors (e.g., half, __nv_bfloat16)
int
NUM_THREADS
// Number of threads per block, configured at launch time
>
__device__
void
pagedCachingKernel
(
// ----- Output Tensors -----
Tdata
*
k_cache_ptr
,
// Pointer to the destination K cache pool [num_blocks, nkvh, block_size, dh]
Tdata
*
v_cache_ptr
,
// Pointer to the destination V cache pool [num_blocks, nkvh, block_size, dh]
// ----- Input Tensors -----
const
Tdata
*
k_ptr
,
// Pointer to the source Keys, shape [ntok, nkvh, dh]
const
Tdata
*
v_ptr
,
// Pointer to the source Values, shape [ntok, nkvh, dh]
const
int64_t
*
slot_mapping_ptr
,
// Pointer to the slot mapping, shape [ntok]
// ----- Metadata -----
const
size_t
head_size
,
// Dimension of each head (dh)
const
size_t
block_size
,
// Number of tokens per block in the KV cache
// ----- Stride Information -----
const
ptrdiff_t
k_src_stride
,
// Stride between tokens in the source K tensor
const
ptrdiff_t
v_src_stride
,
// Stride between tokens in the source V tensor
const
ptrdiff_t
k_cache_block_stride
,
// Stride between blocks in the K cache pool
const
ptrdiff_t
v_cache_block_stride
// Stride between blocks in the V cache pool
)
{
//================================================================================
// 1. Identify Work Unit & Calculate Addresses
//================================================================================
// Each block processes one token.
const
int
token_idx
=
blockIdx
.
y
;
const
int
head_idx
=
blockIdx
.
x
;
// const int num_kv_heads = gridDim.y;
// Retrieve the destination slot for the current token.
const
int64_t
slot_idx
=
slot_mapping_ptr
[
token_idx
];
// Handle padding: if slot_idx is negative, this token is padding and should be ignored.
if
(
slot_idx
<
0
)
{
return
;
}
// Calculate the physical block index and the offset within that block.
const
int64_t
physical_block_idx
=
slot_idx
/
block_size
;
const
int64_t
block_offset
=
slot_idx
%
block_size
;
// Calculate base pointers for source and destination for this specific token.
const
Tdata
*
k_src_head_ptr
=
k_ptr
+
token_idx
*
k_src_stride
+
head_idx
*
head_size
;
const
Tdata
*
v_src_head_ptr
=
v_ptr
+
token_idx
*
v_src_stride
+
head_idx
*
head_size
;
// Destination pointer calculation assumes a [num_blocks, block_size, num_heads, head_size] layout.
// We point to the beginning of the memory region for this token's slot.
const
ptrdiff_t
cache_head_stride
=
block_size
*
head_size
;
Tdata
*
k_cache_block_base_ptr
=
k_cache_ptr
+
physical_block_idx
*
k_cache_block_stride
;
Tdata
*
k_dst_head_ptr
=
k_cache_block_base_ptr
+
head_idx
*
cache_head_stride
+
block_offset
*
head_size
;
Tdata
*
v_cache_block_base_ptr
=
v_cache_ptr
+
physical_block_idx
*
v_cache_block_stride
;
Tdata
*
v_dst_head_ptr
=
v_cache_block_base_ptr
+
head_idx
*
cache_head_stride
+
block_offset
*
head_size
;
//================================================================================
// 2. Perform Element-wise Data Copy (Safe, Non-Vectorized)
//================================================================================
for
(
int
i
=
threadIdx
.
x
;
i
<
head_size
;
i
+=
NUM_THREADS
)
{
k_dst_head_ptr
[
i
]
=
k_src_head_ptr
[
i
];
v_dst_head_ptr
[
i
]
=
v_src_head_ptr
[
i
];
}
}
}
// namespace op::paged_caching::cuda
#endif // __PAGED_CACHING_KERNEL_CUH__
src/infiniop/ops/paged_caching/info.h
0 → 100644
View file @
8d09630a
#ifndef __PAGED_CACHING_INFO_H__
#define __PAGED_CACHING_INFO_H__
#include "../../../utils.h"
#include "../../tensor.h"
#include <optional>
#include <vector>
namespace
op
::
paged_caching
{
class
PagedCachingInfo
{
PagedCachingInfo
()
=
default
;
public:
// --- Data Type ---
infiniDtype_t
dtype
;
// --- Shape Dimensions ---
size_t
num_tokens
;
size_t
num_kv_heads
;
size_t
head_size
;
size_t
block_size
;
// --- Strides for Memory Layout ---
ptrdiff_t
k_src_stride
;
ptrdiff_t
v_src_stride
;
ptrdiff_t
k_cache_block_stride
;
ptrdiff_t
v_cache_block_stride
;
static
utils
::
Result
<
PagedCachingInfo
>
create
(
infiniopTensorDescriptor_t
k_cache_desc
,
infiniopTensorDescriptor_t
v_cache_desc
,
infiniopTensorDescriptor_t
k_desc
,
infiniopTensorDescriptor_t
v_desc
,
infiniopTensorDescriptor_t
slot_mapping_desc
)
{
auto
dtype
=
k_desc
->
dtype
();
CHECK_DTYPE
(
dtype
,
INFINI_DTYPE_F16
,
INFINI_DTYPE_BF16
,
INFINI_DTYPE_F32
);
if
(
v_desc
->
dtype
()
!=
dtype
||
k_cache_desc
->
dtype
()
!=
dtype
||
v_cache_desc
->
dtype
()
!=
dtype
)
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
if
(
slot_mapping_desc
->
dtype
()
!=
INFINI_DTYPE_I64
)
{
printf
(
"slot_mapping must be int64_t.
\n
"
);
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
if
(
k_desc
->
ndim
()
!=
3
||
v_desc
->
ndim
()
!=
3
||
k_cache_desc
->
ndim
()
<
4
||
v_cache_desc
->
ndim
()
<
4
||
slot_mapping_desc
->
ndim
()
!=
1
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
// PagedCachingInfo info;
// --- Extract shape dimensions ---
auto
k_shape
=
k_desc
->
shape
();
auto
k_cache_shape
=
k_cache_desc
->
shape
();
size_t
num_tokens
=
slot_mapping_desc
->
shape
()[
0
];
size_t
num_kv_heads
=
k_shape
[
1
];
size_t
head_size
=
k_shape
[
2
];
size_t
block_size
=
k_cache_shape
[
2
];
// Assuming [num_blocks, num_heads, block_size, head_size] layout
// --- Extract strides for memory access ---
ptrdiff_t
k_src_stride
=
k_desc
->
stride
(
0
);
ptrdiff_t
v_src_stride
=
v_desc
->
stride
(
0
);
ptrdiff_t
k_cache_block_stride
=
k_cache_desc
->
stride
(
0
);
ptrdiff_t
v_cache_block_stride
=
v_cache_desc
->
stride
(
0
);
return
utils
::
Result
<
PagedCachingInfo
>
(
PagedCachingInfo
{
dtype
,
num_tokens
,
num_kv_heads
,
head_size
,
block_size
,
k_src_stride
,
v_src_stride
,
k_cache_block_stride
,
v_cache_block_stride
});
}
};
}
// namespace op::paged_caching
#endif // __PAGED_CACHING_INFO_H__
src/infiniop/ops/paged_caching/metax/paged_caching_metax.h
0 → 100644
View file @
8d09630a
#ifndef __PAGED_CACHING_METAX_H__
#define __PAGED_CACHING_METAX_H__
#include "../paged_caching.h"
DESCRIPTOR
(
metax
)
#endif // __PAGED_CACHING_METAX_H__
src/infiniop/ops/paged_caching/metax/paged_caching_metax.maca
0 → 100644
View file @
8d09630a
#include "../../../devices/metax/metax_common.h"
#include "../../../devices/metax/metax_kernel_common.h"
#include "../cuda/kernel.cuh"
#include "paged_caching_metax.h"
template <typename Tdata, int NUM_THREADS>
INFINIOP_METAX_KERNEL pagedCaching(
Tdata *k_cache, Tdata *v_cache,
const Tdata *k, const Tdata *v,
const int64_t *slot_mapping,
const size_t head_size, const size_t block_size,
const ptrdiff_t k_src_stride, const ptrdiff_t v_src_stride,
const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride) {
op::paged_caching::cuda::pagedCachingKernel<Tdata, NUM_THREADS>(
k_cache, v_cache, k, v, slot_mapping, head_size,
block_size, k_src_stride, v_src_stride, k_cache_block_stride, v_cache_block_stride);
}
namespace op::paged_caching::metax {
// PIMPL struct definition
struct Descriptor::Opaque {
std::shared_ptr<device::metax::Handle::Internal> internal;
};
// Destructor implementation
Descriptor::~Descriptor() {
delete _opaque;
}
// Static factory method implementation
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t slot_mapping_desc) {
auto info = PagedCachingInfo::create(k_cache_desc, v_cache_desc, k_desc, v_desc, slot_mapping_desc);
CHECK_RESULT(info);
// Create and return the Descriptor instance.
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
info.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
// The launchKernel function is a templated helper to encapsulate the kernel launch.
// It sets up grid/block dimensions and calls the device-side kernel.
template <int NUM_THREADS>
infiniStatus_t launchKernel(const PagedCachingInfo &info,
void *k_cache, void *v_cache,
infiniDtype_t dtype,
const void *k, const void *v,
const void *slot_mapping,
size_t num_tokens, size_t num_kv_heads, size_t head_size, size_t block_size,
ptrdiff_t k_src_stride, ptrdiff_t v_src_stride,
ptrdiff_t k_cache_block_stride, ptrdiff_t v_cache_block_stride,
hcStream_t stream) {
// Grid dimension is 1D, with one block per token, as we decided.
dim3 grid(uint64_t(num_kv_heads), uint64_t(num_tokens), 1);
// Block dimension is 1D, using the number of threads specified at compile time.
dim3 block(NUM_THREADS);
// This kernel does not require dynamic shared memory.
size_t shared_mem_size = 0;
// Launch the device-side kernel.
if (dtype == INFINI_DTYPE_F16) {
pagedCaching<half, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
(half *)k_cache,
(half *)v_cache,
(const half *)k,
(const half *)v,
(const int64_t *)slot_mapping,
head_size,
block_size,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
} else if (dtype == INFINI_DTYPE_BF16) {
pagedCaching<cuda_bfloat16, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
(cuda_bfloat16 *)k_cache,
(cuda_bfloat16 *)v_cache,
(const cuda_bfloat16 *)k,
(const cuda_bfloat16 *)v,
(const int64_t *)slot_mapping,
head_size,
block_size,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
} else if (dtype == INFINI_DTYPE_F32) {
pagedCaching<float, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
(float *)k_cache,
(float *)v_cache,
(const float *)k,
(const float *)v,
(const int64_t *)slot_mapping,
head_size,
block_size,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
// Execution method implementation
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *k_cache, void *v_cache,
const void *k, const void *v,
const void *slot_mapping,
void *stream_) const {
hcStream_t stream = (hcStream_t)stream_;
// Dispatch logic based on the device's maximum threads per block.
// This allows selecting the largest, most efficient block size the hardware supports.
int max_threads = _opaque->internal->maxThreadsPerBlock();
if (max_threads >= METAX_BLOCK_SIZE_1024) {
// Dispatch based on data type for a 1024-thread block.
launchKernel<METAX_BLOCK_SIZE_1024>(
_info, k_cache, v_cache, _info.dtype, k, v, slot_mapping,
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride,
stream);
} else if (max_threads >= METAX_BLOCK_SIZE_512) {
launchKernel<METAX_BLOCK_SIZE_512>(
_info, k_cache, v_cache, _info.dtype, k, v, slot_mapping,
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride,
stream);
} else {
// If the device supports fewer threads, return an error.
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::paged_caching::metax
src/infiniop/ops/paged_caching/moore/paged_caching_moore.h
0 → 100644
View file @
8d09630a
#ifndef __PAGED_CACHING_MOORE_H__
#define __PAGED_CACHING_MOORE_H__
#include "../paged_caching.h"
DESCRIPTOR
(
moore
)
#endif // __PAGED_CACHING_MOORE_H__
src/infiniop/ops/paged_caching/moore/paged_caching_moore.mu
0 → 100644
View file @
8d09630a
#include "../../../devices/moore/moore_common.h"
#include "../../../devices/moore/moore_kernel_common.h"
#include "../cuda/kernel.cuh"
#include "paged_caching_moore.h"
template <typename Tdata, int NUM_THREADS>
INFINIOP_MOORE_KERNEL pagedCaching(
Tdata *k_cache, Tdata *v_cache,
const Tdata *k, const Tdata *v,
const int64_t *slot_mapping,
const size_t head_size, const size_t block_size,
const ptrdiff_t k_src_stride, const ptrdiff_t v_src_stride,
const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride) {
op::paged_caching::cuda::pagedCachingKernel<Tdata, NUM_THREADS>(
k_cache, v_cache, k, v, slot_mapping, head_size,
block_size, k_src_stride, v_src_stride, k_cache_block_stride, v_cache_block_stride);
}
namespace op::paged_caching::moore {
// PIMPL struct definition
struct Descriptor::Opaque {
std::shared_ptr<device::moore::Handle::Internal> internal;
};
// Destructor implementation
Descriptor::~Descriptor() {
delete _opaque;
}
// Static factory method implementation
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t slot_mapping_desc) {
auto info = PagedCachingInfo::create(k_cache_desc, v_cache_desc, k_desc, v_desc, slot_mapping_desc);
CHECK_RESULT(info);
// Create and return the Descriptor instance.
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::moore::Handle *>(handle)->internal()},
info.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
// The launchKernel function is a templated helper to encapsulate the MUSA kernel launch.
// It sets up grid/block dimensions and calls the device-side kernel.
template <int NUM_THREADS>
infiniStatus_t launchKernel(const PagedCachingInfo &info,
void *k_cache, void *v_cache,
infiniDtype_t dtype,
const void *k, const void *v,
const void *slot_mapping,
size_t num_tokens, size_t num_kv_heads, size_t head_size, size_t block_size,
ptrdiff_t k_src_stride, ptrdiff_t v_src_stride,
ptrdiff_t k_cache_block_stride, ptrdiff_t v_cache_block_stride,
musaStream_t stream) {
// Grid dimension is 1D, with one block per token, as we decided.
dim3 grid(uint64_t(num_kv_heads), uint64_t(num_tokens), 1);
// Block dimension is 1D, using the number of threads specified at compile time.
dim3 block(NUM_THREADS);
// This kernel does not require dynamic shared memory.
size_t shared_mem_size = 0;
// Launch the device-side MUSA kernel.
if (dtype == INFINI_DTYPE_F16) {
pagedCaching<half, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
(half *)k_cache,
(half *)v_cache,
(const half *)k,
(const half *)v,
(const int64_t *)slot_mapping,
head_size,
block_size,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
} else if (dtype == INFINI_DTYPE_BF16) {
pagedCaching<__mt_bfloat16, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
(__mt_bfloat16 *)k_cache,
(__mt_bfloat16 *)v_cache,
(const __mt_bfloat16 *)k,
(const __mt_bfloat16 *)v,
(const int64_t *)slot_mapping,
head_size,
block_size,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
} else if (dtype == INFINI_DTYPE_F32) {
pagedCaching<float, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
(float *)k_cache,
(float *)v_cache,
(const float *)k,
(const float *)v,
(const int64_t *)slot_mapping,
head_size,
block_size,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
// Execution method implementation
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *k_cache, void *v_cache,
const void *k, const void *v,
const void *slot_mapping,
void *stream_) const {
musaStream_t stream = (musaStream_t)stream_;
// Dispatch logic based on the GPU's maximum threads per block.
// This allows selecting the largest, most efficient block size the hardware supports.
if (_opaque->internal->maxThreadsPerBlock() >= MOORE_BLOCK_SIZE_1024) {
// Dispatch based on data type for a 1024-thread block.
launchKernel<MOORE_BLOCK_SIZE_1024>(
_info, k_cache, v_cache, _info.dtype, k, v, slot_mapping,
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride,
stream);
} else if (_opaque->internal->maxThreadsPerBlock() >= MOORE_BLOCK_SIZE_512) {
launchKernel<MOORE_BLOCK_SIZE_512>(
_info, k_cache, v_cache, _info.dtype, k, v, slot_mapping,
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride,
stream);
} else {
// If the GPU is older and supports fewer threads, return an error.
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::paged_caching::moore
src/infiniop/ops/paged_caching/nvidia/paged_caching_nvidia.cu
0 → 100644
View file @
8d09630a
#include "../../../devices/nvidia/nvidia_common.cuh"
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
#include "../cuda/kernel.cuh"
#include "paged_caching_nvidia.cuh"
template
<
typename
Tdata
,
int
NUM_THREADS
>
INFINIOP_CUDA_KERNEL
pagedCaching
(
Tdata
*
k_cache
,
Tdata
*
v_cache
,
const
Tdata
*
k
,
const
Tdata
*
v
,
const
int64_t
*
slot_mapping
,
const
size_t
head_size
,
const
size_t
block_size
,
const
ptrdiff_t
k_src_stride
,
const
ptrdiff_t
v_src_stride
,
const
ptrdiff_t
k_cache_block_stride
,
const
ptrdiff_t
v_cache_block_stride
)
{
op
::
paged_caching
::
cuda
::
pagedCachingKernel
<
Tdata
,
NUM_THREADS
>
(
k_cache
,
v_cache
,
k
,
v
,
slot_mapping
,
head_size
,
block_size
,
k_src_stride
,
v_src_stride
,
k_cache_block_stride
,
v_cache_block_stride
);
}
namespace
op
::
paged_caching
::
nvidia
{
// PIMPL struct definition
struct
Descriptor
::
Opaque
{
std
::
shared_ptr
<
device
::
nvidia
::
Handle
::
Internal
>
internal
;
};
// Destructor implementation
Descriptor
::~
Descriptor
()
{
delete
_opaque
;
}
// Static factory method implementation
infiniStatus_t
Descriptor
::
create
(
infiniopHandle_t
handle
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
k_cache_desc
,
infiniopTensorDescriptor_t
v_cache_desc
,
infiniopTensorDescriptor_t
k_desc
,
infiniopTensorDescriptor_t
v_desc
,
infiniopTensorDescriptor_t
slot_mapping_desc
)
{
auto
info
=
PagedCachingInfo
::
create
(
k_cache_desc
,
v_cache_desc
,
k_desc
,
v_desc
,
slot_mapping_desc
);
CHECK_RESULT
(
info
);
// Create and return the Descriptor instance.
*
desc_ptr
=
new
Descriptor
(
new
Opaque
{
reinterpret_cast
<
device
::
nvidia
::
Handle
*>
(
handle
)
->
internal
()},
info
.
take
(),
0
,
handle
->
device
,
handle
->
device_id
);
return
INFINI_STATUS_SUCCESS
;
}
// The launchKernel function is a templated helper to encapsulate the CUDA kernel launch.
// It sets up grid/block dimensions and calls the device-side kernel.
template
<
int
NUM_THREADS
>
infiniStatus_t
launchKernel
(
const
PagedCachingInfo
&
info
,
void
*
k_cache
,
void
*
v_cache
,
infiniDtype_t
dtype
,
const
void
*
k
,
const
void
*
v
,
const
void
*
slot_mapping
,
size_t
num_tokens
,
size_t
num_kv_heads
,
size_t
head_size
,
size_t
block_size
,
ptrdiff_t
k_src_stride
,
ptrdiff_t
v_src_stride
,
ptrdiff_t
k_cache_block_stride
,
ptrdiff_t
v_cache_block_stride
,
cudaStream_t
stream
)
{
// Grid dimension is 1D, with one block per token, as we decided.
dim3
grid
(
uint64_t
(
num_kv_heads
),
uint64_t
(
num_tokens
),
1
);
// Block dimension is 1D, using the number of threads specified at compile time.
dim3
block
(
NUM_THREADS
);
// This kernel does not require dynamic shared memory.
size_t
shared_mem_size
=
0
;
// Launch the device-side CUDA kernel.
if
(
dtype
==
INFINI_DTYPE_F16
)
{
pagedCaching
<
half
,
NUM_THREADS
>
<<<
grid
,
block
,
shared_mem_size
,
stream
>>>
(
(
half
*
)
k_cache
,
(
half
*
)
v_cache
,
(
const
half
*
)
k
,
(
const
half
*
)
v
,
(
const
int64_t
*
)
slot_mapping
,
head_size
,
block_size
,
k_src_stride
,
v_src_stride
,
k_cache_block_stride
,
v_cache_block_stride
);
}
else
if
(
dtype
==
INFINI_DTYPE_BF16
)
{
pagedCaching
<
__nv_bfloat16
,
NUM_THREADS
>
<<<
grid
,
block
,
shared_mem_size
,
stream
>>>
(
(
__nv_bfloat16
*
)
k_cache
,
(
__nv_bfloat16
*
)
v_cache
,
(
const
__nv_bfloat16
*
)
k
,
(
const
__nv_bfloat16
*
)
v
,
(
const
int64_t
*
)
slot_mapping
,
head_size
,
block_size
,
k_src_stride
,
v_src_stride
,
k_cache_block_stride
,
v_cache_block_stride
);
}
else
if
(
dtype
==
INFINI_DTYPE_F32
)
{
pagedCaching
<
float
,
NUM_THREADS
>
<<<
grid
,
block
,
shared_mem_size
,
stream
>>>
(
(
float
*
)
k_cache
,
(
float
*
)
v_cache
,
(
const
float
*
)
k
,
(
const
float
*
)
v
,
(
const
int64_t
*
)
slot_mapping
,
head_size
,
block_size
,
k_src_stride
,
v_src_stride
,
k_cache_block_stride
,
v_cache_block_stride
);
}
else
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
return
INFINI_STATUS_SUCCESS
;
}
// Execution method implementation
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
k_cache
,
void
*
v_cache
,
const
void
*
k
,
const
void
*
v
,
const
void
*
slot_mapping
,
void
*
stream_
)
const
{
cudaStream_t
stream
=
(
cudaStream_t
)
stream_
;
// Dispatch logic based on the GPU's maximum threads per block.
// This allows selecting the largest, most efficient block size the hardware supports.
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
>=
CUDA_BLOCK_SIZE_1024
)
{
// Dispatch based on data type for a 1024-thread block.
launchKernel
<
CUDA_BLOCK_SIZE_1024
>
(
_info
,
k_cache
,
v_cache
,
_info
.
dtype
,
k
,
v
,
slot_mapping
,
_info
.
num_tokens
,
_info
.
num_kv_heads
,
_info
.
head_size
,
_info
.
block_size
,
_info
.
k_src_stride
,
_info
.
v_src_stride
,
_info
.
k_cache_block_stride
,
_info
.
v_cache_block_stride
,
stream
);
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
>=
CUDA_BLOCK_SIZE_512
)
{
launchKernel
<
CUDA_BLOCK_SIZE_512
>
(
_info
,
k_cache
,
v_cache
,
_info
.
dtype
,
k
,
v
,
slot_mapping
,
_info
.
num_tokens
,
_info
.
num_kv_heads
,
_info
.
head_size
,
_info
.
block_size
,
_info
.
k_src_stride
,
_info
.
v_src_stride
,
_info
.
k_cache_block_stride
,
_info
.
v_cache_block_stride
,
stream
);
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
>=
CUDA_BLOCK_SIZE_4096
)
{
launchKernel
<
CUDA_BLOCK_SIZE_4096
>
(
_info
,
k_cache
,
v_cache
,
_info
.
dtype
,
k
,
v
,
slot_mapping
,
_info
.
num_tokens
,
_info
.
num_kv_heads
,
_info
.
head_size
,
_info
.
block_size
,
_info
.
k_src_stride
,
_info
.
v_src_stride
,
_info
.
k_cache_block_stride
,
_info
.
v_cache_block_stride
,
stream
);
}
else
{
// If the GPU is older and supports fewer threads, return an error.
return
INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED
;
}
return
INFINI_STATUS_SUCCESS
;
}
}
// namespace op::paged_caching::nvidia
src/infiniop/ops/paged_caching/nvidia/paged_caching_nvidia.cuh
0 → 100644
View file @
8d09630a
#ifndef __PAGED_CACHING_NVIDIA_H__
#define __PAGED_CACHING_NVIDIA_H__
#include "../paged_caching.h"
DESCRIPTOR
(
nvidia
)
#endif // __PAGED_CACHING_NVIDIA_H__
src/infiniop/ops/paged_caching/operator.cc
0 → 100644
View file @
8d09630a
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/paged_caching.h"
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API)
#include "nvidia/paged_caching_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
#include "metax/paged_caching_metax.h"
#endif
#ifdef ENABLE_MOORE_API
#include "moore/paged_caching_moore.h"
#endif
__C
infiniStatus_t
infiniopCreatePagedCachingDescriptor
(
infiniopHandle_t
handle
,
infiniopPagedCachingDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
k_cache_desc
,
infiniopTensorDescriptor_t
v_cache_desc
,
infiniopTensorDescriptor_t
k_desc
,
infiniopTensorDescriptor_t
v_desc
,
infiniopTensorDescriptor_t
slot_mapping_desc
)
{
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::paged_caching::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::paged_caching::NAMESPACE::Descriptor **>(desc_ptr), \
k_cache_desc, v_cache_desc, k_desc, v_desc, slot_mapping_desc);
switch
(
handle
->
device
)
{
#ifdef ENABLE_NVIDIA_API
CREATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
#ifdef ENABLE_METAX_API
CREATE
(
INFINI_DEVICE_METAX
,
metax
)
#endif
#ifdef ENABLE_ALI_API
CREATE
(
INFINI_DEVICE_ALI
,
nvidia
)
#endif
#ifdef ENABLE_ILUVATAR_API
CREATE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
)
#endif
#ifdef ENABLE_MOORE_API
CREATE
(
INFINI_DEVICE_MOORE
,
moore
)
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
}
__C
infiniStatus_t
infiniopGetPagedCachingWorkspaceSize
(
infiniopPagedCachingDescriptor_t
desc
,
size_t
*
size
)
{
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<op::paged_caching::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS;
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_NVIDIA_API
GET
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
#ifdef ENABLE_METAX_API
GET
(
INFINI_DEVICE_METAX
,
metax
)
#endif
#ifdef ENABLE_ALI_API
GET
(
INFINI_DEVICE_ALI
,
nvidia
)
#endif
#ifdef ENABLE_ILUVATAR_API
GET
(
INFINI_DEVICE_ILUVATAR
,
nvidia
)
#endif
#ifdef ENABLE_MOORE_API
GET
(
INFINI_DEVICE_MOORE
,
moore
)
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
}
__C
infiniStatus_t
infiniopPagedCaching
(
infiniopPagedCachingDescriptor_t
desc
,
void
*
workspace
,
size_t
workspace_size
,
void
*
k_cache
,
void
*
v_cache
,
const
void
*
k
,
const
void
*
v
,
const
void
*
slot_mapping
,
void
*
stream
)
{
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<op::paged_caching::NAMESPACE::Descriptor *>(desc)->calculate( \
workspace, workspace_size, k_cache, v_cache, k, v, slot_mapping, stream);
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_NVIDIA_API
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
#ifdef ENABLE_METAX_API
CALCULATE
(
INFINI_DEVICE_METAX
,
metax
)
#endif
#ifdef ENABLE_ALI_API
CALCULATE
(
INFINI_DEVICE_ALI
,
nvidia
)
#endif
#ifdef ENABLE_ILUVATAR_API
CALCULATE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
)
#endif
#ifdef ENABLE_MOORE_API
CALCULATE
(
INFINI_DEVICE_MOORE
,
moore
)
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
}
__C
infiniStatus_t
infiniopDestroyPagedCachingDescriptor
(
infiniopPagedCachingDescriptor_t
desc
)
{
#define DESTROY(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<op::paged_caching::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_NVIDIA_API
DESTROY
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
#ifdef ENABLE_METAX_API
DESTROY
(
INFINI_DEVICE_METAX
,
metax
)
#endif
#ifdef ENABLE_ALI_API
DESTROY
(
INFINI_DEVICE_ALI
,
nvidia
)
#endif
#ifdef ENABLE_ILUVATAR_API
DESTROY
(
INFINI_DEVICE_ILUVATAR
,
nvidia
)
#endif
#ifdef ENABLE_MOORE_API
DESTROY
(
INFINI_DEVICE_MOORE
,
moore
)
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
}
src/infiniop/ops/paged_caching/paged_caching.h
0 → 100644
View file @
8d09630a
#ifndef PAGED_CACHING_H
#define PAGED_CACHING_H
#include "../../operator.h"
#include "info.h"
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::paged_caching::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
PagedCachingInfo _info; \
size_t _workspace_size; \
\
Descriptor( \
Opaque *opaque, \
PagedCachingInfo info, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t k_cache_desc, \
infiniopTensorDescriptor_t v_cache_desc, \
infiniopTensorDescriptor_t k_desc, \
infiniopTensorDescriptor_t v_desc, \
infiniopTensorDescriptor_t slot_mapping_desc); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *k_cache, void *v_cache, \
const void *k, const void *v, \
const void *slot_mapping, \
void *stream) const; \
}; \
}
#endif // PAGED_CACHING_H
src/infiniop/ops/quant/per_channel_quant_int8/cuda/kernel.cuh
0 → 100644
View file @
8d09630a
#ifndef __PERCHANNEL_QUANTINT8_KERNEL_CUH__
#define __PERCHANNEL_QUANTINT8_KERNEL_CUH__
#include <cub/block/block_reduce.cuh>
__device__
inline
int
round_half_away_from_zero
(
float
x
)
{
float
ax
=
fabsf
(
x
);
float
r
=
floorf
(
ax
+
0.5
f
);
return
(
x
>=
0.0
f
)
?
(
int
)
r
:
-
(
int
)
r
;
}
template
<
typename
Tdata
,
unsigned
int
BLOCK_SIZE
>
__device__
void
blockPerChannelQuantI8Kernel
(
int8_t
*
x_packed
,
float
*
x_scale
,
float
*
x_zero
,
const
Tdata
*
x
,
int
M
,
int
K
)
{
int
row
=
blockIdx
.
x
;
int
tid
=
row
*
K
;
// ---- 1. reduce max ----
float
local_max
=
op
::
common_cuda
::
reduce_op
::
max
<
BLOCK_SIZE
,
Tdata
>
(
x
+
tid
,
K
);
__shared__
float
global_max_f
;
if
(
threadIdx
.
x
==
0
)
{
global_max_f
=
local_max
;
}
__syncthreads
();
typedef
cub
::
BlockReduce
<
float
,
BLOCK_SIZE
>
BlockReduce
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage
;
// ---- 2. reduce min ----
float
thread_min
=
__FLT_MAX__
;
for
(
int
ind
=
threadIdx
.
x
;
ind
<
K
;
ind
+=
BLOCK_SIZE
)
{
thread_min
=
fminf
(
thread_min
,
(
float
)
x
[
tid
+
ind
]);
}
#if CUDART_VERSION >= 12090
float
local_min
=
BlockReduce
(
temp_storage
).
Reduce
(
thread_min
,
::
cuda
::
minimum
());
#else
float
local_min
=
BlockReduce
(
temp_storage
).
Reduce
(
thread_min
,
cub
::
Min
());
#endif
__shared__
float
global_min_f
;
if
(
threadIdx
.
x
==
0
)
{
global_min_f
=
local_min
;
}
__syncthreads
();
float
global_max
=
global_max_f
;
float
global_min
=
global_min_f
;
float
scale
=
(
global_max
-
global_min
)
/
255.0
f
;
if
(
scale
<
1e-8
f
)
{
scale
=
1e-8
f
;
}
float
inv_scale
=
1.0
f
/
scale
;
float
zero
=
-
global_min
*
inv_scale
-
128.0
f
;
x_scale
[
row
]
=
(
Tdata
)
scale
;
x_zero
[
row
]
=
(
Tdata
)
zero
;
for
(
int
ind
=
threadIdx
.
x
;
ind
<
K
;
ind
+=
BLOCK_SIZE
)
{
float
v
=
(
float
)
x
[
tid
+
ind
];
float
qf
=
v
*
inv_scale
+
zero
;
int
q
=
round_half_away_from_zero
(
qf
);
if
(
q
>
127
)
{
q
=
127
;
}
if
(
q
<
-
128
)
{
q
=
-
128
;
}
x_packed
[
tid
+
ind
]
=
(
int8_t
)
q
;
}
}
template
<
typename
Tdata
,
unsigned
int
BLOCK_SIZE
>
__device__
void
blockPerChannelQuantI8SymKernel
(
int8_t
*
x_packed
,
float
*
x_scale
,
const
Tdata
*
x
,
int
M
,
int
K
)
{
int
row
=
blockIdx
.
x
;
int
tid
=
row
*
K
;
typedef
cub
::
BlockReduce
<
float
,
BLOCK_SIZE
>
BlockReduce
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage
;
// ---- 2. reduce min ----
float
thread_max
=
-
__FLT_MAX__
;
for
(
int
ind
=
threadIdx
.
x
;
ind
<
K
;
ind
+=
BLOCK_SIZE
)
{
thread_max
=
fmaxf
(
thread_max
,
fabs
((
float
)
x
[
tid
+
ind
]));
}
#if CUDART_VERSION >= 12090
float
local_max
=
BlockReduce
(
temp_storage
).
Reduce
(
thread_max
,
::
cuda
::
maximum
());
#else
float
local_max
=
BlockReduce
(
temp_storage
).
Reduce
(
thread_max
,
cub
::
Max
());
#endif
__shared__
float
global_max_f
;
if
(
threadIdx
.
x
==
0
)
{
global_max_f
=
local_max
;
}
__syncthreads
();
float
global_max
=
global_max_f
;
float
scale
=
global_max
/
127.0
f
;
if
(
scale
<
1e-8
f
)
{
scale
=
1e-8
f
;
}
float
inv_scale
=
1.0
f
/
scale
;
x_scale
[
row
]
=
(
Tdata
)
scale
;
for
(
int
ind
=
threadIdx
.
x
;
ind
<
K
;
ind
+=
BLOCK_SIZE
)
{
float
v
=
(
float
)
x
[
tid
+
ind
];
float
qf
=
v
*
inv_scale
;
int
q
=
round_half_away_from_zero
(
qf
);
if
(
q
>
127
)
{
q
=
127
;
}
if
(
q
<
-
127
)
{
q
=
-
127
;
}
x_packed
[
tid
+
ind
]
=
(
int8_t
)
q
;
}
}
template
<
typename
T
>
struct
MaxOp
{
__device__
__forceinline__
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
return
max
(
a
,
b
);
}
};
template
<
typename
T
>
struct
MinOp
{
__device__
__forceinline__
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
return
min
(
a
,
b
);
}
};
template
<
template
<
typename
>
class
ReductionOp
,
typename
T
,
int
thread_group_width
>
__inline__
__device__
T
WarpAllReduce
(
T
val
)
{
for
(
int
mask
=
thread_group_width
/
2
;
mask
>
0
;
mask
/=
2
)
{
val
=
ReductionOp
<
T
>
()(
val
,
__shfl_xor_sync
(
0xffffffff
,
val
,
mask
));
}
return
val
;
}
template
<
typename
Tdata
,
unsigned
int
BLOCK_SIZE_x
,
unsigned
int
BLOCK_SIZE_y
>
__device__
void
warpPerChannelQuantI8Kernel
(
int8_t
*
x_packed
,
float
*
x_scale
,
float
*
x_zero
,
const
Tdata
*
x
,
int
M
,
int
K
)
{
int
otherIdx
=
blockIdx
.
x
*
blockDim
.
y
+
threadIdx
.
y
;
int
tid
=
otherIdx
*
K
;
if
(
otherIdx
<
M
)
{
__shared__
float
max_total
[
BLOCK_SIZE_y
];
__shared__
float
min_total
[
BLOCK_SIZE_y
];
float
max_data
=
-
__FLT_MAX__
;
float
min_data
=
__FLT_MAX__
;
// ---- reduce max/min ----
for
(
int
ind
=
threadIdx
.
x
;
ind
<
K
;
ind
+=
BLOCK_SIZE_x
)
{
float
v
=
(
float
)
x
[
tid
+
ind
];
max_data
=
fmaxf
(
max_data
,
v
);
min_data
=
fminf
(
min_data
,
v
);
}
max_data
=
WarpAllReduce
<
MaxOp
,
float
,
BLOCK_SIZE_x
>
(
max_data
);
min_data
=
WarpAllReduce
<
MinOp
,
float
,
BLOCK_SIZE_x
>
(
min_data
);
if
(
threadIdx
.
x
==
0
)
{
max_total
[
threadIdx
.
y
]
=
max_data
;
min_total
[
threadIdx
.
y
]
=
min_data
;
}
__syncthreads
();
float
max_f
=
max_total
[
threadIdx
.
y
];
float
min_f
=
min_total
[
threadIdx
.
y
];
float
scale
=
(
max_f
-
min_f
)
/
255.0
f
;
if
(
scale
<
1e-8
f
)
{
scale
=
1e-8
f
;
}
float
inv_scale
=
1.0
f
/
scale
;
float
zero
=
-
min_f
*
inv_scale
-
128.0
f
;
x_scale
[
otherIdx
]
=
scale
;
x_zero
[
otherIdx
]
=
zero
;
for
(
int
ind
=
threadIdx
.
x
;
ind
<
K
;
ind
+=
BLOCK_SIZE_x
)
{
float
v
=
(
float
)
x
[
tid
+
ind
];
float
qf
=
v
*
inv_scale
+
zero
;
int
q
=
round_half_away_from_zero
(
qf
);
if
(
q
>
127
)
{
q
=
127
;
}
if
(
q
<
-
128
)
{
q
=
-
128
;
}
x_packed
[
tid
+
ind
]
=
(
int8_t
)
q
;
}
}
}
template
<
typename
Tdata
,
unsigned
int
BLOCK_SIZE_x
,
unsigned
int
BLOCK_SIZE_y
>
__device__
void
warpPerChannelQuantI8SymKernel
(
int8_t
*
x_packed
,
float
*
x_scale
,
const
Tdata
*
x
,
int
M
,
int
K
)
{
int
otherIdx
=
blockIdx
.
x
*
blockDim
.
y
+
threadIdx
.
y
;
int
tid
=
otherIdx
*
K
;
if
(
otherIdx
<
M
)
{
__shared__
float
max_total
[
BLOCK_SIZE_y
];
float
max_data
=
-
__FLT_MAX__
;
// ---- reduce max/min ----
for
(
int
ind
=
threadIdx
.
x
;
ind
<
K
;
ind
+=
BLOCK_SIZE_x
)
{
float
v
=
fabs
((
float
)
x
[
tid
+
ind
]);
max_data
=
fmaxf
(
max_data
,
v
);
}
max_data
=
WarpAllReduce
<
MaxOp
,
float
,
BLOCK_SIZE_x
>
(
max_data
);
if
(
threadIdx
.
x
==
0
)
{
max_total
[
threadIdx
.
y
]
=
max_data
;
}
__syncthreads
();
float
max_f
=
max_total
[
threadIdx
.
y
];
float
scale
=
max_f
/
127.0
f
;
if
(
scale
<
1e-8
f
)
{
scale
=
1e-8
f
;
}
float
inv_scale
=
1.0
f
/
scale
;
x_scale
[
otherIdx
]
=
scale
;
for
(
int
ind
=
threadIdx
.
x
;
ind
<
K
;
ind
+=
BLOCK_SIZE_x
)
{
float
v
=
(
float
)
x
[
tid
+
ind
];
float
qf
=
v
*
inv_scale
;
int
q
=
round_half_away_from_zero
(
qf
);
if
(
q
>
127
)
{
q
=
127
;
}
if
(
q
<
-
127
)
{
q
=
-
127
;
}
x_packed
[
tid
+
ind
]
=
(
int8_t
)
q
;
}
}
}
#endif // __PERCHANNEL_QUANTINT8_KERNEL_CUH__
src/infiniop/ops/quant/per_channel_quant_int8/info.h
0 → 100644
View file @
8d09630a
#ifndef __PER_CHANNEL_QUANT_INT8_INFO_H__
#define __PER_CHANNEL_QUANT_INT8_INFO_H__
#include "../../../../utils.h"
#include "../../../operator.h"
#include "../../../tensor.h"
namespace
op
::
per_channel_quant_int8
{
class
PerChannelQuantI8Info
{
private:
PerChannelQuantI8Info
()
=
default
;
public:
infiniDtype_t
dtype
,
packed_type
;
size_t
M
,
K
;
static
utils
::
Result
<
PerChannelQuantI8Info
>
createPerChannelQuantI8Info
(
infiniopTensorDescriptor_t
x_packed_desc
,
infiniopTensorDescriptor_t
x_scale_desc
,
infiniopTensorDescriptor_t
x_zero_desc
,
infiniopTensorDescriptor_t
x_desc
)
{
CHECK_OR_RETURN
(
x_packed_desc
!=
nullptr
&&
x_scale_desc
!=
nullptr
&&
x_desc
!=
nullptr
,
INFINI_STATUS_NULL_POINTER
);
const
infiniDtype_t
dtype
=
x_desc
->
dtype
();
const
infiniDtype_t
packed_type
=
x_packed_desc
->
dtype
();
CHECK_DTYPE
(
dtype
,
INFINI_DTYPE_F16
,
INFINI_DTYPE_BF16
,
INFINI_DTYPE_F32
);
CHECK_DTYPE
(
packed_type
,
INFINI_DTYPE_I8
);
CHECK_OR_RETURN
(
x_desc
->
ndim
()
==
2
&&
x_packed_desc
->
ndim
()
==
2
&&
x_scale_desc
->
ndim
()
==
2
,
INFINI_STATUS_BAD_TENSOR_SHAPE
);
size_t
M
=
x_desc
->
dim
(
0
);
size_t
K
=
x_desc
->
dim
(
1
);
CHECK_OR_RETURN
(
M
==
x_packed_desc
->
dim
(
0
)
||
K
==
x_packed_desc
->
dim
(
1
)
||
M
==
x_scale_desc
->
dim
(
0
)
||
1
==
x_scale_desc
->
dim
(
1
),
INFINI_STATUS_BAD_TENSOR_SHAPE
);
return
utils
::
Result
<
PerChannelQuantI8Info
>
(
PerChannelQuantI8Info
{
dtype
,
packed_type
,
M
,
K
,
});
}
};
}
// namespace op::per_channel_quant_int8
#endif // __PER_CHANNEL_QUANT_INT8_INFO_H__
Prev
1
…
9
10
11
12
13
14
15
16
17
…
20
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment