Commit 1c18c046 authored by PanZezhong's avatar PanZezhong Committed by wooway777
Browse files

issue/979 optimize paged attention

parent 97eced0e
#ifndef __PAGED_ATTENTION_KERNEL_V2_CUH__
#define __PAGED_ATTENTION_KERNEL_V2_CUH__
namespace op::paged_attention::cuda {
struct OnlineSoftmaxState {
float m = -INFINITY;
float l = 0.0f;
__device__ __forceinline__ void update(float x, float &alpha, float &beta) {
const float m_new = fmaxf(m, x);
alpha = expf(m - m_new);
beta = expf(x - m_new);
l = l * alpha + beta;
m = m_new;
}
};
__device__ __forceinline__ float warpReduceSum(float x) {
for (int offset = 16; offset > 0; offset >>= 1) {
x += __shfl_down_sync(0xffffffff, x, offset);
}
return x;
}
__device__ __forceinline__ float warpReduceMax(float x) {
for (int offset = 16; offset > 0; offset >>= 1) {
x = fmaxf(x, __shfl_down_sync(0xffffffff, x, offset));
}
return x;
}
__device__ __forceinline__ unsigned int cvtaToShared(const void *ptr) {
return static_cast<unsigned int>(__cvta_generic_to_shared(ptr));
}
__device__ __forceinline__ void cpAsyncCaSharedGlobal16(void *dst_shared, const void *src_global) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
const unsigned int dst = cvtaToShared(dst_shared);
asm volatile("cp.async.ca.shared.global [%0], [%1], 16;\n" ::"r"(dst), "l"(src_global));
#else
auto *dst = reinterpret_cast<uint4 *>(dst_shared);
const auto *src = reinterpret_cast<const uint4 *>(src_global);
*dst = *src;
#endif
}
__device__ __forceinline__ void cpAsyncCommit() {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
asm volatile("cp.async.commit_group;\n" ::);
#endif
}
template <int N>
__device__ __forceinline__ void cpAsyncWaitGroup() {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
asm volatile("cp.async.wait_group %0;\n" ::"n"(N));
#endif
}
// cp.async.wait_group requires a compile-time immediate, so for small fixed
// stage counts we provide a tiny runtime switch.
__device__ __forceinline__ void cpAsyncWaitGroupRt(int n) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
if (n <= 0) {
cpAsyncWaitGroup<0>();
} else if (n == 1) {
cpAsyncWaitGroup<1>();
} else {
// Clamp to 2 because v0.4 CTA kernel uses STAGES=3.
cpAsyncWaitGroup<2>();
}
#else
(void)n;
#endif
}
__device__ __forceinline__ void cpAsyncWaitAll() {
cpAsyncWaitGroup<0>();
}
template <typename Tindex, typename Tdata, int HEAD_SIZE>
__device__ void flashAttentionDecodeWarpKernel(
Tdata *out_,
const Tdata *q_,
const Tdata *k_cache_,
const Tdata *v_cache_,
const Tindex *block_tables_,
const Tindex *cache_lens_,
const float *alibi_slopes_,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride) {
const int seq_idx = blockIdx.y;
const int head_idx = blockIdx.x;
const int lane = threadIdx.x;
constexpr int kWarpSize = 32;
static_assert(HEAD_SIZE == 64 || HEAD_SIZE == 128, "Only head_size 64/128 supported in v0.4.");
static_assert(HEAD_SIZE % kWarpSize == 0, "HEAD_SIZE must be divisible by 32.");
constexpr int DIMS_PER_THREAD = HEAD_SIZE / kWarpSize;
const int seq_len = static_cast<int>(cache_lens_[seq_idx]);
if (seq_len <= 0) {
return;
}
const int num_heads = gridDim.x;
const int num_queries_per_kv = num_heads / static_cast<int>(num_kv_heads);
const int kv_head_idx = head_idx / num_queries_per_kv;
const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx];
constexpr float kLog2e = 1.4426950408889634f;
const float scale_log2 = scale * kLog2e;
const Tindex *block_table = block_tables_ + seq_idx * static_cast<int>(max_num_blocks_per_seq);
// q/out are [num_seqs, num_heads, head_size]
const Tdata *q_ptr = q_ + seq_idx * q_stride + head_idx * HEAD_SIZE;
Tdata *out_ptr = out_ + seq_idx * o_stride + head_idx * HEAD_SIZE;
float q_reg[DIMS_PER_THREAD];
float acc[DIMS_PER_THREAD];
#pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
const int dim = lane * DIMS_PER_THREAD + i;
q_reg[i] = static_cast<float>(q_ptr[dim]);
acc[i] = 0.0f;
}
#if defined(__CUDA_ARCH__)
float2 q_reg2[DIMS_PER_THREAD / 2];
if constexpr (std::is_same_v<Tdata, half>) {
const int dim_base = lane * DIMS_PER_THREAD;
const half2 *q2 = reinterpret_cast<const half2 *>(q_ptr + dim_base);
#pragma unroll
for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) {
q_reg2[j] = __half22float2(q2[j]);
}
}
if constexpr (std::is_same_v<Tdata, __nv_bfloat16>) {
const int dim_base = lane * DIMS_PER_THREAD;
const __nv_bfloat162 *q2 = reinterpret_cast<const __nv_bfloat162 *>(q_ptr + dim_base);
#pragma unroll
for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) {
q_reg2[j] = __bfloat1622float2(q2[j]);
}
}
#endif
float m = -INFINITY;
float l = 0.0f;
const int pbs = static_cast<int>(page_block_size);
// Iterate by blocks to avoid per-token division/mod and redundant block_table loads.
// Note: Per-token cp.async prefetching is generally too fine-grained for decode and can regress.
// We keep the warp kernel simple and reserve cp.async pipelining for CTA tile kernels.
int t_base = 0;
for (int logical_block = 0; t_base < seq_len; ++logical_block, t_base += pbs) {
int physical_block = 0;
if (lane == 0) {
physical_block = static_cast<int>(block_table[logical_block]);
}
physical_block = __shfl_sync(0xffffffff, physical_block, 0);
const Tdata *k_base = k_cache_ + physical_block * k_batch_stride + kv_head_idx * k_head_stride;
const Tdata *v_base = v_cache_ + physical_block * v_batch_stride + kv_head_idx * v_head_stride;
const int token_end = min(pbs, seq_len - t_base);
for (int token_in_block = 0; token_in_block < token_end; ++token_in_block) {
const int t = t_base + token_in_block;
const Tdata *k_ptr = k_base + token_in_block * k_row_stride;
const Tdata *v_ptr = v_base + token_in_block * v_row_stride;
float qk = 0.0f;
#if defined(__CUDA_ARCH__)
if constexpr (std::is_same_v<Tdata, half>) {
const int dim_base = lane * DIMS_PER_THREAD;
const half2 *k2 = reinterpret_cast<const half2 *>(k_ptr + dim_base);
#pragma unroll
for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) {
const float2 qf = q_reg2[j];
const float2 kf = __half22float2(k2[j]);
qk += qf.x * kf.x + qf.y * kf.y;
}
} else if constexpr (std::is_same_v<Tdata, __nv_bfloat16>) {
const int dim_base = lane * DIMS_PER_THREAD;
const __nv_bfloat162 *k2 = reinterpret_cast<const __nv_bfloat162 *>(k_ptr + dim_base);
#pragma unroll
for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) {
const float2 qf = q_reg2[j];
const float2 kf = __bfloat1622float2(k2[j]);
qk += qf.x * kf.x + qf.y * kf.y;
}
} else
#endif
{
#pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
const int dim = lane * DIMS_PER_THREAD + i;
qk += q_reg[i] * static_cast<float>(k_ptr[dim]);
}
}
qk = warpReduceSum(qk);
float alpha = 1.0f;
float beta = 0.0f;
if (lane == 0) {
float score = qk * scale_log2;
if (alibi_slope != 0.0f) {
score += (alibi_slope * static_cast<float>(t - (seq_len - 1))) * kLog2e;
}
const float m_new = fmaxf(m, score);
alpha = exp2f(m - m_new);
beta = exp2f(score - m_new);
l = l * alpha + beta;
m = m_new;
}
alpha = __shfl_sync(0xffffffff, alpha, 0);
beta = __shfl_sync(0xffffffff, beta, 0);
#if defined(__CUDA_ARCH__)
if constexpr (std::is_same_v<Tdata, half>) {
const int dim_base = lane * DIMS_PER_THREAD;
const half2 *v2 = reinterpret_cast<const half2 *>(v_ptr + dim_base);
#pragma unroll
for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) {
const float2 vf = __half22float2(v2[j]);
acc[j * 2 + 0] = acc[j * 2 + 0] * alpha + beta * vf.x;
acc[j * 2 + 1] = acc[j * 2 + 1] * alpha + beta * vf.y;
}
} else if constexpr (std::is_same_v<Tdata, __nv_bfloat16>) {
const int dim_base = lane * DIMS_PER_THREAD;
const __nv_bfloat162 *v2 = reinterpret_cast<const __nv_bfloat162 *>(v_ptr + dim_base);
#pragma unroll
for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) {
const float2 vf = __bfloat1622float2(v2[j]);
acc[j * 2 + 0] = acc[j * 2 + 0] * alpha + beta * vf.x;
acc[j * 2 + 1] = acc[j * 2 + 1] * alpha + beta * vf.y;
}
} else
#endif
{
#pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
const int dim = lane * DIMS_PER_THREAD + i;
const float v_val = static_cast<float>(v_ptr[dim]);
acc[i] = acc[i] * alpha + beta * v_val;
}
}
}
}
float inv_l = 0.0f;
if (lane == 0) {
inv_l = 1.0f / (l + 1e-6f);
}
inv_l = __shfl_sync(0xffffffff, inv_l, 0);
#pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
const int dim = lane * DIMS_PER_THREAD + i;
const float o = acc[i] * inv_l;
if constexpr (std::is_same_v<Tdata, half>) {
out_ptr[dim] = __float2half_rn(o);
} else if constexpr (std::is_same_v<Tdata, __nv_bfloat16>) {
out_ptr[dim] = __float2bfloat16_rn(o);
} else {
out_ptr[dim] = static_cast<Tdata>(o);
}
}
}
// Split-KV decode (FA2-style): each split scans a shard of KV and writes partial (m, l, acc)
// to workspace, then a combine kernel merges splits into final out.
template <typename Tindex, typename Tdata, int HEAD_SIZE>
__device__ void flashAttentionDecodeSplitKvWarpKernel(
float *partial_acc, // [num_splits, num_seqs, num_heads, head_size]
float *partial_m, // [num_splits, num_seqs, num_heads]
float *partial_l, // [num_splits, num_seqs, num_heads]
const Tdata *q_,
const Tdata *k_cache_,
const Tdata *v_cache_,
const Tindex *block_tables_,
const Tindex *cache_lens_,
const float *alibi_slopes_,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
int num_splits) {
const int seq_idx = blockIdx.y;
const int head_idx = blockIdx.x;
const int split_idx = static_cast<int>(blockIdx.z);
const int lane = threadIdx.x;
constexpr int kWarpSize = 32;
static_assert(HEAD_SIZE == 64 || HEAD_SIZE == 128, "Only head_size 64/128 supported in v0.4.");
static_assert(HEAD_SIZE % kWarpSize == 0, "HEAD_SIZE must be divisible by 32.");
constexpr int DIMS_PER_THREAD = HEAD_SIZE / kWarpSize;
const int seq_len = static_cast<int>(cache_lens_[seq_idx]);
if (seq_len <= 0 || num_splits <= 0) {
return;
}
// Split the [0, seq_len) range into num_splits contiguous shards.
const int shard = (seq_len + num_splits - 1) / num_splits;
const int start = split_idx * shard;
const int end = min(seq_len, start + shard);
if (start >= end) {
// Empty shard => write neutral element.
const int n = gridDim.y * gridDim.x;
const int idx = (split_idx * n + seq_idx * gridDim.x + head_idx);
if (lane == 0) {
partial_m[idx] = -INFINITY;
partial_l[idx] = 0.0f;
}
#pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
const int dim = lane * DIMS_PER_THREAD + i;
partial_acc[idx * HEAD_SIZE + dim] = 0.0f;
}
return;
}
const int num_heads = gridDim.x;
const int num_queries_per_kv = num_heads / static_cast<int>(num_kv_heads);
const int kv_head_idx = head_idx / num_queries_per_kv;
const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx];
constexpr float kLog2e = 1.4426950408889634f;
const float scale_log2 = scale * kLog2e;
const Tindex *block_table = block_tables_ + seq_idx * static_cast<int>(max_num_blocks_per_seq);
const Tdata *q_ptr = q_ + seq_idx * q_stride + head_idx * HEAD_SIZE;
float q_reg[DIMS_PER_THREAD];
float acc[DIMS_PER_THREAD];
#pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
const int dim = lane * DIMS_PER_THREAD + i;
q_reg[i] = static_cast<float>(q_ptr[dim]);
acc[i] = 0.0f;
}
#if defined(__CUDA_ARCH__)
float2 q_reg2[DIMS_PER_THREAD / 2];
if constexpr (std::is_same_v<Tdata, half>) {
const int dim_base = lane * DIMS_PER_THREAD;
const half2 *q2 = reinterpret_cast<const half2 *>(q_ptr + dim_base);
#pragma unroll
for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) {
q_reg2[j] = __half22float2(q2[j]);
}
}
if constexpr (std::is_same_v<Tdata, __nv_bfloat16>) {
const int dim_base = lane * DIMS_PER_THREAD;
const __nv_bfloat162 *q2 = reinterpret_cast<const __nv_bfloat162 *>(q_ptr + dim_base);
#pragma unroll
for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) {
q_reg2[j] = __bfloat1622float2(q2[j]);
}
}
#endif
float m = -INFINITY;
float l = 0.0f;
const int pbs = static_cast<int>(page_block_size);
// Scan only [start, end).
int t = start;
int logical_block = t / pbs;
int token_in_block = t - logical_block * pbs;
for (; t < end; ++logical_block) {
int physical_block = 0;
if (lane == 0) {
physical_block = static_cast<int>(block_table[logical_block]);
}
physical_block = __shfl_sync(0xffffffff, physical_block, 0);
const Tdata *k_base = k_cache_ + physical_block * k_batch_stride + kv_head_idx * k_head_stride;
const Tdata *v_base = v_cache_ + physical_block * v_batch_stride + kv_head_idx * v_head_stride;
const int token_end = min(pbs, end - logical_block * pbs);
for (; token_in_block < token_end && t < end; ++token_in_block, ++t) {
const Tdata *k_ptr = k_base + token_in_block * k_row_stride;
const Tdata *v_ptr = v_base + token_in_block * v_row_stride;
float qk = 0.0f;
#if defined(__CUDA_ARCH__)
if constexpr (std::is_same_v<Tdata, half>) {
const int dim_base = lane * DIMS_PER_THREAD;
const half2 *k2 = reinterpret_cast<const half2 *>(k_ptr + dim_base);
#pragma unroll
for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) {
const float2 qf = q_reg2[j];
const float2 kf = __half22float2(k2[j]);
qk += qf.x * kf.x + qf.y * kf.y;
}
} else if constexpr (std::is_same_v<Tdata, __nv_bfloat16>) {
const int dim_base = lane * DIMS_PER_THREAD;
const __nv_bfloat162 *k2 = reinterpret_cast<const __nv_bfloat162 *>(k_ptr + dim_base);
#pragma unroll
for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) {
const float2 qf = q_reg2[j];
const float2 kf = __bfloat1622float2(k2[j]);
qk += qf.x * kf.x + qf.y * kf.y;
}
} else
#endif
{
#pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
const int dim = lane * DIMS_PER_THREAD + i;
qk += q_reg[i] * static_cast<float>(k_ptr[dim]);
}
}
qk = warpReduceSum(qk);
float alpha = 1.0f;
float beta = 0.0f;
if (lane == 0) {
float score = qk * scale_log2;
if (alibi_slope != 0.0f) {
score += (alibi_slope * static_cast<float>(t - (seq_len - 1))) * kLog2e;
}
const float m_new = fmaxf(m, score);
alpha = exp2f(m - m_new);
beta = exp2f(score - m_new);
l = l * alpha + beta;
m = m_new;
}
alpha = __shfl_sync(0xffffffff, alpha, 0);
beta = __shfl_sync(0xffffffff, beta, 0);
#if defined(__CUDA_ARCH__)
if constexpr (std::is_same_v<Tdata, half>) {
const int dim_base = lane * DIMS_PER_THREAD;
const half2 *v2 = reinterpret_cast<const half2 *>(v_ptr + dim_base);
#pragma unroll
for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) {
const float2 vf = __half22float2(v2[j]);
acc[j * 2 + 0] = acc[j * 2 + 0] * alpha + beta * vf.x;
acc[j * 2 + 1] = acc[j * 2 + 1] * alpha + beta * vf.y;
}
} else if constexpr (std::is_same_v<Tdata, __nv_bfloat16>) {
const int dim_base = lane * DIMS_PER_THREAD;
const __nv_bfloat162 *v2 = reinterpret_cast<const __nv_bfloat162 *>(v_ptr + dim_base);
#pragma unroll
for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) {
const float2 vf = __bfloat1622float2(v2[j]);
acc[j * 2 + 0] = acc[j * 2 + 0] * alpha + beta * vf.x;
acc[j * 2 + 1] = acc[j * 2 + 1] * alpha + beta * vf.y;
}
} else
#endif
{
#pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
const int dim = lane * DIMS_PER_THREAD + i;
const float v_val = static_cast<float>(v_ptr[dim]);
acc[i] = acc[i] * alpha + beta * v_val;
}
}
}
token_in_block = 0;
}
const int n = gridDim.y * gridDim.x;
const int idx = (split_idx * n + seq_idx * gridDim.x + head_idx);
if (lane == 0) {
partial_m[idx] = m;
partial_l[idx] = l;
}
#pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
const int dim = lane * DIMS_PER_THREAD + i;
partial_acc[idx * HEAD_SIZE + dim] = acc[i];
}
}
template <typename Tdata, int HEAD_SIZE>
__device__ void flashAttentionDecodeSplitKvCombineWarpKernel(
Tdata *out_,
const float *partial_acc, // [num_splits, num_seqs, num_heads, head_size]
const float *partial_m, // [num_splits, num_seqs, num_heads]
const float *partial_l, // [num_splits, num_seqs, num_heads]
int num_splits,
ptrdiff_t o_stride) {
const int seq_idx = blockIdx.y;
const int head_idx = blockIdx.x;
const int lane = threadIdx.x;
constexpr int kWarpSize = 32;
static_assert(HEAD_SIZE % kWarpSize == 0, "HEAD_SIZE must be divisible by 32.");
constexpr int DIMS_PER_THREAD = HEAD_SIZE / kWarpSize;
const int n = gridDim.y * gridDim.x;
const int base = (seq_idx * gridDim.x + head_idx);
float m = -INFINITY;
if (lane == 0) {
for (int s = 0; s < num_splits; ++s) {
m = fmaxf(m, partial_m[s * n + base]);
}
}
m = __shfl_sync(0xffffffff, m, 0);
float l = 0.0f;
if (lane == 0) {
for (int s = 0; s < num_splits; ++s) {
const float ms = partial_m[s * n + base];
const float ls = partial_l[s * n + base];
if (ls > 0.0f) {
l += ls * exp2f(ms - m);
}
}
}
l = __shfl_sync(0xffffffff, l, 0);
const float inv_l = 1.0f / (l + 1e-6f);
// Combine acc for each dim.
Tdata *out_ptr = out_ + seq_idx * o_stride + head_idx * HEAD_SIZE;
#pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
const int dim = lane * DIMS_PER_THREAD + i;
float acc = 0.0f;
for (int s = 0; s < num_splits; ++s) {
const float ms = partial_m[s * n + base];
const float w = exp2f(ms - m);
acc += partial_acc[(s * n + base) * HEAD_SIZE + dim] * w;
}
const float o = acc * inv_l;
if constexpr (std::is_same_v<Tdata, half>) {
out_ptr[dim] = __float2half_rn(o);
} else if constexpr (std::is_same_v<Tdata, __nv_bfloat16>) {
out_ptr[dim] = __float2bfloat16_rn(o);
} else {
out_ptr[dim] = static_cast<Tdata>(o);
}
}
}
// Split-KV decode with a CTA tile kernel (FA2-style): each CTA scans a shard of KV,
// writes partial (m, l, acc) to workspace, then a combine kernel merges splits.
template <typename Tindex, typename Tdata, int HEAD_SIZE, int CTA_THREADS, int TOKENS_PER_TILE>
__device__ void flashAttentionDecodeSplitKvCtaKernel(
float *partial_acc, // [num_splits, num_seqs, num_heads, head_size]
float *partial_m, // [num_splits, num_seqs, num_heads]
float *partial_l, // [num_splits, num_seqs, num_heads]
const Tdata *q_,
const Tdata *k_cache_,
const Tdata *v_cache_,
const Tindex *block_tables_,
const Tindex *cache_lens_,
const float *alibi_slopes_,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
int num_splits) {
constexpr int kWarpSize = 32;
static_assert(CTA_THREADS % kWarpSize == 0, "CTA_THREADS must be a multiple of 32.");
static_assert(TOKENS_PER_TILE > 0 && TOKENS_PER_TILE <= 16, "TOKENS_PER_TILE should stay small.");
constexpr int NUM_WARPS = CTA_THREADS / kWarpSize;
static_assert(HEAD_SIZE == 64 || HEAD_SIZE == 128, "Only head_size 64/128 supported in v0.4.");
static_assert(HEAD_SIZE % CTA_THREADS == 0, "HEAD_SIZE must be divisible by CTA_THREADS.");
constexpr int kPack = HEAD_SIZE / CTA_THREADS; // 2 (64@32t, 128@64t) or 4 (128@32t)
static_assert(kPack == 2 || kPack == 4, "v0.4 split-kv CTA kernel supports kPack=2/4 only.");
constexpr int kPackedDims = CTA_THREADS;
constexpr int kComputeWarps = (kPackedDims + kWarpSize - 1) / kWarpSize;
const int seq_idx = blockIdx.y;
const int head_idx = blockIdx.x;
const int split_idx = static_cast<int>(blockIdx.z);
const int tid = threadIdx.x;
const int lane = tid % kWarpSize;
const int warp_id = tid / kWarpSize;
const int seq_len = static_cast<int>(cache_lens_[seq_idx]);
if (seq_len <= 0 || num_splits <= 0) {
return;
}
// Split the [0, seq_len) range into num_splits contiguous shards.
const int shard = (seq_len + num_splits - 1) / num_splits;
const int start = split_idx * shard;
const int end = min(seq_len, start + shard);
const int n = gridDim.y * gridDim.x;
const int idx = (split_idx * n + seq_idx * gridDim.x + head_idx);
if (start >= end) {
// Empty shard => write neutral element.
if (tid == 0) {
partial_m[idx] = -INFINITY;
partial_l[idx] = 0.0f;
}
const int dim = tid * kPack;
if constexpr (kPack == 2) {
partial_acc[idx * HEAD_SIZE + dim + 0] = 0.0f;
partial_acc[idx * HEAD_SIZE + dim + 1] = 0.0f;
} else {
partial_acc[idx * HEAD_SIZE + dim + 0] = 0.0f;
partial_acc[idx * HEAD_SIZE + dim + 1] = 0.0f;
partial_acc[idx * HEAD_SIZE + dim + 2] = 0.0f;
partial_acc[idx * HEAD_SIZE + dim + 3] = 0.0f;
}
return;
}
const int num_heads = gridDim.x;
const int num_queries_per_kv = num_heads / static_cast<int>(num_kv_heads);
const int kv_head_idx = head_idx / num_queries_per_kv;
const Tindex *block_table = block_tables_ + seq_idx * static_cast<int>(max_num_blocks_per_seq);
const Tdata *q_ptr = q_ + seq_idx * q_stride + head_idx * HEAD_SIZE;
const int dim = tid * kPack;
float q0 = 0.0f, q1 = 0.0f, q2 = 0.0f, q3 = 0.0f;
#if defined(__CUDA_ARCH__)
if constexpr (std::is_same_v<Tdata, half>) {
if constexpr (kPack == 2) {
const half2 qh2 = *reinterpret_cast<const half2 *>(q_ptr + dim);
const float2 qf = __half22float2(qh2);
q0 = qf.x;
q1 = qf.y;
} else {
const half2 qh2_0 = *reinterpret_cast<const half2 *>(q_ptr + dim + 0);
const half2 qh2_1 = *reinterpret_cast<const half2 *>(q_ptr + dim + 2);
const float2 qf0 = __half22float2(qh2_0);
const float2 qf1 = __half22float2(qh2_1);
q0 = qf0.x;
q1 = qf0.y;
q2 = qf1.x;
q3 = qf1.y;
}
} else if constexpr (std::is_same_v<Tdata, __nv_bfloat16>) {
if constexpr (kPack == 2) {
const __nv_bfloat162 qb2 = *reinterpret_cast<const __nv_bfloat162 *>(q_ptr + dim);
const float2 qf = __bfloat1622float2(qb2);
q0 = qf.x;
q1 = qf.y;
} else {
const __nv_bfloat162 qb2_0 = *reinterpret_cast<const __nv_bfloat162 *>(q_ptr + dim + 0);
const __nv_bfloat162 qb2_1 = *reinterpret_cast<const __nv_bfloat162 *>(q_ptr + dim + 2);
const float2 qf0 = __bfloat1622float2(qb2_0);
const float2 qf1 = __bfloat1622float2(qb2_1);
q0 = qf0.x;
q1 = qf0.y;
q2 = qf1.x;
q3 = qf1.y;
}
} else
#endif
{
q0 = static_cast<float>(q_ptr[dim + 0]);
q1 = static_cast<float>(q_ptr[dim + 1]);
if constexpr (kPack == 4) {
q2 = static_cast<float>(q_ptr[dim + 2]);
q3 = static_cast<float>(q_ptr[dim + 3]);
}
}
float acc0 = 0.0f, acc1 = 0.0f, acc2 = 0.0f, acc3 = 0.0f;
float m = -INFINITY;
float l = 0.0f;
__shared__ float warp_sums[TOKENS_PER_TILE][kComputeWarps];
__shared__ float alpha_shared;
__shared__ float weights_shared[TOKENS_PER_TILE];
const int pbs = static_cast<int>(page_block_size);
const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx];
constexpr float kLog2e = 1.4426950408889634f;
const float scale_log2 = scale * kLog2e;
static_assert(sizeof(Tdata) == 2, "CTA split-kv kernel assumes fp16/bf16.");
constexpr int CHUNK_ELEMS = 8; // 8 * 2 bytes = 16 bytes.
constexpr int CHUNKS = HEAD_SIZE / CHUNK_ELEMS;
constexpr int LOADS_PER_TILE = CHUNKS * TOKENS_PER_TILE;
constexpr int STAGES = 3;
__shared__ __align__(16) Tdata sh_k[STAGES][TOKENS_PER_TILE][HEAD_SIZE];
__shared__ __align__(16) Tdata sh_v[STAGES][TOKENS_PER_TILE][HEAD_SIZE];
const int first_block = start / pbs;
const int last_block = (end - 1) / pbs;
for (int logical_block = first_block; logical_block <= last_block; ++logical_block) {
const int physical_block = static_cast<int>(block_table[logical_block]);
const Tdata *k_base = k_cache_ + physical_block * k_batch_stride + kv_head_idx * k_head_stride;
const Tdata *v_base = v_cache_ + physical_block * v_batch_stride + kv_head_idx * v_head_stride;
const int t_base = logical_block * pbs;
const int token_begin = (logical_block == first_block) ? (start - t_base) : 0;
const int token_end = (logical_block == last_block) ? (end - t_base) : pbs;
const int token_count = token_end - token_begin;
if (token_count <= 0) {
continue;
}
const int num_tiles = (token_count + TOKENS_PER_TILE - 1) / TOKENS_PER_TILE;
int pending_groups = 0;
const int preload = min(STAGES, num_tiles);
for (int ti = 0; ti < preload; ++ti) {
const int token_in_block = token_begin + ti * TOKENS_PER_TILE;
const int tile_n = min(TOKENS_PER_TILE, token_end - token_in_block);
for (int li = tid; li < LOADS_PER_TILE; li += CTA_THREADS) {
const int tok = li / CHUNKS;
const int chunk = li - tok * CHUNKS;
const int off = chunk * CHUNK_ELEMS;
if (tok < tile_n) {
const Tdata *k_src = k_base + (token_in_block + tok) * k_row_stride + off;
const Tdata *v_src = v_base + (token_in_block + tok) * v_row_stride + off;
cpAsyncCaSharedGlobal16(&sh_k[ti][tok][off], k_src);
cpAsyncCaSharedGlobal16(&sh_v[ti][tok][off], v_src);
} else {
reinterpret_cast<uint4 *>(&sh_k[ti][tok][off])[0] = make_uint4(0, 0, 0, 0);
reinterpret_cast<uint4 *>(&sh_v[ti][tok][off])[0] = make_uint4(0, 0, 0, 0);
}
}
cpAsyncCommit();
++pending_groups;
}
int desired_pending = pending_groups - 1;
if (desired_pending < 0) {
desired_pending = 0;
}
if (desired_pending > (STAGES - 1)) {
desired_pending = (STAGES - 1);
}
cpAsyncWaitGroupRt(desired_pending);
pending_groups = desired_pending;
if constexpr (NUM_WARPS == 1) {
__syncwarp();
} else {
__syncthreads();
}
for (int tile_idx = 0; tile_idx < num_tiles; ++tile_idx) {
const int buf = tile_idx % STAGES;
const int token_in_block = token_begin + tile_idx * TOKENS_PER_TILE;
const int tile_n = min(TOKENS_PER_TILE, token_end - token_in_block);
float partial[TOKENS_PER_TILE];
#pragma unroll
for (int j = 0; j < TOKENS_PER_TILE; ++j) {
if (j < tile_n) {
float k0 = 0.0f, k1 = 0.0f, k2 = 0.0f, k3 = 0.0f;
#if defined(__CUDA_ARCH__)
if constexpr (std::is_same_v<Tdata, half>) {
if constexpr (kPack == 2) {
const half2 kh2 = *reinterpret_cast<const half2 *>(&sh_k[buf][j][dim]);
const float2 kf = __half22float2(kh2);
k0 = kf.x;
k1 = kf.y;
} else {
const half2 kh2_0 = *reinterpret_cast<const half2 *>(&sh_k[buf][j][dim + 0]);
const half2 kh2_1 = *reinterpret_cast<const half2 *>(&sh_k[buf][j][dim + 2]);
const float2 kf0 = __half22float2(kh2_0);
const float2 kf1 = __half22float2(kh2_1);
k0 = kf0.x;
k1 = kf0.y;
k2 = kf1.x;
k3 = kf1.y;
}
} else if constexpr (std::is_same_v<Tdata, __nv_bfloat16>) {
if constexpr (kPack == 2) {
const __nv_bfloat162 kb2 = *reinterpret_cast<const __nv_bfloat162 *>(&sh_k[buf][j][dim]);
const float2 kf = __bfloat1622float2(kb2);
k0 = kf.x;
k1 = kf.y;
} else {
const __nv_bfloat162 kb2_0 = *reinterpret_cast<const __nv_bfloat162 *>(&sh_k[buf][j][dim + 0]);
const __nv_bfloat162 kb2_1 = *reinterpret_cast<const __nv_bfloat162 *>(&sh_k[buf][j][dim + 2]);
const float2 kf0 = __bfloat1622float2(kb2_0);
const float2 kf1 = __bfloat1622float2(kb2_1);
k0 = kf0.x;
k1 = kf0.y;
k2 = kf1.x;
k3 = kf1.y;
}
} else
#endif
{
k0 = static_cast<float>(sh_k[buf][j][dim + 0]);
k1 = static_cast<float>(sh_k[buf][j][dim + 1]);
if constexpr (kPack == 4) {
k2 = static_cast<float>(sh_k[buf][j][dim + 2]);
k3 = static_cast<float>(sh_k[buf][j][dim + 3]);
}
}
if constexpr (kPack == 2) {
partial[j] = fmaf(q0, k0, q1 * k1);
} else {
partial[j] = fmaf(q0, k0, fmaf(q1, k1, fmaf(q2, k2, q3 * k3)));
}
} else {
partial[j] = 0.0f;
}
}
#pragma unroll
for (int j = 0; j < TOKENS_PER_TILE; ++j) {
const float sum = warpReduceSum(partial[j]);
if (lane == 0 && warp_id < kComputeWarps) {
warp_sums[j][warp_id] = sum;
}
}
if constexpr (NUM_WARPS == 1) {
__syncwarp();
} else {
__syncthreads();
}
if (warp_id == 0) {
float score = -INFINITY;
if (lane < TOKENS_PER_TILE && lane < tile_n) {
float qk = 0.0f;
#pragma unroll
for (int w = 0; w < kComputeWarps; ++w) {
qk += warp_sums[lane][w];
}
const int t = t_base + token_in_block + lane;
score = qk * scale_log2;
if (alibi_slope != 0.0f) {
score += (alibi_slope * static_cast<float>(t - (seq_len - 1))) * kLog2e;
}
}
float tile_max = warpReduceMax(score);
tile_max = __shfl_sync(0xffffffff, tile_max, 0);
float m_new = 0.0f;
if (lane == 0) {
m_new = fmaxf(m, tile_max);
}
m_new = __shfl_sync(0xffffffff, m_new, 0);
float w = 0.0f;
if (lane < TOKENS_PER_TILE && lane < tile_n) {
w = exp2f(score - m_new);
}
if (lane < TOKENS_PER_TILE) {
weights_shared[lane] = (lane < tile_n) ? w : 0.0f;
}
const float tile_sum = warpReduceSum(w);
if (lane == 0) {
const float alpha = exp2f(m - m_new);
alpha_shared = alpha;
l = l * alpha + tile_sum;
m = m_new;
}
}
if constexpr (NUM_WARPS == 1) {
__syncwarp();
} else {
__syncthreads();
}
const float alpha = alpha_shared;
float sum_wv0 = 0.0f, sum_wv1 = 0.0f, sum_wv2 = 0.0f, sum_wv3 = 0.0f;
#pragma unroll
for (int j = 0; j < TOKENS_PER_TILE; ++j) {
const float w = weights_shared[j];
float v0 = 0.0f, v1 = 0.0f, v2 = 0.0f, v3 = 0.0f;
#if defined(__CUDA_ARCH__)
if constexpr (std::is_same_v<Tdata, half>) {
if constexpr (kPack == 2) {
const half2 vh2 = *reinterpret_cast<const half2 *>(&sh_v[buf][j][dim]);
const float2 vf = __half22float2(vh2);
v0 = vf.x;
v1 = vf.y;
} else {
const half2 vh2_0 = *reinterpret_cast<const half2 *>(&sh_v[buf][j][dim + 0]);
const half2 vh2_1 = *reinterpret_cast<const half2 *>(&sh_v[buf][j][dim + 2]);
const float2 vf0 = __half22float2(vh2_0);
const float2 vf1 = __half22float2(vh2_1);
v0 = vf0.x;
v1 = vf0.y;
v2 = vf1.x;
v3 = vf1.y;
}
} else if constexpr (std::is_same_v<Tdata, __nv_bfloat16>) {
if constexpr (kPack == 2) {
const __nv_bfloat162 vb2 = *reinterpret_cast<const __nv_bfloat162 *>(&sh_v[buf][j][dim]);
const float2 vf = __bfloat1622float2(vb2);
v0 = vf.x;
v1 = vf.y;
} else {
const __nv_bfloat162 vb2_0 = *reinterpret_cast<const __nv_bfloat162 *>(&sh_v[buf][j][dim + 0]);
const __nv_bfloat162 vb2_1 = *reinterpret_cast<const __nv_bfloat162 *>(&sh_v[buf][j][dim + 2]);
const float2 vf0 = __bfloat1622float2(vb2_0);
const float2 vf1 = __bfloat1622float2(vb2_1);
v0 = vf0.x;
v1 = vf0.y;
v2 = vf1.x;
v3 = vf1.y;
}
} else
#endif
{
v0 = static_cast<float>(sh_v[buf][j][dim + 0]);
v1 = static_cast<float>(sh_v[buf][j][dim + 1]);
if constexpr (kPack == 4) {
v2 = static_cast<float>(sh_v[buf][j][dim + 2]);
v3 = static_cast<float>(sh_v[buf][j][dim + 3]);
}
}
sum_wv0 = fmaf(w, v0, sum_wv0);
sum_wv1 = fmaf(w, v1, sum_wv1);
if constexpr (kPack == 4) {
sum_wv2 = fmaf(w, v2, sum_wv2);
sum_wv3 = fmaf(w, v3, sum_wv3);
}
}
acc0 = acc0 * alpha + sum_wv0;
acc1 = acc1 * alpha + sum_wv1;
if constexpr (kPack == 4) {
acc2 = acc2 * alpha + sum_wv2;
acc3 = acc3 * alpha + sum_wv3;
}
const int prefetch_tile = tile_idx + STAGES;
if (prefetch_tile < num_tiles) {
const int token_prefetch = token_begin + prefetch_tile * TOKENS_PER_TILE;
const int prefetch_n = min(TOKENS_PER_TILE, token_end - token_prefetch);
for (int li = tid; li < LOADS_PER_TILE; li += CTA_THREADS) {
const int tok = li / CHUNKS;
const int chunk = li - tok * CHUNKS;
const int off = chunk * CHUNK_ELEMS;
if (tok < prefetch_n) {
const Tdata *k_src = k_base + (token_prefetch + tok) * k_row_stride + off;
const Tdata *v_src = v_base + (token_prefetch + tok) * v_row_stride + off;
cpAsyncCaSharedGlobal16(&sh_k[buf][tok][off], k_src);
cpAsyncCaSharedGlobal16(&sh_v[buf][tok][off], v_src);
} else {
reinterpret_cast<uint4 *>(&sh_k[buf][tok][off])[0] = make_uint4(0, 0, 0, 0);
reinterpret_cast<uint4 *>(&sh_v[buf][tok][off])[0] = make_uint4(0, 0, 0, 0);
}
}
cpAsyncCommit();
++pending_groups;
}
if (tile_idx + 1 < num_tiles) {
int desired_pending2 = pending_groups - 1;
if (desired_pending2 < 0) {
desired_pending2 = 0;
}
if (desired_pending2 > (STAGES - 1)) {
desired_pending2 = (STAGES - 1);
}
cpAsyncWaitGroupRt(desired_pending2);
pending_groups = desired_pending2;
if constexpr (NUM_WARPS == 1) {
__syncwarp();
} else {
__syncthreads();
}
}
}
cpAsyncWaitAll();
if constexpr (NUM_WARPS == 1) {
__syncwarp();
} else {
__syncthreads();
}
}
if (tid == 0) {
partial_m[idx] = m;
partial_l[idx] = l;
}
if constexpr (kPack == 2) {
partial_acc[idx * HEAD_SIZE + dim + 0] = acc0;
partial_acc[idx * HEAD_SIZE + dim + 1] = acc1;
} else {
partial_acc[idx * HEAD_SIZE + dim + 0] = acc0;
partial_acc[idx * HEAD_SIZE + dim + 1] = acc1;
partial_acc[idx * HEAD_SIZE + dim + 2] = acc2;
partial_acc[idx * HEAD_SIZE + dim + 3] = acc3;
}
}
template <typename Tindex, typename Tdata, int HEAD_SIZE>
__device__ void flashAttentionDecodeCtaPipelinedKernel(
Tdata *out_,
const Tdata *q_,
const Tdata *k_cache_,
const Tdata *v_cache_,
const Tindex *block_tables_,
const Tindex *cache_lens_,
const float *alibi_slopes_,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride) {
constexpr int kWarpSize = 32;
static_assert(HEAD_SIZE == 64 || HEAD_SIZE == 128, "Only head_size 64/128 supported in v0.4.");
static_assert(HEAD_SIZE % kWarpSize == 0, "HEAD_SIZE must be divisible by 32.");
constexpr int NUM_WARPS = HEAD_SIZE / kWarpSize;
const int seq_idx = blockIdx.y;
const int head_idx = blockIdx.x;
const int tid = threadIdx.x;
const int lane = tid % kWarpSize;
const int warp_id = tid / kWarpSize;
const int seq_len = static_cast<int>(cache_lens_[seq_idx]);
if (seq_len <= 0) {
return;
}
const int num_heads = gridDim.x;
const int num_queries_per_kv = num_heads / static_cast<int>(num_kv_heads);
const int kv_head_idx = head_idx / num_queries_per_kv;
const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx];
constexpr float kLog2e = 1.4426950408889634f;
const float scale_log2 = scale * kLog2e;
const Tindex *block_table = block_tables_ + seq_idx * static_cast<int>(max_num_blocks_per_seq);
const Tdata *q_ptr = q_ + seq_idx * q_stride + head_idx * HEAD_SIZE;
Tdata *out_ptr = out_ + seq_idx * o_stride + head_idx * HEAD_SIZE;
const float q_val = static_cast<float>(q_ptr[tid]);
float acc = 0.0f;
float m = -INFINITY;
float l = 0.0f;
__shared__ Tdata sh_k[2][HEAD_SIZE];
__shared__ Tdata sh_v[2][HEAD_SIZE];
__shared__ float warp_sums[NUM_WARPS];
__shared__ float alpha_s;
__shared__ float beta_s;
__shared__ int physical_block_s;
constexpr int CHUNK_ELEMS = 8; // 8 * 2 bytes = 16 bytes.
constexpr int CHUNKS = HEAD_SIZE / CHUNK_ELEMS;
const int pbs = static_cast<int>(page_block_size);
// Prefetch the very first token.
int buf = 0;
int t_base = 0;
int token_in_block = 0;
int logical_block = 0;
{
if (tid == 0) {
physical_block_s = static_cast<int>(block_table[0]);
}
__syncthreads();
const Tdata *k_base = k_cache_ + physical_block_s * k_batch_stride + kv_head_idx * k_head_stride;
const Tdata *v_base = v_cache_ + physical_block_s * v_batch_stride + kv_head_idx * v_head_stride;
if (tid < CHUNKS) {
const int off = tid * CHUNK_ELEMS;
cpAsyncCaSharedGlobal16(&sh_k[buf][off], (k_base + 0 * k_row_stride) + off);
cpAsyncCaSharedGlobal16(&sh_v[buf][off], (v_base + 0 * v_row_stride) + off);
}
cpAsyncCommit();
cpAsyncWaitAll();
__syncthreads();
}
for (int t = 0; t < seq_len; ++t) {
// Compute current token location within paged KV.
const int next_t = t + 1;
const bool has_next = next_t < seq_len;
if (has_next) {
const int next_block = next_t / pbs;
const int next_in_block = next_t - next_block * pbs;
if (next_block != logical_block) {
logical_block = next_block;
if (tid == 0) {
physical_block_s = static_cast<int>(block_table[logical_block]);
}
__syncthreads();
}
const Tdata *k_base = k_cache_ + physical_block_s * k_batch_stride + kv_head_idx * k_head_stride;
const Tdata *v_base = v_cache_ + physical_block_s * v_batch_stride + kv_head_idx * v_head_stride;
const Tdata *k_src = k_base + next_in_block * k_row_stride;
const Tdata *v_src = v_base + next_in_block * v_row_stride;
if (tid < CHUNKS) {
const int off = tid * CHUNK_ELEMS;
cpAsyncCaSharedGlobal16(&sh_k[buf ^ 1][off], k_src + off);
cpAsyncCaSharedGlobal16(&sh_v[buf ^ 1][off], v_src + off);
}
cpAsyncCommit();
}
// Dot: each thread handles one dim, reduce across head dim.
const float k_val = static_cast<float>(sh_k[buf][tid]);
float partial = q_val * k_val;
float warp_sum = warpReduceSum(partial);
if (lane == 0) {
warp_sums[warp_id] = warp_sum;
}
__syncthreads();
float qk = 0.0f;
if (warp_id == 0) {
float v = (lane < NUM_WARPS) ? warp_sums[lane] : 0.0f;
v = warpReduceSum(v);
if (lane == 0) {
qk = v;
float score = qk * scale_log2;
if (alibi_slope != 0.0f) {
score += (alibi_slope * static_cast<float>(t - (seq_len - 1))) * kLog2e;
}
const float m_new = fmaxf(m, score);
const float alpha = exp2f(m - m_new);
const float beta = exp2f(score - m_new);
l = l * alpha + beta;
m = m_new;
alpha_s = alpha;
beta_s = beta;
}
}
__syncthreads();
const float alpha = alpha_s;
const float beta = beta_s;
const float v_val = static_cast<float>(sh_v[buf][tid]);
acc = acc * alpha + beta * v_val;
if (has_next) {
cpAsyncWaitAll();
__syncthreads();
buf ^= 1;
}
}
__shared__ float inv_l_s;
if (tid == 0) {
inv_l_s = 1.0f / (l + 1e-6f);
}
__syncthreads();
out_ptr[tid] = static_cast<Tdata>(acc * inv_l_s);
}
template <typename Tindex, typename Tdata, int HEAD_SIZE, int CTA_THREADS, int TOKENS_PER_TILE>
__device__ void flashAttentionDecodeCtaKernel(
Tdata *out_,
const Tdata *q_,
const Tdata *k_cache_,
const Tdata *v_cache_,
const Tindex *block_tables_,
const Tindex *cache_lens_,
const float *alibi_slopes_,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride) {
constexpr int kWarpSize = 32;
static_assert(CTA_THREADS % kWarpSize == 0, "CTA_THREADS must be a multiple of 32.");
static_assert(TOKENS_PER_TILE > 0 && TOKENS_PER_TILE <= 16, "TOKENS_PER_TILE should stay small.");
constexpr int NUM_WARPS = CTA_THREADS / kWarpSize;
const int seq_idx = blockIdx.y;
const int head_idx = blockIdx.x;
const int tid = threadIdx.x;
const int lane = tid % kWarpSize;
const int warp_id = tid / kWarpSize;
// Each thread owns a small packed vector of head dims. This lets us shrink the
// CTA to 1-2 warps and reduce block-wide synchronization overhead.
static_assert(HEAD_SIZE % CTA_THREADS == 0, "HEAD_SIZE must be divisible by CTA_THREADS.");
constexpr int kPack = HEAD_SIZE / CTA_THREADS; // 2 (64@32t, 128@64t) or 4 (128@32t)
static_assert(kPack == 2 || kPack == 4, "v0.4 CTA tile kernel supports kPack=2/4 only.");
constexpr int kPackedDims = CTA_THREADS;
constexpr int kComputeWarps = (kPackedDims + kWarpSize - 1) / kWarpSize;
const int dim = tid * kPack;
const int seq_len = static_cast<int>(cache_lens_[seq_idx]);
if (seq_len <= 0) {
return;
}
const int num_heads = gridDim.x;
const int num_queries_per_kv = num_heads / static_cast<int>(num_kv_heads);
const int kv_head_idx = head_idx / num_queries_per_kv;
const Tindex *block_table = block_tables_ + seq_idx * static_cast<int>(max_num_blocks_per_seq);
// q/out are [num_seqs, num_heads, head_size]
const Tdata *q_ptr = q_ + seq_idx * q_stride + head_idx * HEAD_SIZE;
Tdata *out_ptr = out_ + seq_idx * o_stride + head_idx * HEAD_SIZE;
float q0 = 0.0f;
float q1 = 0.0f;
float q2 = 0.0f;
float q3 = 0.0f;
#if defined(__CUDA_ARCH__)
if constexpr (std::is_same_v<Tdata, half>) {
if constexpr (kPack == 2) {
const half2 qh2 = *reinterpret_cast<const half2 *>(q_ptr + dim);
const float2 qf = __half22float2(qh2);
q0 = qf.x;
q1 = qf.y;
} else {
const half2 qh2_0 = *reinterpret_cast<const half2 *>(q_ptr + dim + 0);
const half2 qh2_1 = *reinterpret_cast<const half2 *>(q_ptr + dim + 2);
const float2 qf0 = __half22float2(qh2_0);
const float2 qf1 = __half22float2(qh2_1);
q0 = qf0.x;
q1 = qf0.y;
q2 = qf1.x;
q3 = qf1.y;
}
} else if constexpr (std::is_same_v<Tdata, __nv_bfloat16>) {
if constexpr (kPack == 2) {
const __nv_bfloat162 qb2 = *reinterpret_cast<const __nv_bfloat162 *>(q_ptr + dim);
const float2 qf = __bfloat1622float2(qb2);
q0 = qf.x;
q1 = qf.y;
} else {
const __nv_bfloat162 qb2_0 = *reinterpret_cast<const __nv_bfloat162 *>(q_ptr + dim + 0);
const __nv_bfloat162 qb2_1 = *reinterpret_cast<const __nv_bfloat162 *>(q_ptr + dim + 2);
const float2 qf0 = __bfloat1622float2(qb2_0);
const float2 qf1 = __bfloat1622float2(qb2_1);
q0 = qf0.x;
q1 = qf0.y;
q2 = qf1.x;
q3 = qf1.y;
}
} else
#endif
{
q0 = static_cast<float>(q_ptr[dim + 0]);
q1 = static_cast<float>(q_ptr[dim + 1]);
if constexpr (kPack == 4) {
q2 = static_cast<float>(q_ptr[dim + 2]);
q3 = static_cast<float>(q_ptr[dim + 3]);
}
}
float acc0 = 0.0f;
float acc1 = 0.0f;
float acc2 = 0.0f;
float acc3 = 0.0f;
float m = -INFINITY;
float l = 0.0f;
// Only the compute warps contribute QK partial sums. Keeping this array
// compact reduces shared-memory traffic and bank pressure.
__shared__ float warp_sums[TOKENS_PER_TILE][kComputeWarps];
__shared__ float alpha_shared;
__shared__ float weights_shared[TOKENS_PER_TILE];
const int pbs = static_cast<int>(page_block_size);
static_assert(sizeof(Tdata) == 2, "CTA tile kernel assumes 16B chunks map to 8 elements for fp16/bf16.");
constexpr int CHUNK_ELEMS = 8; // 8 * 2 bytes = 16 bytes.
constexpr int CHUNKS = HEAD_SIZE / CHUNK_ELEMS;
constexpr int LOADS_PER_TILE = CHUNKS * TOKENS_PER_TILE;
// Multi-stage cp.async pipeline. Using >= 3 stages allows us to keep
// multiple groups in-flight and overlap global->shared copies with compute.
constexpr int STAGES = 3;
__shared__ __align__(16) Tdata sh_k[STAGES][TOKENS_PER_TILE][HEAD_SIZE];
__shared__ __align__(16) Tdata sh_v[STAGES][TOKENS_PER_TILE][HEAD_SIZE];
const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx];
constexpr float kLog2e = 1.4426950408889634f;
const float scale_log2 = scale * kLog2e;
int t_base = 0;
for (int logical_block = 0; t_base < seq_len; ++logical_block, t_base += pbs) {
const int physical_block = static_cast<int>(block_table[logical_block]);
const Tdata *k_base = k_cache_ + physical_block * k_batch_stride + kv_head_idx * k_head_stride;
const Tdata *v_base = v_cache_ + physical_block * v_batch_stride + kv_head_idx * v_head_stride;
const int token_end = min(pbs, seq_len - t_base);
const int num_tiles = (token_end + TOKENS_PER_TILE - 1) / TOKENS_PER_TILE;
if (num_tiles <= 0) {
continue;
}
int pending_groups = 0;
const int preload = min(STAGES, num_tiles);
for (int ti = 0; ti < preload; ++ti) {
const int token_in_block = ti * TOKENS_PER_TILE;
const int tile_n = min(TOKENS_PER_TILE, token_end - token_in_block);
for (int li = tid; li < LOADS_PER_TILE; li += CTA_THREADS) {
const int tok = li / CHUNKS;
const int chunk = li - tok * CHUNKS;
const int off = chunk * CHUNK_ELEMS;
if (tok < tile_n) {
const Tdata *k_src = k_base + (token_in_block + tok) * k_row_stride + off;
const Tdata *v_src = v_base + (token_in_block + tok) * v_row_stride + off;
cpAsyncCaSharedGlobal16(&sh_k[ti][tok][off], k_src);
cpAsyncCaSharedGlobal16(&sh_v[ti][tok][off], v_src);
} else {
reinterpret_cast<uint4 *>(&sh_k[ti][tok][off])[0] = make_uint4(0, 0, 0, 0);
reinterpret_cast<uint4 *>(&sh_v[ti][tok][off])[0] = make_uint4(0, 0, 0, 0);
}
}
cpAsyncCommit();
++pending_groups;
}
// Ensure tile 0 is ready. We want to keep up to (STAGES - 1) groups
// in flight for overlap, but still make forward progress in the tail
// when we stop issuing new prefetch groups.
int desired_pending = pending_groups - 1;
if (desired_pending < 0) {
desired_pending = 0;
}
if (desired_pending > (STAGES - 1)) {
desired_pending = (STAGES - 1);
}
cpAsyncWaitGroupRt(desired_pending);
pending_groups = desired_pending;
if constexpr (NUM_WARPS == 1) {
__syncwarp();
} else {
__syncthreads();
}
for (int tile_idx = 0; tile_idx < num_tiles; ++tile_idx) {
const int buf = tile_idx % STAGES;
const int token_in_block = tile_idx * TOKENS_PER_TILE;
const int tile_n = min(TOKENS_PER_TILE, token_end - token_in_block);
float partial[TOKENS_PER_TILE];
#pragma unroll
for (int j = 0; j < TOKENS_PER_TILE; ++j) {
if (j < tile_n) {
float k0 = 0.0f;
float k1 = 0.0f;
float k2 = 0.0f;
float k3 = 0.0f;
#if defined(__CUDA_ARCH__)
if constexpr (std::is_same_v<Tdata, half>) {
if constexpr (kPack == 2) {
const half2 kh2 = *reinterpret_cast<const half2 *>(&sh_k[buf][j][dim]);
const float2 kf = __half22float2(kh2);
k0 = kf.x;
k1 = kf.y;
} else {
const half2 kh2_0 = *reinterpret_cast<const half2 *>(&sh_k[buf][j][dim + 0]);
const half2 kh2_1 = *reinterpret_cast<const half2 *>(&sh_k[buf][j][dim + 2]);
const float2 kf0 = __half22float2(kh2_0);
const float2 kf1 = __half22float2(kh2_1);
k0 = kf0.x;
k1 = kf0.y;
k2 = kf1.x;
k3 = kf1.y;
}
} else if constexpr (std::is_same_v<Tdata, __nv_bfloat16>) {
if constexpr (kPack == 2) {
const __nv_bfloat162 kb2 = *reinterpret_cast<const __nv_bfloat162 *>(&sh_k[buf][j][dim]);
const float2 kf = __bfloat1622float2(kb2);
k0 = kf.x;
k1 = kf.y;
} else {
const __nv_bfloat162 kb2_0 = *reinterpret_cast<const __nv_bfloat162 *>(&sh_k[buf][j][dim + 0]);
const __nv_bfloat162 kb2_1 = *reinterpret_cast<const __nv_bfloat162 *>(&sh_k[buf][j][dim + 2]);
const float2 kf0 = __bfloat1622float2(kb2_0);
const float2 kf1 = __bfloat1622float2(kb2_1);
k0 = kf0.x;
k1 = kf0.y;
k2 = kf1.x;
k3 = kf1.y;
}
} else
#endif
{
k0 = static_cast<float>(sh_k[buf][j][dim + 0]);
k1 = static_cast<float>(sh_k[buf][j][dim + 1]);
if constexpr (kPack == 4) {
k2 = static_cast<float>(sh_k[buf][j][dim + 2]);
k3 = static_cast<float>(sh_k[buf][j][dim + 3]);
}
}
if constexpr (kPack == 2) {
partial[j] = fmaf(q0, k0, q1 * k1);
} else {
partial[j] = fmaf(q0, k0, fmaf(q1, k1, fmaf(q2, k2, q3 * k3)));
}
} else {
partial[j] = 0.0f;
}
}
#pragma unroll
for (int j = 0; j < TOKENS_PER_TILE; ++j) {
float sum = warpReduceSum(partial[j]);
// Only compute warps contribute to qk; load-only warps would
// otherwise write zeros and increase reduction overhead.
if (lane == 0 && warp_id < kComputeWarps) {
warp_sums[j][warp_id] = sum;
}
}
if constexpr (NUM_WARPS == 1) {
__syncwarp();
} else {
__syncthreads();
}
if (warp_id == 0) {
// Distribute token-wise score computation across lanes to avoid
// serial loops in lane0. TOKENS_PER_TILE <= 16 by construction.
float score = -INFINITY;
if (lane < TOKENS_PER_TILE && lane < tile_n) {
float qk = 0.0f;
#pragma unroll
for (int w = 0; w < kComputeWarps; ++w) {
qk += warp_sums[lane][w];
}
const int t = t_base + token_in_block + lane;
score = qk * scale_log2;
if (alibi_slope != 0.0f) {
score += (alibi_slope * static_cast<float>(t - (seq_len - 1))) * kLog2e;
}
}
float tile_max = warpReduceMax(score);
tile_max = __shfl_sync(0xffffffff, tile_max, 0);
float m_new = 0.0f;
if (lane == 0) {
m_new = fmaxf(m, tile_max);
}
m_new = __shfl_sync(0xffffffff, m_new, 0);
float w = 0.0f;
if (lane < TOKENS_PER_TILE && lane < tile_n) {
w = exp2f(score - m_new);
}
if (lane < TOKENS_PER_TILE) {
weights_shared[lane] = (lane < tile_n) ? w : 0.0f;
}
float tile_sum = warpReduceSum(w);
if (lane == 0) {
const float alpha = exp2f(m - m_new);
alpha_shared = alpha;
l = l * alpha + tile_sum;
m = m_new;
}
}
if constexpr (NUM_WARPS == 1) {
__syncwarp();
} else {
__syncthreads();
}
const float alpha = alpha_shared;
float sum_wv0 = 0.0f;
float sum_wv1 = 0.0f;
float sum_wv2 = 0.0f;
float sum_wv3 = 0.0f;
#pragma unroll
for (int j = 0; j < TOKENS_PER_TILE; ++j) {
const float w = weights_shared[j];
float v0 = 0.0f;
float v1 = 0.0f;
float v2 = 0.0f;
float v3 = 0.0f;
#if defined(__CUDA_ARCH__)
if constexpr (std::is_same_v<Tdata, half>) {
if constexpr (kPack == 2) {
const half2 vh2 = *reinterpret_cast<const half2 *>(&sh_v[buf][j][dim]);
const float2 vf = __half22float2(vh2);
v0 = vf.x;
v1 = vf.y;
} else {
const half2 vh2_0 = *reinterpret_cast<const half2 *>(&sh_v[buf][j][dim + 0]);
const half2 vh2_1 = *reinterpret_cast<const half2 *>(&sh_v[buf][j][dim + 2]);
const float2 vf0 = __half22float2(vh2_0);
const float2 vf1 = __half22float2(vh2_1);
v0 = vf0.x;
v1 = vf0.y;
v2 = vf1.x;
v3 = vf1.y;
}
} else if constexpr (std::is_same_v<Tdata, __nv_bfloat16>) {
if constexpr (kPack == 2) {
const __nv_bfloat162 vb2 = *reinterpret_cast<const __nv_bfloat162 *>(&sh_v[buf][j][dim]);
const float2 vf = __bfloat1622float2(vb2);
v0 = vf.x;
v1 = vf.y;
} else {
const __nv_bfloat162 vb2_0 = *reinterpret_cast<const __nv_bfloat162 *>(&sh_v[buf][j][dim + 0]);
const __nv_bfloat162 vb2_1 = *reinterpret_cast<const __nv_bfloat162 *>(&sh_v[buf][j][dim + 2]);
const float2 vf0 = __bfloat1622float2(vb2_0);
const float2 vf1 = __bfloat1622float2(vb2_1);
v0 = vf0.x;
v1 = vf0.y;
v2 = vf1.x;
v3 = vf1.y;
}
} else
#endif
{
v0 = static_cast<float>(sh_v[buf][j][dim + 0]);
v1 = static_cast<float>(sh_v[buf][j][dim + 1]);
if constexpr (kPack == 4) {
v2 = static_cast<float>(sh_v[buf][j][dim + 2]);
v3 = static_cast<float>(sh_v[buf][j][dim + 3]);
}
}
sum_wv0 = fmaf(w, v0, sum_wv0);
sum_wv1 = fmaf(w, v1, sum_wv1);
if constexpr (kPack == 4) {
sum_wv2 = fmaf(w, v2, sum_wv2);
sum_wv3 = fmaf(w, v3, sum_wv3);
}
}
acc0 = acc0 * alpha + sum_wv0;
acc1 = acc1 * alpha + sum_wv1;
if constexpr (kPack == 4) {
acc2 = acc2 * alpha + sum_wv2;
acc3 = acc3 * alpha + sum_wv3;
}
// Prefetch the tile that will reuse this buffer (STAGES steps ahead).
const int prefetch_tile = tile_idx + STAGES;
if (prefetch_tile < num_tiles) {
const int token_prefetch = prefetch_tile * TOKENS_PER_TILE;
const int prefetch_n = min(TOKENS_PER_TILE, token_end - token_prefetch);
for (int li = tid; li < LOADS_PER_TILE; li += CTA_THREADS) {
const int tok = li / CHUNKS;
const int chunk = li - tok * CHUNKS;
const int off = chunk * CHUNK_ELEMS;
if (tok < prefetch_n) {
const Tdata *k_src = k_base + (token_prefetch + tok) * k_row_stride + off;
const Tdata *v_src = v_base + (token_prefetch + tok) * v_row_stride + off;
cpAsyncCaSharedGlobal16(&sh_k[buf][tok][off], k_src);
cpAsyncCaSharedGlobal16(&sh_v[buf][tok][off], v_src);
} else {
reinterpret_cast<uint4 *>(&sh_k[buf][tok][off])[0] = make_uint4(0, 0, 0, 0);
reinterpret_cast<uint4 *>(&sh_v[buf][tok][off])[0] = make_uint4(0, 0, 0, 0);
}
}
cpAsyncCommit();
++pending_groups;
}
if (tile_idx + 1 < num_tiles) {
// Before consuming the next tile, ensure at least one group
// completes. In steady state we keep (STAGES - 1) in flight; in
// the tail (no more prefetches) we gradually drain.
int desired_pending = pending_groups - 1;
if (desired_pending < 0) {
desired_pending = 0;
}
if (desired_pending > (STAGES - 1)) {
desired_pending = (STAGES - 1);
}
cpAsyncWaitGroupRt(desired_pending);
pending_groups = desired_pending;
if constexpr (NUM_WARPS == 1) {
__syncwarp();
} else {
__syncthreads();
}
}
}
// Drain any in-flight async copies before moving to the next paged block.
cpAsyncWaitAll();
if constexpr (NUM_WARPS == 1) {
__syncwarp();
} else {
__syncthreads();
}
}
__shared__ float inv_l_shared;
if (tid == 0) {
inv_l_shared = 1.0f / (l + 1e-6f);
}
if constexpr (NUM_WARPS == 1) {
__syncwarp();
} else {
__syncthreads();
}
{
const float s = inv_l_shared;
const float o0 = acc0 * s;
const float o1 = acc1 * s;
const float o2 = acc2 * s;
const float o3 = acc3 * s;
#if defined(__CUDA_ARCH__)
if constexpr (std::is_same_v<Tdata, half>) {
out_ptr[dim + 0] = __float2half_rn(o0);
out_ptr[dim + 1] = __float2half_rn(o1);
if constexpr (kPack == 4) {
out_ptr[dim + 2] = __float2half_rn(o2);
out_ptr[dim + 3] = __float2half_rn(o3);
}
} else if constexpr (std::is_same_v<Tdata, __nv_bfloat16>) {
out_ptr[dim + 0] = __float2bfloat16_rn(o0);
out_ptr[dim + 1] = __float2bfloat16_rn(o1);
if constexpr (kPack == 4) {
out_ptr[dim + 2] = __float2bfloat16_rn(o2);
out_ptr[dim + 3] = __float2bfloat16_rn(o3);
}
} else
#endif
{
out_ptr[dim + 0] = static_cast<Tdata>(o0);
out_ptr[dim + 1] = static_cast<Tdata>(o1);
if constexpr (kPack == 4) {
out_ptr[dim + 2] = static_cast<Tdata>(o2);
out_ptr[dim + 3] = static_cast<Tdata>(o3);
}
}
}
}
// GQA/MQA fused decode kernel: one CTA computes outputs for NGROUPS query heads that
// share the same KV head. This reduces redundant K/V reads when num_heads > num_kv_heads.
//
// v0.4: implemented for head_dim=128 and NGROUPS=4 (common case: 32 Q heads / 8 KV heads).
template <typename Tindex, typename Tdata, int HEAD_SIZE, int CTA_THREADS, int TOKENS_PER_TILE, int NGROUPS>
__device__ void flashAttentionDecodeCtaGqaKernel(
Tdata *out_,
const Tdata *q_,
const Tdata *k_cache_,
const Tdata *v_cache_,
const Tindex *block_tables_,
const Tindex *cache_lens_,
const float *alibi_slopes_,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride) {
constexpr int kWarpSize = 32;
static_assert(HEAD_SIZE == 128, "v0.4 GQA fused CTA kernel is implemented for head_size=128 only.");
static_assert(NGROUPS == 4, "v0.4 GQA fused CTA kernel is implemented for NGROUPS=4 only.");
static_assert(CTA_THREADS % kWarpSize == 0, "CTA_THREADS must be a multiple of 32.");
static_assert(TOKENS_PER_TILE > 0 && TOKENS_PER_TILE <= 16, "TOKENS_PER_TILE should stay small.");
constexpr int NUM_WARPS = CTA_THREADS / kWarpSize;
// Pack dims per thread. For head_dim=128 and CTA_THREADS=64, kPack=2.
static_assert(HEAD_SIZE % CTA_THREADS == 0, "HEAD_SIZE must be divisible by CTA_THREADS.");
constexpr int kPack = HEAD_SIZE / CTA_THREADS;
static_assert(kPack == 2, "v0.4 GQA fused CTA kernel expects kPack=2.");
constexpr int kPackedDims = CTA_THREADS;
constexpr int kComputeWarps = (kPackedDims + kWarpSize - 1) / kWarpSize;
const int seq_idx = blockIdx.y;
const int kv_head_idx = blockIdx.x;
const int tid = threadIdx.x;
const int lane = tid % kWarpSize;
const int warp_id = tid / kWarpSize;
const int dim = tid * kPack;
const int seq_len = static_cast<int>(cache_lens_[seq_idx]);
if (seq_len <= 0) {
return;
}
// v0.4 limitation: alibi slopes are per query head; support can be added later.
if (alibi_slopes_ != nullptr) {
return;
}
const Tindex *block_table = block_tables_ + seq_idx * static_cast<int>(max_num_blocks_per_seq);
// q/out are [num_seqs, num_heads, head_size]. For a KV head, we handle NGROUPS query heads:
// q_head = kv_head * NGROUPS + g
float q0[NGROUPS];
float q1[NGROUPS];
#if defined(__CUDA_ARCH__)
if constexpr (std::is_same_v<Tdata, half>) {
#pragma unroll
for (int g = 0; g < NGROUPS; ++g) {
const int q_head = kv_head_idx * NGROUPS + g;
const Tdata *q_ptr = q_ + seq_idx * q_stride + q_head * HEAD_SIZE;
const half2 qh2 = *reinterpret_cast<const half2 *>(q_ptr + dim);
const float2 qf = __half22float2(qh2);
q0[g] = qf.x;
q1[g] = qf.y;
}
} else if constexpr (std::is_same_v<Tdata, __nv_bfloat16>) {
#pragma unroll
for (int g = 0; g < NGROUPS; ++g) {
const int q_head = kv_head_idx * NGROUPS + g;
const Tdata *q_ptr = q_ + seq_idx * q_stride + q_head * HEAD_SIZE;
const __nv_bfloat162 qb2 = *reinterpret_cast<const __nv_bfloat162 *>(q_ptr + dim);
const float2 qf = __bfloat1622float2(qb2);
q0[g] = qf.x;
q1[g] = qf.y;
}
} else
#endif
{
#pragma unroll
for (int g = 0; g < NGROUPS; ++g) {
const int q_head = kv_head_idx * NGROUPS + g;
const Tdata *q_ptr = q_ + seq_idx * q_stride + q_head * HEAD_SIZE;
q0[g] = static_cast<float>(q_ptr[dim + 0]);
q1[g] = static_cast<float>(q_ptr[dim + 1]);
}
}
float acc0[NGROUPS];
float acc1[NGROUPS];
float m[NGROUPS];
float l[NGROUPS];
#pragma unroll
for (int g = 0; g < NGROUPS; ++g) {
acc0[g] = 0.0f;
acc1[g] = 0.0f;
m[g] = -INFINITY;
l[g] = 0.0f;
}
__shared__ float warp_sums[NGROUPS][TOKENS_PER_TILE][kComputeWarps];
__shared__ float alpha_shared[NGROUPS];
__shared__ float weights_shared[NGROUPS][TOKENS_PER_TILE];
const int pbs = static_cast<int>(page_block_size);
constexpr float kLog2e = 1.4426950408889634f;
const float scale_log2 = scale * kLog2e;
static_assert(sizeof(Tdata) == 2, "CTA GQA kernel assumes fp16/bf16.");
constexpr int CHUNK_ELEMS = 8; // 8 * 2 bytes = 16 bytes.
constexpr int CHUNKS = HEAD_SIZE / CHUNK_ELEMS;
constexpr int LOADS_PER_TILE = CHUNKS * TOKENS_PER_TILE;
constexpr int STAGES = 3;
__shared__ __align__(16) Tdata sh_k[STAGES][TOKENS_PER_TILE][HEAD_SIZE];
__shared__ __align__(16) Tdata sh_v[STAGES][TOKENS_PER_TILE][HEAD_SIZE];
int t_base = 0;
for (int logical_block = 0; t_base < seq_len; ++logical_block, t_base += pbs) {
const int physical_block = static_cast<int>(block_table[logical_block]);
const Tdata *k_base = k_cache_ + physical_block * k_batch_stride + kv_head_idx * k_head_stride;
const Tdata *v_base = v_cache_ + physical_block * v_batch_stride + kv_head_idx * v_head_stride;
const int token_end = min(pbs, seq_len - t_base);
const int num_tiles = (token_end + TOKENS_PER_TILE - 1) / TOKENS_PER_TILE;
if (num_tiles <= 0) {
continue;
}
int pending_groups = 0;
const int preload = min(STAGES, num_tiles);
for (int ti = 0; ti < preload; ++ti) {
const int token_in_block = ti * TOKENS_PER_TILE;
const int tile_n = min(TOKENS_PER_TILE, token_end - token_in_block);
for (int li = tid; li < LOADS_PER_TILE; li += CTA_THREADS) {
const int tok = li / CHUNKS;
const int chunk = li - tok * CHUNKS;
const int off = chunk * CHUNK_ELEMS;
if (tok < tile_n) {
const Tdata *k_src = k_base + (token_in_block + tok) * k_row_stride + off;
const Tdata *v_src = v_base + (token_in_block + tok) * v_row_stride + off;
cpAsyncCaSharedGlobal16(&sh_k[ti][tok][off], k_src);
cpAsyncCaSharedGlobal16(&sh_v[ti][tok][off], v_src);
} else {
reinterpret_cast<uint4 *>(&sh_k[ti][tok][off])[0] = make_uint4(0, 0, 0, 0);
reinterpret_cast<uint4 *>(&sh_v[ti][tok][off])[0] = make_uint4(0, 0, 0, 0);
}
}
cpAsyncCommit();
++pending_groups;
}
int desired_pending = pending_groups - 1;
if (desired_pending < 0) {
desired_pending = 0;
}
if (desired_pending > (STAGES - 1)) {
desired_pending = (STAGES - 1);
}
cpAsyncWaitGroupRt(desired_pending);
pending_groups = desired_pending;
if constexpr (NUM_WARPS == 1) {
__syncwarp();
} else {
__syncthreads();
}
for (int tile_idx = 0; tile_idx < num_tiles; ++tile_idx) {
const int buf = tile_idx % STAGES;
const int token_in_block = tile_idx * TOKENS_PER_TILE;
const int tile_n = min(TOKENS_PER_TILE, token_end - token_in_block);
// Compute QK partial sums for each group and each token in the tile.
float partial_qk[NGROUPS][TOKENS_PER_TILE];
#pragma unroll
for (int g = 0; g < NGROUPS; ++g) {
#pragma unroll
for (int j = 0; j < TOKENS_PER_TILE; ++j) {
if (j < tile_n) {
float k0 = 0.0f;
float k1 = 0.0f;
#if defined(__CUDA_ARCH__)
if constexpr (std::is_same_v<Tdata, half>) {
const half2 kh2 = *reinterpret_cast<const half2 *>(&sh_k[buf][j][dim]);
const float2 kf = __half22float2(kh2);
k0 = kf.x;
k1 = kf.y;
} else if constexpr (std::is_same_v<Tdata, __nv_bfloat16>) {
const __nv_bfloat162 kb2 = *reinterpret_cast<const __nv_bfloat162 *>(&sh_k[buf][j][dim]);
const float2 kf = __bfloat1622float2(kb2);
k0 = kf.x;
k1 = kf.y;
} else
#endif
{
k0 = static_cast<float>(sh_k[buf][j][dim + 0]);
k1 = static_cast<float>(sh_k[buf][j][dim + 1]);
}
partial_qk[g][j] = fmaf(q0[g], k0, q1[g] * k1);
} else {
partial_qk[g][j] = 0.0f;
}
}
}
#pragma unroll
for (int g = 0; g < NGROUPS; ++g) {
#pragma unroll
for (int j = 0; j < TOKENS_PER_TILE; ++j) {
const float sum = warpReduceSum(partial_qk[g][j]);
if (lane == 0 && warp_id < kComputeWarps) {
warp_sums[g][j][warp_id] = sum;
}
}
}
if constexpr (NUM_WARPS == 1) {
__syncwarp();
} else {
__syncthreads();
}
if (warp_id == 0) {
#pragma unroll
for (int g = 0; g < NGROUPS; ++g) {
float score = -INFINITY;
if (lane < TOKENS_PER_TILE && lane < tile_n) {
float qk = 0.0f;
#pragma unroll
for (int w = 0; w < kComputeWarps; ++w) {
qk += warp_sums[g][lane][w];
}
score = qk * scale_log2;
}
float tile_max = warpReduceMax(score);
tile_max = __shfl_sync(0xffffffff, tile_max, 0);
float m_new = 0.0f;
if (lane == 0) {
m_new = fmaxf(m[g], tile_max);
}
m_new = __shfl_sync(0xffffffff, m_new, 0);
float w = 0.0f;
if (lane < TOKENS_PER_TILE && lane < tile_n) {
w = exp2f(score - m_new);
}
if (lane < TOKENS_PER_TILE) {
weights_shared[g][lane] = (lane < tile_n) ? w : 0.0f;
}
const float tile_sum = warpReduceSum(w);
if (lane == 0) {
const float alpha = exp2f(m[g] - m_new);
alpha_shared[g] = alpha;
l[g] = l[g] * alpha + tile_sum;
m[g] = m_new;
}
}
}
if constexpr (NUM_WARPS == 1) {
__syncwarp();
} else {
__syncthreads();
}
float alpha[NGROUPS];
float sum_wv0[NGROUPS];
float sum_wv1[NGROUPS];
#pragma unroll
for (int g = 0; g < NGROUPS; ++g) {
alpha[g] = alpha_shared[g];
sum_wv0[g] = 0.0f;
sum_wv1[g] = 0.0f;
}
#pragma unroll
for (int j = 0; j < TOKENS_PER_TILE; ++j) {
float v0 = 0.0f;
float v1 = 0.0f;
#if defined(__CUDA_ARCH__)
if constexpr (std::is_same_v<Tdata, half>) {
const half2 vh2 = *reinterpret_cast<const half2 *>(&sh_v[buf][j][dim]);
const float2 vf = __half22float2(vh2);
v0 = vf.x;
v1 = vf.y;
} else if constexpr (std::is_same_v<Tdata, __nv_bfloat16>) {
const __nv_bfloat162 vb2 = *reinterpret_cast<const __nv_bfloat162 *>(&sh_v[buf][j][dim]);
const float2 vf = __bfloat1622float2(vb2);
v0 = vf.x;
v1 = vf.y;
} else
#endif
{
v0 = static_cast<float>(sh_v[buf][j][dim + 0]);
v1 = static_cast<float>(sh_v[buf][j][dim + 1]);
}
#pragma unroll
for (int g = 0; g < NGROUPS; ++g) {
const float w = weights_shared[g][j];
sum_wv0[g] = fmaf(w, v0, sum_wv0[g]);
sum_wv1[g] = fmaf(w, v1, sum_wv1[g]);
}
}
#pragma unroll
for (int g = 0; g < NGROUPS; ++g) {
acc0[g] = acc0[g] * alpha[g] + sum_wv0[g];
acc1[g] = acc1[g] * alpha[g] + sum_wv1[g];
}
const int prefetch_tile = tile_idx + STAGES;
if (prefetch_tile < num_tiles) {
const int token_prefetch = prefetch_tile * TOKENS_PER_TILE;
const int prefetch_n = min(TOKENS_PER_TILE, token_end - token_prefetch);
for (int li = tid; li < LOADS_PER_TILE; li += CTA_THREADS) {
const int tok = li / CHUNKS;
const int chunk = li - tok * CHUNKS;
const int off = chunk * CHUNK_ELEMS;
if (tok < prefetch_n) {
const Tdata *k_src = k_base + (token_prefetch + tok) * k_row_stride + off;
const Tdata *v_src = v_base + (token_prefetch + tok) * v_row_stride + off;
cpAsyncCaSharedGlobal16(&sh_k[buf][tok][off], k_src);
cpAsyncCaSharedGlobal16(&sh_v[buf][tok][off], v_src);
} else {
reinterpret_cast<uint4 *>(&sh_k[buf][tok][off])[0] = make_uint4(0, 0, 0, 0);
reinterpret_cast<uint4 *>(&sh_v[buf][tok][off])[0] = make_uint4(0, 0, 0, 0);
}
}
cpAsyncCommit();
++pending_groups;
}
if (tile_idx + 1 < num_tiles) {
int desired_pending2 = pending_groups - 1;
if (desired_pending2 < 0) {
desired_pending2 = 0;
}
if (desired_pending2 > (STAGES - 1)) {
desired_pending2 = (STAGES - 1);
}
cpAsyncWaitGroupRt(desired_pending2);
pending_groups = desired_pending2;
if constexpr (NUM_WARPS == 1) {
__syncwarp();
} else {
__syncthreads();
}
}
}
cpAsyncWaitAll();
if constexpr (NUM_WARPS == 1) {
__syncwarp();
} else {
__syncthreads();
}
}
// Write outputs for each group.
__shared__ float inv_l_shared[NGROUPS];
if (tid < NGROUPS) {
inv_l_shared[tid] = 1.0f / (l[tid] + 1e-6f);
}
if constexpr (NUM_WARPS == 1) {
__syncwarp();
} else {
__syncthreads();
}
#pragma unroll
for (int g = 0; g < NGROUPS; ++g) {
const int q_head = kv_head_idx * NGROUPS + g;
Tdata *out_ptr = out_ + seq_idx * o_stride + q_head * HEAD_SIZE;
const float s = inv_l_shared[g];
const float o0 = acc0[g] * s;
const float o1 = acc1[g] * s;
#if defined(__CUDA_ARCH__)
if constexpr (std::is_same_v<Tdata, half>) {
out_ptr[dim + 0] = __float2half_rn(o0);
out_ptr[dim + 1] = __float2half_rn(o1);
} else if constexpr (std::is_same_v<Tdata, __nv_bfloat16>) {
out_ptr[dim + 0] = __float2bfloat16_rn(o0);
out_ptr[dim + 1] = __float2bfloat16_rn(o1);
} else
#endif
{
out_ptr[dim + 0] = static_cast<Tdata>(o0);
out_ptr[dim + 1] = static_cast<Tdata>(o1);
}
}
}
} // namespace op::paged_attention::cuda
#endif // __PAGED_ATTENTION_KERNEL_V2_CUH__
......@@ -13,92 +13,171 @@ class PagedAttentionInfo {
PagedAttentionInfo() = default;
public:
// --- Data Types and Scale ---
infiniDtype_t dtype;
infiniDtype_t index_dtype;
float scale;
// --- Shape Dimensions ---
size_t num_seqs;
size_t num_heads;
size_t num_kv_heads;
size_t head_size;
size_t block_size;
size_t page_block_size;
size_t max_num_blocks_per_seq;
// --- Strides for Memory Layout ---
ptrdiff_t q_stride;
ptrdiff_t kv_block_stride;
ptrdiff_t kv_head_stride;
ptrdiff_t k_batch_stride;
ptrdiff_t k_row_stride;
ptrdiff_t k_head_stride;
ptrdiff_t v_batch_stride;
ptrdiff_t v_row_stride;
ptrdiff_t v_head_stride;
ptrdiff_t o_stride;
ptrdiff_t block_table_batch_stride;
ptrdiff_t cache_lens_stride;
static utils::Result<PagedAttentionInfo> create(
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t seq_lens_desc,
infiniopTensorDescriptor_t cache_lens_desc,
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc,
float scale) {
auto dtype = q_desc->dtype();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32);
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16);
if (out_desc->dtype() != dtype || k_cache_desc->dtype() != dtype || v_cache_desc->dtype() != dtype) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (q_desc->ndim() != 3 || k_cache_desc->ndim() < 4 || v_cache_desc->ndim() < 4 || block_tables_desc->ndim() != 2 || seq_lens_desc->ndim() != 1) {
if (q_desc->ndim() != 3 || out_desc->ndim() != 3) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (k_cache_desc->ndim() != 4 || v_cache_desc->ndim() != 4) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (block_tables_desc->ndim() != 2 || cache_lens_desc->ndim() != 1) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (block_tables_desc->dtype() != INFINI_DTYPE_I64) {
CHECK_OR_RETURN(q_desc->stride(2) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES);
CHECK_OR_RETURN(out_desc->stride(2) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES);
CHECK_OR_RETURN(k_cache_desc->stride(3) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES);
CHECK_OR_RETURN(v_cache_desc->stride(3) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES);
const auto block_tables_dt = block_tables_desc->dtype();
const auto cache_lens_dt = cache_lens_desc->dtype();
const bool debug_dtype = (std::getenv("INFINIOP_FLASH_DEBUG_DTYPE") != nullptr);
const bool block_tables_ok = (block_tables_dt == INFINI_DTYPE_I64) || (block_tables_dt == INFINI_DTYPE_I32) || (block_tables_dt == INFINI_DTYPE_U32);
const bool cache_lens_ok = (cache_lens_dt == INFINI_DTYPE_I64) || (cache_lens_dt == INFINI_DTYPE_I32) || (cache_lens_dt == INFINI_DTYPE_U32);
if (!(block_tables_ok && cache_lens_ok)) {
if (debug_dtype) {
std::fprintf(stderr,
"[flash_attention] Bad index dtype: block_tables=%d cache_lens=%d (expected I32/I64/U32)\n",
static_cast<int>(block_tables_dt), static_cast<int>(cache_lens_dt));
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (block_tables_dt != cache_lens_dt) {
// Keep them consistent to simplify backend dispatch.
if (debug_dtype) {
std::fprintf(stderr,
"[flash_attention] Mismatched index dtype: block_tables=%d cache_lens=%d\n",
static_cast<int>(block_tables_dt), static_cast<int>(cache_lens_dt));
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
CHECK_OR_RETURN(block_tables_desc->stride(1) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES);
CHECK_OR_RETURN(cache_lens_desc->stride(0) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES);
if (seq_lens_desc->dtype() != INFINI_DTYPE_I64) {
if (alibi_slopes_desc.has_value() && alibi_slopes_desc.value() != nullptr) {
if (alibi_slopes_desc.value()->dtype() != INFINI_DTYPE_F32) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (alibi_slopes_desc.value()->ndim() != 1) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
CHECK_OR_RETURN(alibi_slopes_desc.value()->stride(0) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES);
}
// --- Extract shape dimensions ---
// Shapes
auto q_shape = q_desc->shape();
auto k_cache_shape = k_cache_desc->shape();
auto k_shape = k_cache_desc->shape();
const size_t num_seqs = q_shape[0];
const size_t num_heads = q_shape[1];
const size_t head_size = q_shape[2];
const size_t num_blocks = k_shape[0];
(void)num_blocks;
const size_t page_block_size = k_shape[2];
const size_t num_kv_heads = k_shape[1];
// if (page_block_size % 256 != 0) {
// printf("paged block size %zu\n", page_block_size);
// return INFINI_STATUS_BAD_TENSOR_SHAPE;
// }
if (head_size != 64 && head_size != 128) {
// First build only targets common FA2 head dims (expand later).
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (num_heads % num_kv_heads != 0) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (v_cache_desc->shape()[0] != k_shape[0] || v_cache_desc->shape()[1] != k_shape[1] || v_cache_desc->shape()[2] != k_shape[2] || v_cache_desc->shape()[3] != k_shape[3]) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
size_t num_seqs = q_shape[0];
size_t num_heads = q_shape[1];
size_t head_size = q_shape[2];
if (out_desc->shape()[0] != q_shape[0] || out_desc->shape()[1] != q_shape[1] || out_desc->shape()[2] != q_shape[2]) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (head_size != 16 && head_size != 32 && head_size != 64 && head_size != 128 && head_size != 256) {
std::cerr << "[Error] Now only supports head_size = 16/32/64/128/256, but got "
<< head_size << "." << std::endl;
if (cache_lens_desc->shape()[0] != num_seqs) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
size_t num_kv_heads = k_cache_shape[1];
size_t block_size = v_cache_desc->shape()[2]; // 使用V cache的block size维度更可靠
size_t max_num_blocks_per_seq = block_tables_desc->shape()[1];
const size_t max_num_blocks_per_seq = block_tables_desc->shape()[1];
// Strides (in elements)
const ptrdiff_t q_stride = q_desc->stride(0);
const ptrdiff_t o_stride = out_desc->stride(0);
const ptrdiff_t k_batch_stride = k_cache_desc->stride(0);
const ptrdiff_t k_row_stride = k_cache_desc->stride(2);
const ptrdiff_t k_head_stride = k_cache_desc->stride(1);
const ptrdiff_t v_batch_stride = v_cache_desc->stride(0);
const ptrdiff_t v_row_stride = v_cache_desc->stride(2);
const ptrdiff_t v_head_stride = v_cache_desc->stride(1);
// --- Calculate max_seq_len for shared memory allocation ---
// This is a safe upper bound.
// info.max_seq_len = info.max_num_blocks_per_seq * info.block_size;
// --- Extract strides for memory access ---
ptrdiff_t q_stride = q_desc->stride(0);
ptrdiff_t kv_block_stride = k_cache_desc->stride(0);
ptrdiff_t kv_head_stride = k_cache_desc->stride(1);
ptrdiff_t o_stride = out_desc->stride(0);
const ptrdiff_t block_table_batch_stride = block_tables_desc->stride(0);
const ptrdiff_t cache_lens_stride = cache_lens_desc->stride(0);
return utils::Result<PagedAttentionInfo>(PagedAttentionInfo{
dtype,
block_tables_dt,
scale,
num_seqs,
num_heads,
num_kv_heads,
head_size,
block_size,
page_block_size,
max_num_blocks_per_seq,
q_stride,
kv_block_stride,
kv_head_stride,
o_stride});
k_batch_stride,
k_row_stride,
k_head_stride,
v_batch_stride,
v_row_stride,
v_head_stride,
o_stride,
block_table_batch_stride,
cache_lens_stride,
});
}
};
......
#include <cuda_runtime.h>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include "../../../devices/nvidia/nvidia_common.cuh"
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
#include "../cuda/kernel_v2.cuh"
namespace op::paged_attention::nvidia {
namespace {
constexpr int kMaxSplits = 8;
constexpr size_t ceilDiv(size_t a, size_t b) {
return (a + b - 1) / b;
}
inline int getSmCount() {
int device = 0;
if (cudaGetDevice(&device) != cudaSuccess) {
return 0;
}
int sm_count = 0;
if (cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, device) != cudaSuccess) {
return 0;
}
return sm_count;
}
// A lightweight FA2-style "waves" heuristic.
//
// Important: our split-kv kernel shards the KV sequence length, so the main "work"
// dimension is tokens, not the number of pages. We use an upper bound for seqlen_k
// (max pages * page size), which matches common decode microbench where all seqs
// share the same cache length.
inline int chooseNumSplitsHeuristic(size_t num_heads, size_t num_seqs, size_t seqlen_k, int sm_count) {
if (sm_count <= 0) {
return 1;
}
if (num_heads == 0 || num_seqs == 0) {
return 1;
}
if (seqlen_k <= 256) {
return 1;
}
const size_t base_blocks = num_heads * num_seqs;
int best_splits = 1;
// Baseline: one kernel, base_blocks CTAs, each scanning seqlen_k tokens.
size_t best_score = (ceilDiv(base_blocks, static_cast<size_t>(sm_count)) * seqlen_k);
size_t prev_work_per_block = seqlen_k;
for (int s = 2; s <= kMaxSplits; ++s) {
const size_t blocks = base_blocks * static_cast<size_t>(s);
const size_t waves_split = ceilDiv(blocks, static_cast<size_t>(sm_count));
const size_t work_per_block = ceilDiv(seqlen_k, static_cast<size_t>(s));
// If this split count doesn't reduce per-block work vs the previous split, it's effectively redundant.
if (work_per_block == prev_work_per_block) {
continue;
}
prev_work_per_block = work_per_block;
// Combine is one extra kernel with base_blocks blocks; approximate as one more wave unit.
const size_t waves_combine = ceilDiv(base_blocks, static_cast<size_t>(sm_count));
const size_t score = waves_split * work_per_block + waves_combine;
if (score < best_score) {
best_score = score;
best_splits = s;
}
}
return best_splits;
}
} // namespace
inline bool envBool(const char *name) {
if (const char *env = std::getenv(name)) {
return (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0);
}
return false;
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL flashAttentionDecodeHd128Warp(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const Tindex *cache_lens,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride) {
op::paged_attention::cuda::flashAttentionDecodeWarpKernel<Tindex, Tdata, 128>(
out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride,
k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride,
v_head_stride, o_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL flashAttentionDecodeHd128Cta(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const Tindex *cache_lens,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride) {
// Default CTA variant (lower overhead).
op::paged_attention::cuda::flashAttentionDecodeCtaKernel<Tindex, Tdata, 128, 64, 8>(
out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride,
k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride,
v_head_stride, o_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL flashAttentionDecodeHd128CtaTile16(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const Tindex *cache_lens,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride) {
op::paged_attention::cuda::flashAttentionDecodeCtaKernel<Tindex, Tdata, 128, 64, 16>(
out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride,
k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride,
v_head_stride, o_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL flashAttentionDecodeHd128Cta32(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const Tindex *cache_lens,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride) {
// Experimental 1-warp CTA variant for head_dim=128 (kPack=4).
op::paged_attention::cuda::flashAttentionDecodeCtaKernel<Tindex, Tdata, 128, 32, 8>(
out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride,
k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride,
v_head_stride, o_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL flashAttentionDecodeHd128Cta32Tile16(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const Tindex *cache_lens,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride) {
op::paged_attention::cuda::flashAttentionDecodeCtaKernel<Tindex, Tdata, 128, 32, 16>(
out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride,
k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride,
v_head_stride, o_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL flashAttentionDecodeHd128CtaGqa4(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const Tindex *cache_lens,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride) {
// GQA fused kernel: CTA computes 4 query heads for one KV head (head_dim=128).
op::paged_attention::cuda::flashAttentionDecodeCtaGqaKernel<Tindex, Tdata, 128, 64, 8, 4>(
out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride,
k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride,
v_head_stride, o_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL flashAttentionDecodeHd128SplitKv(
float *partial_acc,
float *partial_m,
float *partial_l,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const Tindex *cache_lens,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
int num_splits) {
op::paged_attention::cuda::flashAttentionDecodeSplitKvWarpKernel<Tindex, Tdata, 128>(
partial_acc, partial_m, partial_l,
q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride,
k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride,
v_head_stride, num_splits);
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL flashAttentionDecodeHd128SplitKvCta(
float *partial_acc,
float *partial_m,
float *partial_l,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const Tindex *cache_lens,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
int num_splits) {
op::paged_attention::cuda::flashAttentionDecodeSplitKvCtaKernel<Tindex, Tdata, 128, 64, 8>(
partial_acc, partial_m, partial_l,
q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride,
k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride,
v_head_stride, num_splits);
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL flashAttentionDecodeHd128SplitKvCtaTile16(
float *partial_acc,
float *partial_m,
float *partial_l,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const Tindex *cache_lens,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
int num_splits) {
op::paged_attention::cuda::flashAttentionDecodeSplitKvCtaKernel<Tindex, Tdata, 128, 64, 16>(
partial_acc, partial_m, partial_l,
q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride,
k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride,
v_head_stride, num_splits);
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL flashAttentionDecodeHd128SplitKvCta32(
float *partial_acc,
float *partial_m,
float *partial_l,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const Tindex *cache_lens,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
int num_splits) {
op::paged_attention::cuda::flashAttentionDecodeSplitKvCtaKernel<Tindex, Tdata, 128, 32, 8>(
partial_acc, partial_m, partial_l,
q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride,
k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride,
v_head_stride, num_splits);
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL flashAttentionDecodeHd128SplitKvCta32Tile16(
float *partial_acc,
float *partial_m,
float *partial_l,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const Tindex *cache_lens,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
int num_splits) {
op::paged_attention::cuda::flashAttentionDecodeSplitKvCtaKernel<Tindex, Tdata, 128, 32, 16>(
partial_acc, partial_m, partial_l,
q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride,
k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride,
v_head_stride, num_splits);
}
template <typename Tdata>
INFINIOP_CUDA_KERNEL flashAttentionDecodeHd128SplitKvCombine(
Tdata *out,
const float *partial_acc,
const float *partial_m,
const float *partial_l,
int num_splits,
ptrdiff_t o_stride) {
op::paged_attention::cuda::flashAttentionDecodeSplitKvCombineWarpKernel<Tdata, 128>(
out, partial_acc, partial_m, partial_l, num_splits, o_stride);
}
template <typename Tindex>
infiniStatus_t launch_decode_hd128_impl(
void *workspace,
size_t workspace_size,
void *out,
const void *q,
const void *k_cache,
const void *v_cache,
infiniDtype_t dtype,
const Tindex *block_tables,
const Tindex *cache_lens,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
cudaStream_t stream) {
// Default decode config (2026-01-22):
// decode_flash_cta8_64_gqa_splitkv_4
// Users can override any knob via the corresponding INFINIOP_FLASH_* env vars.
bool use_cta = true;
if (const char *env = std::getenv("INFINIOP_FLASH_DECODE_KERNEL")) {
// Backward-compatible: any non-"cta" value means "warp".
use_cta = (std::strcmp(env, "cta") == 0);
}
bool use_gqa_fused = true;
if (const char *env = std::getenv("INFINIOP_FLASH_GQA_FUSED")) {
if (std::strcmp(env, "0") == 0 || std::strcmp(env, "false") == 0) {
use_gqa_fused = false;
} else {
use_gqa_fused = (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0);
}
}
int cta_tile = 8;
if (const char *env = std::getenv("INFINIOP_FLASH_CTA_TILE")) {
const int v = std::atoi(env);
if (v == 8 || v == 16) {
cta_tile = v;
}
}
int cta_threads = 64;
if (const char *env = std::getenv("INFINIOP_FLASH_CTA_THREADS")) {
const int v = std::atoi(env);
if (v == 32 || v == 64) {
cta_threads = v;
}
}
dim3 block(use_cta ? static_cast<uint32_t>(cta_threads) : 32);
bool use_split = true;
bool use_split_auto = false;
if (const char *env = std::getenv("INFINIOP_FLASH_DECODE_SPLITKV")) {
if (std::strcmp(env, "auto") == 0) {
use_split_auto = true;
use_split = false;
} else {
if (std::strcmp(env, "0") == 0 || std::strcmp(env, "false") == 0) {
use_split = false;
} else {
use_split = (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0);
}
}
}
int num_splits = 4;
bool fixed_num_splits = true;
if (const char *env = std::getenv("INFINIOP_FLASH_NUM_SPLITS")) {
if (std::strcmp(env, "auto") == 0) {
fixed_num_splits = false;
} else {
num_splits = std::atoi(env);
fixed_num_splits = (num_splits > 0);
}
}
if (num_splits < 1) {
num_splits = 1;
}
if (num_splits > kMaxSplits) {
num_splits = kMaxSplits;
}
const bool debug_dispatch = envBool("INFINIOP_FLASH_DEBUG_DISPATCH");
auto dump_dispatch = [&](const char *path) {
if (!debug_dispatch) {
return;
}
// Avoid spamming: only print when the key dispatch signature changes.
struct Sig {
const char *path;
int dtype;
size_t heads;
size_t kv_heads;
size_t seqs;
size_t pbs;
size_t max_blocks;
int cta_tile;
int cta_threads;
int split;
int split_auto;
int num_splits;
int fixed;
int gqa_fused;
};
static Sig last{};
static bool has_last = false;
Sig cur{
path,
static_cast<int>(dtype),
num_heads,
num_kv_heads,
num_seqs,
page_block_size,
max_num_blocks_per_seq,
cta_tile,
cta_threads,
static_cast<int>(use_split),
static_cast<int>(use_split_auto),
num_splits,
static_cast<int>(fixed_num_splits),
static_cast<int>(use_gqa_fused),
};
if (has_last && cur.path == last.path && cur.dtype == last.dtype && cur.heads == last.heads && cur.kv_heads == last.kv_heads && cur.seqs == last.seqs && cur.pbs == last.pbs && cur.max_blocks == last.max_blocks && cur.cta_tile == last.cta_tile && cur.cta_threads == last.cta_threads && cur.split == last.split && cur.split_auto == last.split_auto && cur.num_splits == last.num_splits && cur.fixed == last.fixed && cur.gqa_fused == last.gqa_fused) {
return;
}
last = cur;
has_last = true;
fprintf(stderr,
"[INFINIOP][paged_attention][hd128] dispatch: path=%s dtype=%d heads=%zu kv_heads=%zu seqs=%zu "
"pbs=%zu max_blocks=%zu cta_tile=%d cta_threads=%d split=%d split_auto=%d num_splits=%d fixed=%d gqa_fused=%d\n",
path, static_cast<int>(dtype), num_heads, num_kv_heads, num_seqs,
page_block_size, max_num_blocks_per_seq, cta_tile, cta_threads,
static_cast<int>(use_split), static_cast<int>(use_split_auto), num_splits, static_cast<int>(fixed_num_splits),
static_cast<int>(use_gqa_fused));
};
// Split-kv auto mode: decide whether to split based on a heuristic.
if (use_split_auto) {
// Approximate seqlen_k by the per-seq KV capacity (paged KV upper bound).
const size_t seqlen_k = max_num_blocks_per_seq * page_block_size;
const int sm_count = getSmCount();
num_splits = chooseNumSplitsHeuristic(num_heads, num_seqs, seqlen_k, sm_count);
if (const char *dbg = std::getenv("INFINIOP_FLASH_DEBUG_SPLITS")) {
if (std::strcmp(dbg, "1") == 0 || std::strcmp(dbg, "true") == 0) {
static size_t last_seqlen_k = 0;
if (last_seqlen_k != seqlen_k) {
last_seqlen_k = seqlen_k;
fprintf(stderr,
"[INFINIOP][paged_attention] splitkv auto(mode): sm=%d heads=%zu seqs=%zu seqlen_k~%zu -> num_splits=%d\n",
sm_count, num_heads, num_seqs, seqlen_k, num_splits);
}
}
}
// If auto picks 1, fall back to non-split to avoid extra workspace and kernel overhead.
use_split = (num_splits > 1);
}
// const bool debug_dispatch = [] {
// if (const char *env = std::getenv("INFINIOP_FLASH_DEBUG_DISPATCH")) {
// return (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0);
// }
// return false;
// }();
// const char *selected_path = "unknown";
// Optional: fuse GQA groups (4) when seqlen_q=1 decode and alibi is disabled.
// This reuses K/V loads across query heads that share the same KV head.
// Controlled by INFINIOP_FLASH_GQA_FUSED (default: enabled).
if (use_gqa_fused && use_cta && !use_split && alibi_slopes == nullptr && num_kv_heads > 0 && num_heads == num_kv_heads * 4) {
dump_dispatch("cta_gqa_fused");
dim3 grid_gqa(static_cast<uint64_t>(num_kv_heads), static_cast<uint64_t>(num_seqs), 1);
if (dtype == INFINI_DTYPE_F16) {
flashAttentionDecodeHd128CtaGqa4<Tindex, half><<<grid_gqa, 64, 0, stream>>>(
static_cast<half *>(out),
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, cache_lens, nullptr,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
return INFINI_STATUS_SUCCESS;
}
if (dtype == INFINI_DTYPE_BF16) {
flashAttentionDecodeHd128CtaGqa4<Tindex, __nv_bfloat16><<<grid_gqa, 64, 0, stream>>>(
static_cast<__nv_bfloat16 *>(out),
static_cast<const __nv_bfloat16 *>(q),
static_cast<const __nv_bfloat16 *>(k_cache),
static_cast<const __nv_bfloat16 *>(v_cache),
block_tables, cache_lens, nullptr,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
return INFINI_STATUS_SUCCESS;
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
dim3 grid(static_cast<uint64_t>(num_heads), static_cast<uint64_t>(num_seqs), 1);
if (use_split) {
dump_dispatch(use_cta ? "splitkv_cta" : "splitkv_warp");
// }
if (!fixed_num_splits) {
// Approximate seqlen_k by the per-seq KV capacity (paged KV upper bound).
const size_t seqlen_k = max_num_blocks_per_seq * page_block_size;
const int sm_count = getSmCount();
num_splits = chooseNumSplitsHeuristic(num_heads, num_seqs, seqlen_k, sm_count);
if (const char *dbg = std::getenv("INFINIOP_FLASH_DEBUG_SPLITS")) {
if (std::strcmp(dbg, "1") == 0 || std::strcmp(dbg, "true") == 0) {
static size_t last_seqlen_k = 0;
if (last_seqlen_k != seqlen_k) {
last_seqlen_k = seqlen_k;
fprintf(stderr,
"[INFINIOP][paged_attention] splitkv auto: sm=%d heads=%zu seqs=%zu seqlen_k~%zu -> num_splits=%d\n",
sm_count, num_heads, num_seqs, seqlen_k, num_splits);
}
}
}
}
const size_t n = num_seqs * num_heads;
const size_t acc_elems = static_cast<size_t>(kMaxSplits) * n * 128;
const size_t m_elems = static_cast<size_t>(kMaxSplits) * n;
const size_t l_elems = static_cast<size_t>(kMaxSplits) * n;
const size_t needed_bytes = (acc_elems + m_elems + l_elems) * sizeof(float);
if (workspace == nullptr || workspace_size < needed_bytes) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
float *ws = static_cast<float *>(workspace);
float *partial_acc = ws;
float *partial_m = partial_acc + acc_elems;
float *partial_l = partial_m + m_elems;
dim3 grid_split(static_cast<uint64_t>(num_heads), static_cast<uint64_t>(num_seqs), static_cast<uint64_t>(num_splits));
dim3 block_split(use_cta ? static_cast<uint32_t>(cta_threads) : 32);
if (dtype == INFINI_DTYPE_F16) {
if (use_cta) {
if (cta_threads == 32) {
if (cta_tile == 16) {
flashAttentionDecodeHd128SplitKvCta32Tile16<Tindex, half><<<grid_split, block_split, 0, stream>>>(
partial_acc, partial_m, partial_l,
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, num_splits);
} else {
flashAttentionDecodeHd128SplitKvCta32<Tindex, half><<<grid_split, block_split, 0, stream>>>(
partial_acc, partial_m, partial_l,
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, num_splits);
}
} else {
if (cta_tile == 16) {
flashAttentionDecodeHd128SplitKvCtaTile16<Tindex, half><<<grid_split, block_split, 0, stream>>>(
partial_acc, partial_m, partial_l,
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, num_splits);
} else {
flashAttentionDecodeHd128SplitKvCta<Tindex, half><<<grid_split, block_split, 0, stream>>>(
partial_acc, partial_m, partial_l,
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, num_splits);
}
}
} else {
flashAttentionDecodeHd128SplitKv<Tindex, half><<<grid_split, block_split, 0, stream>>>(
partial_acc, partial_m, partial_l,
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, num_splits);
}
flashAttentionDecodeHd128SplitKvCombine<half><<<grid, 32, 0, stream>>>(
static_cast<half *>(out), partial_acc, partial_m, partial_l, num_splits, o_stride);
return INFINI_STATUS_SUCCESS;
}
if (dtype == INFINI_DTYPE_BF16) {
if (use_cta) {
if (cta_threads == 32) {
if (cta_tile == 16) {
flashAttentionDecodeHd128SplitKvCta32Tile16<Tindex, __nv_bfloat16><<<grid_split, block_split, 0, stream>>>(
partial_acc, partial_m, partial_l,
static_cast<const __nv_bfloat16 *>(q),
static_cast<const __nv_bfloat16 *>(k_cache),
static_cast<const __nv_bfloat16 *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, num_splits);
} else {
flashAttentionDecodeHd128SplitKvCta32<Tindex, __nv_bfloat16><<<grid_split, block_split, 0, stream>>>(
partial_acc, partial_m, partial_l,
static_cast<const __nv_bfloat16 *>(q),
static_cast<const __nv_bfloat16 *>(k_cache),
static_cast<const __nv_bfloat16 *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, num_splits);
}
} else {
if (cta_tile == 16) {
flashAttentionDecodeHd128SplitKvCtaTile16<Tindex, __nv_bfloat16><<<grid_split, block_split, 0, stream>>>(
partial_acc, partial_m, partial_l,
static_cast<const __nv_bfloat16 *>(q),
static_cast<const __nv_bfloat16 *>(k_cache),
static_cast<const __nv_bfloat16 *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, num_splits);
} else {
flashAttentionDecodeHd128SplitKvCta<Tindex, __nv_bfloat16><<<grid_split, block_split, 0, stream>>>(
partial_acc, partial_m, partial_l,
static_cast<const __nv_bfloat16 *>(q),
static_cast<const __nv_bfloat16 *>(k_cache),
static_cast<const __nv_bfloat16 *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, num_splits);
}
}
} else {
flashAttentionDecodeHd128SplitKv<Tindex, __nv_bfloat16><<<grid_split, block_split, 0, stream>>>(
partial_acc, partial_m, partial_l,
static_cast<const __nv_bfloat16 *>(q),
static_cast<const __nv_bfloat16 *>(k_cache),
static_cast<const __nv_bfloat16 *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, num_splits);
}
flashAttentionDecodeHd128SplitKvCombine<__nv_bfloat16><<<grid, 32, 0, stream>>>(
static_cast<__nv_bfloat16 *>(out), partial_acc, partial_m, partial_l, num_splits, o_stride);
return INFINI_STATUS_SUCCESS;
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
dump_dispatch(use_cta ? "cta_nosplit" : "warp_nosplit");
if (dtype == INFINI_DTYPE_F16) {
if (use_cta) {
if (cta_tile == 16) {
if (cta_threads == 32) {
flashAttentionDecodeHd128Cta32Tile16<Tindex, half><<<grid, block, 0, stream>>>(
static_cast<half *>(out),
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
} else {
flashAttentionDecodeHd128CtaTile16<Tindex, half><<<grid, block, 0, stream>>>(
static_cast<half *>(out),
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
}
} else {
if (cta_threads == 32) {
flashAttentionDecodeHd128Cta32<Tindex, half><<<grid, block, 0, stream>>>(
static_cast<half *>(out),
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
} else {
flashAttentionDecodeHd128Cta<Tindex, half><<<grid, block, 0, stream>>>(
static_cast<half *>(out),
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
}
}
} else {
flashAttentionDecodeHd128Warp<Tindex, half><<<grid, block, 0, stream>>>(
static_cast<half *>(out),
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
}
return INFINI_STATUS_SUCCESS;
}
if (dtype == INFINI_DTYPE_BF16) {
if (use_cta) {
if (cta_tile == 16) {
if (cta_threads == 32) {
flashAttentionDecodeHd128Cta32Tile16<Tindex, __nv_bfloat16><<<grid, block, 0, stream>>>(
static_cast<__nv_bfloat16 *>(out),
static_cast<const __nv_bfloat16 *>(q),
static_cast<const __nv_bfloat16 *>(k_cache),
static_cast<const __nv_bfloat16 *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
} else {
flashAttentionDecodeHd128CtaTile16<Tindex, __nv_bfloat16><<<grid, block, 0, stream>>>(
static_cast<__nv_bfloat16 *>(out),
static_cast<const __nv_bfloat16 *>(q),
static_cast<const __nv_bfloat16 *>(k_cache),
static_cast<const __nv_bfloat16 *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
}
} else {
if (cta_threads == 32) {
flashAttentionDecodeHd128Cta32<Tindex, __nv_bfloat16><<<grid, block, 0, stream>>>(
static_cast<__nv_bfloat16 *>(out),
static_cast<const __nv_bfloat16 *>(q),
static_cast<const __nv_bfloat16 *>(k_cache),
static_cast<const __nv_bfloat16 *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
} else {
flashAttentionDecodeHd128Cta<Tindex, __nv_bfloat16><<<grid, block, 0, stream>>>(
static_cast<__nv_bfloat16 *>(out),
static_cast<const __nv_bfloat16 *>(q),
static_cast<const __nv_bfloat16 *>(k_cache),
static_cast<const __nv_bfloat16 *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
}
}
} else {
flashAttentionDecodeHd128Warp<Tindex, __nv_bfloat16><<<grid, block, 0, stream>>>(
static_cast<__nv_bfloat16 *>(out),
static_cast<const __nv_bfloat16 *>(q),
static_cast<const __nv_bfloat16 *>(k_cache),
static_cast<const __nv_bfloat16 *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
}
return INFINI_STATUS_SUCCESS;
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
infiniStatus_t launch_decode_hd128_i64(
void *workspace,
size_t workspace_size,
void *out,
const void *q,
const void *k_cache,
const void *v_cache,
infiniDtype_t dtype,
const int64_t *block_tables,
const int64_t *cache_lens,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
cudaStream_t stream) {
return launch_decode_hd128_impl<int64_t>(
workspace, workspace_size,
out, q, k_cache, v_cache, dtype, block_tables, cache_lens, alibi_slopes, num_heads, num_seqs,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, k_batch_stride, k_row_stride,
k_head_stride, v_batch_stride, v_row_stride, v_head_stride, o_stride, stream);
}
infiniStatus_t launch_decode_hd128_i32(
void *workspace,
size_t workspace_size,
void *out,
const void *q,
const void *k_cache,
const void *v_cache,
infiniDtype_t dtype,
const int32_t *block_tables,
const int32_t *cache_lens,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
cudaStream_t stream) {
return launch_decode_hd128_impl<int32_t>(
workspace, workspace_size,
out, q, k_cache, v_cache, dtype, block_tables, cache_lens, alibi_slopes, num_heads, num_seqs,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, k_batch_stride, k_row_stride,
k_head_stride, v_batch_stride, v_row_stride, v_head_stride, o_stride, stream);
}
infiniStatus_t launch_decode_hd128_u32(
void *workspace,
size_t workspace_size,
void *out,
const void *q,
const void *k_cache,
const void *v_cache,
infiniDtype_t dtype,
const uint32_t *block_tables,
const uint32_t *cache_lens,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
cudaStream_t stream) {
return launch_decode_hd128_impl<uint32_t>(
workspace, workspace_size,
out, q, k_cache, v_cache, dtype, block_tables, cache_lens, alibi_slopes, num_heads, num_seqs,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, k_batch_stride, k_row_stride,
k_head_stride, v_batch_stride, v_row_stride, v_head_stride, o_stride, stream);
}
} // namespace op::paged_attention::nvidia
#include <cuda_runtime.h>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include "../../../devices/nvidia/nvidia_common.cuh"
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
#include "../cuda/kernel_v2.cuh"
namespace op::paged_attention::nvidia {
namespace {
constexpr int kMaxSplits = 8;
constexpr size_t ceilDiv(size_t a, size_t b) {
return (a + b - 1) / b;
}
inline int getSmCount() {
int device = 0;
if (cudaGetDevice(&device) != cudaSuccess) {
return 0;
}
int sm_count = 0;
if (cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, device) != cudaSuccess) {
return 0;
}
return sm_count;
}
// A lightweight FA2-style "waves" heuristic.
//
// Important: our split-kv kernel shards the KV sequence length, so the main "work"
// dimension is tokens, not the number of pages. We use an upper bound for seqlen_k
// (max pages * page size), which matches common decode microbench where all seqs
// share the same cache length.
inline int chooseNumSplitsHeuristic(size_t num_heads, size_t num_seqs, size_t seqlen_k, int sm_count) {
if (sm_count <= 0) {
return 1;
}
if (num_heads == 0 || num_seqs == 0) {
return 1;
}
if (seqlen_k <= 256) {
return 1;
}
const size_t base_blocks = num_heads * num_seqs;
int best_splits = 1;
// Baseline: one kernel, base_blocks CTAs, each scanning seqlen_k tokens.
size_t best_score = (ceilDiv(base_blocks, static_cast<size_t>(sm_count)) * seqlen_k);
size_t prev_work_per_block = seqlen_k;
for (int s = 2; s <= kMaxSplits; ++s) {
const size_t blocks = base_blocks * static_cast<size_t>(s);
const size_t waves_split = ceilDiv(blocks, static_cast<size_t>(sm_count));
const size_t work_per_block = ceilDiv(seqlen_k, static_cast<size_t>(s));
// If this split count doesn't reduce per-block work vs the previous split, it's effectively redundant.
if (work_per_block == prev_work_per_block) {
continue;
}
prev_work_per_block = work_per_block;
// Combine is one extra kernel with base_blocks blocks; approximate as one more wave unit.
const size_t waves_combine = ceilDiv(base_blocks, static_cast<size_t>(sm_count));
const size_t score = waves_split * work_per_block + waves_combine;
if (score < best_score) {
best_score = score;
best_splits = s;
}
}
return best_splits;
}
} // namespace
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL flashAttentionDecodeHd64Warp(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const Tindex *cache_lens,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride) {
op::paged_attention::cuda::flashAttentionDecodeWarpKernel<Tindex, Tdata, 64>(
out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride,
k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride,
v_head_stride, o_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL flashAttentionDecodeHd64Cta(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const Tindex *cache_lens,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride) {
// Default CTA variant (lower overhead).
op::paged_attention::cuda::flashAttentionDecodeCtaKernel<Tindex, Tdata, 64, 32, 8>(
out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride,
k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride,
v_head_stride, o_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL flashAttentionDecodeHd64CtaTile16(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const Tindex *cache_lens,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride) {
op::paged_attention::cuda::flashAttentionDecodeCtaKernel<Tindex, Tdata, 64, 32, 16>(
out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride,
k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride,
v_head_stride, o_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL flashAttentionDecodeHd64SplitKv(
float *partial_acc,
float *partial_m,
float *partial_l,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const Tindex *cache_lens,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
int num_splits) {
op::paged_attention::cuda::flashAttentionDecodeSplitKvWarpKernel<Tindex, Tdata, 64>(
partial_acc, partial_m, partial_l,
q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride,
k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride,
v_head_stride, num_splits);
}
template <typename Tdata>
INFINIOP_CUDA_KERNEL flashAttentionDecodeHd64SplitKvCombine(
Tdata *out,
const float *partial_acc,
const float *partial_m,
const float *partial_l,
int num_splits,
ptrdiff_t o_stride) {
op::paged_attention::cuda::flashAttentionDecodeSplitKvCombineWarpKernel<Tdata, 64>(
out, partial_acc, partial_m, partial_l, num_splits, o_stride);
}
template <typename Tindex>
infiniStatus_t launch_decode_hd64_impl(
void *workspace,
size_t workspace_size,
void *out,
const void *q,
const void *k_cache,
const void *v_cache,
infiniDtype_t dtype,
const Tindex *block_tables,
const Tindex *cache_lens,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
cudaStream_t stream) {
dim3 grid(static_cast<uint64_t>(num_heads), static_cast<uint64_t>(num_seqs), 1);
bool use_cta = false;
if (const char *env = std::getenv("INFINIOP_FLASH_DECODE_KERNEL")) {
use_cta = (std::strcmp(env, "cta") == 0);
}
int cta_tile = 8;
if (const char *env = std::getenv("INFINIOP_FLASH_CTA_TILE")) {
const int v = std::atoi(env);
if (v == 8 || v == 16) {
cta_tile = v;
}
}
// For head_dim=64 we use a 1-warp CTA (32 threads) with packed loads.
dim3 block(32);
bool use_split = false;
if (const char *env = std::getenv("INFINIOP_FLASH_DECODE_SPLITKV")) {
use_split = (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0);
}
int num_splits = 4;
bool fixed_num_splits = false;
if (const char *env = std::getenv("INFINIOP_FLASH_NUM_SPLITS")) {
if (std::strcmp(env, "auto") == 0) {
fixed_num_splits = false;
} else {
num_splits = std::atoi(env);
fixed_num_splits = (num_splits > 0);
}
}
if (num_splits < 1) {
num_splits = 1;
}
if (num_splits > kMaxSplits) {
num_splits = kMaxSplits;
}
if (use_split) {
if (use_cta) {
// We currently only implement the split-kv path with warp kernels.
// The CTA kernel is a separate non-split implementation.
static bool warned = false;
if (!warned) {
warned = true;
fprintf(stderr,
"[INFINIOP][paged_attention] split-kv is enabled; ignoring INFINIOP_FLASH_DECODE_KERNEL=cta "
"(CTA split-kv not implemented yet)\n");
}
}
if (!fixed_num_splits) {
// Approximate seqlen_k by the per-seq KV capacity (paged KV upper bound).
const size_t seqlen_k = max_num_blocks_per_seq * page_block_size;
const int sm_count = getSmCount();
num_splits = chooseNumSplitsHeuristic(num_heads, num_seqs, seqlen_k, sm_count);
if (const char *dbg = std::getenv("INFINIOP_FLASH_DEBUG_SPLITS")) {
if (std::strcmp(dbg, "1") == 0 || std::strcmp(dbg, "true") == 0) {
static size_t last_seqlen_k = 0;
if (last_seqlen_k != seqlen_k) {
last_seqlen_k = seqlen_k;
fprintf(stderr,
"[INFINIOP][paged_attention] splitkv auto: sm=%d heads=%zu seqs=%zu seqlen_k~%zu -> num_splits=%d\n",
sm_count, num_heads, num_seqs, seqlen_k, num_splits);
}
}
}
}
const size_t n = num_seqs * num_heads;
const size_t acc_elems = static_cast<size_t>(kMaxSplits) * n * 64;
const size_t m_elems = static_cast<size_t>(kMaxSplits) * n;
const size_t l_elems = static_cast<size_t>(kMaxSplits) * n;
const size_t needed_bytes = (acc_elems + m_elems + l_elems) * sizeof(float);
if (workspace == nullptr || workspace_size < needed_bytes) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
float *ws = static_cast<float *>(workspace);
float *partial_acc = ws;
float *partial_m = partial_acc + acc_elems;
float *partial_l = partial_m + m_elems;
dim3 grid_split(static_cast<uint64_t>(num_heads), static_cast<uint64_t>(num_seqs), static_cast<uint64_t>(num_splits));
dim3 block_split(32);
if (dtype == INFINI_DTYPE_F16) {
flashAttentionDecodeHd64SplitKv<Tindex, half><<<grid_split, block_split, 0, stream>>>(
partial_acc, partial_m, partial_l,
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, num_splits);
flashAttentionDecodeHd64SplitKvCombine<half><<<grid, 32, 0, stream>>>(
static_cast<half *>(out), partial_acc, partial_m, partial_l, num_splits, o_stride);
return INFINI_STATUS_SUCCESS;
}
if (dtype == INFINI_DTYPE_BF16) {
flashAttentionDecodeHd64SplitKv<Tindex, __nv_bfloat16><<<grid_split, block_split, 0, stream>>>(
partial_acc, partial_m, partial_l,
static_cast<const __nv_bfloat16 *>(q),
static_cast<const __nv_bfloat16 *>(k_cache),
static_cast<const __nv_bfloat16 *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, num_splits);
flashAttentionDecodeHd64SplitKvCombine<__nv_bfloat16><<<grid, 32, 0, stream>>>(
static_cast<__nv_bfloat16 *>(out), partial_acc, partial_m, partial_l, num_splits, o_stride);
return INFINI_STATUS_SUCCESS;
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (dtype == INFINI_DTYPE_F16) {
if (use_cta) {
if (cta_tile == 16) {
flashAttentionDecodeHd64CtaTile16<Tindex, half><<<grid, block, 0, stream>>>(
static_cast<half *>(out),
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
} else {
flashAttentionDecodeHd64Cta<Tindex, half><<<grid, block, 0, stream>>>(
static_cast<half *>(out),
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
}
} else {
flashAttentionDecodeHd64Warp<Tindex, half><<<grid, block, 0, stream>>>(
static_cast<half *>(out),
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
}
return INFINI_STATUS_SUCCESS;
}
if (dtype == INFINI_DTYPE_BF16) {
if (use_cta) {
if (cta_tile == 16) {
flashAttentionDecodeHd64CtaTile16<Tindex, __nv_bfloat16><<<grid, block, 0, stream>>>(
static_cast<__nv_bfloat16 *>(out),
static_cast<const __nv_bfloat16 *>(q),
static_cast<const __nv_bfloat16 *>(k_cache),
static_cast<const __nv_bfloat16 *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
} else {
flashAttentionDecodeHd64Cta<Tindex, __nv_bfloat16><<<grid, block, 0, stream>>>(
static_cast<__nv_bfloat16 *>(out),
static_cast<const __nv_bfloat16 *>(q),
static_cast<const __nv_bfloat16 *>(k_cache),
static_cast<const __nv_bfloat16 *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
}
} else {
flashAttentionDecodeHd64Warp<Tindex, __nv_bfloat16><<<grid, block, 0, stream>>>(
static_cast<__nv_bfloat16 *>(out),
static_cast<const __nv_bfloat16 *>(q),
static_cast<const __nv_bfloat16 *>(k_cache),
static_cast<const __nv_bfloat16 *>(v_cache),
block_tables, cache_lens, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
q_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride);
}
return INFINI_STATUS_SUCCESS;
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
infiniStatus_t launch_decode_hd64_i64(
void *workspace,
size_t workspace_size,
void *out,
const void *q,
const void *k_cache,
const void *v_cache,
infiniDtype_t dtype,
const int64_t *block_tables,
const int64_t *cache_lens,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
cudaStream_t stream) {
return launch_decode_hd64_impl<int64_t>(
workspace, workspace_size,
out, q, k_cache, v_cache, dtype, block_tables, cache_lens, alibi_slopes, num_heads, num_seqs,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, k_batch_stride, k_row_stride,
k_head_stride, v_batch_stride, v_row_stride, v_head_stride, o_stride, stream);
}
infiniStatus_t launch_decode_hd64_i32(
void *workspace,
size_t workspace_size,
void *out,
const void *q,
const void *k_cache,
const void *v_cache,
infiniDtype_t dtype,
const int32_t *block_tables,
const int32_t *cache_lens,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
cudaStream_t stream) {
return launch_decode_hd64_impl<int32_t>(
workspace, workspace_size,
out, q, k_cache, v_cache, dtype, block_tables, cache_lens, alibi_slopes, num_heads, num_seqs,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, k_batch_stride, k_row_stride,
k_head_stride, v_batch_stride, v_row_stride, v_head_stride, o_stride, stream);
}
infiniStatus_t launch_decode_hd64_u32(
void *workspace,
size_t workspace_size,
void *out,
const void *q,
const void *k_cache,
const void *v_cache,
infiniDtype_t dtype,
const uint32_t *block_tables,
const uint32_t *cache_lens,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t q_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
cudaStream_t stream) {
return launch_decode_hd64_impl<uint32_t>(
workspace, workspace_size,
out, q, k_cache, v_cache, dtype, block_tables, cache_lens, alibi_slopes, num_heads, num_seqs,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, k_batch_stride, k_row_stride,
k_head_stride, v_batch_stride, v_row_stride, v_head_stride, o_stride, stream);
}
} // namespace op::paged_attention::nvidia
#include <cub/block/block_reduce.cuh>
#include <cuda_runtime.h>
#include "../../../devices/nvidia/nvidia_common.cuh"
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
#include <cstdint>
#include <cstdlib>
#include <cstring>
#include "../../../reduce/cuda/reduce.cuh"
#include "../cuda/kernel.cuh"
#include "../../../devices/nvidia/nvidia_common.cuh"
#include "paged_attention_nvidia.cuh"
template <typename Tdata, typename Tcompute, size_t HEAD_SIZE, size_t NUM_THREADS>
INFINIOP_CUDA_KERNEL pagedAttention(
Tdata *out, const Tdata *q, const Tdata *k_cache, const Tdata *v_cache,
const int64_t *block_tables, const int64_t *seq_lens, const float *alibi_slopes,
const size_t num_kv_heads, const float scale, const size_t max_num_blocks_per_seq,
const size_t block_size,
const ptrdiff_t q_stride,
const ptrdiff_t kv_block_stride,
const ptrdiff_t kv_head_stride,
const ptrdiff_t o_stride) {
op::paged_attention::cuda::pagedAttentionKernel<Tdata, Tcompute, HEAD_SIZE, NUM_THREADS>(
out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes, num_kv_heads, scale,
max_num_blocks_per_seq, block_size, q_stride, kv_block_stride, kv_head_stride, o_stride);
}
namespace op::paged_attention::nvidia {
infiniStatus_t launch_decode_hd64_i64(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
infiniDtype_t dtype, const int64_t *block_tables, const int64_t *cache_lens, const float *alibi_slopes,
size_t num_heads, size_t num_seqs, size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t page_block_size,
ptrdiff_t q_stride, ptrdiff_t k_batch_stride, ptrdiff_t k_row_stride, ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride, ptrdiff_t v_row_stride, ptrdiff_t v_head_stride, ptrdiff_t o_stride,
cudaStream_t stream);
infiniStatus_t launch_decode_hd64_i32(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
infiniDtype_t dtype, const int32_t *block_tables, const int32_t *cache_lens, const float *alibi_slopes,
size_t num_heads, size_t num_seqs, size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t page_block_size,
ptrdiff_t q_stride, ptrdiff_t k_batch_stride, ptrdiff_t k_row_stride, ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride, ptrdiff_t v_row_stride, ptrdiff_t v_head_stride, ptrdiff_t o_stride,
cudaStream_t stream);
infiniStatus_t launch_decode_hd64_u32(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
infiniDtype_t dtype, const uint32_t *block_tables, const uint32_t *cache_lens, const float *alibi_slopes,
size_t num_heads, size_t num_seqs, size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t page_block_size,
ptrdiff_t q_stride, ptrdiff_t k_batch_stride, ptrdiff_t k_row_stride, ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride, ptrdiff_t v_row_stride, ptrdiff_t v_head_stride, ptrdiff_t o_stride,
cudaStream_t stream);
infiniStatus_t launch_decode_hd128_i64(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
infiniDtype_t dtype, const int64_t *block_tables, const int64_t *cache_lens, const float *alibi_slopes,
size_t num_heads, size_t num_seqs, size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t page_block_size,
ptrdiff_t q_stride, ptrdiff_t k_batch_stride, ptrdiff_t k_row_stride, ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride, ptrdiff_t v_row_stride, ptrdiff_t v_head_stride, ptrdiff_t o_stride,
cudaStream_t stream);
infiniStatus_t launch_decode_hd128_i32(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
infiniDtype_t dtype, const int32_t *block_tables, const int32_t *cache_lens, const float *alibi_slopes,
size_t num_heads, size_t num_seqs, size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t page_block_size,
ptrdiff_t q_stride, ptrdiff_t k_batch_stride, ptrdiff_t k_row_stride, ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride, ptrdiff_t v_row_stride, ptrdiff_t v_head_stride, ptrdiff_t o_stride,
cudaStream_t stream);
infiniStatus_t launch_decode_hd128_u32(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
infiniDtype_t dtype, const uint32_t *block_tables, const uint32_t *cache_lens, const float *alibi_slopes,
size_t num_heads, size_t num_seqs, size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t page_block_size,
ptrdiff_t q_stride, ptrdiff_t k_batch_stride, ptrdiff_t k_row_stride, ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride, ptrdiff_t v_row_stride, ptrdiff_t v_head_stride, ptrdiff_t o_stride,
cudaStream_t stream);
struct Descriptor::Opaque {
std::shared_ptr<device::nvidia::Handle::Internal> internal;
};
......@@ -40,108 +79,284 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t seq_lens_desc,
infiniopTensorDescriptor_t cache_lens_desc,
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc,
float scale) {
auto info = PagedAttentionInfo::create(out_desc, q_desc, k_cache_desc, v_cache_desc, block_tables_desc, seq_lens_desc, alibi_slopes_desc, scale);
CHECK_RESULT(info);
auto info_res = PagedAttentionInfo::create(out_desc, q_desc, k_cache_desc, v_cache_desc, block_tables_desc, cache_lens_desc, alibi_slopes_desc, scale);
CHECK_RESULT(info_res);
auto info = info_res.take();
// Reserve workspace for optional split-kv decode (partial acc + m/l).
// Workspace is independent of runtime env toggles; kernels will clamp num_splits <= kMaxSplits.
constexpr size_t kMaxSplits = 8;
const size_t per_split = info.num_seqs * info.num_heads * (info.head_size + 2) * sizeof(float);
const size_t workspace_bytes = kMaxSplits * per_split;
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::nvidia::Handle *>(handle)->internal()},
info.take(), 0, handle->device, handle->device_id);
info, workspace_bytes, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <size_t HEAD_SIZE, size_t NUM_THREADS>
infiniStatus_t launchKernel(void *out, const void *q, const void *k_cache, const void *v_cache,
infiniDtype_t dtype,
const void *block_tables, const void *seq_lens, const void *alibi_slopes,
size_t num_heads, size_t num_seqs,
size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t block_size,
ptrdiff_t q_stride, ptrdiff_t kv_block_stride, ptrdiff_t kv_head_stride, ptrdiff_t o_stride,
cudaStream_t stream) {
dim3 grid(uint64_t(num_heads), uint64_t(num_seqs), 1);
dim3 block(NUM_THREADS);
size_t shared_mem_size = (HEAD_SIZE + max_num_blocks_per_seq * block_size + 2) * sizeof(float);
if (dtype == INFINI_DTYPE_F16) {
pagedAttention<half, float, HEAD_SIZE, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
(half *)out,
(const half *)q, (const half *)k_cache, (const half *)v_cache,
(const int64_t *)block_tables, (const int64_t *)seq_lens, (const float *)alibi_slopes, num_kv_heads,
scale, max_num_blocks_per_seq, block_size,
q_stride, kv_block_stride, kv_head_stride, o_stride);
} else if (dtype == INFINI_DTYPE_BF16) {
pagedAttention<__nv_bfloat16, float, HEAD_SIZE, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
(__nv_bfloat16 *)out, (const __nv_bfloat16 *)q, (const __nv_bfloat16 *)k_cache, (const __nv_bfloat16 *)v_cache,
(const int64_t *)block_tables, (const int64_t *)seq_lens, (const float *)alibi_slopes, num_kv_heads,
scale, max_num_blocks_per_seq, block_size,
q_stride, kv_block_stride, kv_head_stride, o_stride);
} else if (dtype == INFINI_DTYPE_F32) {
pagedAttention<float, float, HEAD_SIZE, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
(float *)out, (const float *)q, (const float *)k_cache, (const float *)v_cache,
(const int64_t *)block_tables, (const int64_t *)seq_lens, (const float *)alibi_slopes, num_kv_heads,
scale, max_num_blocks_per_seq, block_size,
q_stride, kv_block_stride, kv_head_stride, o_stride);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
const void *block_tables, const void *seq_lens, const void *alibi_slopes,
const void *block_tables, const void *cache_lens, const void *alibi_slopes,
void *stream_) const {
cudaStream_t stream = (cudaStream_t)stream_;
#define LAUNCH_HEADSIZE_BLOCKSIZE(__H_SIZE, __B_SIZE) \
launchKernel<__H_SIZE, __B_SIZE>( \
out, q, k_cache, v_cache, _info.dtype, block_tables, seq_lens, alibi_slopes, \
_info.num_heads, _info.num_seqs, \
_info.num_kv_heads, _info.scale, _info.max_num_blocks_per_seq, _info.block_size, \
_info.q_stride, _info.kv_block_stride, _info.kv_head_stride, _info.o_stride, \
stream);
#define SWITCH_HEAD_SIZE(__B_SIZE) \
switch (_info.head_size) { \
case 16: \
LAUNCH_HEADSIZE_BLOCKSIZE(16, __B_SIZE) \
break; \
case 32: \
LAUNCH_HEADSIZE_BLOCKSIZE(32, __B_SIZE) \
break; \
case 64: \
LAUNCH_HEADSIZE_BLOCKSIZE(64, __B_SIZE) \
break; \
case 128: \
LAUNCH_HEADSIZE_BLOCKSIZE(128, __B_SIZE) \
break; \
case 256: \
LAUNCH_HEADSIZE_BLOCKSIZE(256, __B_SIZE) \
break; \
default: \
return INFINI_STATUS_BAD_TENSOR_SHAPE; \
}
if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) {
SWITCH_HEAD_SIZE(CUDA_BLOCK_SIZE_1024)
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_512) {
SWITCH_HEAD_SIZE(CUDA_BLOCK_SIZE_512)
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) {
SWITCH_HEAD_SIZE(CUDA_BLOCK_SIZE_4096)
bool need_workspace = false;
if (const char *env = std::getenv("INFINIOP_FLASH_DECODE_SPLITKV")) {
// "auto" may enable split-kv depending on the runtime heuristic.
need_workspace = (std::strcmp(env, "auto") == 0) || (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0);
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
// Keep hd64 behavior unchanged, but for hd128 we default to split-kv decode, which needs workspace.
need_workspace = (_info.head_size == 128);
}
if (need_workspace && workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
#undef LAUNCH_HEADSIZE_BLOCKSIZE
#undef SWITCH_HEAD_SIZE
auto stream = static_cast<cudaStream_t>(stream_);
return INFINI_STATUS_SUCCESS;
const float *alibi_ptr = (alibi_slopes == nullptr) ? nullptr : static_cast<const float *>(alibi_slopes);
if (_info.index_dtype == INFINI_DTYPE_I64) {
const auto *block_table_i64 = static_cast<const int64_t *>(block_tables);
const auto *cache_lens_i64 = static_cast<const int64_t *>(cache_lens);
switch (_info.head_size) {
case 64:
return launch_decode_hd64_i64(
workspace, workspace_size,
out, q, k_cache, v_cache, _info.dtype,
block_table_i64, cache_lens_i64, alibi_ptr,
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.scale,
_info.max_num_blocks_per_seq, _info.page_block_size,
_info.q_stride, _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride,
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride,
_info.o_stride, stream);
case 128:
return launch_decode_hd128_i64(
workspace, workspace_size,
out, q, k_cache, v_cache, _info.dtype,
block_table_i64, cache_lens_i64, alibi_ptr,
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.scale,
_info.max_num_blocks_per_seq, _info.page_block_size,
_info.q_stride, _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride,
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride,
_info.o_stride, stream);
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
if (_info.index_dtype == INFINI_DTYPE_I32) {
const auto *block_table_i32 = static_cast<const int32_t *>(block_tables);
const auto *cache_lens_i32 = static_cast<const int32_t *>(cache_lens);
switch (_info.head_size) {
case 64:
return launch_decode_hd64_i32(
workspace, workspace_size,
out, q, k_cache, v_cache, _info.dtype,
block_table_i32, cache_lens_i32, alibi_ptr,
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.scale,
_info.max_num_blocks_per_seq, _info.page_block_size,
_info.q_stride, _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride,
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride,
_info.o_stride, stream);
case 128:
return launch_decode_hd128_i32(
workspace, workspace_size,
out, q, k_cache, v_cache, _info.dtype,
block_table_i32, cache_lens_i32, alibi_ptr,
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.scale,
_info.max_num_blocks_per_seq, _info.page_block_size,
_info.q_stride, _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride,
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride,
_info.o_stride, stream);
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
if (_info.index_dtype == INFINI_DTYPE_U32) {
const auto *block_table_u32 = static_cast<const uint32_t *>(block_tables);
const auto *cache_lens_u32 = static_cast<const uint32_t *>(cache_lens);
switch (_info.head_size) {
case 64:
return launch_decode_hd64_u32(
workspace, workspace_size,
out, q, k_cache, v_cache, _info.dtype,
block_table_u32, cache_lens_u32, alibi_ptr,
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.scale,
_info.max_num_blocks_per_seq, _info.page_block_size,
_info.q_stride, _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride,
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride,
_info.o_stride, stream);
case 128:
return launch_decode_hd128_u32(
workspace, workspace_size,
out, q, k_cache, v_cache, _info.dtype,
block_table_u32, cache_lens_u32, alibi_ptr,
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.scale,
_info.max_num_blocks_per_seq, _info.page_block_size,
_info.q_stride, _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride,
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride,
_info.o_stride, stream);
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} // namespace op::paged_attention::nvidia
// #include <cub/block/block_reduce.cuh>
// #include "../../../devices/nvidia/nvidia_common.cuh"
// #include "../../../devices/nvidia/nvidia_kernel_common.cuh"
// #include "../../../reduce/cuda/reduce.cuh"
// #include "../cuda/kernel.cuh"
// #include "paged_attention_nvidia.cuh"
// template <typename Tdata, typename Tcompute, size_t HEAD_SIZE, size_t NUM_THREADS>
// INFINIOP_CUDA_KERNEL pagedAttention(
// Tdata *out, const Tdata *q, const Tdata *k_cache, const Tdata *v_cache,
// const int64_t *block_tables, const int64_t *seq_lens, const float *alibi_slopes,
// const size_t num_kv_heads, const float scale, const size_t max_num_blocks_per_seq,
// const size_t block_size,
// const ptrdiff_t q_stride,
// const ptrdiff_t kv_block_stride,
// const ptrdiff_t kv_head_stride,
// const ptrdiff_t o_stride) {
// op::paged_attention::cuda::pagedAttentionKernel<Tdata, Tcompute, HEAD_SIZE, NUM_THREADS>(
// out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes, num_kv_heads, scale,
// max_num_blocks_per_seq, block_size, q_stride, kv_block_stride, kv_head_stride, o_stride);
// }
// namespace op::paged_attention::nvidia {
// struct Descriptor::Opaque {
// std::shared_ptr<device::nvidia::Handle::Internal> internal;
// };
// Descriptor::~Descriptor() {
// delete _opaque;
// }
// infiniStatus_t Descriptor::create(
// infiniopHandle_t handle,
// Descriptor **desc_ptr,
// infiniopTensorDescriptor_t out_desc,
// infiniopTensorDescriptor_t q_desc,
// infiniopTensorDescriptor_t k_cache_desc,
// infiniopTensorDescriptor_t v_cache_desc,
// infiniopTensorDescriptor_t block_tables_desc,
// infiniopTensorDescriptor_t seq_lens_desc,
// const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc,
// float scale) {
// auto info = PagedAttentionInfo::create(out_desc, q_desc, k_cache_desc, v_cache_desc, block_tables_desc, seq_lens_desc, alibi_slopes_desc, scale);
// CHECK_RESULT(info);
// *desc_ptr = new Descriptor(
// new Opaque{reinterpret_cast<device::nvidia::Handle *>(handle)->internal()},
// info.take(), 0, handle->device, handle->device_id);
// return INFINI_STATUS_SUCCESS;
// }
// template <size_t HEAD_SIZE, size_t NUM_THREADS>
// infiniStatus_t launchKernel(void *out, const void *q, const void *k_cache, const void *v_cache,
// infiniDtype_t dtype,
// const void *block_tables, const void *seq_lens, const void *alibi_slopes,
// size_t num_heads, size_t num_seqs,
// size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t block_size,
// ptrdiff_t q_stride, ptrdiff_t kv_block_stride, ptrdiff_t kv_head_stride, ptrdiff_t o_stride,
// cudaStream_t stream) {
// dim3 grid(uint64_t(num_heads), uint64_t(num_seqs), 1);
// dim3 block(NUM_THREADS);
// size_t shared_mem_size = (HEAD_SIZE + max_num_blocks_per_seq * block_size + 2) * sizeof(float);
// if (dtype == INFINI_DTYPE_F16) {
// pagedAttention<half, float, HEAD_SIZE, NUM_THREADS>
// <<<grid, block, shared_mem_size, stream>>>(
// (half *)out,
// (const half *)q, (const half *)k_cache, (const half *)v_cache,
// (const int64_t *)block_tables, (const int64_t *)seq_lens, (const float *)alibi_slopes, num_kv_heads,
// scale, max_num_blocks_per_seq, block_size,
// q_stride, kv_block_stride, kv_head_stride, o_stride);
// } else if (dtype == INFINI_DTYPE_BF16) {
// pagedAttention<__nv_bfloat16, float, HEAD_SIZE, NUM_THREADS>
// <<<grid, block, shared_mem_size, stream>>>(
// (__nv_bfloat16 *)out, (const __nv_bfloat16 *)q, (const __nv_bfloat16 *)k_cache, (const __nv_bfloat16 *)v_cache,
// (const int64_t *)block_tables, (const int64_t *)seq_lens, (const float *)alibi_slopes, num_kv_heads,
// scale, max_num_blocks_per_seq, block_size,
// q_stride, kv_block_stride, kv_head_stride, o_stride);
// } else if (dtype == INFINI_DTYPE_F32) {
// pagedAttention<float, float, HEAD_SIZE, NUM_THREADS>
// <<<grid, block, shared_mem_size, stream>>>(
// (float *)out, (const float *)q, (const float *)k_cache, (const float *)v_cache,
// (const int64_t *)block_tables, (const int64_t *)seq_lens, (const float *)alibi_slopes, num_kv_heads,
// scale, max_num_blocks_per_seq, block_size,
// q_stride, kv_block_stride, kv_head_stride, o_stride);
// } else {
// return INFINI_STATUS_BAD_TENSOR_DTYPE;
// }
// return INFINI_STATUS_SUCCESS;
// }
// infiniStatus_t Descriptor::calculate(
// void *workspace, size_t workspace_size,
// void *out, const void *q, const void *k_cache, const void *v_cache,
// const void *block_tables, const void *seq_lens, const void *alibi_slopes,
// void *stream_) const {
// cudaStream_t stream = (cudaStream_t)stream_;
// #define LAUNCH_HEADSIZE_BLOCKSIZE(__H_SIZE, __B_SIZE) \
// launchKernel<__H_SIZE, __B_SIZE>( \
// out, q, k_cache, v_cache, _info.dtype, block_tables, seq_lens, alibi_slopes, \
// _info.num_heads, _info.num_seqs, \
// _info.num_kv_heads, _info.scale, _info.max_num_blocks_per_seq, _info.block_size, \
// _info.q_stride, _info.kv_block_stride, _info.kv_head_stride, _info.o_stride, \
// stream);
// #define SWITCH_HEAD_SIZE(__B_SIZE) \
// switch (_info.head_size) { \
// case 16: \
// LAUNCH_HEADSIZE_BLOCKSIZE(16, __B_SIZE) \
// break; \
// case 32: \
// LAUNCH_HEADSIZE_BLOCKSIZE(32, __B_SIZE) \
// break; \
// case 64: \
// LAUNCH_HEADSIZE_BLOCKSIZE(64, __B_SIZE) \
// break; \
// case 128: \
// LAUNCH_HEADSIZE_BLOCKSIZE(128, __B_SIZE) \
// break; \
// case 256: \
// LAUNCH_HEADSIZE_BLOCKSIZE(256, __B_SIZE) \
// break; \
// default: \
// return INFINI_STATUS_BAD_TENSOR_SHAPE; \
// }
// if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) {
// SWITCH_HEAD_SIZE(CUDA_BLOCK_SIZE_1024)
// } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_512) {
// SWITCH_HEAD_SIZE(CUDA_BLOCK_SIZE_512)
// } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) {
// SWITCH_HEAD_SIZE(CUDA_BLOCK_SIZE_4096)
// } else {
// return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
// }
// #undef LAUNCH_HEADSIZE_BLOCKSIZE
// #undef SWITCH_HEAD_SIZE
// return INFINI_STATUS_SUCCESS;
// }
// } // namespace op::paged_attention::nvidia
#ifndef __PAGED_ATTENTION_PREFILL_KERNEL_V2_CUH__
#define __PAGED_ATTENTION_PREFILL_KERNEL_V2_CUH__
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <mma.h>
#include <cstdint>
#include <type_traits>
// Reuse warp-level primitives and math helpers from decode flash_attention kernels.
#include "../../paged_attention/cuda/kernel_v2.cuh"
namespace op::paged_attention_prefill::cuda {
__device__ __forceinline__ size_t find_seq_id(size_t token_idx, const int64_t *cu_seqlens_q, size_t num_seqs) {
size_t low = 0, high = (num_seqs == 0) ? 0 : (num_seqs - 1);
while (low <= high) {
size_t mid = (low + high) >> 1;
const size_t start = static_cast<size_t>(cu_seqlens_q[mid]);
const size_t end = static_cast<size_t>(cu_seqlens_q[mid + 1]);
if (token_idx >= start && token_idx < end) {
return mid;
} else if (token_idx < start) {
if (mid == 0) {
break;
}
high = mid - 1;
} else {
low = mid + 1;
}
}
return 0;
}
template <typename Tindex, typename Tdata, int HEAD_SIZE>
__device__ void PagedAttentionPrefillWarpKernel(
Tdata *out_,
const Tdata *q_,
const Tdata *k_cache_,
const Tdata *v_cache_,
const Tindex *block_tables_,
const int64_t *total_kv_lens_,
const int64_t *cu_seqlens_q_,
const float *alibi_slopes_,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
constexpr int kWarpSize = 32;
static_assert(HEAD_SIZE == 64 || HEAD_SIZE == 128, "Only head_size 64/128 supported in v0.4.");
static_assert(HEAD_SIZE % kWarpSize == 0, "HEAD_SIZE must be divisible by 32.");
constexpr int DIMS_PER_THREAD = HEAD_SIZE / kWarpSize;
const int lane = threadIdx.x;
const int head_idx = static_cast<int>(blockIdx.x);
const int seq_idx = static_cast<int>(blockIdx.y);
const int q_token_local = static_cast<int>(blockIdx.z);
const int64_t q_start = cu_seqlens_q_[seq_idx];
const int64_t q_end = cu_seqlens_q_[seq_idx + 1];
const int q_len = static_cast<int>(q_end - q_start);
if (q_token_local >= q_len) {
return;
}
const int kv_len_total = static_cast<int>(total_kv_lens_[seq_idx]);
const int history_len = kv_len_total - q_len;
const int allowed_k_len = history_len + q_token_local + 1;
if (allowed_k_len <= 0) {
return;
}
const int num_heads = gridDim.x;
const int num_queries_per_kv = num_heads / static_cast<int>(num_kv_heads);
const int kv_head_idx = head_idx / num_queries_per_kv;
const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx];
constexpr float kLog2e = 1.4426950408889634f;
const float scale_log2 = scale * kLog2e;
const int64_t q_token = q_start + static_cast<int64_t>(q_token_local);
const Tdata *q_ptr = q_ + q_token * q_stride + static_cast<int64_t>(head_idx) * q_head_stride;
Tdata *out_ptr = out_ + q_token * o_stride + static_cast<int64_t>(head_idx) * o_head_stride;
const Tindex *block_table = block_tables_ + static_cast<int64_t>(seq_idx) * static_cast<int64_t>(block_table_batch_stride);
float q_reg[DIMS_PER_THREAD];
float acc[DIMS_PER_THREAD];
#pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
const int dim = lane * DIMS_PER_THREAD + i;
q_reg[i] = static_cast<float>(q_ptr[dim]);
acc[i] = 0.0f;
}
#if defined(__CUDA_ARCH__)
float2 q_reg2[DIMS_PER_THREAD / 2];
if constexpr (std::is_same_v<Tdata, half>) {
const int dim_base = lane * DIMS_PER_THREAD;
const half2 *q2 = reinterpret_cast<const half2 *>(q_ptr + dim_base);
#pragma unroll
for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) {
q_reg2[j] = __half22float2(q2[j]);
}
}
if constexpr (std::is_same_v<Tdata, __nv_bfloat16>) {
const int dim_base = lane * DIMS_PER_THREAD;
const __nv_bfloat162 *q2 = reinterpret_cast<const __nv_bfloat162 *>(q_ptr + dim_base);
#pragma unroll
for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) {
q_reg2[j] = __bfloat1622float2(q2[j]);
}
}
#endif
float m = -INFINITY;
float l = 0.0f;
const int pbs = static_cast<int>(page_block_size);
int t_base = 0;
for (int logical_block = 0; t_base < allowed_k_len; ++logical_block, t_base += pbs) {
int physical_block = 0;
if (lane == 0) {
physical_block = static_cast<int>(block_table[logical_block]);
}
physical_block = __shfl_sync(0xffffffff, physical_block, 0);
const Tdata *k_base = k_cache_ + static_cast<int64_t>(physical_block) * k_batch_stride + static_cast<int64_t>(kv_head_idx) * k_head_stride;
const Tdata *v_base = v_cache_ + static_cast<int64_t>(physical_block) * v_batch_stride + static_cast<int64_t>(kv_head_idx) * v_head_stride;
const int token_end = min(pbs, allowed_k_len - t_base);
for (int token_in_block = 0; token_in_block < token_end; ++token_in_block) {
const int t = t_base + token_in_block;
const Tdata *k_ptr = k_base + static_cast<int64_t>(token_in_block) * k_row_stride;
const Tdata *v_ptr = v_base + static_cast<int64_t>(token_in_block) * v_row_stride;
float qk = 0.0f;
#if defined(__CUDA_ARCH__)
if constexpr (std::is_same_v<Tdata, half>) {
const int dim_base = lane * DIMS_PER_THREAD;
const half2 *k2 = reinterpret_cast<const half2 *>(k_ptr + dim_base);
#pragma unroll
for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) {
const float2 qf = q_reg2[j];
const float2 kf = __half22float2(k2[j]);
qk += qf.x * kf.x + qf.y * kf.y;
}
} else if constexpr (std::is_same_v<Tdata, __nv_bfloat16>) {
const int dim_base = lane * DIMS_PER_THREAD;
const __nv_bfloat162 *k2 = reinterpret_cast<const __nv_bfloat162 *>(k_ptr + dim_base);
#pragma unroll
for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) {
const float2 qf = q_reg2[j];
const float2 kf = __bfloat1622float2(k2[j]);
qk += qf.x * kf.x + qf.y * kf.y;
}
} else
#endif
#pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
const int dim = lane * DIMS_PER_THREAD + i;
qk += q_reg[i] * static_cast<float>(k_ptr[dim]);
}
qk = op::paged_attention::cuda::warpReduceSum(qk);
float alpha = 1.0f;
float beta = 0.0f;
if (lane == 0) {
float score = qk * scale_log2;
if (alibi_slope != 0.0f) {
const int causal_limit = allowed_k_len - 1;
score += (alibi_slope * static_cast<float>(t - causal_limit)) * kLog2e;
}
const float m_new = fmaxf(m, score);
alpha = exp2f(m - m_new);
beta = exp2f(score - m_new);
l = l * alpha + beta;
m = m_new;
}
alpha = __shfl_sync(0xffffffff, alpha, 0);
beta = __shfl_sync(0xffffffff, beta, 0);
#if defined(__CUDA_ARCH__)
if constexpr (std::is_same_v<Tdata, half>) {
const int dim_base = lane * DIMS_PER_THREAD;
const half2 *v2 = reinterpret_cast<const half2 *>(v_ptr + dim_base);
#pragma unroll
for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) {
const float2 vf = __half22float2(v2[j]);
acc[j * 2 + 0] = acc[j * 2 + 0] * alpha + beta * vf.x;
acc[j * 2 + 1] = acc[j * 2 + 1] * alpha + beta * vf.y;
}
} else if constexpr (std::is_same_v<Tdata, __nv_bfloat16>) {
const int dim_base = lane * DIMS_PER_THREAD;
const __nv_bfloat162 *v2 = reinterpret_cast<const __nv_bfloat162 *>(v_ptr + dim_base);
#pragma unroll
for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) {
const float2 vf = __bfloat1622float2(v2[j]);
acc[j * 2 + 0] = acc[j * 2 + 0] * alpha + beta * vf.x;
acc[j * 2 + 1] = acc[j * 2 + 1] * alpha + beta * vf.y;
}
} else
#endif
{
#pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
const int dim = lane * DIMS_PER_THREAD + i;
const float v_val = static_cast<float>(v_ptr[dim]);
acc[i] = acc[i] * alpha + beta * v_val;
}
}
}
}
float inv_l = 0.0f;
if (lane == 0) {
inv_l = 1.0f / (l + 1e-6f);
}
inv_l = __shfl_sync(0xffffffff, inv_l, 0);
#pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
const int dim = lane * DIMS_PER_THREAD + i;
const float o = acc[i] * inv_l;
if constexpr (std::is_same_v<Tdata, half>) {
out_ptr[dim] = __float2half_rn(o);
} else if constexpr (std::is_same_v<Tdata, __nv_bfloat16>) {
out_ptr[dim] = __float2bfloat16_rn(o);
} else {
out_ptr[dim] = static_cast<Tdata>(o);
}
}
}
template <typename Tindex, typename Tdata, int HEAD_SIZE>
__global__ void PagedAttentionPrefillWarpGlobalKernel(
Tdata *out_,
const Tdata *q_,
const Tdata *k_cache_,
const Tdata *v_cache_,
const Tindex *block_tables_,
const int64_t *total_kv_lens_,
const int64_t *cu_seqlens_q_,
const float *alibi_slopes_,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
constexpr int kWarpSize = 32;
static_assert(HEAD_SIZE == 64 || HEAD_SIZE == 128, "Only head_size 64/128 supported in v0.4.");
static_assert(HEAD_SIZE % kWarpSize == 0, "HEAD_SIZE must be divisible by 32.");
constexpr int DIMS_PER_THREAD = HEAD_SIZE / kWarpSize;
const int lane = threadIdx.x;
const size_t head_idx = static_cast<size_t>(blockIdx.x);
const size_t global_token_idx = static_cast<size_t>(blockIdx.y);
if (lane >= kWarpSize || head_idx >= num_heads || global_token_idx >= total_q_tokens) {
return;
}
const size_t seq_idx = find_seq_id(global_token_idx, cu_seqlens_q_, num_seqs);
const int64_t q_start = cu_seqlens_q_[seq_idx];
const int64_t q_end = cu_seqlens_q_[seq_idx + 1];
const int q_len = static_cast<int>(q_end - q_start);
const int q_token_local = static_cast<int>(global_token_idx - static_cast<size_t>(q_start));
if (q_token_local < 0 || q_token_local >= q_len) {
return;
}
const int kv_len_total = static_cast<int>(total_kv_lens_[seq_idx]);
const int history_len = kv_len_total - q_len;
const int allowed_k_len = history_len + q_token_local + 1;
if (allowed_k_len <= 0) {
return;
}
const int num_queries_per_kv = static_cast<int>(num_heads / num_kv_heads);
const int kv_head_idx = static_cast<int>(head_idx) / num_queries_per_kv;
const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx];
constexpr float kLog2e = 1.4426950408889634f;
const float scale_log2 = scale * kLog2e;
const Tdata *q_ptr = q_ + static_cast<int64_t>(global_token_idx) * q_stride + static_cast<int64_t>(head_idx) * q_head_stride;
Tdata *out_ptr = out_ + static_cast<int64_t>(global_token_idx) * o_stride + static_cast<int64_t>(head_idx) * o_head_stride;
const Tindex *block_table = block_tables_ + static_cast<int64_t>(seq_idx) * static_cast<int64_t>(block_table_batch_stride);
const int pbs = static_cast<int>(page_block_size);
float q_reg[DIMS_PER_THREAD];
float acc[DIMS_PER_THREAD];
#pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
const int dim = lane * DIMS_PER_THREAD + i;
q_reg[i] = static_cast<float>(q_ptr[dim]);
acc[i] = 0.0f;
}
#if defined(__CUDA_ARCH__)
float2 q_reg2[DIMS_PER_THREAD / 2];
if constexpr (std::is_same_v<Tdata, half>) {
const int dim_base = lane * DIMS_PER_THREAD;
const half2 *q2 = reinterpret_cast<const half2 *>(q_ptr + dim_base);
#pragma unroll
for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) {
q_reg2[j] = __half22float2(q2[j]);
}
}
if constexpr (std::is_same_v<Tdata, __nv_bfloat16>) {
const int dim_base = lane * DIMS_PER_THREAD;
const __nv_bfloat162 *q2 = reinterpret_cast<const __nv_bfloat162 *>(q_ptr + dim_base);
#pragma unroll
for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) {
q_reg2[j] = __bfloat1622float2(q2[j]);
}
}
#endif
float m = -INFINITY;
float l = 0.0f;
// Iterate by pages to avoid per-token division/mod and redundant block_table loads.
int t_base = 0;
for (int logical_block = 0; t_base < allowed_k_len; ++logical_block, t_base += pbs) {
const int32_t phys = static_cast<int32_t>(block_table[logical_block]);
const Tdata *k_base = k_cache_ + static_cast<int64_t>(phys) * k_batch_stride + static_cast<int64_t>(kv_head_idx) * k_head_stride;
const Tdata *v_base = v_cache_ + static_cast<int64_t>(phys) * v_batch_stride + static_cast<int64_t>(kv_head_idx) * v_head_stride;
const int token_end = min(pbs, allowed_k_len - t_base);
for (int token_in_block = 0; token_in_block < token_end; ++token_in_block) {
const int t = t_base + token_in_block;
const Tdata *k_ptr = k_base + static_cast<int64_t>(token_in_block) * k_row_stride;
const Tdata *v_ptr = v_base + static_cast<int64_t>(token_in_block) * v_row_stride;
float qk = 0.0f;
#if defined(__CUDA_ARCH__)
if constexpr (std::is_same_v<Tdata, half>) {
const int dim_base = lane * DIMS_PER_THREAD;
const half2 *k2 = reinterpret_cast<const half2 *>(k_ptr + dim_base);
#pragma unroll
for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) {
const float2 qf = q_reg2[j];
const float2 kf = __half22float2(k2[j]);
qk += qf.x * kf.x + qf.y * kf.y;
}
} else if constexpr (std::is_same_v<Tdata, __nv_bfloat16>) {
const int dim_base = lane * DIMS_PER_THREAD;
const __nv_bfloat162 *k2 = reinterpret_cast<const __nv_bfloat162 *>(k_ptr + dim_base);
#pragma unroll
for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) {
const float2 qf = q_reg2[j];
const float2 kf = __bfloat1622float2(k2[j]);
qk += qf.x * kf.x + qf.y * kf.y;
}
} else
#endif
{
#pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
const int dim = lane * DIMS_PER_THREAD + i;
qk += q_reg[i] * static_cast<float>(k_ptr[dim]);
}
}
qk = op::paged_attention::cuda::warpReduceSum(qk);
float alpha = 1.0f;
float beta = 0.0f;
if (lane == 0) {
float score = qk * scale_log2;
if (alibi_slope != 0.0f) {
const int causal_limit = allowed_k_len - 1;
score += (alibi_slope * static_cast<float>(t - causal_limit)) * kLog2e;
}
const float m_new = fmaxf(m, score);
alpha = exp2f(m - m_new);
beta = exp2f(score - m_new);
l = l * alpha + beta;
m = m_new;
}
alpha = __shfl_sync(0xffffffff, alpha, 0);
beta = __shfl_sync(0xffffffff, beta, 0);
#if defined(__CUDA_ARCH__)
if constexpr (std::is_same_v<Tdata, half>) {
const int dim_base = lane * DIMS_PER_THREAD;
const half2 *v2 = reinterpret_cast<const half2 *>(v_ptr + dim_base);
#pragma unroll
for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) {
const float2 vf = __half22float2(v2[j]);
acc[j * 2 + 0] = acc[j * 2 + 0] * alpha + beta * vf.x;
acc[j * 2 + 1] = acc[j * 2 + 1] * alpha + beta * vf.y;
}
} else if constexpr (std::is_same_v<Tdata, __nv_bfloat16>) {
const int dim_base = lane * DIMS_PER_THREAD;
const __nv_bfloat162 *v2 = reinterpret_cast<const __nv_bfloat162 *>(v_ptr + dim_base);
#pragma unroll
for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) {
const float2 vf = __bfloat1622float2(v2[j]);
acc[j * 2 + 0] = acc[j * 2 + 0] * alpha + beta * vf.x;
acc[j * 2 + 1] = acc[j * 2 + 1] * alpha + beta * vf.y;
}
} else
#endif
{
#pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
const int dim = lane * DIMS_PER_THREAD + i;
const float v_val = static_cast<float>(v_ptr[dim]);
acc[i] = acc[i] * alpha + beta * v_val;
}
}
}
}
float inv_l = 0.0f;
if (lane == 0) {
inv_l = 1.0f / (l + 1e-6f);
}
inv_l = __shfl_sync(0xffffffff, inv_l, 0);
#pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
const int dim = lane * DIMS_PER_THREAD + i;
const float o = acc[i] * inv_l;
if constexpr (std::is_same_v<Tdata, half>) {
out_ptr[dim] = __float2half_rn(o);
} else if constexpr (std::is_same_v<Tdata, __nv_bfloat16>) {
out_ptr[dim] = __float2bfloat16_rn(o);
} else {
out_ptr[dim] = static_cast<Tdata>(o);
}
}
}
template <typename Tindex, typename Tdata, typename Tcompute, int HEAD_SIZE>
__global__ void PagedAttentionPrefillReferenceKernel(
Tdata *out_,
const Tdata *q_,
const Tdata *k_cache_,
const Tdata *v_cache_,
const Tindex *block_tables_,
const int64_t *total_kv_lens_,
const int64_t *cu_seqlens_q_,
const float *alibi_slopes_,
size_t num_heads,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
size_t num_seqs) {
const size_t global_token_idx = static_cast<size_t>(blockIdx.x);
const size_t head_idx = static_cast<size_t>(blockIdx.y);
const size_t dim_idx = static_cast<size_t>(threadIdx.x);
if (dim_idx >= HEAD_SIZE || head_idx >= num_heads) {
return;
}
const size_t seq_idx = find_seq_id(global_token_idx, cu_seqlens_q_, num_seqs);
const size_t q_token_idx = global_token_idx - static_cast<size_t>(cu_seqlens_q_[seq_idx]);
const size_t q_len = static_cast<size_t>(cu_seqlens_q_[seq_idx + 1] - cu_seqlens_q_[seq_idx]);
const size_t total_kv_len = static_cast<size_t>(total_kv_lens_[seq_idx]);
const size_t history_len = total_kv_len - q_len;
const size_t causal_limit = history_len + q_token_idx;
const size_t num_queries_per_kv = num_heads / num_kv_heads;
const size_t kv_head_idx = head_idx / num_queries_per_kv;
const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx];
const Tdata *q_vec = q_ + static_cast<int64_t>(global_token_idx) * q_stride + static_cast<int64_t>(head_idx) * q_head_stride;
Tdata *out_ptr = out_ + static_cast<int64_t>(global_token_idx) * o_stride + static_cast<int64_t>(head_idx) * o_head_stride;
const Tindex *block_table = block_tables_ + static_cast<int64_t>(seq_idx) * static_cast<int64_t>(block_table_batch_stride);
const size_t pbs = page_block_size;
Tcompute max_score = -INFINITY;
for (size_t t = 0; t <= causal_limit; ++t) {
const size_t page = t / pbs;
const size_t off = t - page * pbs;
const ptrdiff_t phys = static_cast<ptrdiff_t>(block_table[page]);
const Tdata *k_vec = k_cache_ + static_cast<int64_t>(phys) * k_batch_stride + static_cast<int64_t>(off) * k_row_stride + static_cast<int64_t>(kv_head_idx) * k_head_stride;
Tcompute score = 0;
for (size_t d = 0; d < HEAD_SIZE; ++d) {
score += static_cast<Tcompute>(q_vec[d]) * static_cast<Tcompute>(k_vec[d]);
}
score *= static_cast<Tcompute>(scale);
if (alibi_slope != 0.0f) {
score += static_cast<Tcompute>(alibi_slope * static_cast<float>(t - causal_limit));
}
if (score > max_score) {
max_score = score;
}
}
Tcompute sum_exp = 0;
for (size_t t = 0; t <= causal_limit; ++t) {
const size_t page = t / pbs;
const size_t off = t - page * pbs;
const ptrdiff_t phys = static_cast<ptrdiff_t>(block_table[page]);
const Tdata *k_vec = k_cache_ + static_cast<int64_t>(phys) * k_batch_stride + static_cast<int64_t>(off) * k_row_stride + static_cast<int64_t>(kv_head_idx) * k_head_stride;
Tcompute score = 0;
for (size_t d = 0; d < HEAD_SIZE; ++d) {
score += static_cast<Tcompute>(q_vec[d]) * static_cast<Tcompute>(k_vec[d]);
}
score *= static_cast<Tcompute>(scale);
if (alibi_slope != 0.0f) {
score += static_cast<Tcompute>(alibi_slope * static_cast<float>(t - causal_limit));
}
sum_exp += static_cast<Tcompute>(expf(static_cast<float>(score - max_score)));
}
const Tcompute inv_sum = static_cast<Tcompute>(1.0f) / (sum_exp + static_cast<Tcompute>(1e-6f));
Tcompute acc = 0;
for (size_t t = 0; t <= causal_limit; ++t) {
const size_t page = t / pbs;
const size_t off = t - page * pbs;
const ptrdiff_t phys = static_cast<ptrdiff_t>(block_table[page]);
const Tdata *k_vec = k_cache_ + static_cast<int64_t>(phys) * k_batch_stride + static_cast<int64_t>(off) * k_row_stride + static_cast<int64_t>(kv_head_idx) * k_head_stride;
Tcompute score = 0;
for (size_t d = 0; d < HEAD_SIZE; ++d) {
score += static_cast<Tcompute>(q_vec[d]) * static_cast<Tcompute>(k_vec[d]);
}
score *= static_cast<Tcompute>(scale);
if (alibi_slope != 0.0f) {
score += static_cast<Tcompute>(alibi_slope * static_cast<float>(t - causal_limit));
}
const Tcompute prob = static_cast<Tcompute>(expf(static_cast<float>(score - max_score))) * inv_sum;
const Tdata *v_vec = v_cache_ + static_cast<int64_t>(phys) * v_batch_stride + static_cast<int64_t>(off) * v_row_stride + static_cast<int64_t>(kv_head_idx) * v_head_stride;
acc += prob * static_cast<Tcompute>(v_vec[dim_idx]);
}
out_ptr[dim_idx] = static_cast<Tdata>(acc);
}
template <typename Tindex, typename Tdata, int HEAD_SIZE, int BLOCK_M, int BLOCK_N>
__device__ void PagedAttentionPrefillWarpCtaKernel(
Tdata *out_,
const Tdata *q_,
const Tdata *k_cache_,
const Tdata *v_cache_,
const Tindex *block_tables_,
const int64_t *total_kv_lens_,
const int64_t *cu_seqlens_q_,
const float *alibi_slopes_,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
static_assert(HEAD_SIZE == 64 || HEAD_SIZE == 128, "Only head_size 64/128 supported in v0.4.");
static_assert(BLOCK_M > 0 && BLOCK_M <= 16, "BLOCK_M must be small (warp-per-query design).");
static_assert(BLOCK_N == 64 || BLOCK_N == 128, "BLOCK_N must be 64/128 in v0.4.");
constexpr int kWarpSize = 32;
constexpr int DIMS_PER_THREAD = HEAD_SIZE / kWarpSize;
static_assert(HEAD_SIZE % kWarpSize == 0, "HEAD_SIZE must be divisible by 32.");
const int lane = threadIdx.x & (kWarpSize - 1);
const int warp_id = threadIdx.x / kWarpSize;
if (warp_id >= BLOCK_M) {
return;
}
const int head_idx = static_cast<int>(blockIdx.x);
const int seq_idx = static_cast<int>(blockIdx.y);
const int m_block = static_cast<int>(blockIdx.z);
const int64_t q_start = cu_seqlens_q_[seq_idx];
const int64_t q_end = cu_seqlens_q_[seq_idx + 1];
const int q_len = static_cast<int>(q_end - q_start);
if (q_len <= 0) {
return;
}
const int m_start = m_block * BLOCK_M;
const int q_token_local = m_start + warp_id;
// IMPORTANT: do not early-return for a subset of warps in this CTA because we use __syncthreads()
// later. Tail tiles are handled by masking inactive warps.
if (m_start >= q_len) {
return; // uniform across the CTA
}
const bool is_active = (q_token_local < q_len);
const int64_t kv_len_total_i64 = total_kv_lens_[seq_idx];
const int kv_len_total = static_cast<int>(kv_len_total_i64);
// history_len = total_kv_len - q_len (KV already includes current q tokens).
const int history_len = kv_len_total - q_len;
const int allowed_k_len = is_active ? (history_len + q_token_local + 1) : 0;
const int num_heads = gridDim.x;
const int num_queries_per_kv = num_heads / static_cast<int>(num_kv_heads);
const int kv_head_idx = head_idx / num_queries_per_kv;
const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx];
constexpr float kLog2e = 1.4426950408889634f;
const float scale_log2 = scale * kLog2e;
int64_t q_token = q_start;
if (is_active) {
q_token += static_cast<int64_t>(q_token_local);
}
const Tindex *block_table = block_tables_ + static_cast<int64_t>(seq_idx) * static_cast<int64_t>(block_table_batch_stride);
const Tdata *q_ptr = nullptr;
Tdata *out_ptr = nullptr;
if (is_active) {
q_ptr = q_ + q_token * q_stride + static_cast<int64_t>(head_idx) * q_head_stride;
out_ptr = out_ + q_token * o_stride + static_cast<int64_t>(head_idx) * o_head_stride;
}
float q_reg[DIMS_PER_THREAD];
float acc[DIMS_PER_THREAD];
#pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
const int dim = lane * DIMS_PER_THREAD + i;
q_reg[i] = is_active ? static_cast<float>(q_ptr[dim]) : 0.0f;
acc[i] = 0.0f;
}
#if defined(__CUDA_ARCH__)
float2 q_reg2[DIMS_PER_THREAD / 2];
#pragma unroll
for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) {
q_reg2[j] = make_float2(q_reg[j * 2 + 0], q_reg[j * 2 + 1]);
}
#endif
float m = -INFINITY;
float l = 0.0f;
// For this CTA, we only need to scan up to the max allowed k among active warps.
const int max_q_in_tile = min(m_start + BLOCK_M, q_len);
const int max_allowed_k_len = min(history_len + max_q_in_tile, kv_len_total);
__shared__ int32_t s_phys[BLOCK_N];
__shared__ int32_t s_off[BLOCK_N];
// Ensure shared-memory tiles are aligned for half2/bfloat162 vector loads.
__shared__ __align__(16) Tdata s_k[BLOCK_N * HEAD_SIZE];
__shared__ __align__(16) Tdata s_v[BLOCK_N * HEAD_SIZE];
const int pbs = static_cast<int>(page_block_size);
for (int k_base = 0; k_base < max_allowed_k_len; k_base += BLOCK_N) {
const int tile_n = min(BLOCK_N, max_allowed_k_len - k_base);
// Precompute page mapping once per token in the tile.
for (int t = threadIdx.x; t < tile_n; t += blockDim.x) {
const int kpos = k_base + t;
const int page = (pbs == 256) ? (kpos >> 8) : (kpos / pbs);
const int off = (pbs == 256) ? (kpos & 255) : (kpos - page * pbs);
const int32_t phys = static_cast<int32_t>(block_table[page]);
s_phys[t] = phys;
s_off[t] = off;
}
__syncthreads();
// Load K/V tile into shared memory (contiguous in head_dim).
const int tile_elems = tile_n * HEAD_SIZE;
for (int idx = threadIdx.x; idx < tile_elems; idx += blockDim.x) {
const int t = idx / HEAD_SIZE;
const int dim = idx - t * HEAD_SIZE;
const int32_t phys = s_phys[t];
const int32_t off = s_off[t];
const Tdata *k_base_ptr = k_cache_ + static_cast<int64_t>(phys) * k_batch_stride + static_cast<int64_t>(off) * k_row_stride + static_cast<int64_t>(kv_head_idx) * k_head_stride;
const Tdata *v_base_ptr = v_cache_ + static_cast<int64_t>(phys) * v_batch_stride + static_cast<int64_t>(off) * v_row_stride + static_cast<int64_t>(kv_head_idx) * v_head_stride;
s_k[t * HEAD_SIZE + dim] = k_base_ptr[dim];
s_v[t * HEAD_SIZE + dim] = v_base_ptr[dim];
}
__syncthreads();
// Each warp processes one query token and scans the K/V tile.
for (int t = 0; t < tile_n; ++t) {
const int kpos = k_base + t;
if (kpos >= allowed_k_len) {
break;
}
const Tdata *k_ptr = s_k + t * HEAD_SIZE;
const Tdata *v_ptr = s_v + t * HEAD_SIZE;
float qk = 0.0f;
#if defined(__CUDA_ARCH__)
if constexpr (std::is_same_v<Tdata, half>) {
const int dim_base = lane * DIMS_PER_THREAD;
const half2 *k2 = reinterpret_cast<const half2 *>(k_ptr + dim_base);
#pragma unroll
for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) {
const float2 qf = q_reg2[j];
const float2 kf = __half22float2(k2[j]);
qk += qf.x * kf.x + qf.y * kf.y;
}
} else if constexpr (std::is_same_v<Tdata, __nv_bfloat16>) {
const int dim_base = lane * DIMS_PER_THREAD;
const __nv_bfloat162 *k2 = reinterpret_cast<const __nv_bfloat162 *>(k_ptr + dim_base);
#pragma unroll
for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) {
const float2 qf = q_reg2[j];
const float2 kf = __bfloat1622float2(k2[j]);
qk += qf.x * kf.x + qf.y * kf.y;
}
} else
#endif
#pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
const int dim = lane * DIMS_PER_THREAD + i;
qk += q_reg[i] * static_cast<float>(k_ptr[dim]);
}
qk = op::paged_attention::cuda::warpReduceSum(qk);
float alpha = 1.0f;
float beta = 0.0f;
if (lane == 0) {
float score = qk * scale_log2;
if (alibi_slope != 0.0f) {
// Causal prefill: last position is (allowed_k_len - 1) for this query.
score += (alibi_slope * static_cast<float>(kpos - (allowed_k_len - 1))) * kLog2e;
}
const float m_new = fmaxf(m, score);
alpha = exp2f(m - m_new);
beta = exp2f(score - m_new);
l = l * alpha + beta;
m = m_new;
}
alpha = __shfl_sync(0xffffffff, alpha, 0);
beta = __shfl_sync(0xffffffff, beta, 0);
#if defined(__CUDA_ARCH__)
if constexpr (std::is_same_v<Tdata, half>) {
const int dim_base = lane * DIMS_PER_THREAD;
const half2 *v2 = reinterpret_cast<const half2 *>(v_ptr + dim_base);
#pragma unroll
for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) {
const float2 vf = __half22float2(v2[j]);
acc[j * 2 + 0] = acc[j * 2 + 0] * alpha + beta * vf.x;
acc[j * 2 + 1] = acc[j * 2 + 1] * alpha + beta * vf.y;
}
} else if constexpr (std::is_same_v<Tdata, __nv_bfloat16>) {
const int dim_base = lane * DIMS_PER_THREAD;
const __nv_bfloat162 *v2 = reinterpret_cast<const __nv_bfloat162 *>(v_ptr + dim_base);
#pragma unroll
for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) {
const float2 vf = __bfloat1622float2(v2[j]);
acc[j * 2 + 0] = acc[j * 2 + 0] * alpha + beta * vf.x;
acc[j * 2 + 1] = acc[j * 2 + 1] * alpha + beta * vf.y;
}
} else
#endif
{
#pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
const int dim = lane * DIMS_PER_THREAD + i;
const float v_val = static_cast<float>(v_ptr[dim]);
acc[i] = acc[i] * alpha + beta * v_val;
}
}
}
__syncthreads();
}
float inv_l = 0.0f;
if (lane == 0) {
inv_l = 1.0f / (l + 1e-6f);
}
inv_l = __shfl_sync(0xffffffff, inv_l, 0);
#pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
const int dim = lane * DIMS_PER_THREAD + i;
const float out_val = acc[i] * inv_l;
if (!is_active) {
continue;
}
if constexpr (std::is_same_v<Tdata, half>) {
out_ptr[dim] = __float2half_rn(out_val);
} else if constexpr (std::is_same_v<Tdata, __nv_bfloat16>) {
out_ptr[dim] = __float2bfloat16_rn(out_val);
} else {
out_ptr[dim] = static_cast<Tdata>(out_val);
}
}
}
// Pipelined CTA kernel (FA2-style): stage K/V loads with cp.async and overlap global->shared
// copies with compute.
//
// Design notes:
// - Keep shared memory <= 48KB for compatibility with multi-arch builds that include SM75.
// - Iterate by paged blocks (logical pages) so each tile stays within one physical block and
// avoids per-token (page, off) mapping arrays in shared memory.
// - One warp computes one query token (same as warpcta kernels). Warps with shorter causal
// limits simply mask the tail tokens but still participate in CTA-wide barriers.
template <typename Tindex, typename Tdata, int HEAD_SIZE, int BLOCK_M, int TOKENS_PER_TILE, int STAGES>
__device__ void PagedAttentionPrefillWarpCtaKernelPipelined(
Tdata *out_,
const Tdata *q_,
const Tdata *k_cache_,
const Tdata *v_cache_,
const Tindex *block_tables_,
const int64_t *total_kv_lens_,
const int64_t *cu_seqlens_q_,
const float *alibi_slopes_,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
static_assert(HEAD_SIZE == 64 || HEAD_SIZE == 128, "Only head_size 64/128 supported in v0.4.");
static_assert(BLOCK_M > 0 && BLOCK_M <= 16, "BLOCK_M must be <= 16.");
static_assert(TOKENS_PER_TILE == 32, "Pipelined CTA kernel currently assumes TOKENS_PER_TILE == 32.");
static_assert(STAGES >= 2 && STAGES <= 3, "STAGES must be 2 or 3.");
static_assert(sizeof(Tdata) == 2, "Pipelined CTA kernel supports only fp16/bf16.");
constexpr int kWarpSize = 32;
static_assert(HEAD_SIZE % kWarpSize == 0, "HEAD_SIZE must be divisible by 32.");
constexpr int DIMS_PER_THREAD = HEAD_SIZE / kWarpSize;
const int lane = threadIdx.x & (kWarpSize - 1);
const int warp_id = threadIdx.x / kWarpSize;
if (warp_id >= BLOCK_M) {
return;
}
const int head_idx = static_cast<int>(blockIdx.x);
const int seq_idx = static_cast<int>(blockIdx.y);
const int m_block = static_cast<int>(blockIdx.z);
const int64_t q_start = cu_seqlens_q_[seq_idx];
const int64_t q_end = cu_seqlens_q_[seq_idx + 1];
const int q_len = static_cast<int>(q_end - q_start);
if (q_len <= 0) {
return;
}
const int m_start = m_block * BLOCK_M;
const int q_token_local = m_start + warp_id;
// Uniform return for empty tail CTAs (avoid deadlock with __syncthreads).
if (m_start >= q_len) {
return;
}
const bool is_active = (q_token_local < q_len);
const int kv_len_total = static_cast<int>(total_kv_lens_[seq_idx]);
const int history_len = kv_len_total - q_len;
const int allowed_k_len = is_active ? (history_len + q_token_local + 1) : 0;
const int num_heads = gridDim.x;
const int num_queries_per_kv = num_heads / static_cast<int>(num_kv_heads);
const int kv_head_idx = head_idx / num_queries_per_kv;
const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx];
constexpr float kLog2e = 1.4426950408889634f;
const float scale_log2 = scale * kLog2e;
int64_t q_token = q_start;
if (is_active) {
q_token += static_cast<int64_t>(q_token_local);
}
const Tindex *block_table = block_tables_ + static_cast<int64_t>(seq_idx) * static_cast<int64_t>(block_table_batch_stride);
const Tdata *q_ptr = nullptr;
Tdata *out_ptr = nullptr;
if (is_active) {
q_ptr = q_ + q_token * q_stride + static_cast<int64_t>(head_idx) * q_head_stride;
out_ptr = out_ + q_token * o_stride + static_cast<int64_t>(head_idx) * o_head_stride;
}
float q_reg[DIMS_PER_THREAD];
float acc[DIMS_PER_THREAD];
#pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
const int dim = lane * DIMS_PER_THREAD + i;
q_reg[i] = is_active ? static_cast<float>(q_ptr[dim]) : 0.0f;
acc[i] = 0.0f;
}
#if defined(__CUDA_ARCH__)
float2 q_reg2[DIMS_PER_THREAD / 2];
#pragma unroll
for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) {
q_reg2[j] = make_float2(q_reg[j * 2 + 0], q_reg[j * 2 + 1]);
}
#endif
float m = -INFINITY;
float l = 0.0f;
// For this CTA, scan KV up to the max causal limit among active warps.
const int max_q_in_tile = min(m_start + BLOCK_M, q_len);
const int max_allowed_k_len = min(history_len + max_q_in_tile, kv_len_total);
if (max_allowed_k_len <= 0) {
// Nothing to attend to (should be rare). Produce zeros.
if (is_active) {
#pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
const int dim = lane * DIMS_PER_THREAD + i;
out_ptr[dim] = Tdata{};
}
}
return;
}
// cp.async uses 16B chunks; for fp16/bf16 that's 8 elements.
constexpr int CHUNK_ELEMS = 8;
constexpr int CHUNKS = HEAD_SIZE / CHUNK_ELEMS;
constexpr int LOADS_PER_TILE = CHUNKS * TOKENS_PER_TILE;
// Multi-stage pipeline buffers.
__shared__ __align__(16) Tdata sh_k[STAGES][TOKENS_PER_TILE][HEAD_SIZE];
__shared__ __align__(16) Tdata sh_v[STAGES][TOKENS_PER_TILE][HEAD_SIZE];
// Per-warp scratch for tile-wise softmax (scores over TOKENS_PER_TILE).
// We keep scores in shared so each lane can load its token score (lane -> token index),
// then weights are broadcast via warp shuffles to avoid extra shared-memory traffic.
__shared__ float sh_scores[BLOCK_M][TOKENS_PER_TILE];
// Store Q in shared (per warp). This enables more tile-level parallelism in score
// computation without expensive cross-lane shuffles of Q registers.
__shared__ __align__(16) Tdata sh_q[BLOCK_M][HEAD_SIZE];
const int pbs = static_cast<int>(page_block_size);
const int tid = threadIdx.x;
// Populate per-warp Q shared tile once.
#pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
const int dim = lane * DIMS_PER_THREAD + i;
sh_q[warp_id][dim] = is_active ? q_ptr[dim] : Tdata{};
}
__syncwarp();
int t_base = 0;
for (int logical_block = 0; t_base < max_allowed_k_len; ++logical_block, t_base += pbs) {
const int physical_block = static_cast<int>(block_table[logical_block]);
const Tdata *k_base = k_cache_ + static_cast<int64_t>(physical_block) * k_batch_stride + static_cast<int64_t>(kv_head_idx) * k_head_stride;
const Tdata *v_base = v_cache_ + static_cast<int64_t>(physical_block) * v_batch_stride + static_cast<int64_t>(kv_head_idx) * v_head_stride;
const int token_end = min(pbs, max_allowed_k_len - t_base);
const int num_tiles = (token_end + TOKENS_PER_TILE - 1) / TOKENS_PER_TILE;
if (num_tiles <= 0) {
continue;
}
int pending_groups = 0;
const int preload = min(STAGES, num_tiles);
for (int ti = 0; ti < preload; ++ti) {
const int token_in_block = ti * TOKENS_PER_TILE;
const int tile_n = min(TOKENS_PER_TILE, token_end - token_in_block);
for (int li = tid; li < LOADS_PER_TILE; li += blockDim.x) {
const int tok = li / CHUNKS;
const int chunk = li - tok * CHUNKS;
const int off = chunk * CHUNK_ELEMS;
if (tok < tile_n) {
const Tdata *k_src = k_base + static_cast<int64_t>(token_in_block + tok) * k_row_stride + off;
const Tdata *v_src = v_base + static_cast<int64_t>(token_in_block + tok) * v_row_stride + off;
op::paged_attention::cuda::cpAsyncCaSharedGlobal16(&sh_k[ti][tok][off], k_src);
op::paged_attention::cuda::cpAsyncCaSharedGlobal16(&sh_v[ti][tok][off], v_src);
} else {
reinterpret_cast<uint4 *>(&sh_k[ti][tok][off])[0] = make_uint4(0, 0, 0, 0);
reinterpret_cast<uint4 *>(&sh_v[ti][tok][off])[0] = make_uint4(0, 0, 0, 0);
}
}
op::paged_attention::cuda::cpAsyncCommit();
++pending_groups;
}
int desired_pending = pending_groups - 1;
if (desired_pending < 0) {
desired_pending = 0;
}
if (desired_pending > (STAGES - 1)) {
desired_pending = (STAGES - 1);
}
op::paged_attention::cuda::cpAsyncWaitGroupRt(desired_pending);
pending_groups = desired_pending;
__syncthreads();
for (int tile_idx = 0; tile_idx < num_tiles; ++tile_idx) {
const int buf = tile_idx % STAGES;
const int token_in_block = tile_idx * TOKENS_PER_TILE;
const int tile_n = min(TOKENS_PER_TILE, token_end - token_in_block);
const int global_k_base = t_base + token_in_block;
// Tile-wise online softmax (more FA2-like than per-token update):
// 1) Compute scores for this tile (masked to each warp's causal limit).
// 2) Compute tile max + sumexp.
// 3) Accumulate weighted V for the tile.
// 4) Merge into running (m, l, acc) in a numerically stable way.
//
// NOTE: this does not yet implement MMA / full tile-level GEMM; it mainly reduces
// the serial (lane0) online-softmax update frequency from per-token to per-tile.
float alpha = 1.0f;
float beta = 0.0f;
float tile_sumexp = 0.0f;
float tile_m = -INFINITY;
if (allowed_k_len > 0) {
// 1) scores
// Increase tile-level parallelism vs the previous per-token loop:
// split the warp into 4 groups of 8 lanes; each group computes one token score in parallel.
constexpr int LANES_PER_GROUP = 8;
constexpr int GROUPS_PER_WARP = 4;
constexpr int DIMS_PER_GROUP_LANE = HEAD_SIZE / LANES_PER_GROUP;
static_assert(HEAD_SIZE % LANES_PER_GROUP == 0, "HEAD_SIZE must be divisible by 8.");
const int group_id = lane / LANES_PER_GROUP; // [0..3]
const int lane_g = lane & (LANES_PER_GROUP - 1); // [0..7]
const unsigned int group_mask = 0xFFu << (group_id * LANES_PER_GROUP);
for (int j_base = 0; j_base < TOKENS_PER_TILE; j_base += GROUPS_PER_WARP) {
const int j = j_base + group_id; // token index in [0..31]
const int kpos = global_k_base + j;
const bool token_in_tile = (j < tile_n);
const bool token_unmasked = token_in_tile && (kpos < allowed_k_len);
float qk_part = 0.0f;
if (token_unmasked) {
const Tdata *k_ptr = &sh_k[buf][j][0];
const int dim_base = lane_g * DIMS_PER_GROUP_LANE;
#if defined(__CUDA_ARCH__)
if constexpr (std::is_same_v<Tdata, half>) {
const half2 *q2 = reinterpret_cast<const half2 *>(&sh_q[warp_id][dim_base]);
const half2 *k2 = reinterpret_cast<const half2 *>(k_ptr + dim_base);
#pragma unroll
for (int t = 0; t < DIMS_PER_GROUP_LANE / 2; ++t) {
const float2 qf = __half22float2(q2[t]);
const float2 kf = __half22float2(k2[t]);
qk_part += qf.x * kf.x + qf.y * kf.y;
}
} else if constexpr (std::is_same_v<Tdata, __nv_bfloat16>) {
const __nv_bfloat162 *q2 = reinterpret_cast<const __nv_bfloat162 *>(&sh_q[warp_id][dim_base]);
const __nv_bfloat162 *k2 = reinterpret_cast<const __nv_bfloat162 *>(k_ptr + dim_base);
#pragma unroll
for (int t = 0; t < DIMS_PER_GROUP_LANE / 2; ++t) {
const float2 qf = __bfloat1622float2(q2[t]);
const float2 kf = __bfloat1622float2(k2[t]);
qk_part += qf.x * kf.x + qf.y * kf.y;
}
} else
#endif
{
#pragma unroll
for (int t = 0; t < DIMS_PER_GROUP_LANE; ++t) {
qk_part += static_cast<float>(sh_q[warp_id][dim_base + t]) * static_cast<float>(k_ptr[dim_base + t]);
}
}
}
// Reduce within 8-lane group.
for (int offset = LANES_PER_GROUP / 2; offset > 0; offset >>= 1) {
qk_part += __shfl_down_sync(group_mask, qk_part, offset, LANES_PER_GROUP);
}
if (lane_g == 0) {
float score = -INFINITY;
if (token_unmasked) {
score = qk_part * scale_log2;
if (alibi_slope != 0.0f) {
const int causal_limit = allowed_k_len - 1;
score += (alibi_slope * static_cast<float>(kpos - causal_limit)) * kLog2e;
}
}
sh_scores[warp_id][j] = score;
}
}
__syncwarp();
// 2) tile max + sumexp (lane t corresponds to token t within the tile)
const float score_lane = (lane < tile_n) ? sh_scores[warp_id][lane] : -INFINITY;
float tile_m_tmp = op::paged_attention::cuda::warpReduceMax(score_lane);
tile_m_tmp = __shfl_sync(0xffffffff, tile_m_tmp, 0);
tile_m = tile_m_tmp;
float w_lane = 0.0f;
if (lane < tile_n && tile_m != -INFINITY) {
w_lane = exp2f(score_lane - tile_m);
}
float sumexp_tmp = op::paged_attention::cuda::warpReduceSum(w_lane);
sumexp_tmp = __shfl_sync(0xffffffff, sumexp_tmp, 0);
tile_sumexp = sumexp_tmp;
// 3) weighted V for this tile (per lane owns HEAD_SIZE/32 dims)
float acc_tile[DIMS_PER_THREAD];
#pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
acc_tile[i] = 0.0f;
}
if (tile_sumexp > 0.0f) {
for (int j = 0; j < tile_n; ++j) {
// Broadcast weight for token j from lane j.
const float wj = __shfl_sync(0xffffffff, w_lane, j);
const Tdata *v_ptr = &sh_v[buf][j][0];
#if defined(__CUDA_ARCH__)
if constexpr (std::is_same_v<Tdata, half>) {
const int dim_base = lane * DIMS_PER_THREAD;
const half2 *v2 = reinterpret_cast<const half2 *>(v_ptr + dim_base);
#pragma unroll
for (int jj = 0; jj < DIMS_PER_THREAD / 2; ++jj) {
const float2 vf = __half22float2(v2[jj]);
acc_tile[jj * 2 + 0] += wj * vf.x;
acc_tile[jj * 2 + 1] += wj * vf.y;
}
} else if constexpr (std::is_same_v<Tdata, __nv_bfloat16>) {
const int dim_base = lane * DIMS_PER_THREAD;
const __nv_bfloat162 *v2 = reinterpret_cast<const __nv_bfloat162 *>(v_ptr + dim_base);
#pragma unroll
for (int jj = 0; jj < DIMS_PER_THREAD / 2; ++jj) {
const float2 vf = __bfloat1622float2(v2[jj]);
acc_tile[jj * 2 + 0] += wj * vf.x;
acc_tile[jj * 2 + 1] += wj * vf.y;
}
} else
#endif
{
#pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
const int dim = lane * DIMS_PER_THREAD + i;
acc_tile[i] += wj * static_cast<float>(v_ptr[dim]);
}
}
}
}
// 4) merge tile into running (m, l, acc)
if (lane == 0) {
if (tile_sumexp > 0.0f && tile_m != -INFINITY) {
const float m_new = fmaxf(m, tile_m);
alpha = exp2f(m - m_new);
beta = exp2f(tile_m - m_new);
l = l * alpha + tile_sumexp * beta;
m = m_new;
} else {
alpha = 1.0f;
beta = 0.0f;
}
}
alpha = __shfl_sync(0xffffffff, alpha, 0);
beta = __shfl_sync(0xffffffff, beta, 0);
#pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
acc[i] = acc[i] * alpha + beta * acc_tile[i];
}
}
// IMPORTANT: warps in this CTA can have different allowed_k_len (due to causal mask + history),
// so they may finish the token loop at different times. We must not start prefetching into
// the circular shared-memory buffer until all warps finish consuming the current tile.
__syncthreads();
// Prefetch the tile that will reuse this buffer (STAGES steps ahead).
const int prefetch_tile = tile_idx + STAGES;
if (prefetch_tile < num_tiles) {
const int token_prefetch = prefetch_tile * TOKENS_PER_TILE;
const int prefetch_n = min(TOKENS_PER_TILE, token_end - token_prefetch);
for (int li = tid; li < LOADS_PER_TILE; li += blockDim.x) {
const int tok = li / CHUNKS;
const int chunk = li - tok * CHUNKS;
const int off = chunk * CHUNK_ELEMS;
if (tok < prefetch_n) {
const Tdata *k_src = k_base + static_cast<int64_t>(token_prefetch + tok) * k_row_stride + off;
const Tdata *v_src = v_base + static_cast<int64_t>(token_prefetch + tok) * v_row_stride + off;
op::paged_attention::cuda::cpAsyncCaSharedGlobal16(&sh_k[buf][tok][off], k_src);
op::paged_attention::cuda::cpAsyncCaSharedGlobal16(&sh_v[buf][tok][off], v_src);
} else {
reinterpret_cast<uint4 *>(&sh_k[buf][tok][off])[0] = make_uint4(0, 0, 0, 0);
reinterpret_cast<uint4 *>(&sh_v[buf][tok][off])[0] = make_uint4(0, 0, 0, 0);
}
}
op::paged_attention::cuda::cpAsyncCommit();
++pending_groups;
}
if (tile_idx + 1 < num_tiles) {
int desired_pending2 = pending_groups - 1;
if (desired_pending2 < 0) {
desired_pending2 = 0;
}
if (desired_pending2 > (STAGES - 1)) {
desired_pending2 = (STAGES - 1);
}
op::paged_attention::cuda::cpAsyncWaitGroupRt(desired_pending2);
pending_groups = desired_pending2;
__syncthreads();
}
}
op::paged_attention::cuda::cpAsyncWaitAll();
__syncthreads();
}
float inv_l = 0.0f;
if (lane == 0) {
inv_l = 1.0f / (l + 1e-6f);
}
inv_l = __shfl_sync(0xffffffff, inv_l, 0);
#pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
const int dim = lane * DIMS_PER_THREAD + i;
const float out_val = acc[i] * inv_l;
if (!is_active) {
continue;
}
if constexpr (std::is_same_v<Tdata, half>) {
out_ptr[dim] = __float2half_rn(out_val);
} else if constexpr (std::is_same_v<Tdata, __nv_bfloat16>) {
out_ptr[dim] = __float2bfloat16_rn(out_val);
} else {
out_ptr[dim] = static_cast<Tdata>(out_val);
}
}
}
// Split-KV prefill (FA2-style): each split scans a shard of KV and writes partial (m, l, acc)
// to workspace. A separate combine kernel merges splits into the final output.
//
// Notes:
// - Implemented for the pipelined CTA kernel family (warpcta8pipe). We split by logical paged blocks.
// - Each warp still applies its own causal limit (allowed_k_len) so correctness is preserved.
template <typename Tindex, typename Tdata, int HEAD_SIZE, int BLOCK_M, int TOKENS_PER_TILE, int STAGES>
__device__ void PagedAttentionPrefillWarpCtaKernelPipelinedSplitKv(
float *partial_acc, // [num_splits, total_q_tokens, num_heads, head_size]
float *partial_m, // [num_splits, total_q_tokens, num_heads]
float *partial_l, // [num_splits, total_q_tokens, num_heads]
int split_idx,
int num_splits,
int m_block,
size_t total_q_tokens,
const Tdata *q_,
const Tdata *k_cache_,
const Tdata *v_cache_,
const Tindex *block_tables_,
const int64_t *total_kv_lens_,
const int64_t *cu_seqlens_q_,
const float *alibi_slopes_,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride) {
(void)max_num_blocks_per_seq;
static_assert(HEAD_SIZE == 64 || HEAD_SIZE == 128, "Only head_size 64/128 supported in v0.4.");
static_assert(BLOCK_M > 0 && BLOCK_M <= 16, "BLOCK_M must be <= 16.");
static_assert(TOKENS_PER_TILE == 32, "Split-KV prefill assumes TOKENS_PER_TILE == 32.");
static_assert(STAGES >= 2 && STAGES <= 3, "STAGES must be 2 or 3.");
static_assert(sizeof(Tdata) == 2, "Split-KV prefill supports only fp16/bf16.");
constexpr int kWarpSize = 32;
static_assert(HEAD_SIZE % kWarpSize == 0, "HEAD_SIZE must be divisible by 32.");
constexpr int DIMS_PER_THREAD = HEAD_SIZE / kWarpSize;
const int lane = threadIdx.x & (kWarpSize - 1);
const int warp_id = threadIdx.x / kWarpSize;
if (warp_id >= BLOCK_M) {
return;
}
const int head_idx = static_cast<int>(blockIdx.x);
const int seq_idx = static_cast<int>(blockIdx.y);
const int64_t q_start = cu_seqlens_q_[seq_idx];
const int64_t q_end = cu_seqlens_q_[seq_idx + 1];
const int q_len = static_cast<int>(q_end - q_start);
if (q_len <= 0) {
return;
}
const int m_start = m_block * BLOCK_M;
const int q_token_local = m_start + warp_id;
if (m_start >= q_len) {
return; // uniform
}
const bool is_active = (q_token_local < q_len);
const int kv_len_total = static_cast<int>(total_kv_lens_[seq_idx]);
const int history_len = kv_len_total - q_len;
const int allowed_k_len = is_active ? (history_len + q_token_local + 1) : 0;
const int num_heads = gridDim.x;
const int num_queries_per_kv = num_heads / static_cast<int>(num_kv_heads);
const int kv_head_idx = head_idx / num_queries_per_kv;
const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx];
constexpr float kLog2e = 1.4426950408889634f;
const float scale_log2 = scale * kLog2e;
int64_t q_token = q_start;
if (is_active) {
q_token += static_cast<int64_t>(q_token_local);
}
const size_t n = total_q_tokens * static_cast<size_t>(num_heads);
size_t base = 0;
if (is_active) {
base = static_cast<size_t>(q_token) * static_cast<size_t>(num_heads) + static_cast<size_t>(head_idx);
}
const Tindex *block_table = block_tables_ + static_cast<int64_t>(seq_idx) * static_cast<int64_t>(block_table_batch_stride);
const Tdata *q_ptr = nullptr;
if (is_active) {
q_ptr = q_ + q_token * q_stride + static_cast<int64_t>(head_idx) * q_head_stride;
}
float q_reg[DIMS_PER_THREAD];
float acc[DIMS_PER_THREAD];
#pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
const int dim = lane * DIMS_PER_THREAD + i;
q_reg[i] = is_active ? static_cast<float>(q_ptr[dim]) : 0.0f;
acc[i] = 0.0f;
}
float m = -INFINITY;
float l = 0.0f;
const int max_q_in_tile = min(m_start + BLOCK_M, q_len);
const int max_allowed_k_len = min(history_len + max_q_in_tile, kv_len_total);
if (max_allowed_k_len <= 0) {
if (is_active) {
const size_t idx = static_cast<size_t>(split_idx) * n + base;
if (lane == 0) {
partial_m[idx] = -INFINITY;
partial_l[idx] = 0.0f;
}
#pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
const int dim = lane * DIMS_PER_THREAD + i;
partial_acc[idx * HEAD_SIZE + dim] = 0.0f;
}
}
return;
}
const int pbs = static_cast<int>(page_block_size);
const int num_blocks_total = (max_allowed_k_len + pbs - 1) / pbs;
const int blocks_per_split = (num_blocks_total + num_splits - 1) / num_splits;
const int start_block = split_idx * blocks_per_split;
const int end_block = min(num_blocks_total, start_block + blocks_per_split);
if (start_block >= end_block) {
if (is_active) {
const size_t idx = static_cast<size_t>(split_idx) * n + base;
if (lane == 0) {
partial_m[idx] = -INFINITY;
partial_l[idx] = 0.0f;
}
#pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
const int dim = lane * DIMS_PER_THREAD + i;
partial_acc[idx * HEAD_SIZE + dim] = 0.0f;
}
}
return;
}
const int max_allowed_k_len_split = min(max_allowed_k_len, end_block * pbs);
constexpr int CHUNK_ELEMS = 8;
constexpr int CHUNKS = HEAD_SIZE / CHUNK_ELEMS;
constexpr int LOADS_PER_TILE = CHUNKS * TOKENS_PER_TILE;
__shared__ __align__(16) Tdata sh_k[STAGES][TOKENS_PER_TILE][HEAD_SIZE];
__shared__ __align__(16) Tdata sh_v[STAGES][TOKENS_PER_TILE][HEAD_SIZE];
__shared__ float sh_scores[BLOCK_M][TOKENS_PER_TILE];
const int tid = threadIdx.x;
int t_base = start_block * pbs;
for (int logical_block = start_block; t_base < max_allowed_k_len_split; ++logical_block, t_base += pbs) {
const int physical_block = static_cast<int>(block_table[logical_block]);
const Tdata *k_base = k_cache_ + static_cast<int64_t>(physical_block) * k_batch_stride + static_cast<int64_t>(kv_head_idx) * k_head_stride;
const Tdata *v_base = v_cache_ + static_cast<int64_t>(physical_block) * v_batch_stride + static_cast<int64_t>(kv_head_idx) * v_head_stride;
const int token_end = min(pbs, max_allowed_k_len_split - t_base);
const int num_tiles = (token_end + TOKENS_PER_TILE - 1) / TOKENS_PER_TILE;
if (num_tiles <= 0) {
continue;
}
int pending_groups = 0;
const int preload = min(STAGES, num_tiles);
for (int ti = 0; ti < preload; ++ti) {
const int token_in_block = ti * TOKENS_PER_TILE;
const int tile_n = min(TOKENS_PER_TILE, token_end - token_in_block);
for (int li = tid; li < LOADS_PER_TILE; li += blockDim.x) {
const int tok = li / CHUNKS;
const int chunk = li - tok * CHUNKS;
const int off = chunk * CHUNK_ELEMS;
if (tok < tile_n) {
const Tdata *k_src = k_base + static_cast<int64_t>(token_in_block + tok) * k_row_stride + off;
const Tdata *v_src = v_base + static_cast<int64_t>(token_in_block + tok) * v_row_stride + off;
op::paged_attention::cuda::cpAsyncCaSharedGlobal16(&sh_k[ti][tok][off], k_src);
op::paged_attention::cuda::cpAsyncCaSharedGlobal16(&sh_v[ti][tok][off], v_src);
} else {
reinterpret_cast<uint4 *>(&sh_k[ti][tok][off])[0] = make_uint4(0, 0, 0, 0);
reinterpret_cast<uint4 *>(&sh_v[ti][tok][off])[0] = make_uint4(0, 0, 0, 0);
}
}
op::paged_attention::cuda::cpAsyncCommit();
++pending_groups;
}
int desired_pending = pending_groups - 1;
if (desired_pending < 0) {
desired_pending = 0;
}
if (desired_pending > (STAGES - 1)) {
desired_pending = (STAGES - 1);
}
op::paged_attention::cuda::cpAsyncWaitGroupRt(desired_pending);
pending_groups = desired_pending;
__syncthreads();
for (int tile_idx = 0; tile_idx < num_tiles; ++tile_idx) {
const int buf = tile_idx % STAGES;
const int token_in_block = tile_idx * TOKENS_PER_TILE;
const int tile_n = min(TOKENS_PER_TILE, token_end - token_in_block);
const int global_k_base = t_base + token_in_block;
float alpha = 1.0f;
float beta = 0.0f;
float tile_sumexp = 0.0f;
float tile_m = -INFINITY;
float w_lane = 0.0f;
if (allowed_k_len > 0) {
// 1) scores
for (int j = 0; j < tile_n; ++j) {
const int kpos = global_k_base + j;
const bool token_unmasked = (kpos < allowed_k_len);
float qk = 0.0f;
if (token_unmasked) {
const Tdata *k_ptr = &sh_k[buf][j][0];
#if defined(__CUDA_ARCH__)
if constexpr (std::is_same_v<Tdata, half>) {
const int dim_base = lane * DIMS_PER_THREAD;
const half2 *q2 = reinterpret_cast<const half2 *>(q_ptr + dim_base);
const half2 *k2 = reinterpret_cast<const half2 *>(k_ptr + dim_base);
#pragma unroll
for (int ii = 0; ii < DIMS_PER_THREAD / 2; ++ii) {
const float2 qf = __half22float2(q2[ii]);
const float2 kf = __half22float2(k2[ii]);
qk = fmaf(qf.x, kf.x, qk);
qk = fmaf(qf.y, kf.y, qk);
}
} else if constexpr (std::is_same_v<Tdata, __nv_bfloat16>) {
const int dim_base = lane * DIMS_PER_THREAD;
const __nv_bfloat162 *q2 = reinterpret_cast<const __nv_bfloat162 *>(q_ptr + dim_base);
const __nv_bfloat162 *k2 = reinterpret_cast<const __nv_bfloat162 *>(k_ptr + dim_base);
#pragma unroll
for (int ii = 0; ii < DIMS_PER_THREAD / 2; ++ii) {
const float2 qf = __bfloat1622float2(q2[ii]);
const float2 kf = __bfloat1622float2(k2[ii]);
qk = fmaf(qf.x, kf.x, qk);
qk = fmaf(qf.y, kf.y, qk);
}
} else
#endif
{
#pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
const int dim = lane * DIMS_PER_THREAD + i;
qk = fmaf(q_reg[i], static_cast<float>(k_ptr[dim]), qk);
}
}
}
qk = op::paged_attention::cuda::warpReduceSum(qk);
if (lane == 0) {
float score = token_unmasked ? (qk * scale_log2) : -INFINITY;
if (token_unmasked && alibi_slope != 0.0f) {
const int causal_limit = allowed_k_len - 1;
score += (alibi_slope * static_cast<float>(kpos - causal_limit)) * kLog2e;
}
sh_scores[warp_id][j] = score;
}
}
__syncwarp();
// 2) tile max / sumexp
float max_tmp = -INFINITY;
if (lane < tile_n) {
max_tmp = sh_scores[warp_id][lane];
}
max_tmp = op::paged_attention::cuda::warpReduceMax(max_tmp);
max_tmp = __shfl_sync(0xffffffff, max_tmp, 0);
tile_m = max_tmp;
if (lane < tile_n) {
const float s = sh_scores[warp_id][lane];
w_lane = (s == -INFINITY) ? 0.0f : exp2f(s - tile_m);
} else {
w_lane = 0.0f;
}
float sumexp_tmp = op::paged_attention::cuda::warpReduceSum(w_lane);
sumexp_tmp = __shfl_sync(0xffffffff, sumexp_tmp, 0);
tile_sumexp = sumexp_tmp;
// 3) weighted V for this tile
float acc_tile[DIMS_PER_THREAD];
#pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
acc_tile[i] = 0.0f;
}
if (tile_sumexp > 0.0f) {
for (int j = 0; j < tile_n; ++j) {
const float wj = __shfl_sync(0xffffffff, w_lane, j);
const Tdata *v_ptr = &sh_v[buf][j][0];
#if defined(__CUDA_ARCH__)
if constexpr (std::is_same_v<Tdata, half>) {
const int dim_base = lane * DIMS_PER_THREAD;
const half2 *v2 = reinterpret_cast<const half2 *>(v_ptr + dim_base);
#pragma unroll
for (int jj = 0; jj < DIMS_PER_THREAD / 2; ++jj) {
const float2 vf = __half22float2(v2[jj]);
acc_tile[jj * 2 + 0] += wj * vf.x;
acc_tile[jj * 2 + 1] += wj * vf.y;
}
} else if constexpr (std::is_same_v<Tdata, __nv_bfloat16>) {
const int dim_base = lane * DIMS_PER_THREAD;
const __nv_bfloat162 *v2 = reinterpret_cast<const __nv_bfloat162 *>(v_ptr + dim_base);
#pragma unroll
for (int jj = 0; jj < DIMS_PER_THREAD / 2; ++jj) {
const float2 vf = __bfloat1622float2(v2[jj]);
acc_tile[jj * 2 + 0] += wj * vf.x;
acc_tile[jj * 2 + 1] += wj * vf.y;
}
} else
#endif
{
#pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
const int dim = lane * DIMS_PER_THREAD + i;
acc_tile[i] += wj * static_cast<float>(v_ptr[dim]);
}
}
}
}
// 4) merge tile into running (m, l, acc)
if (lane == 0) {
if (tile_sumexp > 0.0f && tile_m != -INFINITY) {
const float m_new = fmaxf(m, tile_m);
alpha = exp2f(m - m_new);
beta = exp2f(tile_m - m_new);
l = l * alpha + tile_sumexp * beta;
m = m_new;
} else {
alpha = 1.0f;
beta = 0.0f;
}
}
alpha = __shfl_sync(0xffffffff, alpha, 0);
beta = __shfl_sync(0xffffffff, beta, 0);
#pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
acc[i] = acc[i] * alpha + beta * acc_tile[i];
}
}
__syncthreads();
const int prefetch_tile = tile_idx + STAGES;
if (prefetch_tile < num_tiles) {
const int token_prefetch = prefetch_tile * TOKENS_PER_TILE;
const int prefetch_n = min(TOKENS_PER_TILE, token_end - token_prefetch);
for (int li = tid; li < LOADS_PER_TILE; li += blockDim.x) {
const int tok = li / CHUNKS;
const int chunk = li - tok * CHUNKS;
const int off = chunk * CHUNK_ELEMS;
if (tok < prefetch_n) {
const Tdata *k_src = k_base + static_cast<int64_t>(token_prefetch + tok) * k_row_stride + off;
const Tdata *v_src = v_base + static_cast<int64_t>(token_prefetch + tok) * v_row_stride + off;
op::paged_attention::cuda::cpAsyncCaSharedGlobal16(&sh_k[buf][tok][off], k_src);
op::paged_attention::cuda::cpAsyncCaSharedGlobal16(&sh_v[buf][tok][off], v_src);
} else {
reinterpret_cast<uint4 *>(&sh_k[buf][tok][off])[0] = make_uint4(0, 0, 0, 0);
reinterpret_cast<uint4 *>(&sh_v[buf][tok][off])[0] = make_uint4(0, 0, 0, 0);
}
}
op::paged_attention::cuda::cpAsyncCommit();
++pending_groups;
}
if (tile_idx + 1 < num_tiles) {
int desired_pending2 = pending_groups - 1;
if (desired_pending2 < 0) {
desired_pending2 = 0;
}
if (desired_pending2 > (STAGES - 1)) {
desired_pending2 = (STAGES - 1);
}
op::paged_attention::cuda::cpAsyncWaitGroupRt(desired_pending2);
pending_groups = desired_pending2;
__syncthreads();
}
}
op::paged_attention::cuda::cpAsyncWaitAll();
__syncthreads();
}
if (is_active) {
const size_t idx = static_cast<size_t>(split_idx) * n + base;
if (lane == 0) {
partial_m[idx] = m;
partial_l[idx] = l;
}
#pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
const int dim = lane * DIMS_PER_THREAD + i;
partial_acc[idx * HEAD_SIZE + dim] = acc[i];
}
}
}
template <typename Tdata, int HEAD_SIZE>
__device__ void PagedAttentionPrefillSplitKvCombineWarpKernel(
Tdata *out_,
const float *partial_acc, // [num_splits, total_q_tokens, num_heads, head_size]
const float *partial_m, // [num_splits, total_q_tokens, num_heads]
const float *partial_l, // [num_splits, total_q_tokens, num_heads]
int num_splits,
size_t total_q_tokens,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
const int head_idx = static_cast<int>(blockIdx.x);
const int token_idx = static_cast<int>(blockIdx.y);
const int lane = threadIdx.x;
constexpr int kWarpSize = 32;
static_assert(HEAD_SIZE % kWarpSize == 0, "HEAD_SIZE must be divisible by 32.");
constexpr int DIMS_PER_THREAD = HEAD_SIZE / kWarpSize;
const int num_heads = gridDim.x;
const size_t n = total_q_tokens * static_cast<size_t>(num_heads);
const size_t base = static_cast<size_t>(token_idx) * static_cast<size_t>(num_heads) + static_cast<size_t>(head_idx);
float m = -INFINITY;
if (lane == 0) {
for (int s = 0; s < num_splits; ++s) {
m = fmaxf(m, partial_m[static_cast<size_t>(s) * n + base]);
}
}
m = __shfl_sync(0xffffffff, m, 0);
float l = 0.0f;
if (lane == 0) {
for (int s = 0; s < num_splits; ++s) {
const float ms = partial_m[static_cast<size_t>(s) * n + base];
const float ls = partial_l[static_cast<size_t>(s) * n + base];
if (ls > 0.0f) {
l += ls * exp2f(ms - m);
}
}
}
l = __shfl_sync(0xffffffff, l, 0);
const float inv_l = 1.0f / (l + 1e-6f);
Tdata *out_ptr = out_ + static_cast<int64_t>(token_idx) * o_stride + static_cast<int64_t>(head_idx) * o_head_stride;
#pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
const int dim = lane * DIMS_PER_THREAD + i;
float acc = 0.0f;
for (int s = 0; s < num_splits; ++s) {
const float ms = partial_m[static_cast<size_t>(s) * n + base];
const float w = exp2f(ms - m);
acc += partial_acc[(static_cast<size_t>(s) * n + base) * HEAD_SIZE + dim] * w;
}
const float o = acc * inv_l;
if constexpr (std::is_same_v<Tdata, half>) {
out_ptr[dim] = __float2half_rn(o);
} else if constexpr (std::is_same_v<Tdata, __nv_bfloat16>) {
out_ptr[dim] = __float2bfloat16_rn(o);
} else {
out_ptr[dim] = static_cast<Tdata>(o);
}
}
}
// Variant for large K tile where (K+V) shared memory would exceed the per-block limit on some GPUs.
// We keep K in shared memory for reuse across warps, but load V directly from global memory.
template <typename Tindex, typename Tdata, int HEAD_SIZE, int BLOCK_M, int BLOCK_N>
__device__ void PagedAttentionPrefillWarpCtaKernelKOnly(
Tdata *out_,
const Tdata *q_,
const Tdata *k_cache_,
const Tdata *v_cache_,
const Tindex *block_tables_,
const int64_t *total_kv_lens_,
const int64_t *cu_seqlens_q_,
const float *alibi_slopes_,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
static_assert(HEAD_SIZE == 64 || HEAD_SIZE == 128, "Only head_size 64/128 supported in v0.4.");
static_assert(BLOCK_M > 0 && BLOCK_M <= 16, "BLOCK_M must be <=16.");
static_assert(BLOCK_N > 0 && BLOCK_N <= 128, "BLOCK_N must be <=128.");
constexpr int kWarpSize = 32;
constexpr int DIMS_PER_THREAD = HEAD_SIZE / kWarpSize;
static_assert(HEAD_SIZE % kWarpSize == 0, "HEAD_SIZE must be divisible by 32.");
const int lane = threadIdx.x & (kWarpSize - 1);
const int warp_id = threadIdx.x / kWarpSize;
if (warp_id >= BLOCK_M) {
return;
}
const int head_idx = static_cast<int>(blockIdx.x);
const int seq_idx = static_cast<int>(blockIdx.y);
const int m_block = static_cast<int>(blockIdx.z);
const int64_t q_start = cu_seqlens_q_[seq_idx];
const int64_t q_end = cu_seqlens_q_[seq_idx + 1];
const int q_len = static_cast<int>(q_end - q_start);
if (q_len <= 0) {
return;
}
const int m_start = m_block * BLOCK_M;
const int q_token_local = m_start + warp_id;
// IMPORTANT: do not early-return for a subset of warps in this CTA because we use __syncthreads()
// later. Tail tiles are handled by masking inactive warps.
if (m_start >= q_len) {
return; // uniform across the CTA
}
const bool is_active = (q_token_local < q_len);
const int kv_len_total = static_cast<int>(total_kv_lens_[seq_idx]);
const int history_len = kv_len_total - q_len;
const int allowed_k_len = is_active ? (history_len + q_token_local + 1) : 0;
const int num_heads = gridDim.x;
const int num_queries_per_kv = num_heads / static_cast<int>(num_kv_heads);
const int kv_head_idx = head_idx / num_queries_per_kv;
const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx];
constexpr float kLog2e = 1.4426950408889634f;
const float scale_log2 = scale * kLog2e;
int64_t q_token = q_start;
if (is_active) {
q_token += static_cast<int64_t>(q_token_local);
}
const Tindex *block_table = block_tables_ + static_cast<int64_t>(seq_idx) * static_cast<int64_t>(block_table_batch_stride);
const Tdata *q_ptr = nullptr;
Tdata *out_ptr = nullptr;
if (is_active) {
q_ptr = q_ + q_token * q_stride + static_cast<int64_t>(head_idx) * q_head_stride;
out_ptr = out_ + q_token * o_stride + static_cast<int64_t>(head_idx) * o_head_stride;
}
float q_reg[DIMS_PER_THREAD];
float acc[DIMS_PER_THREAD];
#pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
const int dim = lane * DIMS_PER_THREAD + i;
q_reg[i] = is_active ? static_cast<float>(q_ptr[dim]) : 0.0f;
acc[i] = 0.0f;
}
#if defined(__CUDA_ARCH__)
float2 q_reg2[DIMS_PER_THREAD / 2];
#pragma unroll
for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) {
q_reg2[j] = make_float2(q_reg[j * 2 + 0], q_reg[j * 2 + 1]);
}
#endif
float m = -INFINITY;
float l = 0.0f;
const int max_q_in_tile = min(m_start + BLOCK_M, q_len);
const int max_allowed_k_len = min(history_len + max_q_in_tile, kv_len_total);
__shared__ int32_t s_phys[BLOCK_N];
__shared__ int32_t s_off[BLOCK_N];
__shared__ __align__(16) Tdata s_k[BLOCK_N * HEAD_SIZE];
const int pbs = static_cast<int>(page_block_size);
for (int k_base = 0; k_base < max_allowed_k_len; k_base += BLOCK_N) {
const int tile_n = min(BLOCK_N, max_allowed_k_len - k_base);
for (int t = threadIdx.x; t < tile_n; t += blockDim.x) {
const int kpos = k_base + t;
const int page = (pbs == 256) ? (kpos >> 8) : (kpos / pbs);
const int off = (pbs == 256) ? (kpos & 255) : (kpos - page * pbs);
const int32_t phys = static_cast<int32_t>(block_table[page]);
s_phys[t] = phys;
s_off[t] = off;
}
__syncthreads();
const int tile_elems = tile_n * HEAD_SIZE;
for (int idx = threadIdx.x; idx < tile_elems; idx += blockDim.x) {
const int t = idx / HEAD_SIZE;
const int dim = idx - t * HEAD_SIZE;
const int32_t phys = s_phys[t];
const int32_t off = s_off[t];
const Tdata *k_base_ptr = k_cache_ + static_cast<int64_t>(phys) * k_batch_stride + static_cast<int64_t>(off) * k_row_stride + static_cast<int64_t>(kv_head_idx) * k_head_stride;
s_k[t * HEAD_SIZE + dim] = k_base_ptr[dim];
}
__syncthreads();
for (int t = 0; t < tile_n; ++t) {
const int kpos = k_base + t;
if (kpos >= allowed_k_len) {
break;
}
const Tdata *k_ptr = s_k + t * HEAD_SIZE;
const int32_t phys = s_phys[t];
const int32_t off = s_off[t];
const Tdata *v_ptr = v_cache_ + static_cast<int64_t>(phys) * v_batch_stride + static_cast<int64_t>(off) * v_row_stride + static_cast<int64_t>(kv_head_idx) * v_head_stride;
float qk = 0.0f;
#if defined(__CUDA_ARCH__)
if constexpr (std::is_same_v<Tdata, half>) {
const int dim_base = lane * DIMS_PER_THREAD;
const half2 *k2 = reinterpret_cast<const half2 *>(k_ptr + dim_base);
#pragma unroll
for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) {
const float2 qf = q_reg2[j];
const float2 kf = __half22float2(k2[j]);
qk += qf.x * kf.x + qf.y * kf.y;
}
} else if constexpr (std::is_same_v<Tdata, __nv_bfloat16>) {
const int dim_base = lane * DIMS_PER_THREAD;
const __nv_bfloat162 *k2 = reinterpret_cast<const __nv_bfloat162 *>(k_ptr + dim_base);
#pragma unroll
for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) {
const float2 qf = q_reg2[j];
const float2 kf = __bfloat1622float2(k2[j]);
qk += qf.x * kf.x + qf.y * kf.y;
}
} else
#endif
{
#pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
const int dim = lane * DIMS_PER_THREAD + i;
qk += q_reg[i] * static_cast<float>(k_ptr[dim]);
}
}
qk = op::paged_attention::cuda::warpReduceSum(qk);
float alpha = 1.0f;
float beta = 0.0f;
if (lane == 0) {
float score = qk * scale_log2;
if (alibi_slope != 0.0f) {
score += (alibi_slope * static_cast<float>(kpos - (allowed_k_len - 1))) * kLog2e;
}
const float m_new = fmaxf(m, score);
alpha = exp2f(m - m_new);
beta = exp2f(score - m_new);
l = l * alpha + beta;
m = m_new;
}
alpha = __shfl_sync(0xffffffff, alpha, 0);
beta = __shfl_sync(0xffffffff, beta, 0);
#if defined(__CUDA_ARCH__)
if constexpr (std::is_same_v<Tdata, half>) {
const int dim_base = lane * DIMS_PER_THREAD;
const half2 *v2 = reinterpret_cast<const half2 *>(v_ptr + dim_base);
#pragma unroll
for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) {
const float2 vf = __half22float2(v2[j]);
acc[j * 2 + 0] = acc[j * 2 + 0] * alpha + beta * vf.x;
acc[j * 2 + 1] = acc[j * 2 + 1] * alpha + beta * vf.y;
}
} else if constexpr (std::is_same_v<Tdata, __nv_bfloat16>) {
const int dim_base = lane * DIMS_PER_THREAD;
const __nv_bfloat162 *v2 = reinterpret_cast<const __nv_bfloat162 *>(v_ptr + dim_base);
#pragma unroll
for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) {
const float2 vf = __bfloat1622float2(v2[j]);
acc[j * 2 + 0] = acc[j * 2 + 0] * alpha + beta * vf.x;
acc[j * 2 + 1] = acc[j * 2 + 1] * alpha + beta * vf.y;
}
} else
#endif
{
#pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
const int dim = lane * DIMS_PER_THREAD + i;
const float v_val = static_cast<float>(v_ptr[dim]);
acc[i] = acc[i] * alpha + beta * v_val;
}
}
}
__syncthreads();
}
float inv_l = 0.0f;
if (lane == 0) {
inv_l = 1.0f / (l + 1e-6f);
}
inv_l = __shfl_sync(0xffffffff, inv_l, 0);
#pragma unroll
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
const int dim = lane * DIMS_PER_THREAD + i;
const float out_val = acc[i] * inv_l;
if (!is_active) {
continue;
}
if constexpr (std::is_same_v<Tdata, half>) {
out_ptr[dim] = __float2half_rn(out_val);
} else if constexpr (std::is_same_v<Tdata, __nv_bfloat16>) {
out_ptr[dim] = __float2bfloat16_rn(out_val);
} else {
out_ptr[dim] = static_cast<Tdata>(out_val);
}
}
}
// TensorCore (WMMA) score kernel (v0.4 experimental):
// - Target shape: head_dim=128, page_block_size=256, fp16.
// - Compute QK^T with WMMA into shared memory, then reuse the existing online-softmax + V accumulation
// pattern (SIMT) per query row.
//
// Notes:
// - This is a correctness-first kernel. It doesn't yet use MMA for PV (P * V) update.
// - We keep the same grid mapping as other prefill kernels: blockIdx = (head, seq, m_block).
template <int kWarpSize, int kBlockN, int kHeadDim, int kDimsPerThread>
__device__ __forceinline__ void PagedAttentionPrefillMmaScoreUpdateRow(
int lane,
int k_base,
int allowed_k_len,
const float *scores_row, // [kBlockN]
const half *v_tile, // [kBlockN, kHeadDim]
float scale_log2,
float alibi_slope_log2,
float &m,
float &l,
float *acc) { // [kDimsPerThread]
// Max over keys in this tile.
float local_max = -INFINITY;
for (int t = lane; t < kBlockN; t += kWarpSize) {
const int kpos = k_base + t;
if (kpos >= allowed_k_len) {
continue;
}
float score = scores_row[t] * scale_log2;
if (alibi_slope_log2 != 0.0f) {
score += alibi_slope_log2 * static_cast<float>(kpos - (allowed_k_len - 1));
}
local_max = fmaxf(local_max, score);
}
float tile_m = op::paged_attention::cuda::warpReduceMax(local_max);
tile_m = __shfl_sync(0xffffffff, tile_m, 0);
// Sumexp + weighted V over keys in this tile, partitioned by lanes.
float sumexp_lane = 0.0f;
float acc_tile[kDimsPerThread] = {0.0f, 0.0f, 0.0f, 0.0f};
const int dim_base = lane * kDimsPerThread;
if (tile_m != -INFINITY) {
for (int t = lane; t < kBlockN; t += kWarpSize) {
const int kpos = k_base + t;
if (kpos >= allowed_k_len) {
continue;
}
float score = scores_row[t] * scale_log2;
if (alibi_slope_log2 != 0.0f) {
score += alibi_slope_log2 * static_cast<float>(kpos - (allowed_k_len - 1));
}
const float w = exp2f(score - tile_m);
sumexp_lane += w;
const half *v_ptr = v_tile + t * kHeadDim + dim_base;
const half2 *v2 = reinterpret_cast<const half2 *>(v_ptr);
#pragma unroll
for (int j = 0; j < kDimsPerThread / 2; ++j) {
const float2 vf = __half22float2(v2[j]);
acc_tile[j * 2 + 0] += w * vf.x;
acc_tile[j * 2 + 1] += w * vf.y;
}
}
}
float tile_sumexp = op::paged_attention::cuda::warpReduceSum(sumexp_lane);
tile_sumexp = __shfl_sync(0xffffffff, tile_sumexp, 0);
float alpha = 1.0f;
float beta = 0.0f;
if (lane == 0) {
if (tile_sumexp > 0.0f && tile_m != -INFINITY) {
const float m_new = fmaxf(m, tile_m);
alpha = exp2f(m - m_new);
beta = exp2f(tile_m - m_new);
l = l * alpha + tile_sumexp * beta;
m = m_new;
} else {
alpha = 1.0f;
beta = 0.0f;
}
}
alpha = __shfl_sync(0xffffffff, alpha, 0);
beta = __shfl_sync(0xffffffff, beta, 0);
#pragma unroll
for (int i = 0; i < kDimsPerThread; ++i) {
acc[i] = acc[i] * alpha + beta * acc_tile[i];
}
}
template <int kWarpSize, int kHeadDim, int kDimsPerThread>
__device__ __forceinline__ void PagedAttentionPrefillMmaScoreWriteRow(
int lane,
bool active,
int q_token_local,
int64_t q_start,
int head_idx,
half *out_,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
float l,
const float *acc) { // [kDimsPerThread]
if (!active) {
return;
}
float inv_l = 0.0f;
if (lane == 0) {
inv_l = 1.0f / (l + 1e-6f);
}
inv_l = __shfl_sync(0xffffffff, inv_l, 0);
const int64_t q_token = q_start + static_cast<int64_t>(q_token_local);
half *out_ptr = out_ + q_token * o_stride + static_cast<int64_t>(head_idx) * o_head_stride;
#pragma unroll
for (int i = 0; i < kDimsPerThread; ++i) {
const int dim = lane * kDimsPerThread + i;
out_ptr[dim] = __float2half_rn(acc[i] * inv_l);
}
}
template <typename Tindex>
__device__ void PagedAttentionPrefillWarpCta8MmaHd128Kernel(
half *out_,
const half *q_,
const half *k_cache_,
const half *v_cache_,
const Tindex *block_tables_,
const int64_t *total_kv_lens_,
const int64_t *cu_seqlens_q_,
const float *alibi_slopes_,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
(void)max_num_blocks_per_seq;
constexpr int kWarpSize = 32;
constexpr int kWarps = 8;
constexpr int kHeadDim = 128;
// Extra padding in the K dimension to reduce shared-memory bank conflicts for ldmatrix / wmma loads.
// NOTE: FA2 uses a swizzled smem layout; padding is a smaller step that keeps our code simple.
constexpr int kHeadDimSmem = 136; // must be a multiple of 8 for wmma::load_matrix_sync
constexpr int kBlockM = 16; // 2 rows per warp
// Keep static shared memory <= 48KB for compatibility with build targets that cap SMEM at 0xC000.
// kBlockN=64 brings s_q+s_k+s_v+s_scores+s_phys/s_off down to ~41KB.
constexpr int kBlockN = 64;
constexpr int kDimsPerThread = kHeadDim / kWarpSize;
static_assert(kHeadDim % kWarpSize == 0, "head_dim must be divisible by 32.");
const int lane = threadIdx.x & (kWarpSize - 1);
const int warp_id = threadIdx.x / kWarpSize;
if (warp_id >= kWarps) {
return;
}
const int head_idx = static_cast<int>(blockIdx.x);
const int seq_idx = static_cast<int>(blockIdx.y);
const int m_block = static_cast<int>(blockIdx.z);
const int64_t q_start = cu_seqlens_q_[seq_idx];
const int64_t q_end = cu_seqlens_q_[seq_idx + 1];
const int q_len = static_cast<int>(q_end - q_start);
if (q_len <= 0) {
return;
}
const int m_start = m_block * kBlockM;
// Uniform early return for empty tail tiles (avoid deadlock with __syncthreads()).
if (m_start >= q_len) {
return;
}
const int kv_len_total = static_cast<int>(total_kv_lens_[seq_idx]);
const int history_len = kv_len_total - q_len;
// Clamp max k length for this CTA based on the last active query row in the tile.
const int max_q_in_tile = min(m_start + kBlockM, q_len);
const int max_allowed_k_len = min(history_len + max_q_in_tile, kv_len_total);
const int num_heads = gridDim.x;
const int num_queries_per_kv = num_heads / static_cast<int>(num_kv_heads);
const int kv_head_idx = head_idx / num_queries_per_kv;
const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx];
constexpr float kLog2e = 1.4426950408889634f;
const float scale_log2 = scale * kLog2e;
const float alibi_slope_log2 = alibi_slope * kLog2e;
const int pbs = static_cast<int>(page_block_size);
const Tindex *block_table = block_tables_ + static_cast<int64_t>(seq_idx) * static_cast<int64_t>(block_table_batch_stride);
// Shared memory:
// - s_q: [kBlockM, kHeadDimSmem] (padded)
// - s_k/s_v: [kBlockN, kHeadDim]
// - s_scores: [kBlockM, kBlockN] raw dot products (no scale / alibi)
__shared__ __align__(16) half s_q[kBlockM * kHeadDimSmem];
__shared__ int32_t s_phys[kBlockN];
__shared__ int32_t s_off[kBlockN];
__shared__ __align__(16) half s_k[kBlockN * kHeadDimSmem];
__shared__ __align__(16) half s_v[kBlockN * kHeadDimSmem];
__shared__ __align__(16) float s_scores[kBlockM * kBlockN];
// Load Q tile (pad inactive rows with 0).
for (int idx = threadIdx.x; idx < kBlockM * kHeadDim; idx += blockDim.x) {
const int r = idx / kHeadDim;
const int d = idx - r * kHeadDim;
const int q_token_local = m_start + r;
if (q_token_local < q_len) {
const int64_t q_token = q_start + static_cast<int64_t>(q_token_local);
const half *q_ptr = q_ + q_token * q_stride + static_cast<int64_t>(head_idx) * q_head_stride;
s_q[r * kHeadDimSmem + d] = q_ptr[d];
} else {
s_q[r * kHeadDimSmem + d] = __float2half_rn(0.0f);
}
}
__syncthreads();
// Two rows per warp: row0=warp_id, row1=warp_id+kWarps.
const int row0 = warp_id;
const int row1 = warp_id + kWarps;
const bool active0 = (row0 < kBlockM) && ((m_start + row0) < q_len);
const bool active1 = (row1 < kBlockM) && ((m_start + row1) < q_len);
const int allowed0 = active0 ? min(history_len + (m_start + row0) + 1, kv_len_total) : 0;
const int allowed1 = active1 ? min(history_len + (m_start + row1) + 1, kv_len_total) : 0;
float m0 = -INFINITY, l0 = 0.0f;
float m1 = -INFINITY, l1 = 0.0f;
float acc0[kDimsPerThread] = {0.0f, 0.0f, 0.0f, 0.0f};
float acc1[kDimsPerThread] = {0.0f, 0.0f, 0.0f, 0.0f};
// Iterate over K/V tiles.
for (int k_base = 0; k_base < max_allowed_k_len; k_base += kBlockN) {
// Map logical k positions to physical blocks for this tile (pad the tail with -1).
for (int t = threadIdx.x; t < kBlockN; t += blockDim.x) {
const int kpos = k_base + t;
if (kpos < max_allowed_k_len) {
const int page = (pbs == 256) ? (kpos >> 8) : (kpos / pbs);
const int off = (pbs == 256) ? (kpos & 255) : (kpos - page * pbs);
s_phys[t] = static_cast<int32_t>(block_table[page]);
s_off[t] = off;
} else {
s_phys[t] = -1;
s_off[t] = 0;
}
}
__syncthreads();
// Load K/V tile into shared memory (pad with 0 for inactive tokens).
for (int idx = threadIdx.x; idx < kBlockN * kHeadDim; idx += blockDim.x) {
const int t = idx / kHeadDim;
const int d = idx - t * kHeadDim;
const int32_t phys = s_phys[t];
if (phys >= 0) {
const int32_t off = s_off[t];
const half *k_ptr = k_cache_ + static_cast<int64_t>(phys) * k_batch_stride + static_cast<int64_t>(off) * k_row_stride + static_cast<int64_t>(kv_head_idx) * k_head_stride;
const half *v_ptr = v_cache_ + static_cast<int64_t>(phys) * v_batch_stride + static_cast<int64_t>(off) * v_row_stride + static_cast<int64_t>(kv_head_idx) * v_head_stride;
s_k[t * kHeadDimSmem + d] = k_ptr[d];
s_v[t * kHeadDimSmem + d] = v_ptr[d];
} else {
s_k[t * kHeadDimSmem + d] = __float2half_rn(0.0f);
s_v[t * kHeadDimSmem + d] = __float2half_rn(0.0f);
}
}
__syncthreads();
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700)
// WMMA: each warp computes scores for 16 keys (one 16-column slice of the K tile) across all 16 rows.
// For kBlockN=64, only the first 4 warps participate in WMMA score computation.
namespace wmma = nvcuda::wmma;
constexpr int kNSub = kBlockN / 16;
if (warp_id < kNSub) {
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> a_frag;
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::col_major> b_frag;
wmma::fragment<wmma::accumulator, 16, 16, 16, float> c_frag;
wmma::fill_fragment(c_frag, 0.0f);
const int n_sub = warp_id; // [0, kNSub)
const half *q_tile = s_q;
const half *k_tile = s_k + (n_sub * 16) * kHeadDimSmem;
// K loop (head_dim=128).
#pragma unroll
for (int kk = 0; kk < (kHeadDim / 16); ++kk) {
wmma::load_matrix_sync(a_frag, q_tile + kk * 16, kHeadDimSmem);
wmma::load_matrix_sync(b_frag, k_tile + kk * 16, kHeadDimSmem);
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
}
float *scores_tile = s_scores + n_sub * 16;
wmma::store_matrix_sync(scores_tile, c_frag, kBlockN, wmma::mem_row_major);
}
#else
// No WMMA support on this architecture: fall back to scalar dot in the existing kernels.
// (We keep scores as 0 so this kernel is effectively incorrect; host dispatch must avoid selecting it.)
if (threadIdx.x == 0) {
// Intentionally empty.
}
#endif
__syncthreads();
// Online softmax + V update per row handled by the same warp across tiles.
if (row0 < kBlockM) {
PagedAttentionPrefillMmaScoreUpdateRow<kWarpSize, kBlockN, kHeadDim, kDimsPerThread>(
lane, k_base, allowed0, s_scores + row0 * kBlockN, s_v, scale_log2, alibi_slope_log2, m0, l0, acc0);
}
if (row1 < kBlockM) {
PagedAttentionPrefillMmaScoreUpdateRow<kWarpSize, kBlockN, kHeadDim, kDimsPerThread>(
lane, k_base, allowed1, s_scores + row1 * kBlockN, s_v, scale_log2, alibi_slope_log2, m1, l1, acc1);
}
__syncthreads();
}
// Write outputs.
if (row0 < kBlockM) {
PagedAttentionPrefillMmaScoreWriteRow<kWarpSize, kHeadDim, kDimsPerThread>(
lane, active0, m_start + row0, q_start, head_idx, out_, o_stride, o_head_stride, l0, acc0);
}
if (row1 < kBlockM) {
PagedAttentionPrefillMmaScoreWriteRow<kWarpSize, kHeadDim, kDimsPerThread>(
lane, active1, m_start + row1, q_start, head_idx, out_, o_stride, o_head_stride, l1, acc1);
}
}
} // namespace op::paged_attention_prefill::cuda
#endif
......@@ -3,6 +3,7 @@
#include "../../../utils.h"
#include "../../tensor.h"
#include <cstring>
#include <iostream>
#include <optional>
#include <vector>
......@@ -14,21 +15,30 @@ class PagedAttentionPrefillInfo {
public:
infiniDtype_t dtype;
infiniDtype_t index_dtype;
float scale;
size_t num_seqs;
size_t total_q_tokens;
size_t num_heads;
size_t num_kv_heads;
size_t head_size;
size_t block_size;
size_t page_block_size;
size_t max_num_blocks_per_seq;
size_t total_q_tokens;
size_t num_blocks;
ptrdiff_t q_stride;
ptrdiff_t q_head_stride;
ptrdiff_t kv_block_stride;
ptrdiff_t kv_head_stride;
ptrdiff_t k_batch_stride;
ptrdiff_t k_row_stride;
ptrdiff_t k_head_stride;
ptrdiff_t v_batch_stride;
ptrdiff_t v_row_stride;
ptrdiff_t v_head_stride;
ptrdiff_t o_stride;
ptrdiff_t o_head_stride;
ptrdiff_t block_table_batch_stride;
static utils::Result<PagedAttentionPrefillInfo> create(
infiniopTensorDescriptor_t out_desc,
......@@ -36,89 +46,161 @@ public:
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t seq_lens_desc,
infiniopTensorDescriptor_t cum_seq_lens_q_desc,
infiniopTensorDescriptor_t total_kv_lens_desc,
infiniopTensorDescriptor_t cum_seqlens_q_desc,
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc,
float scale) {
auto dtype = q_desc->dtype();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32);
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16);
if (out_desc->dtype() != dtype || k_cache_desc->dtype() != dtype || v_cache_desc->dtype() != dtype) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (cum_seq_lens_q_desc->dtype() != INFINI_DTYPE_I64 || seq_lens_desc->dtype() != INFINI_DTYPE_I64) {
// q/out: [total_q, heads, head_dim]
if (q_desc->ndim() != 3 || out_desc->ndim() != 3) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
// FA2 paged KV layout: [num_blocks, page_block_size, kv_heads, head_dim]
if (k_cache_desc->ndim() != 4 || v_cache_desc->ndim() != 4) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (block_tables_desc->ndim() != 2 || total_kv_lens_desc->ndim() != 1 || cum_seqlens_q_desc->ndim() != 1) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
CHECK_OR_RETURN(q_desc->stride(2) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES);
CHECK_OR_RETURN(out_desc->stride(2) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES);
CHECK_OR_RETURN(k_cache_desc->stride(3) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES);
CHECK_OR_RETURN(v_cache_desc->stride(3) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES);
// Index dtypes: allow I32/I64/U32 (v0.4 roadmap allows internal conversion to I32).
const auto block_tables_dt = block_tables_desc->dtype();
if (!((block_tables_dt == INFINI_DTYPE_I64) || (block_tables_dt == INFINI_DTYPE_I32) || (block_tables_dt == INFINI_DTYPE_U32))) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
// Keep it simple: require total_kv_lens + cum_seqlens_q to be int64 for now
// (matches current paged_attention_prefill signature). We will convert to int32 internally later.
if (total_kv_lens_desc->dtype() != INFINI_DTYPE_I64 || cum_seqlens_q_desc->dtype() != INFINI_DTYPE_I64) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
CHECK_OR_RETURN(block_tables_desc->stride(1) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES);
if (alibi_slopes_desc.has_value() && alibi_slopes_desc.value() != nullptr) {
if (alibi_slopes_desc.value()->dtype() != INFINI_DTYPE_F32) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (alibi_slopes_desc.value()->ndim() != 1) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
CHECK_OR_RETURN(alibi_slopes_desc.value()->stride(0) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES);
}
auto k_shape = k_cache_desc->shape();
auto v_shape = v_cache_desc->shape();
auto block_tables_shape = block_tables_desc->shape();
auto seq_lens_shape = seq_lens_desc->shape();
auto cum_seq_lens_q_shape = cum_seq_lens_q_desc->shape();
const auto q_shape = q_desc->shape();
const auto k_shape = k_cache_desc->shape();
if (k_shape.size() != 4 || v_shape.size() != 4) {
const size_t total_q_tokens = q_shape[0];
const size_t num_heads = q_shape[1];
const size_t head_size = q_shape[2];
const size_t num_blocks = k_shape[0];
const size_t page_block_size = k_shape[2];
const size_t num_kv_heads = k_shape[1];
if (head_size != 64 && head_size != 128) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (num_heads % num_kv_heads != 0) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (block_tables_shape.size() != 2) {
// v_cache must match the inferred K layout.
const auto v_shape = v_cache_desc->shape();
if (v_shape[0] != num_blocks || v_shape[3] != head_size) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (seq_lens_shape.size() != 1 || cum_seq_lens_q_shape.size() != 1) {
if (v_shape[1] != num_kv_heads || v_shape[2] != page_block_size) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (cum_seq_lens_q_shape[0] != seq_lens_shape[0] + 1) {
return INFINI_STATUS_BAD_PARAM;
if (v_cache_desc->shape()[0] != k_shape[0] || v_cache_desc->shape()[3] != k_shape[3]) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
// Q shape: [total_tokens, heads, dim]
auto q_shape = q_desc->shape();
if (q_shape.size() != 3) {
if (out_desc->shape()[0] != q_shape[0] || out_desc->shape()[1] != q_shape[1] || out_desc->shape()[2] != q_shape[2]) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
size_t total_q_tokens = q_shape[0];
size_t num_heads = q_shape[1];
size_t head_size = q_shape[2];
if (head_size > 1024) {
const size_t num_seqs = total_kv_lens_desc->shape()[0];
if (cum_seqlens_q_desc->shape()[0] != num_seqs + 1) {
return INFINI_STATUS_BAD_PARAM;
}
size_t num_seqs = seq_lens_shape[0];
size_t num_kv_heads = k_shape[1];
size_t block_size = k_shape[2];
size_t max_num_blocks_per_seq = block_tables_shape[1];
ptrdiff_t q_stride = q_desc->stride(0);
ptrdiff_t q_head_stride = q_desc->stride(1);
ptrdiff_t kv_block_stride = k_cache_desc->stride(0);
ptrdiff_t kv_head_stride = k_cache_desc->stride(1);
ptrdiff_t o_stride = out_desc->stride(0);
const size_t max_num_blocks_per_seq = block_tables_desc->shape()[1];
// Strides (in elements)
const ptrdiff_t q_stride = q_desc->stride(0);
const ptrdiff_t q_head_stride = q_desc->stride(1);
const ptrdiff_t o_stride = out_desc->stride(0);
const ptrdiff_t o_head_stride = out_desc->stride(1);
const ptrdiff_t k_batch_stride = k_cache_desc->stride(0);
const ptrdiff_t k_row_stride = k_cache_desc->stride(2);
const ptrdiff_t k_head_stride = k_cache_desc->stride(1);
const ptrdiff_t v_batch_stride = v_cache_desc->stride(0);
const ptrdiff_t v_row_stride = v_cache_desc->stride(2);
const ptrdiff_t v_head_stride = v_cache_desc->stride(1);
const ptrdiff_t block_table_batch_stride = block_tables_desc->stride(0);
if (const char *dbg = std::getenv("INFINIOP_DEBUG_PREFILL_INFO")) {
static bool printed = false;
if (!printed && std::strcmp(dbg, "1") == 0) {
const auto bt_shape = block_tables_desc->shape();
std::fprintf(stderr,
"[infiniop][flash_attention_prefill][info] k_shape=[%zu,%zu,%zu,%zu] k_strides=[%td,%td,%td,%td] (row_stride=%td head_stride=%td)\n",
static_cast<size_t>(k_shape[0]), static_cast<size_t>(k_shape[1]),
static_cast<size_t>(k_shape[2]), static_cast<size_t>(k_shape[3]),
k_cache_desc->stride(0), k_cache_desc->stride(1), k_cache_desc->stride(2), k_cache_desc->stride(3),
k_row_stride, k_head_stride);
std::fprintf(stderr,
"[infiniop][flash_attention_prefill][info] block_tables shape=[%zu,%zu] strides=[%td,%td]\n",
static_cast<size_t>(bt_shape[0]), static_cast<size_t>(bt_shape[1]),
block_tables_desc->stride(0), block_tables_desc->stride(1));
printed = true;
}
}
return utils::Result<PagedAttentionPrefillInfo>(PagedAttentionPrefillInfo{
dtype,
block_tables_dt,
scale,
num_seqs,
total_q_tokens,
num_heads,
num_kv_heads,
head_size,
block_size,
page_block_size,
max_num_blocks_per_seq,
total_q_tokens,
num_blocks,
q_stride,
q_head_stride,
kv_block_stride,
kv_head_stride,
o_stride});
k_batch_stride,
k_row_stride,
k_head_stride,
v_batch_stride,
v_row_stride,
v_head_stride,
o_stride,
o_head_stride,
block_table_batch_stride,
});
}
};
} // namespace op::paged_attention_prefill
#endif
#include <cuda_fp16.h>
#include <float.h>
#include <math.h>
#include <stdint.h>
#include <cuda_runtime.h>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include "../../../devices/nvidia/nvidia_common.cuh"
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
#include "../cuda/kernel.cuh"
// #include "paged_attention_prefill_fa2.cuh"
#include "paged_attention_prefill_nvidia.cuh"
template <typename Tdata, typename Tcompute>
infiniStatus_t launchPagedAttentionPrefill(
Tdata *out, const Tdata *q, const Tdata *k_cache, const Tdata *v_cache,
const int64_t *block_tables,
const int64_t *seq_lens,
const int64_t *cum_seq_lens_q,
const float *alibi_slopes,
const size_t num_heads,
const size_t num_seqs,
const size_t num_kv_heads,
const float scale,
const size_t max_num_blocks_per_seq,
const size_t block_size,
const size_t total_q_tokens,
const size_t head_size,
const ptrdiff_t kv_block_stride,
const ptrdiff_t kv_head_stride,
const ptrdiff_t q_stride,
const ptrdiff_t q_head_stride,
#include "../cuda/kernel_v2.cuh"
namespace op::paged_attention_prefill::nvidia {
namespace {
constexpr size_t ceilDiv(size_t a, size_t b) {
return (a + b - 1) / b;
}
inline const char *default_prefill_kernel(const PagedAttentionPrefillInfo &info) {
// Heuristic auto-dispatch (v0.4):
// - Prefer the pipelined + tile-wise softmax kernel on FA2-compatible block_size=256.
// - Keep a conservative fallback for other shapes / older GPUs (cp.async is a no-op below SM80).
//
// Users can always override via INFINIOP_FLASH_PREFILL_KERNEL.
if (info.page_block_size == 256 && (info.dtype == INFINI_DTYPE_F16 || info.dtype == INFINI_DTYPE_BF16)) {
if (info.head_size == 128) {
return "warpcta8pipe";
}
// For head_size=64 we keep the previous default until we have broader perf coverage.
}
return "warpcta8";
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128Warp(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// Legacy per-seq launch (kept only as a wrapper; current "warp" impl uses a global-token kernel).
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpKernel<Tindex, Tdata, 128>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, block_table_batch_stride,
q_stride, q_head_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd64Warp(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// Legacy per-seq launch (kept only as a wrapper; current "warp" impl uses a global-token kernel).
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpKernel<Tindex, Tdata, 64>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, block_table_batch_stride,
q_stride, q_head_stride, k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride, o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128WarpCta(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 4 warps per CTA, one warp per query token.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel<Tindex, Tdata, 128, 4, 64>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd64WarpCta(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 4 warps per CTA, one warp per query token.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel<Tindex, Tdata, 64, 4, 128>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128WarpCta8(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 8 warps per CTA, one warp per query token.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel<Tindex, Tdata, 128, 8, 64>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128WarpCta8N128(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 8 warps per CTA, one warp per query token, tile_n=128 for fewer K stages.
// Note: we keep K in shared memory but load V from global to stay within the per-block shared limit.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernelKOnly<Tindex, Tdata, 128, 8, 128>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd64WarpCta8(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 8 warps per CTA, one warp per query token.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel<Tindex, Tdata, 64, 8, 128>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128WarpCta8Pipe(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 8 warps per CTA, one warp per query token, with cp.async pipelining.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernelPipelined<Tindex, Tdata, 128, 8, 32, 2>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex>
INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128WarpCta8Mma(
half *out,
const half *q,
const half *k_cache,
const half *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCta8MmaHd128Kernel<Tindex>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd64WarpCta8Pipe(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 8 warps per CTA, one warp per query token, with cp.async pipelining.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernelPipelined<Tindex, Tdata, 64, 8, 32, 2>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128WarpCta8PipeSplitKv(
float *partial_acc,
float *partial_m,
float *partial_l,
int num_splits,
size_t total_q_tokens,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride) {
// Encode (split_idx, m_block) into blockIdx.z to allow a single kernel launch:
// blockIdx.z in [0, num_splits * num_m_blocks).
const int num_m_blocks = static_cast<int>((total_q_tokens + 8 - 1) / 8);
const int bz = static_cast<int>(blockIdx.z);
const int split_idx = bz / num_m_blocks;
const int m_block = bz - split_idx * num_m_blocks;
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernelPipelinedSplitKv<Tindex, Tdata, 128, 8, 32, 2>(
partial_acc, partial_m, partial_l, split_idx, num_splits, m_block, total_q_tokens,
q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd64WarpCta8PipeSplitKv(
float *partial_acc,
float *partial_m,
float *partial_l,
int num_splits,
size_t total_q_tokens,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride) {
const int num_m_blocks = static_cast<int>((total_q_tokens + 8 - 1) / 8);
const int bz = static_cast<int>(blockIdx.z);
const int split_idx = bz / num_m_blocks;
const int m_block = bz - split_idx * num_m_blocks;
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernelPipelinedSplitKv<Tindex, Tdata, 64, 8, 32, 2>(
partial_acc, partial_m, partial_l, split_idx, num_splits, m_block, total_q_tokens,
q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride);
}
template <typename Tdata>
INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128SplitKvCombine(
Tdata *out,
const float *partial_acc,
const float *partial_m,
const float *partial_l,
int num_splits,
size_t total_q_tokens,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
op::paged_attention_prefill::cuda::PagedAttentionPrefillSplitKvCombineWarpKernel<Tdata, 128>(
out, partial_acc, partial_m, partial_l, num_splits, total_q_tokens, o_stride, o_head_stride);
}
template <typename Tdata>
INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd64SplitKvCombine(
Tdata *out,
const float *partial_acc,
const float *partial_m,
const float *partial_l,
int num_splits,
size_t total_q_tokens,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
op::paged_attention_prefill::cuda::PagedAttentionPrefillSplitKvCombineWarpKernel<Tdata, 64>(
out, partial_acc, partial_m, partial_l, num_splits, total_q_tokens, o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128WarpCta16(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 16 warps per CTA, one warp per query token.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel<Tindex, Tdata, 128, 16, 64>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata>
INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd64WarpCta16(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_kv_heads,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride) {
// 16 warps per CTA, one warp per query token.
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel<Tindex, Tdata, 64, 16, 128>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
}
template <typename Tindex, typename Tdata, typename Tcompute>
infiniStatus_t launch_prefill_ref(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
cudaStream_t stream) {
if (total_q_tokens == 0 || num_heads == 0) {
const dim3 grid(static_cast<uint32_t>(total_q_tokens), static_cast<uint32_t>(num_heads), 1);
const dim3 block(static_cast<uint32_t>(head_size), 1, 1);
if (head_size == 64) {
op::paged_attention_prefill::cuda::PagedAttentionPrefillReferenceKernel<Tindex, Tdata, Tcompute, 64>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_heads, num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride, q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride, num_seqs);
return INFINI_STATUS_SUCCESS;
}
if (head_size == 128) {
op::paged_attention_prefill::cuda::PagedAttentionPrefillReferenceKernel<Tindex, Tdata, Tcompute, 128>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_heads, num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride, q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride, num_seqs);
return INFINI_STATUS_SUCCESS;
}
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill_warp(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
cudaStream_t stream) {
const dim3 block(32, 1, 1);
// Global-token launch:
// - dramatically reduces grid size vs the legacy (num_seqs * total_q_tokens) launch
// - matches PagedAttention varlen (cu_seqlens) mental model better
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(total_q_tokens),
1);
switch (head_size) {
case 64:
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpGlobalKernel<Tindex, Tdata, 64>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_heads, num_seqs, num_kv_heads, total_q_tokens, scale, max_num_blocks_per_seq,
page_block_size, block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
case 128:
op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpGlobalKernel<Tindex, Tdata, 128>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_heads, num_seqs, num_kv_heads, total_q_tokens, scale, max_num_blocks_per_seq,
page_block_size, block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
cudaStream_t stream) {
dim3 grid(total_q_tokens, num_heads);
dim3 block(head_size);
constexpr int kWarps = 4;
constexpr int kThreads = kWarps * 32;
const dim3 block(kThreads);
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(num_seqs),
static_cast<uint32_t>(ceilDiv(total_q_tokens, static_cast<size_t>(kWarps))));
op::paged_attention_prefill::cuda::pagedAttentionPrefillKernel<Tdata, Tcompute>
switch (head_size) {
case 64:
PagedAttentionPrefillHd64WarpCta<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
case 128:
PagedAttentionPrefillHd128WarpCta<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache,
block_tables, seq_lens, cum_seq_lens_q, alibi_slopes,
num_heads, num_kv_heads, scale,
max_num_blocks_per_seq, block_size,
kv_block_stride, kv_head_stride,
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
head_size,
num_seqs);
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill_warpcta8(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
cudaStream_t stream) {
constexpr int kWarps = 8;
constexpr int kThreads = kWarps * 32;
const dim3 block(kThreads);
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(num_seqs),
static_cast<uint32_t>(ceilDiv(total_q_tokens, static_cast<size_t>(kWarps))));
switch (head_size) {
case 64:
PagedAttentionPrefillHd64WarpCta8<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
case 128:
PagedAttentionPrefillHd128WarpCta8<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
namespace op::paged_attention_prefill::nvidia {
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill_warpcta8pipe(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
cudaStream_t stream) {
constexpr int kWarps = 8;
constexpr int kThreads = kWarps * 32;
const dim3 block(kThreads);
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(num_seqs),
static_cast<uint32_t>(ceilDiv(total_q_tokens, static_cast<size_t>(kWarps))));
switch (head_size) {
case 64:
PagedAttentionPrefillHd64WarpCta8Pipe<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
case 128:
PagedAttentionPrefillHd128WarpCta8Pipe<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill_warpcta8mma(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
cudaStream_t stream) {
// Current WMMA kernel only supports fp16 + head_dim=128.
if constexpr (!std::is_same_v<Tdata, half>) {
return launch_prefill_warpcta8pipe<Tindex, Tdata>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_heads, num_seqs, num_kv_heads, total_q_tokens, head_size, scale,
max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride, stream);
}
if (head_size != 128) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
// Guardrail: the current WMMA-score kernel is correctness-first and can be extremely slow on long prompts.
// Allow power users to force it via INFINIOP_FLASH_PREFILL_MMA_FORCE=1.
const char *force_env = std::getenv("INFINIOP_FLASH_PREFILL_MMA_FORCE");
const bool force_mma = (force_env != nullptr) && (std::strcmp(force_env, "1") == 0);
const size_t seqlen_k_est = max_num_blocks_per_seq * page_block_size;
if (!force_mma && seqlen_k_est > 4096) {
static bool warned = false;
if (!warned) {
std::fprintf(stderr,
"[infiniop][paged_attention_prefill] warpcta8mma is experimental and very slow for long seqlen_k (est=%zu). "
"Falling back to warpcta8pipe. Set INFINIOP_FLASH_PREFILL_MMA_FORCE=1 to override.\n",
seqlen_k_est);
warned = true;
}
return launch_prefill_warpcta8pipe<Tindex, Tdata>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_heads, num_seqs, num_kv_heads, total_q_tokens, head_size, scale,
max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride, stream);
}
// WMMA requires SM70+. If not supported (or if we can't query), fall back to the pipelined SIMT kernel.
int device = 0;
cudaDeviceProp prop{};
if (cudaGetDevice(&device) == cudaSuccess && cudaGetDeviceProperties(&prop, device) == cudaSuccess) {
if (prop.major < 7) {
return launch_prefill_warpcta8pipe<Tindex, Tdata>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_heads, num_seqs, num_kv_heads, total_q_tokens, head_size, scale,
max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride, stream);
}
}
constexpr int kWarps = 8;
constexpr int kThreads = kWarps * 32;
const dim3 block(kThreads);
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(num_seqs),
static_cast<uint32_t>(ceilDiv(total_q_tokens, static_cast<size_t>(16))));
PagedAttentionPrefillHd128WarpCta8Mma<Tindex>
<<<grid, block, 0, stream>>>(
static_cast<half *>(out),
static_cast<const half *>(q),
static_cast<const half *>(k_cache),
static_cast<const half *>(v_cache),
block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill_warpcta8pipe_splitkv(
float *partial_acc,
float *partial_m,
float *partial_l,
int num_splits,
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
cudaStream_t stream) {
constexpr int kMaxSplits = 8;
if (num_splits < 1) {
num_splits = 1;
}
if (num_splits > kMaxSplits) {
num_splits = kMaxSplits;
}
constexpr int kWarps = 8;
constexpr int kThreads = kWarps * 32;
const dim3 block(kThreads);
const size_t num_m_blocks = ceilDiv(total_q_tokens, static_cast<size_t>(kWarps));
// Single kernel launch with split_idx encoded in grid.z:
// blockIdx.z in [0, num_splits * num_m_blocks).
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(num_seqs),
static_cast<uint32_t>(num_m_blocks * static_cast<size_t>(num_splits)));
switch (head_size) {
case 64:
PagedAttentionPrefillHd64WarpCta8PipeSplitKv<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
partial_acc, partial_m, partial_l, num_splits, total_q_tokens,
q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride);
break;
case 128:
PagedAttentionPrefillHd128WarpCta8PipeSplitKv<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
partial_acc, partial_m, partial_l, num_splits, total_q_tokens,
q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride);
break;
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
// Combine: one warp per (token, head).
const dim3 block2(32);
const dim3 grid2(static_cast<uint32_t>(num_heads), static_cast<uint32_t>(total_q_tokens), 1);
switch (head_size) {
case 64:
PagedAttentionPrefillHd64SplitKvCombine<Tdata>
<<<grid2, block2, 0, stream>>>(
out, partial_acc, partial_m, partial_l, num_splits, total_q_tokens, o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
case 128:
PagedAttentionPrefillHd128SplitKvCombine<Tdata>
<<<grid2, block2, 0, stream>>>(
out, partial_acc, partial_m, partial_l, num_splits, total_q_tokens, o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill_warpcta8n128(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
cudaStream_t stream) {
constexpr int kWarps = 8;
constexpr int kThreads = kWarps * 32;
const dim3 block(kThreads);
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(num_seqs),
static_cast<uint32_t>(ceilDiv(total_q_tokens, static_cast<size_t>(kWarps))));
// Only meaningful for head_dim=128.
if (head_size != 128) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
PagedAttentionPrefillHd128WarpCta8N128<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
}
template <typename Tindex, typename Tdata>
infiniStatus_t launch_prefill_warpcta16(
Tdata *out,
const Tdata *q,
const Tdata *k_cache,
const Tdata *v_cache,
const Tindex *block_tables,
const int64_t *total_kv_lens,
const int64_t *cu_seqlens_q,
const float *alibi_slopes,
size_t num_heads,
size_t num_seqs,
size_t num_kv_heads,
size_t total_q_tokens,
size_t head_size,
float scale,
size_t max_num_blocks_per_seq,
size_t page_block_size,
ptrdiff_t block_table_batch_stride,
ptrdiff_t q_stride,
ptrdiff_t q_head_stride,
ptrdiff_t k_batch_stride,
ptrdiff_t k_row_stride,
ptrdiff_t k_head_stride,
ptrdiff_t v_batch_stride,
ptrdiff_t v_row_stride,
ptrdiff_t v_head_stride,
ptrdiff_t o_stride,
ptrdiff_t o_head_stride,
cudaStream_t stream) {
constexpr int kWarps = 16;
constexpr int kThreads = kWarps * 32;
const dim3 block(kThreads);
const dim3 grid(static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(num_seqs),
static_cast<uint32_t>(ceilDiv(total_q_tokens, static_cast<size_t>(kWarps))));
switch (head_size) {
case 64:
PagedAttentionPrefillHd64WarpCta16<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
case 128:
PagedAttentionPrefillHd128WarpCta16<Tindex, Tdata>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes,
num_kv_heads, scale, max_num_blocks_per_seq, page_block_size,
block_table_batch_stride,
q_stride, q_head_stride,
k_batch_stride, k_row_stride, k_head_stride,
v_batch_stride, v_row_stride, v_head_stride,
o_stride, o_head_stride);
return INFINI_STATUS_SUCCESS;
default:
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
} // namespace
struct Descriptor::Opaque {
std::shared_ptr<device::nvidia::Handle::Internal> internal;
......@@ -68,22 +1249,87 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t seq_lens_desc,
infiniopTensorDescriptor_t cum_seq_lens_q_desc,
infiniopTensorDescriptor_t total_kv_lens_desc,
infiniopTensorDescriptor_t cum_seqlens_q_desc,
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc,
float scale) {
auto info = PagedAttentionPrefillInfo::create(
out_desc, q_desc, k_cache_desc, v_cache_desc,
block_tables_desc, seq_lens_desc,
cum_seq_lens_q_desc,
block_tables_desc, total_kv_lens_desc, cum_seqlens_q_desc,
alibi_slopes_desc, scale);
CHECK_RESULT(info);
// Optional split-kv prefill requires workspace for partial (m, l, acc).
// IMPORTANT: Unlike decode, prefill's total_q_tokens can be very large, so we must NOT reserve
// a huge workspace unless the user explicitly enables split-kv.
bool use_splitkv = false;
if (const char *env = std::getenv("INFINIOP_FLASH_PREFILL_SPLITKV")) {
use_splitkv = (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0);
}
int num_splits = 1;
if (use_splitkv) {
if (const char *env = std::getenv("INFINIOP_FLASH_PREFILL_NUM_SPLITS")) {
const int v = std::atoi(env);
if (v > 0) {
num_splits = v;
}
} else {
num_splits = 4;
}
constexpr int kMaxSplits = 8;
if (num_splits > kMaxSplits) {
num_splits = kMaxSplits;
}
}
const size_t n = info->total_q_tokens * info->num_heads;
const size_t splitkv_workspace_bytes = use_splitkv ? (static_cast<size_t>(num_splits) * n * (info->head_size + 2) * sizeof(float)) : 0;
// FA2-style kernel needs a workspace scratch for:
// - converting block_tables + total_kv_lens to int32
// - storing softmax LSE (only required to satisfy the upstream kernel contract)
// bool want_fa2 = false;
// if (const char *k_env = std::getenv("INFINIOP_FLASH_PREFILL_KERNEL")) {
// want_fa2 = (std::strcmp(k_env, "fa2") == 0);
// }
// bool fa2_materialize_kv = false;
// if (const char *env = std::getenv("INFINIOP_FA2_MATERIALIZE_PAGED_KV")) {
// fa2_materialize_kv = (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0);
// }
// size_t fa2_workspace_bytes = 0;
// // FA2 prefill supports both fp16 and bf16 inputs (head_dim=128, block_size=256).
// // Workspace sizing is identical since both are 16-bit element types.
// if (want_fa2 && (info->dtype == INFINI_DTYPE_F16 || info->dtype == INFINI_DTYPE_BF16) &&
// info->head_size == 128 && info->page_block_size == 256) {
// const size_t bt_bytes = info->num_seqs * info->max_num_blocks_per_seq * sizeof(int);
// const size_t len_bytes = info->num_seqs * sizeof(int);
// const size_t cuq_bytes = (info->num_seqs + 1) * sizeof(int);
// const size_t cuk_bytes = (info->num_seqs + 1) * sizeof(int);
// const size_t lse_bytes = info->num_heads * info->total_q_tokens * sizeof(float);
// // Add a small alignment slack since we sub-allocate with alignment.
// fa2_workspace_bytes = bt_bytes + len_bytes + cuq_bytes + cuk_bytes + lse_bytes + 64;
// // Optional: materialize paged KV into the FA2-friendly physical layout
// // [num_blocks, page_block_size, kv_heads, head_dim] (token-major) to avoid
// // extremely strided loads when the framework stores KV as
// // [num_blocks, kv_heads, page_block_size, head_dim] (head-major).
// if (fa2_materialize_kv) {
// // Materialize per-seq contiguous KV in *sequence order*:
// // [num_seqs, max_num_blocks_per_seq * page_block_size, kv_heads, head_dim].
// const size_t kv_elems =
// info->num_seqs * info->max_num_blocks_per_seq * info->page_block_size * info->num_kv_heads * info->head_size;
// const size_t kv_bytes = kv_elems * sizeof(uint16_t); // 16-bit (fp16/bf16)
// // K + V + alignment slack
// fa2_workspace_bytes += 2 * kv_bytes + 64;
// }
// }
const size_t workspace_bytes = splitkv_workspace_bytes;
// const size_t workspace_bytes = splitkv_workspace_bytes + fa2_workspace_bytes;
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::nvidia::Handle *>(handle)->internal()},
info.take(), 0, handle->device, handle->device_id);
info.take(), workspace_bytes, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
......@@ -92,35 +1338,379 @@ infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
const void *block_tables,
const void *seq_lens,
const void *cum_seq_lens_q,
const void *total_kv_lens,
const void *cum_seqlens_q,
const void *alibi_slopes,
void *stream_) const {
auto stream = static_cast<cudaStream_t>(stream_);
cudaStream_t stream = (cudaStream_t)stream_;
#define LAUNCH_KERNEL(Tdata, Tcompute) \
launchPagedAttentionPrefill<Tdata, Tcompute>( \
(Tdata *)out, (const Tdata *)q, (const Tdata *)k_cache, (const Tdata *)v_cache, \
(const int64_t *)block_tables, (const int64_t *)seq_lens, (const int64_t *)cum_seq_lens_q, \
(const float *)alibi_slopes, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, \
_info.scale, _info.max_num_blocks_per_seq, \
_info.block_size, _info.total_q_tokens, \
_info.head_size, \
_info.kv_block_stride, _info.kv_head_stride, \
const float *alibi_ptr = (alibi_slopes == nullptr) ? nullptr : static_cast<const float *>(alibi_slopes);
const auto *total_kv_lens_i64 = static_cast<const int64_t *>(total_kv_lens);
const auto *cu_seqlens_q_i64 = static_cast<const int64_t *>(cum_seqlens_q);
bool use_splitkv = false;
if (const char *env = std::getenv("INFINIOP_FLASH_PREFILL_SPLITKV")) {
use_splitkv = (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0);
}
int num_splits = 1;
if (use_splitkv) {
if (const char *env = std::getenv("INFINIOP_FLASH_PREFILL_NUM_SPLITS")) {
const int v = std::atoi(env);
if (v > 0) {
num_splits = v;
}
} else {
// Conservative default; users can override.
num_splits = 4;
}
constexpr int kMaxSplits = 8;
if (num_splits > kMaxSplits) {
num_splits = kMaxSplits;
}
const size_t n = _info.total_q_tokens * _info.num_heads;
const size_t required = static_cast<size_t>(num_splits) * n * (_info.head_size + 2) * sizeof(float);
if (workspace_size < required) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
}
if (use_splitkv) {
const size_t n = _info.total_q_tokens * _info.num_heads;
float *partial_acc = static_cast<float *>(workspace);
float *partial_m = partial_acc + static_cast<size_t>(num_splits) * n * _info.head_size;
float *partial_l = partial_m + static_cast<size_t>(num_splits) * n;
// Dispatch by (Tdata, Tindex). total_kv_lens + cu_seqlens_q are currently always int64.
#define DISPATCH_SPLITKV(Tindex, Tdata, BT_PTR) \
return launch_prefill_warpcta8pipe_splitkv<Tindex, Tdata>( \
partial_acc, partial_m, partial_l, num_splits, \
static_cast<Tdata *>(out), \
static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), \
static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(BT_PTR), \
total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
stream)
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream)
if (_info.dtype == INFINI_DTYPE_F16) {
return LAUNCH_KERNEL(half, float);
} else if (_info.dtype == INFINI_DTYPE_BF16) {
return LAUNCH_KERNEL(__nv_bfloat16, float);
} else if (_info.dtype == INFINI_DTYPE_F32) {
return LAUNCH_KERNEL(float, float);
if (_info.index_dtype == INFINI_DTYPE_I64) {
DISPATCH_SPLITKV(int64_t, half, block_tables);
}
if (_info.index_dtype == INFINI_DTYPE_I32) {
DISPATCH_SPLITKV(int32_t, half, block_tables);
}
if (_info.index_dtype == INFINI_DTYPE_U32) {
DISPATCH_SPLITKV(uint32_t, half, block_tables);
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (_info.dtype == INFINI_DTYPE_BF16) {
if (_info.index_dtype == INFINI_DTYPE_I64) {
DISPATCH_SPLITKV(int64_t, __nv_bfloat16, block_tables);
}
if (_info.index_dtype == INFINI_DTYPE_I32) {
DISPATCH_SPLITKV(int32_t, __nv_bfloat16, block_tables);
}
if (_info.index_dtype == INFINI_DTYPE_U32) {
DISPATCH_SPLITKV(uint32_t, __nv_bfloat16, block_tables);
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
#undef DISPATCH_SPLITKV
}
// Default to the fastest validated kernel for supported shapes.
// "ref" is still available for debugging/correctness bisecting.
#define DISPATCH_KERNEL(Tindex, Tdata, Tcompute) \
do { \
const char *k_env = std::getenv("INFINIOP_FLASH_PREFILL_KERNEL"); \
const char *k = (k_env == nullptr) ? default_prefill_kernel(_info) : k_env; \
if (k_env != nullptr) { \
const bool known = (std::strcmp(k, "warp") == 0) || (std::strcmp(k, "warpcta") == 0) || (std::strcmp(k, "warpcta8") == 0) || (std::strcmp(k, "warpcta8pipe") == 0) || (std::strcmp(k, "warpcta8mma") == 0) || (std::strcmp(k, "warpcta8n128") == 0) || (std::strcmp(k, "warpcta16") == 0) || (std::strcmp(k, "ref") == 0); \
if (!known) { \
const char *fallback = default_prefill_kernel(_info); \
std::fprintf(stderr, \
"[infiniop][paged_attention_prefill] WARNING: unknown kernel '%s', falling back to '%s'\n", \
k, fallback); \
k = fallback; \
} \
} \
const char *dbg = std::getenv("INFINIOP_DEBUG_PREFILL_DISPATCH"); \
static bool printed_dispatch = false; \
if (!printed_dispatch && dbg != nullptr && std::strcmp(dbg, "1") == 0) { \
std::fprintf(stderr, \
"[infiniop][paged_attention_prefill] kernel=%s (override=%s head_size=%zu block=%zu dtype=%zu)\n", \
k, \
(k_env == nullptr ? "auto" : "env"), \
static_cast<size_t>(_info.head_size), \
static_cast<size_t>(_info.page_block_size), \
static_cast<size_t>(_info.dtype)); \
printed_dispatch = true; \
} \
if (std::strcmp(k, "warp") == 0) { \
return launch_prefill_warp<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if (std::strcmp(k, "warpcta") == 0) { \
return launch_prefill<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if (std::strcmp(k, "warpcta8") == 0) { \
return launch_prefill_warpcta8<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if (std::strcmp(k, "warpcta8pipe") == 0) { \
return launch_prefill_warpcta8pipe<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if constexpr (std::is_same_v<Tdata, half>) { \
if (std::strcmp(k, "warpcta8mma") == 0) { \
return launch_prefill_warpcta8mma<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
} \
if (std::strcmp(k, "warpcta8n128") == 0) { \
return launch_prefill_warpcta8n128<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if (std::strcmp(k, "warpcta16") == 0) { \
return launch_prefill_warpcta16<Tindex, Tdata>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
if (std::strcmp(k, "ref") == 0) { \
return launch_prefill_ref<Tindex, Tdata, Tcompute>( \
static_cast<Tdata *>(out), static_cast<const Tdata *>(q), \
static_cast<const Tdata *>(k_cache), static_cast<const Tdata *>(v_cache), \
static_cast<const Tindex *>(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \
_info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \
_info.block_table_batch_stride, \
_info.q_stride, _info.q_head_stride, \
_info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \
_info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \
_info.o_stride, _info.o_head_stride, stream); \
} \
return INFINI_STATUS_BAD_PARAM; \
} while (false)
#define DISPATCH_INDEX(Tindex) \
do { \
if (_info.dtype == INFINI_DTYPE_F16) { \
DISPATCH_KERNEL(Tindex, half, float); \
} \
if (_info.dtype == INFINI_DTYPE_BF16) { \
DISPATCH_KERNEL(Tindex, __nv_bfloat16, float); \
} \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
} while (false)
if (_info.index_dtype == INFINI_DTYPE_I64) {
DISPATCH_INDEX(int64_t);
} else if (_info.index_dtype == INFINI_DTYPE_I32) {
DISPATCH_INDEX(int32_t);
} else if (_info.index_dtype == INFINI_DTYPE_U32) {
DISPATCH_INDEX(uint32_t);
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} // namespace op::paged_attention_prefill::nvidia
// #include <cuda_fp16.h>
// #include <float.h>
// #include <math.h>
// #include <stdint.h>
// #include "../../../devices/nvidia/nvidia_common.cuh"
// #include "../../../devices/nvidia/nvidia_kernel_common.cuh"
// #include "../cuda/kernel.cuh"
// #include "paged_attention_prefill_nvidia.cuh"
// template <typename Tdata, typename Tcompute>
// infiniStatus_t launchPagedAttentionPrefill(
// Tdata *out, const Tdata *q, const Tdata *k_cache, const Tdata *v_cache,
// const int64_t *block_tables,
// const int64_t *seq_lens,
// const int64_t *cum_seq_lens_q,
// const float *alibi_slopes,
// const size_t num_heads,
// const size_t num_seqs,
// const size_t num_kv_heads,
// const float scale,
// const size_t max_num_blocks_per_seq,
// const size_t block_size,
// const size_t total_q_tokens,
// const size_t head_size,
// const ptrdiff_t kv_block_stride,
// const ptrdiff_t kv_head_stride,
// const ptrdiff_t q_stride,
// const ptrdiff_t q_head_stride,
// cudaStream_t stream) {
// if (total_q_tokens == 0 || num_heads == 0) {
// return INFINI_STATUS_BAD_TENSOR_SHAPE;
// }
// dim3 grid(total_q_tokens, num_heads);
// dim3 block(head_size);
// op::paged_attention_prefill::cuda::pagedAttentionPrefillKernel<Tdata, Tcompute>
// <<<grid, block, 0, stream>>>(
// out, q, k_cache, v_cache,
// block_tables, seq_lens, cum_seq_lens_q, alibi_slopes,
// num_heads, num_kv_heads, scale,
// max_num_blocks_per_seq, block_size,
// kv_block_stride, kv_head_stride,
// q_stride, q_head_stride,
// head_size,
// num_seqs);
// return INFINI_STATUS_SUCCESS;
// }
// namespace op::paged_attention_prefill::nvidia {
// struct Descriptor::Opaque {
// std::shared_ptr<device::nvidia::Handle::Internal> internal;
// };
// Descriptor::~Descriptor() {
// delete _opaque;
// }
// infiniStatus_t Descriptor::create(
// infiniopHandle_t handle,
// Descriptor **desc_ptr,
// infiniopTensorDescriptor_t out_desc,
// infiniopTensorDescriptor_t q_desc,
// infiniopTensorDescriptor_t k_cache_desc,
// infiniopTensorDescriptor_t v_cache_desc,
// infiniopTensorDescriptor_t block_tables_desc,
// infiniopTensorDescriptor_t seq_lens_desc,
// infiniopTensorDescriptor_t cum_seq_lens_q_desc,
// const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc,
// float scale) {
// auto info = PagedAttentionPrefillInfo::create(
// out_desc, q_desc, k_cache_desc, v_cache_desc,
// block_tables_desc, seq_lens_desc,
// cum_seq_lens_q_desc,
// alibi_slopes_desc, scale);
// CHECK_RESULT(info);
// *desc_ptr = new Descriptor(
// new Opaque{reinterpret_cast<device::nvidia::Handle *>(handle)->internal()},
// info.take(), 0, handle->device, handle->device_id);
// return INFINI_STATUS_SUCCESS;
// }
// infiniStatus_t Descriptor::calculate(
// void *workspace, size_t workspace_size,
// void *out, const void *q, const void *k_cache, const void *v_cache,
// const void *block_tables,
// const void *seq_lens,
// const void *cum_seq_lens_q,
// const void *alibi_slopes,
// void *stream_) const {
// cudaStream_t stream = (cudaStream_t)stream_;
// #define LAUNCH_KERNEL(Tdata, Tcompute) \
// launchPagedAttentionPrefill<Tdata, Tcompute>( \
// (Tdata *)out, (const Tdata *)q, (const Tdata *)k_cache, (const Tdata *)v_cache, \
// (const int64_t *)block_tables, (const int64_t *)seq_lens, (const int64_t *)cum_seq_lens_q, \
// (const float *)alibi_slopes, \
// _info.num_heads, _info.num_seqs, _info.num_kv_heads, \
// _info.scale, _info.max_num_blocks_per_seq, \
// _info.block_size, _info.total_q_tokens, \
// _info.head_size, \
// _info.kv_block_stride, _info.kv_head_stride, \
// _info.q_stride, _info.q_head_stride, \
// stream)
// if (_info.dtype == INFINI_DTYPE_F16) {
// return LAUNCH_KERNEL(half, float);
// } else if (_info.dtype == INFINI_DTYPE_BF16) {
// return LAUNCH_KERNEL(__nv_bfloat16, float);
// } else if (_info.dtype == INFINI_DTYPE_F32) {
// return LAUNCH_KERNEL(float, float);
// }
// return INFINI_STATUS_BAD_TENSOR_DTYPE;
// }
// } // namespace op::paged_attention_prefill::nvidia
......@@ -100,13 +100,12 @@ _TEST_CASES_ = [
]
# Data types for testing
_TENSOR_DTYPES = [InfiniDtype.BF16, InfiniDtype.F16, InfiniDtype.F32]
_TENSOR_DTYPES = [InfiniDtype.BF16, InfiniDtype.F16]
# Tolerance map for different data types
_TOLERANCE_MAP = {
InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-2},
InfiniDtype.BF16: {"atol": 5e-3, "rtol": 5e-2},
InfiniDtype.F32: {"atol": 1e-5, "rtol": 1e-5},
}
# Global flags for controlling test behavior
......
......@@ -32,10 +32,9 @@ _TEST_CASES = [
(16, 128, 128, 128, 8, 16, 4),
]
_TENSOR_DTYPES = [InfiniDtype.F32, InfiniDtype.BF16, InfiniDtype.F16]
_TENSOR_DTYPES = [InfiniDtype.BF16, InfiniDtype.F16]
_TOLERANCE_MAP = {
InfiniDtype.F32: {"atol": 1e-5, "rtol": 1e-5},
InfiniDtype.F16: {"atol": 1e-2, "rtol": 1e-2},
InfiniDtype.BF16: {"atol": 2e-2, "rtol": 2e-2},
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment