Commit 1ba0bcfa authored by zhushuang's avatar zhushuang
Browse files

issue/848 - feat: add paged attention prefill for nvidia gpu with test pass

parent 298feac2
......@@ -15,6 +15,9 @@
#include "infiniop/ops/lp_norm.h"
#include "infiniop/ops/mul.h"
#include "infiniop/ops/ones.h"
#include "infiniop/ops/paged_attention.h"
#include "infiniop/ops/paged_attention_prefill.h"
#include "infiniop/ops/paged_caching.h"
#include "infiniop/ops/random_sample.h"
#include "infiniop/ops/rearrange.h"
#include "infiniop/ops/relu.h"
......@@ -31,7 +34,5 @@
#include "infiniop/ops/topksoftmax.h"
#include "infiniop/ops/zeros.h"
#include "infiniop/tensor_descriptor.h"
#include "infiniop/ops/paged_attention.h"
#include "infiniop/ops/paged_caching.h"
#endif // __INFINIOP_API_H__
#ifndef __INFINIOP_PAGED_ATTENTION_PREFILL_API_H__
#define __INFINIOP_PAGED_ATTENTION_PREFILL_API_H__
#include "../operator_descriptor.h"
// Define an opaque handle for the Paged Attention Prefill descriptor.
typedef struct InfiniopDescriptor *infiniopPagedAttentionPrefillDescriptor_t;
/**
* @brief Creates a descriptor for the Paged Attention Prefill operation.
* @param handle The handle to the InfiniOP library context.
* @param desc_ptr A pointer to store the created descriptor.
* @param out_desc Descriptor for the output tensor.
* @param q_desc Descriptor for the query tensor (packed/flattened).
* @param k_cache_desc Descriptor for the global physical key cache.
* @param v_cache_desc Descriptor for the global physical value cache.
* @param block_tables_desc Descriptor for the block tables mapping logic to physical blocks.
* @param cache_lens_desc Descriptor for the total sequence lengths (history + current).
* @param seq_lens_desc Descriptor for the current prefill sequence lengths.
* @param offset_desc Descriptor for the start position of each sequence in the packed Q tensor.
* @param alibi_slopes_desc Optional descriptor for the ALiBi slopes tensor. Can be NULL.
* @param scale The attention scaling factor.
* @return infiniStatus_t Status code of the operation.
*/
__C __export infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor(
infiniopHandle_t handle,
infiniopPagedAttentionPrefillDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t cache_lens_desc,
infiniopTensorDescriptor_t seq_lens_desc,
infiniopTensorDescriptor_t offset_desc,
infiniopTensorDescriptor_t alibi_slopes_desc,
float scale);
/**
* @brief Retrieves the workspace size required for the Paged Attention Prefill operation.
*/
__C __export infiniStatus_t infiniopGetPagedAttentionPrefillWorkspaceSize(
infiniopPagedAttentionPrefillDescriptor_t desc, size_t *size);
/**
* @brief Executes the Paged Attention Prefill operation.
* @param desc The Paged Attention Prefill descriptor.
* @param workspace Pointer to the workspace memory.
* @param workspace_size The size of the workspace.
* @param out Pointer to the output tensor data.
* @param q Pointer to the query tensor data (packed).
* @param k_cache Pointer to the global key cache data.
* @param v_cache Pointer to the global value cache data.
* @param block_tables Pointer to the block tables data.
* @param cache_lens Pointer to the total sequence lengths data.
* @param seq_lens Pointer to the current prefill sequence lengths data.
* @param offset Pointer to the sequence start offsets data.
* @param alibi_slopes Pointer to the ALiBi slopes data. Can be NULL.
* @param stream The CUDA/device stream for the operation.
* @return infiniStatus_t Status code of the operation.
*/
__C __export infiniStatus_t infiniopPagedAttentionPrefill(
infiniopPagedAttentionPrefillDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *out,
const void *q,
const void *k_cache,
const void *v_cache,
const void *block_tables,
const void *cache_lens,
const void *seq_lens,
const void *offset,
const void *alibi_slopes,
void *stream);
/**
* @brief Destroys a Paged Attention Prefill descriptor.
*/
__C __export infiniStatus_t infiniopDestroyPagedAttentionPrefillDescriptor(
infiniopPagedAttentionPrefillDescriptor_t desc);
#endif // __INFINIOP_PAGED_ATTENTION_PREFILL_API_H__
#ifndef __PAGED_ATTENTION_PREFILL_KERNEL_CUH__
#define __PAGED_ATTENTION_PREFILL_KERNEL_CUH__
namespace op::paged_attention_prefill::cuda {
// 辅助函数:二分查找确定当前 global_token_idx 属于哪个 sequence
__device__ __forceinline__ int find_seq_id(int token_idx, const int64_t *offset, int num_seqs) {
int low = 0, high = num_seqs - 1;
while (low <= high) {
int mid = (low + high) >> 1;
if (token_idx >= offset[mid] && token_idx < offset[mid + 1]) {
return mid;
} else if (token_idx < offset[mid]) {
high = mid - 1;
} else {
low = mid + 1;
}
}
return 0;
}
template <typename Tdata, typename Tcompute>
__global__ void pagedAttentionPrefillKernel(
Tdata *out_, const Tdata *q_, const Tdata *k_cache_, const Tdata *v_cache_,
const int64_t *block_tables_, const int64_t *cache_lens_, const int64_t *seq_lens_,
const float *alibi_slopes_,
const size_t num_heads, const size_t num_kv_heads, const float scale,
const size_t max_num_blocks_per_seq, const size_t block_size,
const ptrdiff_t kv_block_stride, const ptrdiff_t kv_head_stride,
const size_t head_size,
const int64_t *offset_,
const size_t num_seqs) {
// --- 使用 2D Grid 坐标 ---
const int global_token_idx = blockIdx.x; // 展平后的全局 token 索引
const int head_idx = blockIdx.y; // Head 索引
const int dim_idx = threadIdx.x; // Head 内部维度
if (dim_idx >= head_size) {
return;
}
// --- 通过二分查找 offset 找到所属的 seq_idx ---
int seq_idx = find_seq_id(global_token_idx, offset_, num_seqs);
// --- 获取该 Sequence 本次 Prefill 的长度
const int64_t cur_new_len = seq_lens_[seq_idx];
// --- 该 token 在当前序列中的相对位置
int q_token_idx = global_token_idx - offset_[seq_idx];
const Tdata *q_ptr_base = q_ + global_token_idx * num_heads * head_size + head_idx * head_size;
Tdata *out_ptr = out_ + global_token_idx * num_heads * head_size + head_idx * head_size;
// --- KV Cache 相关信息
const int64_t total_seq_len = cache_lens_[seq_idx];
const int64_t history_len = total_seq_len - cur_new_len;
const int64_t causal_limit = history_len + q_token_idx;
const size_t num_queries_per_kv = num_heads / num_kv_heads;
const size_t kv_head_idx = head_idx / num_queries_per_kv;
const int64_t *block_table = block_tables_ + seq_idx * max_num_blocks_per_seq;
const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx];
// Pass 1: 计算 Score 并找最大值
Tcompute max_score = -FLT_MAX;
for (int t = 0; t <= causal_limit; ++t) {
const int64_t b_idx = t / block_size;
const int64_t t_off = t % block_size;
const int64_t physical_block_id = block_table[b_idx];
const Tdata *k_vec = k_cache_ + physical_block_id * kv_block_stride + kv_head_idx * kv_head_stride + t_off * head_size;
Tcompute score = 0.0f;
for (int d = 0; d < head_size; ++d) {
score += static_cast<Tcompute>(q_ptr_base[d]) * static_cast<Tcompute>(k_vec[d]);
}
score *= static_cast<Tcompute>(scale);
if (alibi_slope != 0.0f) {
score += alibi_slope * static_cast<float>(t - causal_limit);
}
if (score > max_score) {
max_score = score;
}
}
// Pass 2: 计算 Sum of Exp
Tcompute sum_exp = 0.0f;
for (int t = 0; t <= causal_limit; ++t) {
const int64_t b_idx = t / block_size;
const int64_t t_off = t % block_size;
const int64_t physical_block_id = block_table[b_idx];
const Tdata *k_vec = k_cache_ + physical_block_id * kv_block_stride + kv_head_idx * kv_head_stride + t_off * head_size;
Tcompute score = 0.0f;
for (int d = 0; d < head_size; ++d) {
score += static_cast<Tcompute>(q_ptr_base[d]) * static_cast<Tcompute>(k_vec[d]);
}
score *= static_cast<Tcompute>(scale);
if (alibi_slope != 0.0f) {
score += alibi_slope * static_cast<float>(t - causal_limit);
}
sum_exp += expf(static_cast<float>(score - max_score));
}
// Pass 3: 加权求和得到输出
Tcompute acc = 0.0f;
Tcompute inv_sum = 1.0f / (sum_exp + 1e-6f);
for (int t = 0; t <= causal_limit; ++t) {
const int64_t b_idx = t / block_size;
const int64_t t_off = t % block_size;
const int64_t physical_block_id = block_table[b_idx];
const Tdata *k_vec = k_cache_ + physical_block_id * kv_block_stride + kv_head_idx * kv_head_stride + t_off * head_size;
Tcompute score = 0.0f;
for (int d = 0; d < head_size; ++d) {
score += static_cast<Tcompute>(q_ptr_base[d]) * static_cast<Tcompute>(k_vec[d]);
}
score *= static_cast<Tcompute>(scale);
if (alibi_slope != 0.0f) {
score += alibi_slope * static_cast<float>(t - causal_limit);
}
Tcompute prob = expf(static_cast<float>(score - max_score)) * inv_sum;
const Tdata *v_vec = v_cache_ + physical_block_id * kv_block_stride + kv_head_idx * kv_head_stride + t_off * head_size;
acc += prob * static_cast<Tcompute>(v_vec[dim_idx]);
}
out_ptr[dim_idx] = static_cast<Tdata>(acc);
}
} // namespace op::paged_attention_prefill::cuda
#endif
#ifndef __PAGED_ATTENTION_PREFILL_INFO_H__
#define __PAGED_ATTENTION_PREFILL_INFO_H__
#include "../../../utils.h"
#include "../../tensor.h"
#include <iostream>
#include <optional>
#include <vector>
namespace op::paged_attention_prefill {
class PagedAttentionPrefillInfo {
PagedAttentionPrefillInfo() = default;
public:
infiniDtype_t dtype;
float scale;
size_t num_seqs;
size_t num_heads;
size_t num_kv_heads;
size_t head_size;
size_t block_size;
size_t max_num_blocks_per_seq;
size_t total_q_tokens;
ptrdiff_t q_stride;
ptrdiff_t kv_block_stride;
ptrdiff_t kv_head_stride;
ptrdiff_t o_stride;
static utils::Result<PagedAttentionPrefillInfo> create(
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t cache_lens_desc,
infiniopTensorDescriptor_t seq_lens_desc,
infiniopTensorDescriptor_t offset_desc,
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc,
float scale) {
auto dtype = q_desc->dtype();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32);
if (out_desc->dtype() != dtype || k_cache_desc->dtype() != dtype || v_cache_desc->dtype() != dtype) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (offset_desc->dtype() != INFINI_DTYPE_I64 || seq_lens_desc->dtype() != INFINI_DTYPE_I64) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (alibi_slopes_desc.has_value() && alibi_slopes_desc.value() != nullptr) {
std::cerr << "[Error] PagedAttentionPrefill: ALiBi slopes are not supported yet." << std::endl;
return INFINI_STATUS_BAD_PARAM;
}
// Q shape: [total_tokens, heads, dim] (3D)
auto q_shape = q_desc->shape();
if (q_shape.size() < 3) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
size_t total_q_tokens = q_shape[0];
size_t num_heads = q_shape[q_shape.size() - 2];
size_t head_size = q_shape[q_shape.size() - 1];
if (head_size != 128) {
std::cerr << "[Error] PagedAttentionPrefill head_size = 128 supported, got " << head_size << std::endl;
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
// 从 seq_lens 获取 num_seqs
size_t num_seqs = seq_lens_desc->shape()[0];
auto k_cache_shape = k_cache_desc->shape();
size_t num_kv_heads = k_cache_shape[1];
size_t block_size = v_cache_desc->shape()[2];
size_t max_num_blocks_per_seq = block_tables_desc->shape()[1];
// 提取步长,需要保持多个请求的 Q 连续
ptrdiff_t q_stride = q_desc->stride(0);
ptrdiff_t kv_block_stride = k_cache_desc->stride(0);
ptrdiff_t kv_head_stride = k_cache_desc->stride(1);
ptrdiff_t o_stride = out_desc->stride(0);
return utils::Result<PagedAttentionPrefillInfo>(PagedAttentionPrefillInfo{
dtype,
scale,
num_seqs,
num_heads,
num_kv_heads,
head_size,
block_size,
max_num_blocks_per_seq,
total_q_tokens,
q_stride,
kv_block_stride,
kv_head_stride,
o_stride});
}
};
} // namespace op::paged_attention_prefill
#endif
#include <cuda_fp16.h>
#include <float.h>
#include <math.h>
#include <stdint.h>
#include "../../../devices/nvidia/nvidia_common.cuh"
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
#include "../cuda/kernel.cuh"
#include "paged_attention_prefill_nvidia.cuh"
// ==============================================================================
// Host wrapper to launch the global kernel
// ==============================================================================
template <typename Tdata, typename Tcompute>
infiniStatus_t launchPagedAttentionPrefill(
Tdata *out, const Tdata *q, const Tdata *k_cache, const Tdata *v_cache,
const int64_t *block_tables, const int64_t *cache_lens, const int64_t *seq_lens,
const int64_t *offset,
const float *alibi_slopes,
const size_t num_heads,
const size_t num_seqs,
const size_t num_kv_heads,
const float scale,
const size_t max_num_blocks_per_seq,
const size_t block_size,
const size_t total_q_tokens,
const ptrdiff_t q_stride,
const ptrdiff_t kv_block_stride,
const ptrdiff_t kv_head_stride,
const ptrdiff_t o_stride,
const size_t head_size,
cudaStream_t stream) {
if (total_q_tokens == 0 || num_heads == 0) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
// 使用 2D Grid: X轴是所有 Token,Y轴是所有 Head
dim3 grid(total_q_tokens, num_heads);
dim3 block(head_size);
op::paged_attention_prefill::cuda::pagedAttentionPrefillKernel<Tdata, Tcompute>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache,
block_tables, cache_lens, seq_lens, alibi_slopes,
num_heads, num_kv_heads, scale,
max_num_blocks_per_seq, block_size,
kv_block_stride, kv_head_stride,
head_size,
offset, num_seqs);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
std::cerr << "CUDA Kernel Launch Failed: " << cudaGetErrorString(err) << std::endl;
return INFINI_STATUS_INTERNAL_ERROR;
}
return INFINI_STATUS_SUCCESS;
}
namespace op::paged_attention_prefill::nvidia {
struct Descriptor::Opaque {
std::shared_ptr<device::nvidia::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t cache_lens_desc,
infiniopTensorDescriptor_t seq_lens_desc,
infiniopTensorDescriptor_t offset_desc,
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc,
float scale) {
auto info = PagedAttentionPrefillInfo::create(out_desc, q_desc, k_cache_desc, v_cache_desc,
block_tables_desc, cache_lens_desc, seq_lens_desc,
offset_desc,
alibi_slopes_desc, scale);
CHECK_RESULT(info);
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::nvidia::Handle *>(handle)->internal()},
info.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
const void *block_tables, const void *cache_lens, const void *seq_lens,
const void *offset,
const void *alibi_slopes,
void *stream_) const {
cudaStream_t stream = (cudaStream_t)stream_;
if (_info.head_size > 1024) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
#define LAUNCH_KERNEL(Tdata, Tcompute) \
launchPagedAttentionPrefill<Tdata, Tcompute>( \
(Tdata *)out, (const Tdata *)q, (const Tdata *)k_cache, (const Tdata *)v_cache, \
(const int64_t *)block_tables, (const int64_t *)cache_lens, (const int64_t *)seq_lens, \
(const int64_t *)offset, \
(const float *)alibi_slopes, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, \
_info.scale, _info.max_num_blocks_per_seq, \
_info.block_size, _info.total_q_tokens, \
_info.q_stride, _info.kv_block_stride, _info.kv_head_stride, _info.o_stride, \
_info.head_size, \
stream)
if (_info.dtype == INFINI_DTYPE_F16) {
return LAUNCH_KERNEL(half, float);
} else if (_info.dtype == INFINI_DTYPE_BF16) {
return LAUNCH_KERNEL(__nv_bfloat16, float);
} else if (_info.dtype == INFINI_DTYPE_F32) {
return LAUNCH_KERNEL(float, float);
}
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} // namespace op::paged_attention_prefill::nvidia
#ifndef __PAGED_ATTENTION_PREFILL_NVIDIA_H__
#define __PAGED_ATTENTION_PREFILL_NVIDIA_H__
#include "../paged_attention_prefill.h"
DESCRIPTOR(nvidia)
#endif // __PAGED_ATTENTION_PREFILL_NVIDIA_H__
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/paged_attention_prefill.h"
#ifdef ENABLE_NVIDIA_API
#include "nvidia/paged_attention_prefill_nvidia.cuh"
#endif
__C infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor(
infiniopHandle_t handle,
infiniopPagedAttentionPrefillDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t cache_lens_desc,
infiniopTensorDescriptor_t seq_lens_desc,
infiniopTensorDescriptor_t offset_desc,
infiniopTensorDescriptor_t alibi_slopes_desc,
float scale) {
infiniopTensorDescriptor_t alibi_opt = (alibi_slopes_desc == nullptr) ? nullptr : alibi_slopes_desc;
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::paged_attention_prefill::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::paged_attention_prefill::NAMESPACE::Descriptor **>(desc_ptr), \
out_desc, q_desc, k_cache_desc, v_cache_desc, block_tables_desc, cache_lens_desc, \
seq_lens_desc, offset_desc, alibi_opt, scale);
switch (handle->device) {
#ifdef ENABLE_NVIDIA_API
CREATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif
}
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopGetPagedAttentionPrefillWorkspaceSize(
infiniopPagedAttentionPrefillDescriptor_t desc,
size_t *size) {
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<op::paged_attention_prefill::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) {
#ifdef ENABLE_NVIDIA_API
GET(INFINI_DEVICE_NVIDIA, nvidia)
#endif
}
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopPagedAttentionPrefill(
infiniopPagedAttentionPrefillDescriptor_t desc,
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
const void *block_tables, const void *cache_lens, const void *seq_lens,
const void *offset,
const void *alibi_slopes,
void *stream) {
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<op::paged_attention_prefill::NAMESPACE::Descriptor *>(desc)->calculate( \
workspace, workspace_size, out, q, k_cache, v_cache, block_tables, \
cache_lens, seq_lens, offset, alibi_slopes, stream);
switch (desc->device_type) {
#ifdef ENABLE_NVIDIA_API
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif
}
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopDestroyPagedAttentionPrefillDescriptor(
infiniopPagedAttentionPrefillDescriptor_t desc) {
#define DESTROY(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<op::paged_attention_prefill::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) {
#ifdef ENABLE_NVIDIA_API
DESTROY(INFINI_DEVICE_NVIDIA, nvidia)
#endif
}
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#ifndef PAGED_ATTENTION_PREFILL_H
#define PAGED_ATTENTION_PREFILL_H
#include "../../operator.h"
#include "info.h"
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::paged_attention_prefill::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
PagedAttentionPrefillInfo _info; \
size_t _workspace_size; \
\
Descriptor( \
Opaque *opaque, \
PagedAttentionPrefillInfo info, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t out_desc, \
infiniopTensorDescriptor_t q_desc, \
infiniopTensorDescriptor_t k_cache_desc, \
infiniopTensorDescriptor_t v_cache_desc, \
infiniopTensorDescriptor_t block_tables_desc, \
infiniopTensorDescriptor_t cache_lens_desc, \
infiniopTensorDescriptor_t seq_lens_desc, \
infiniopTensorDescriptor_t offset_desc, \
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc, \
float scale); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *out, const void *q, const void *k_cache, const void *v_cache, \
const void *block_tables, const void *cache_lens, const void *seq_lens, \
const void *offset, \
const void *alibi_slopes, \
void *stream) const; \
}; \
}
#endif // PAGED_ATTENTION_PREFILL_H
......@@ -939,6 +939,7 @@ def tanh_(lib):
infiniopOperatorDescriptor_t,
]
@OpRegister.operator
def scaled_mm_int8_(lib):
lib.infiniopCreateI8GemmDescriptor.restype = c_int32
......@@ -1061,3 +1062,50 @@ def paged_caching_(lib):
lib.infiniopDestroyPagedCachingDescriptor.argtypes = [
infiniopOperatorDescriptor_t,
]
@OpRegister.operator
def paged_attention_prefill_(lib):
lib.infiniopCreatePagedAttentionPrefillDescriptor.restype = c_int32
lib.infiniopCreatePagedAttentionPrefillDescriptor.argtypes = [
infiniopHandle_t,
POINTER(infiniopOperatorDescriptor_t),
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
c_float,
]
lib.infiniopGetPagedAttentionPrefillWorkspaceSize.restype = c_int32
lib.infiniopGetPagedAttentionPrefillWorkspaceSize.argtypes = [
infiniopOperatorDescriptor_t,
POINTER(c_size_t),
]
lib.infiniopPagedAttentionPrefill.restype = c_int32
lib.infiniopPagedAttentionPrefill.argtypes = [
infiniopOperatorDescriptor_t,
c_void_p,
c_size_t,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
]
lib.infiniopDestroyPagedAttentionPrefillDescriptor.restype = c_int32
lib.infiniopDestroyPagedAttentionPrefillDescriptor.argtypes = [
infiniopOperatorDescriptor_t,
]
import torch
import ctypes
from ctypes import c_uint64
from libinfiniop import (
LIBINFINIOP,
TestTensor,
get_test_devices,
check_error,
test_operator,
get_args,
debug,
get_tolerance,
profile_operation,
InfiniDtype,
InfiniDtypeNames,
InfiniDeviceNames,
infiniopOperatorDescriptor_t,
TestWorkspace,
)
# ==============================================================================
# Configuration (Internal Use Only)
# ==============================================================================
_TEST_CASES = [
# num_seqs, num_heads, num_kv_heads, head_size, block_size, max_step_len, num_rounds
(1, 1, 1, 128, 8, 16, 1),
(1, 4, 4, 128, 8, 16, 4),
(2, 8, 8, 128, 16, 32, 2),
(4, 16, 16, 128, 8, 64, 3),
(8, 64, 64, 128, 8, 16, 5),
(16, 128, 128, 128, 8, 16, 4),
]
_TENSOR_DTYPES = [InfiniDtype.F32, InfiniDtype.BF16, InfiniDtype.F16]
_TOLERANCE_MAP = {
InfiniDtype.F32: {"atol": 1e-5, "rtol": 1e-5},
InfiniDtype.F16: {"atol": 1e-2, "rtol": 1e-2},
InfiniDtype.BF16: {"atol": 2e-2, "rtol": 2e-2},
}
DEBUG = False
PROFILE = False
NUM_PRERUN = 5
NUM_ITERATIONS = 10
# ==============================================================================
# Helper Classes & Reference Implementation
# ==============================================================================
class SimpleCacheManager:
def __init__(self, num_blocks, block_size):
self.num_blocks = num_blocks
self.block_size = block_size
self.free_blocks = list(range(num_blocks))
self.request_to_blocks = {}
self.request_to_len = {}
def allocate_slots(self, request_id, num_new_tokens):
if request_id not in self.request_to_len:
self.request_to_len[request_id] = 0
self.request_to_blocks[request_id] = []
start_pos = self.request_to_len[request_id]
new_total_len = start_pos + num_new_tokens
needed_blocks = (new_total_len + self.block_size - 1) // self.block_size
added_blocks = needed_blocks - len(self.request_to_blocks[request_id])
for _ in range(added_blocks):
self.request_to_blocks[request_id].append(self.free_blocks.pop(0))
self.request_to_len[request_id] = new_total_len
return self.request_to_blocks[request_id], new_total_len
def ref_paged_attention_multi_turn(
query_new, k_cache, v_cache, block_tables, seq_lens, new_lens, offset, scale
):
block_size = k_cache.shape[2]
outputs = torch.zeros_like(query_new)
for i in range(len(offset) - 1):
total_len = seq_lens[i].item()
num_new = new_lens[i].item()
history_len = total_len - num_new
table = block_tables[i]
keys_all, values_all = [], []
for j in range(total_len):
b_id = table[j // block_size].item()
off = j % block_size
keys_all.append(k_cache[b_id, :, off, :])
values_all.append(v_cache[b_id, :, off, :])
K = torch.stack(keys_all, dim=0)
V = torch.stack(values_all, dim=0)
Q = query_new[offset[i] : offset[i] + num_new, :, :]
scores = torch.einsum("qhd,khd->hqk", Q, K).float() * scale
mask = torch.full((num_new, total_len), float("-inf"), device=Q.device)
for q_idx in range(num_new):
mask[q_idx, : history_len + q_idx + 1] = 0.0
scores = scores + mask.unsqueeze(0)
attn_weights = torch.softmax(scores, dim=-1).to(Q.dtype)
out = torch.einsum("hqk,khd->qhd", attn_weights, V)
outputs[offset[i] : offset[i] + num_new, :, :] = out
return outputs
# ==============================================================================
# Test Operator Implementation
# ==============================================================================
def test(
handle,
device,
num_seqs,
num_heads,
num_kv_heads,
head_size,
block_size,
max_step_len,
num_rounds,
dtype=InfiniDtype.F16,
sync=None,
):
print(
f"Testing PagedAttentionPrefill on {InfiniDeviceNames[device]} with "
f"seqs:{num_seqs}, heads:{num_heads}, head_size:{head_size}, "
f"block:{block_size}, max_step_len:{max_step_len}, num_rounds:{num_rounds}, dtype:{InfiniDtypeNames[dtype]}"
)
# 1. Initialize persistent resources
num_blocks = 8192
manager = SimpleCacheManager(num_blocks, block_size)
scale = head_size**-0.5
k_cache = TestTensor(
(num_blocks, num_kv_heads, block_size, head_size), None, dtype, device
)
v_cache = TestTensor(
(num_blocks, num_kv_heads, block_size, head_size), None, dtype, device
)
# Multi-turn testing loop
for r in range(num_rounds):
# Prepare dynamic inputs for this round
seq_lens_cpu = torch.randint(
1, max_step_len + 1, (num_seqs,), dtype=torch.int64
)
q_total_tokens = seq_lens_cpu.sum().item()
q_packed_tensors = torch.zeros(q_total_tokens, num_heads, head_size)
cache_lens_list = []
all_block_tables = []
offset_list = []
cur_offset = 0
for i in range(num_seqs):
offset_list.append(cur_offset)
cur_new_len = seq_lens_cpu[i].item()
table, cache_len = manager.allocate_slots(i, cur_new_len)
cache_lens_list.append(cache_len)
all_block_tables.append(table)
# Simulated KV insertion
k_new = torch.randn(cur_new_len, num_kv_heads, head_size)
v_new = torch.randn(cur_new_len, num_kv_heads, head_size)
q_val = torch.randn(cur_new_len, num_heads, head_size)
q_packed_tensors[cur_offset : cur_offset + cur_new_len] = q_val
cur_offset = cur_offset + cur_new_len
history_len = cache_len - cur_new_len
for t in range(cur_new_len):
logical_pos = history_len + t
b_id = table[logical_pos // block_size]
off = logical_pos % block_size
k_cache.torch_tensor()[b_id, :, off, :] = k_new[t]
v_cache.torch_tensor()[b_id, :, off, :] = v_new[t]
offset_list.append(cur_offset)
k_cache.actual_tensor().copy_(k_cache._torch_tensor)
v_cache.actual_tensor().copy_(v_cache._torch_tensor)
# 2. Wrap tensors for Infiniop
q_new = TestTensor.from_torch(q_packed_tensors, dtype, device)
out = TestTensor.from_torch(q_packed_tensors, dtype, device)
out.actual_tensor().zero_()
cache_lens = TestTensor.from_torch(
torch.tensor(cache_lens_list, dtype=torch.int64), InfiniDtype.I64, device
)
seq_lens = TestTensor.from_torch(seq_lens_cpu, InfiniDtype.I64, device)
offset = TestTensor.from_torch(
torch.tensor(offset_list, dtype=torch.int64), InfiniDtype.I64, device
)
max_blocks = max(len(t) for t in all_block_tables)
padded_tables = [t + [0] * (max_blocks - len(t)) for t in all_block_tables]
block_tables = TestTensor.from_torch(
torch.tensor(padded_tables, dtype=torch.int64), InfiniDtype.I64, device
)
# 3. Reference Calculation
def torch_paged_attention_multi_turn():
return ref_paged_attention_multi_turn(
q_new.torch_tensor(),
k_cache.torch_tensor(),
v_cache.torch_tensor(),
block_tables.torch_tensor(),
cache_lens.torch_tensor(),
seq_lens.torch_tensor(),
offset.torch_tensor(),
scale,
)
ans = torch_paged_attention_multi_turn()
# 4. Infiniop Operator Execution
descriptor = infiniopOperatorDescriptor_t()
check_error(
LIBINFINIOP.infiniopCreatePagedAttentionPrefillDescriptor(
handle,
ctypes.byref(descriptor),
out.descriptor,
q_new.descriptor,
k_cache.descriptor,
v_cache.descriptor,
block_tables.descriptor,
cache_lens.descriptor,
seq_lens.descriptor,
offset.descriptor,
None, # alibi_slopes_desc
scale,
)
)
workspace_size = c_uint64(0)
check_error(
LIBINFINIOP.infiniopGetPagedAttentionPrefillWorkspaceSize(
descriptor, ctypes.byref(workspace_size)
)
)
workspace = TestWorkspace(workspace_size.value, device)
def lib_attn():
check_error(
LIBINFINIOP.infiniopPagedAttentionPrefill(
descriptor,
workspace.data(),
workspace_size.value,
out.data(),
q_new.data(),
k_cache.data(),
v_cache.data(),
block_tables.data(),
cache_lens.data(),
seq_lens.data(),
offset.data(),
None,
None,
)
)
lib_attn()
if sync:
sync()
# 5. Validation
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG:
debug(out.actual_tensor(), ans, atol=atol, rtol=rtol)
assert torch.allclose(out.actual_tensor(), ans, atol=atol, rtol=rtol)
# Profiling
if PROFILE:
profile_operation(
f"Torch_R{r}",
lambda: torch_paged_attention_multi_turn(),
device,
NUM_PRERUN,
NUM_ITERATIONS,
)
profile_operation(
f" Lib_R{r}", lambda: lib_attn(), device, NUM_PRERUN, NUM_ITERATIONS
)
check_error(
LIBINFINIOP.infiniopDestroyPagedAttentionPrefillDescriptor(descriptor)
)
# ==============================================================================
# Main Execution
# ==============================================================================
if __name__ == "__main__":
args = get_args()
DEBUG = args.debug
PROFILE = args.profile
NUM_PRERUN = args.num_prerun
NUM_ITERATIONS = args.num_iterations
for device in get_test_devices(args):
test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES)
print("\033[92mTest passed!\033[0m")
import torch
import ctypes
from ctypes import c_uint64
from libinfiniop import (
LIBINFINIOP,
TestTensor,
get_test_devices,
check_error,
test_operator,
get_args,
debug,
get_tolerance,
profile_operation,
TestWorkspace,
InfiniDtype,
InfiniDtypeNames,
InfiniDeviceNames,
infiniopOperatorDescriptor_t,
)
# ==============================================================================
# Configuration (Internal Use Only)
# ==============================================================================
_TEST_CASES = [
# num_seqs, max_step_len, num_kv_heads, head_size, block_size, num_rounds
(1, 16, 1, 128, 8, 5),
(2, 64, 8, 128, 16, 2),
(8, 128, 32, 128, 16, 3),
(5, 512, 40, 128, 16, 3),
(16, 64, 8, 128, 32, 1),
(10, 256, 40, 128, 32, 3),
]
_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16, InfiniDtype.F32]
_TOLERANCE_MAP = {
InfiniDtype.F32: {"atol": 1e-8, "rtol": 1e-8},
InfiniDtype.F16: {"atol": 1e-8, "rtol": 1e-8},
InfiniDtype.BF16: {"atol": 1e-8, "rtol": 1e-8},
}
DEBUG = False
PROFILE = False
NUM_PRERUN = 5
NUM_ITERATIONS = 10
# ==============================================================================
# Helper Classes & Reference Implementation
# ==============================================================================
class SimpleCacheManager:
def __init__(self, num_blocks, block_size):
self.num_blocks = num_blocks
self.block_size = block_size
self.free_blocks = list(range(num_blocks))
self.request_to_blocks = {}
self.request_to_len = {}
def allocate_slots(self, request_id, num_new_tokens):
if request_id not in self.request_to_len:
self.request_to_len[request_id] = 0
self.request_to_blocks[request_id] = []
start_pos = self.request_to_len[request_id]
new_total_len = start_pos + num_new_tokens
needed_blocks = (new_total_len + self.block_size - 1) // self.block_size
added_blocks = needed_blocks - len(self.request_to_blocks[request_id])
for _ in range(added_blocks):
self.request_to_blocks[request_id].append(self.free_blocks.pop(0))
slots = []
for i in range(start_pos, new_total_len):
block_idx_in_seq = i // self.block_size
block_offset = i % self.block_size
physical_block_id = self.request_to_blocks[request_id][block_idx_in_seq]
slots.append(physical_block_id * self.block_size + block_offset)
self.request_to_len[request_id] = new_total_len
return torch.tensor(slots, dtype=torch.int32)
def ref_paged_caching(k_new, v_new, k_pool, v_pool, slots, block_size):
"""Reference implementation for incremental caching."""
for i in range(k_new.shape[0]):
slot = slots[i].item()
b_id = slot // block_size
off = slot % block_size
k_pool[b_id, :, off, :] = k_new[i]
v_pool[b_id, :, off, :] = v_new[i]
return k_pool, v_pool
# ==============================================================================
# Test Operator Implementation
# ==============================================================================
def test(
handle,
device,
num_seqs,
max_step_len,
num_kv_heads,
head_size,
block_size,
num_rounds,
dtype=InfiniDtype.F16,
sync=None,
):
print(
f"Testing PagedCaching on {InfiniDeviceNames[device]} with "
f"seqs:{num_seqs}, max_step_len:{max_step_len}, num_kv_heads:{num_kv_heads}, head_size:{head_size}, "
f"block_size:{block_size}, rounds:{num_rounds}, dtype:{InfiniDtypeNames[dtype]}"
)
# 1. Initialize Global Cache Pool
num_blocks = 8192
manager = SimpleCacheManager(num_blocks, block_size)
k_cache_pool = TestTensor(
(num_blocks, num_kv_heads, block_size, head_size), None, dtype, device
)
v_cache_pool = TestTensor(
(num_blocks, num_kv_heads, block_size, head_size), None, dtype, device
)
# Reference pools (CPU/Torch)
k_pool_ref = k_cache_pool.torch_tensor().clone()
v_pool_ref = v_cache_pool.torch_tensor().clone()
for r in range(num_rounds):
# Prepare incremental data for this round
round_ntok_list = torch.randint(
1, max_step_len + 1, (num_seqs,), dtype=torch.int32
)
all_slots, all_k, all_v = [], [], []
for i in range(num_seqs):
n_new = round_ntok_list[i].item()
all_slots.append(manager.allocate_slots(i, n_new))
all_k.append(torch.randn(n_new, num_kv_heads, head_size))
all_v.append(torch.randn(n_new, num_kv_heads, head_size))
k_in_torch = torch.cat(all_k, dim=0)
v_in_torch = torch.cat(all_v, dim=0)
slots_torch = torch.cat(all_slots, dim=0)
k_in = TestTensor.from_torch(k_in_torch, dtype, device)
v_in = TestTensor.from_torch(v_in_torch, dtype, device)
slot_mapping = TestTensor.from_torch(slots_torch, InfiniDtype.I64, device)
# 2. Reference Calculation
def torch_caching():
nonlocal k_pool_ref, v_pool_ref
return ref_paged_caching(
k_in.torch_tensor(),
v_in.torch_tensor(),
k_pool_ref,
v_pool_ref,
slots_torch,
block_size,
)
torch_caching()
# 3. Infiniop Operator Execution
descriptor = infiniopOperatorDescriptor_t()
check_error(
LIBINFINIOP.infiniopCreatePagedCachingDescriptor(
handle,
ctypes.byref(descriptor),
k_in.descriptor,
v_in.descriptor,
k_cache_pool.descriptor,
v_cache_pool.descriptor,
slot_mapping.descriptor,
)
)
workspace_size = c_uint64(0)
check_error(
LIBINFINIOP.infiniopGetPagedCachingWorkspaceSize(
descriptor, ctypes.byref(workspace_size)
)
)
workspace = TestWorkspace(workspace_size.value, device)
def lib_caching():
check_error(
LIBINFINIOP.infiniopPagedCaching(
descriptor,
workspace.data(),
workspace_size.value,
k_in.data(),
v_in.data(),
k_cache_pool.data(),
v_cache_pool.data(),
slot_mapping.data(),
None,
)
)
lib_caching()
if sync:
sync()
# 4. Validation
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG:
# Check a small slice of the updated cache
debug(k_cache_pool.actual_tensor(), k_pool_ref, atol=atol, rtol=rtol)
assert torch.allclose(
k_cache_pool.actual_tensor(), k_pool_ref, atol=atol, rtol=rtol
)
assert torch.allclose(
v_cache_pool.actual_tensor(), v_pool_ref, atol=atol, rtol=rtol
)
# 5. Profiling
if PROFILE:
profile_operation(
f"Torch_R{r}",
lambda: torch_caching(),
device,
NUM_PRERUN,
NUM_ITERATIONS,
)
profile_operation(
f" Lib_R{r}", lambda: lib_caching(), device, NUM_PRERUN, NUM_ITERATIONS
)
check_error(LIBINFINIOP.infiniopDestroyPagedCachingDescriptor(descriptor))
# ==============================================================================
# Main Execution
# ==============================================================================
if __name__ == "__main__":
args = get_args()
DEBUG = args.debug
PROFILE = args.profile
NUM_PRERUN = args.num_prerun
NUM_ITERATIONS = args.num_iterations
for device in get_test_devices(args):
test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES)
print("\033[92mTest passed!\033[0m")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment