Unverified Commit 298feac2 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #836 from InfiniTensor/issue/834

issue/834: add paged attention for nvidia gpu
parents 27777ee1 17299923
......@@ -31,5 +31,7 @@
#include "infiniop/ops/topksoftmax.h"
#include "infiniop/ops/zeros.h"
#include "infiniop/tensor_descriptor.h"
#include "infiniop/ops/paged_attention.h"
#include "infiniop/ops/paged_caching.h"
#endif // __INFINIOP_API_H__
#ifndef __INFINIOP_PAGED_ATTENTION_API_H__
#define __INFINIOP_PAGED_ATTENTION_API_H__
#include "../operator_descriptor.h"
// Define an opaque handle for the Paged Attention descriptor.
typedef struct InfiniopDescriptor *infiniopPagedAttentionDescriptor_t;
/**
* @brief Creates a descriptor for the Paged Attention v1 operation.
*
* @param handle The library context handle.
* @param desc_ptr Pointer to the created descriptor.
* @param out_desc [Output] Shape: (num_seqs, num_heads, head_size).
* The output tensor for the attention mechanism.
* @param q_desc [Input] Shape: (num_seqs, num_heads, head_size).
* The query tensor.
* @param k_cache_desc [Input] Shape: (num_blocks, num_kv_heads, block_size, head_size).
* Paged key cache storing keys for all sequences.
* @param v_cache_desc [Input] Shape: (num_blocks, num_kv_heads, block_size, head_size).
* Paged value cache storing values for all sequences.
* @param block_tables_desc [Input] Shape: (num_seqs, max_num_blocks_per_seq).
* Maps each sequence to its physical block indices in the cache.
* Expected DType: int64_t (I64).
* @param seq_lens_desc [Input] Shape: (num_seqs,).
* The current logical length of each sequence.
* Expected DType: int64_t (I64).
* @param alibi_slopes_desc [Optional] Shape: (num_heads,).
* Slopes for ALiBi (Attention with Linear Biases). Can be NULL.
* @param scale The attention scaling factor (typically 1/sqrt(head_size)).
* @return infiniStatus_t Status code.
*/
__C __export infiniStatus_t infiniopCreatePagedAttentionDescriptor(
infiniopHandle_t handle,
infiniopPagedAttentionDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t seq_lens_desc,
infiniopTensorDescriptor_t alibi_slopes_desc,
float scale);
/**
* @brief Retrieves the workspace size required for the Paged Attention operation.
*
* @param desc The Paged Attention descriptor.
* @param size A pointer to store the required workspace size in bytes.
* @return infiniStatus_t Status code of the operation.
*/
__C __export infiniStatus_t infiniopGetPagedAttentionWorkspaceSize(
infiniopPagedAttentionDescriptor_t desc, size_t *size);
/**
* @brief Executes the Paged Attention v1 operation.
*
* @param desc The Paged Attention descriptor.
* @param workspace Pointer to the workspace memory.
* @param workspace_size The size of the workspace.
* @param out Pointer to the output tensor data.
* @param q Pointer to the query tensor data.
* @param k_cache Pointer to the key cache data.
* @param v_cache Pointer to the value cache data.
* @param block_tables Pointer to the block tables data.
* @param seq_lens Pointer to the sequence lengths data.
* @param alibi_slopes Pointer to the ALiBi slopes data. Can be NULL.
* @param stream The CUDA stream for the operation. Can be NULL.
* @return infiniStatus_t Status code of the operation.
*/
__C __export infiniStatus_t infiniopPagedAttention(
infiniopPagedAttentionDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *out,
const void *q,
const void *k_cache,
const void *v_cache,
const void *block_tables,
const void *seq_lens,
const void *alibi_slopes,
void *stream);
/**
* @brief Destroys a Paged Attention descriptor.
*
* @param desc The descriptor to be destroyed.
* @return infiniStatus_t Status code of the operation.
*/
__C __export infiniStatus_t infiniopDestroyPagedAttentionDescriptor(
infiniopPagedAttentionDescriptor_t desc);
#endif // __INFINIOP_PAGED_ATTENTION_API_H__
#ifndef __INFINIOP_PAGED_CACHING_API_H__
#define __INFINIOP_PAGED_CACHING_API_H__
#include "../operator_descriptor.h"
// Define an opaque handle for the Paged Caching descriptor.
typedef struct InfiniopDescriptor *infiniopPagedCachingDescriptor_t;
/**
* @brief Creates a descriptor for the Paged Caching operation.
*
* This function initializes a descriptor that holds all the metadata needed
* to copy key/value vectors into their respective cache pools.
*
* @param handle The handle to the InfiniOP library context.
* @param desc_ptr A pointer to store the created descriptor.
* @param k_desc Descriptor for the source key tensor.
* @param v_desc Descriptor for the source value tensor.
* @param k_cache_desc Descriptor for the key cache pool tensor.
* @param v_cache_desc Descriptor for the value cache pool tensor.
* @param slot_mapping_desc Descriptor for the slot mapping tensor.
* @return infiniStatus_t Status code of the operation.
*/
__C __export infiniStatus_t infiniopCreatePagedCachingDescriptor(
infiniopHandle_t handle,
infiniopPagedCachingDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t slot_mapping_desc);
/**
* @brief Retrieves the workspace size required for the Paged Caching operation.
*
* @param desc The Paged Caching descriptor.
* @param size A pointer to store the required workspace size in bytes (typically 0).
* @return infiniStatus_t Status code of the operation.
*/
__C __export infiniStatus_t infiniopGetPagedCachingWorkspaceSize(
infiniopPagedCachingDescriptor_t desc, size_t *size);
/**
* @brief Executes the Paged Caching operation.
*
* @param desc The Paged Caching descriptor.
* @param workspace Pointer to the workspace memory.
* @param workspace_size The size of the workspace.
* @param k Pointer to the source key tensor data.
* @param v Pointer to the source value tensor data.
* @param k_cache Pointer to the key cache pool data.
* @param v_cache Pointer to the value cache pool data.
* @param slot_mapping Pointer to the slot mapping data.
* @param stream The CUDA stream for the operation. Can be NULL.
* @return infiniStatus_t Status code of the operation.
*/
__C __export infiniStatus_t infiniopPagedCaching(
infiniopPagedCachingDescriptor_t desc,
void *workspace,
size_t workspace_size,
const void *k,
const void *v,
void *k_cache,
void *v_cache,
const void *slot_mapping,
void *stream);
/**
* @brief Destroys a Paged Caching descriptor.
*
* @param desc The descriptor to be destroyed.
* @return infiniStatus_t Status code of the operation.
*/
__C __export infiniStatus_t infiniopDestroyPagedCachingDescriptor(
infiniopPagedCachingDescriptor_t desc);
#endif // __INFINIOP_PAGED_CACHING_API_H__
#ifndef __PAGED_ATTENTION_KERNEL_CUH__
#define __PAGED_ATTENTION_KERNEL_CUH__
// This kernel is refactored to be high-performance, adopting parallelism strategies
// from industry-standard implementations like vLLM. It fixes functional and performance
// issues in the original draft.
namespace op::paged_attention::cuda {
template <typename Tdata, typename Tcompute, size_t HEAD_SIZE, size_t NUM_THREADS>
__device__ void pagedAttentionKernel(
Tdata *out_,
const Tdata *q_,
const Tdata *k_cache_,
const Tdata *v_cache_,
const int64_t *block_tables_,
const int64_t *seq_lens_,
const float *alibi_slopes_,
const size_t num_kv_heads,
const float scale,
const size_t max_num_blocks_per_seq,
const size_t block_size,
const ptrdiff_t q_stride,
const ptrdiff_t kv_block_stride,
const ptrdiff_t kv_head_stride,
const ptrdiff_t o_stride) {
//================================================================================
// 1. Setup & Query Loading (No changes in this section)
//================================================================================
const int seq_idx = blockIdx.y;
const int head_idx = blockIdx.x;
const int num_heads = gridDim.x;
const int64_t seq_len = seq_lens_[seq_idx];
if (seq_len == 0) {
return;
}
const size_t num_queries_per_kv = num_heads / num_kv_heads;
const size_t kv_head_idx = head_idx / num_queries_per_kv;
const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx];
const int64_t *block_table = block_tables_ + seq_idx * max_num_blocks_per_seq;
const Tdata *q_ptr = q_ + seq_idx * q_stride + head_idx * HEAD_SIZE;
Tdata *out_ptr = out_ + seq_idx * o_stride + head_idx * HEAD_SIZE;
extern __shared__ char shared_mem_char[];
Tcompute *shared_mem = reinterpret_cast<Tcompute *>(shared_mem_char);
Tcompute *q_shared = shared_mem;
Tcompute *logits = shared_mem + HEAD_SIZE;
// printf("static_cast<Tcompute>(q_ptr[i]);");
for (size_t i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
q_shared[i] = static_cast<Tcompute>(q_ptr[i]);
}
__syncthreads();
//================================================================================
// 2. Compute QK Dot Product & Find Max Logit
//================================================================================
for (size_t token_idx = threadIdx.x; token_idx < seq_len; token_idx += NUM_THREADS) {
const int64_t block_idx = token_idx / block_size;
const int64_t token_in_block_idx = token_idx % block_size;
const int64_t physical_block_num = block_table[block_idx];
const Tdata *k_vec_ptr = k_cache_ + physical_block_num * kv_block_stride + kv_head_idx * kv_head_stride + token_in_block_idx * HEAD_SIZE;
Tcompute qk = 0.0f;
#pragma unroll
for (size_t i = 0; i < HEAD_SIZE / 8; ++i) {
const size_t offset = i * 8;
// 手动展开8次计算
qk += q_shared[offset + 0] * static_cast<Tcompute>(k_vec_ptr[offset + 0]);
qk += q_shared[offset + 1] * static_cast<Tcompute>(k_vec_ptr[offset + 1]);
qk += q_shared[offset + 2] * static_cast<Tcompute>(k_vec_ptr[offset + 2]);
qk += q_shared[offset + 3] * static_cast<Tcompute>(k_vec_ptr[offset + 3]);
qk += q_shared[offset + 4] * static_cast<Tcompute>(k_vec_ptr[offset + 4]);
qk += q_shared[offset + 5] * static_cast<Tcompute>(k_vec_ptr[offset + 5]);
qk += q_shared[offset + 6] * static_cast<Tcompute>(k_vec_ptr[offset + 6]);
qk += q_shared[offset + 7] * static_cast<Tcompute>(k_vec_ptr[offset + 7]);
}
qk *= scale;
if (alibi_slope != 0.0f) {
qk += alibi_slope * (token_idx - seq_len + 1);
}
logits[token_idx] = qk;
}
__syncthreads();
__shared__ Tcompute global_qk_max;
Tcompute global_qk_max_0 = op::common_cuda::reduce_op::max<NUM_THREADS, Tcompute>(logits, seq_len);
if (threadIdx.x == 0) {
global_qk_max = global_qk_max_0;
}
__syncthreads();
//================================================================================
// 3. Compute Softmax (No changes in this section)
//================================================================================
for (size_t i = threadIdx.x; i < seq_len; i += NUM_THREADS) {
Tcompute val = expf(logits[i] - global_qk_max); // 使用全局最大值
logits[i] = val;
}
__syncthreads();
__shared__ Tcompute inv_sum;
Tcompute exp_sum_0 = op::common_cuda::reduce_op::sum<NUM_THREADS, Tcompute, Tcompute>(logits, seq_len);
if (threadIdx.x == 0) {
inv_sum = 1.0f / (exp_sum_0 + 1e-6f);
}
__syncthreads();
for (size_t i = threadIdx.x; i < seq_len; i += NUM_THREADS) {
logits[i] *= inv_sum;
}
__syncthreads();
//================================================================================
// 4. Aggregate Values (V) weighted by probabilities
//================================================================================
for (size_t h_dim = threadIdx.x; h_dim < HEAD_SIZE; h_dim += NUM_THREADS) {
Tcompute acc = 0.0f;
for (size_t token_idx = 0; token_idx < seq_len; ++token_idx) {
const size_t block_idx = token_idx / block_size;
const size_t token_in_block_idx = token_idx % block_size;
const int64_t physical_block_num = block_table[block_idx];
const Tcompute prob = logits[token_idx];
const Tdata *v_vec_ptr = v_cache_
+ physical_block_num * kv_block_stride
+ kv_head_idx * kv_head_stride
+ token_in_block_idx * HEAD_SIZE;
const Tdata v_val = v_vec_ptr[h_dim];
acc += prob * static_cast<Tcompute>(v_val);
}
out_ptr[h_dim] = static_cast<Tdata>(acc);
}
}
} // namespace op::paged_attention::cuda
#endif // __PAGED_ATTENTION_KERNEL_CUH__
#ifndef __PAGED_ATTENTION_INFO_H__
#define __PAGED_ATTENTION_INFO_H__
#include "../../../utils.h"
#include "../../tensor.h"
#include <iostream>
#include <optional>
#include <vector>
namespace op::paged_attention {
class PagedAttentionInfo {
PagedAttentionInfo() = default;
public:
// --- Data Types and Scale ---
infiniDtype_t dtype;
float scale;
// --- Shape Dimensions ---
size_t num_seqs;
size_t num_heads;
size_t num_kv_heads;
size_t head_size;
size_t block_size;
size_t max_num_blocks_per_seq;
// --- Strides for Memory Layout ---
ptrdiff_t q_stride;
ptrdiff_t kv_block_stride;
ptrdiff_t kv_head_stride;
ptrdiff_t o_stride;
static utils::Result<PagedAttentionInfo> create(
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t seq_lens_desc,
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc,
float scale) {
auto dtype = q_desc->dtype();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32);
if (out_desc->dtype() != dtype || k_cache_desc->dtype() != dtype || v_cache_desc->dtype() != dtype) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (q_desc->ndim() != 3 || k_cache_desc->ndim() < 4 || v_cache_desc->ndim() < 4 || block_tables_desc->ndim() != 2 || seq_lens_desc->ndim() != 1) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (block_tables_desc->dtype() != INFINI_DTYPE_I64) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (seq_lens_desc->dtype() != INFINI_DTYPE_I64) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
// --- Extract shape dimensions ---
auto q_shape = q_desc->shape();
auto k_cache_shape = k_cache_desc->shape();
size_t num_seqs = q_shape[0];
size_t num_heads = q_shape[1];
size_t head_size = q_shape[2];
if (head_size != 128) {
// 输出具体的错误原因和当前的参数值
std::cerr << "[Error] Now only supports head_size = 128, but got "
<< head_size << "." << std::endl;
// 建议返回 SHAPE 相关的错误码
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
size_t num_kv_heads = k_cache_shape[1];
size_t block_size = v_cache_desc->shape()[2]; // 使用V cache的block size维度更可靠
size_t max_num_blocks_per_seq = block_tables_desc->shape()[1];
// --- Calculate max_seq_len for shared memory allocation ---
// This is a safe upper bound.
// info.max_seq_len = info.max_num_blocks_per_seq * info.block_size;
// --- Extract strides for memory access ---
ptrdiff_t q_stride = q_desc->stride(0);
ptrdiff_t kv_block_stride = k_cache_desc->stride(0);
ptrdiff_t kv_head_stride = k_cache_desc->stride(1);
ptrdiff_t o_stride = out_desc->stride(0);
return utils::Result<PagedAttentionInfo>(PagedAttentionInfo{
dtype,
scale,
num_seqs,
num_heads,
num_kv_heads,
head_size,
block_size,
max_num_blocks_per_seq,
q_stride,
kv_block_stride,
kv_head_stride,
o_stride});
}
};
} // namespace op::paged_attention
#endif // __PAGED_ATTENTION_INFO_H__
#include <cub/block/block_reduce.cuh>
#include "../../../devices/nvidia/nvidia_common.cuh"
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
#include "../../../reduce/cuda/reduce.cuh"
#include "../cuda/kernel.cuh"
#include "paged_attention_nvidia.cuh"
template <typename Tdata, typename Tcompute, size_t HEAD_SIZE, size_t NUM_THREADS>
INFINIOP_CUDA_KERNEL pagedAttention(
Tdata *out, const Tdata *q, const Tdata *k_cache, const Tdata *v_cache,
const int64_t *block_tables, const int64_t *seq_lens, const float *alibi_slopes,
const size_t num_kv_heads, const float scale, const size_t max_num_blocks_per_seq,
const size_t block_size,
const ptrdiff_t q_stride,
const ptrdiff_t kv_block_stride,
const ptrdiff_t kv_head_stride,
const ptrdiff_t o_stride) {
op::paged_attention::cuda::pagedAttentionKernel<Tdata, Tcompute, HEAD_SIZE, NUM_THREADS>(
out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes, num_kv_heads, scale,
max_num_blocks_per_seq, block_size, q_stride, kv_block_stride, kv_head_stride, o_stride);
}
namespace op::paged_attention::nvidia {
struct Descriptor::Opaque {
std::shared_ptr<device::nvidia::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t seq_lens_desc,
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc,
float scale) {
auto info = PagedAttentionInfo::create(out_desc, q_desc, k_cache_desc, v_cache_desc, block_tables_desc, seq_lens_desc, alibi_slopes_desc, scale);
CHECK_RESULT(info);
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::nvidia::Handle *>(handle)->internal()},
info.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <size_t HEAD_SIZE, size_t NUM_THREADS>
infiniStatus_t launchKernel(void *out, const void *q, const void *k_cache, const void *v_cache,
infiniDtype_t dtype,
const void *block_tables, const void *seq_lens, const void *alibi_slopes,
size_t num_heads, size_t num_seqs,
size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t block_size,
ptrdiff_t q_stride, ptrdiff_t kv_block_stride, ptrdiff_t kv_head_stride, ptrdiff_t o_stride,
cudaStream_t stream) {
dim3 grid(uint64_t(num_heads), uint64_t(num_seqs), 1);
dim3 block(NUM_THREADS);
size_t shared_mem_size = (HEAD_SIZE + max_num_blocks_per_seq * block_size + 2) * sizeof(float);
if (dtype == INFINI_DTYPE_F16) {
pagedAttention<half, float, HEAD_SIZE, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
(half *)out,
(const half *)q, (const half *)k_cache, (const half *)v_cache,
(const int64_t *)block_tables, (const int64_t *)seq_lens, (const float *)alibi_slopes, num_kv_heads,
scale, max_num_blocks_per_seq, block_size,
q_stride, kv_block_stride, kv_head_stride, o_stride);
} else if (dtype == INFINI_DTYPE_BF16) {
pagedAttention<__nv_bfloat16, float, HEAD_SIZE, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
(__nv_bfloat16 *)out, (const __nv_bfloat16 *)q, (const __nv_bfloat16 *)k_cache, (const __nv_bfloat16 *)v_cache,
(const int64_t *)block_tables, (const int64_t *)seq_lens, (const float *)alibi_slopes, num_kv_heads,
scale, max_num_blocks_per_seq, block_size,
q_stride, kv_block_stride, kv_head_stride, o_stride);
} else if (dtype == INFINI_DTYPE_F32) {
pagedAttention<float, float, HEAD_SIZE, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
(float *)out, (const float *)q, (const float *)k_cache, (const float *)v_cache,
(const int64_t *)block_tables, (const int64_t *)seq_lens, (const float *)alibi_slopes, num_kv_heads,
scale, max_num_blocks_per_seq, block_size,
q_stride, kv_block_stride, kv_head_stride, o_stride);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
const void *block_tables, const void *seq_lens, const void *alibi_slopes,
void *stream_) const {
cudaStream_t stream = (cudaStream_t)stream_;
if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) {
if (_info.head_size == 128) {
launchKernel<128, CUDA_BLOCK_SIZE_1024>(
out, q, k_cache, v_cache, _info.dtype, block_tables, seq_lens, alibi_slopes,
_info.num_heads, _info.num_seqs,
_info.num_kv_heads, _info.scale, _info.max_num_blocks_per_seq, _info.block_size,
_info.q_stride, _info.kv_block_stride, _info.kv_head_stride, _info.o_stride,
stream);
}
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_512) {
if (_info.head_size == 128) {
launchKernel<128, CUDA_BLOCK_SIZE_512>(
out, q, k_cache, v_cache, _info.dtype, block_tables, seq_lens, alibi_slopes,
_info.num_heads, _info.num_seqs,
_info.num_kv_heads, _info.scale, _info.max_num_blocks_per_seq, _info.block_size,
_info.q_stride, _info.kv_block_stride, _info.kv_head_stride, _info.o_stride,
stream);
}
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) {
if (_info.head_size == 128) {
launchKernel<128, CUDA_BLOCK_SIZE_4096>(
out, q, k_cache, v_cache, _info.dtype, block_tables, seq_lens, alibi_slopes,
_info.num_heads, _info.num_seqs,
_info.num_kv_heads, _info.scale, _info.max_num_blocks_per_seq, _info.block_size,
_info.q_stride, _info.kv_block_stride, _info.kv_head_stride, _info.o_stride,
stream);
}
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::paged_attention::nvidia
#ifndef __PAGED_ATTENTION_NVIDIA_H__
#define __PAGED_ATTENTION_NVIDIA_H__
#include "../paged_attention.h"
DESCRIPTOR(nvidia)
#endif // __PAGED_ATTENTION_NVIDIA_H__
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/paged_attention.h"
#ifdef ENABLE_NVIDIA_API
#include "nvidia/paged_attention_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
#include "metax/paged_attention_metax.h"
#endif
__C infiniStatus_t infiniopCreatePagedAttentionDescriptor(
infiniopHandle_t handle,
infiniopPagedAttentionDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t seq_lens_desc,
infiniopTensorDescriptor_t alibi_slopes_desc,
float scale) {
infiniopTensorDescriptor_t alibi_opt = (alibi_slopes_desc == nullptr) ? nullptr : alibi_slopes_desc;
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::paged_attention::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::paged_attention::NAMESPACE::Descriptor **>(desc_ptr), \
out_desc, q_desc, k_cache_desc, v_cache_desc, block_tables_desc, seq_lens_desc, alibi_opt, scale);
switch (handle->device) {
#ifdef ENABLE_NVIDIA_API
CREATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax)
#endif
}
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopGetPagedAttentionWorkspaceSize(
infiniopPagedAttentionDescriptor_t desc,
size_t *size) {
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<op::paged_attention::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) {
#ifdef ENABLE_NVIDIA_API
GET(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax)
#endif
}
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopPagedAttention(
infiniopPagedAttentionDescriptor_t desc,
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
const void *block_tables, const void *seq_lens, const void *alibi_slopes,
void *stream) {
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<op::paged_attention::NAMESPACE::Descriptor *>(desc)->calculate( \
workspace, workspace_size, out, q, k_cache, v_cache, block_tables, \
seq_lens, alibi_slopes, stream);
switch (desc->device_type) {
#ifdef ENABLE_NVIDIA_API
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax)
#endif
}
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopDestroyPagedAttentionDescriptor(
infiniopPagedAttentionDescriptor_t desc) {
#define DESTROY(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<op::paged_attention::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) {
#ifdef ENABLE_NVIDIA_API
DESTROY(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_METAX_API
DESTROY(INFINI_DEVICE_METAX, metax)
#endif
}
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#ifndef PAGED_ATTENTION_H
#define PAGED_ATTENTION_H
#include "../../operator.h"
#include "info.h"
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::paged_attention::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
PagedAttentionInfo _info; \
size_t _workspace_size; \
\
Descriptor( \
Opaque *opaque, \
PagedAttentionInfo info, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t out_desc, \
infiniopTensorDescriptor_t q_desc, \
infiniopTensorDescriptor_t k_cache_desc, \
infiniopTensorDescriptor_t v_cache_desc, \
infiniopTensorDescriptor_t block_tables_desc, \
infiniopTensorDescriptor_t seq_lens_desc, \
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc, \
float scale); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *out, const void *q, const void *k_cache, const void *v_cache, \
const void *block_tables, const void *seq_lens, \
const void *alibi_slopes, \
void *stream) const; \
}; \
}
#endif // PAGED_ATTENTION_H
#ifndef __PAGED_CACHING_KERNEL_CUH__
#define __PAGED_CACHING_KERNEL_CUH__
//================================================================================
// Paged Caching Operator CUDA Kernel
//
// This kernel implements the "paged_caching" operation, which copies Key and Value
// vectors from a contiguous source tensor into a paged, non-contiguous KV Cache.
//
// Design Principles:
// 1. Token-Centric Parallelism: A 1D grid of `num_tokens` is launched. Each CUDA
// block is responsible for caching one full token (all its heads).
// 2. Coalesced Memory Access: This grid strategy ensures that threads within a
// block read a large, contiguous chunk of memory from the source tensors,
// maximizing memory bandwidth utilization.
// 3. Vectorization: The copy operation is vectorized to further enhance memory
// throughput, processing multiple data elements in a single instruction.
//================================================================================
namespace op::paged_caching::cuda {
template <
typename Tdata, // Data type of the tensors (e.g., half, __nv_bfloat16)
int NUM_THREADS // Number of threads per block, configured at launch time
>
__device__ void pagedCachingKernel(
// ----- Output Tensors -----
Tdata *k_cache_ptr, // Pointer to the destination K cache pool [num_blocks, nkvh, block_size, dh]
Tdata *v_cache_ptr, // Pointer to the destination V cache pool [num_blocks, nkvh, block_size, dh]
// ----- Input Tensors -----
const Tdata *k_ptr, // Pointer to the source Keys, shape [ntok, nkvh, dh]
const Tdata *v_ptr, // Pointer to the source Values, shape [ntok, nkvh, dh]
const int64_t *slot_mapping_ptr, // Pointer to the slot mapping, shape [ntok]
// ----- Metadata -----
const size_t head_size, // Dimension of each head (dh)
const size_t block_size, // Number of tokens per block in the KV cache
// ----- Stride Information -----
const ptrdiff_t k_src_stride, // Stride between tokens in the source K tensor
const ptrdiff_t v_src_stride, // Stride between tokens in the source V tensor
const ptrdiff_t k_cache_block_stride, // Stride between blocks in the K cache pool
const ptrdiff_t v_cache_block_stride // Stride between blocks in the V cache pool
) {
//================================================================================
// 1. Identify Work Unit & Calculate Addresses
//================================================================================
// Each block processes one token.
const int token_idx = blockIdx.y;
const int head_idx = blockIdx.x;
// const int num_kv_heads = gridDim.y;
// Retrieve the destination slot for the current token.
const int64_t slot_idx = slot_mapping_ptr[token_idx];
// Handle padding: if slot_idx is negative, this token is padding and should be ignored.
if (slot_idx < 0) {
return;
}
// Calculate the physical block index and the offset within that block.
const int64_t physical_block_idx = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
// Calculate base pointers for source and destination for this specific token.
const Tdata *k_src_head_ptr = k_ptr + token_idx * k_src_stride + head_idx * head_size;
const Tdata *v_src_head_ptr = v_ptr + token_idx * v_src_stride + head_idx * head_size;
// Destination pointer calculation assumes a [num_blocks, block_size, num_heads, head_size] layout.
// We point to the beginning of the memory region for this token's slot.
const ptrdiff_t cache_head_stride = block_size * head_size;
Tdata *k_cache_block_base_ptr = k_cache_ptr + physical_block_idx * k_cache_block_stride;
Tdata *k_dst_head_ptr = k_cache_block_base_ptr + head_idx * cache_head_stride + block_offset * head_size;
Tdata *v_cache_block_base_ptr = v_cache_ptr + physical_block_idx * v_cache_block_stride;
Tdata *v_dst_head_ptr = v_cache_block_base_ptr + head_idx * cache_head_stride + block_offset * head_size;
//================================================================================
// 2. Perform Element-wise Data Copy (Safe, Non-Vectorized)
//================================================================================
for (int i = threadIdx.x; i < head_size; i += NUM_THREADS) {
k_dst_head_ptr[i] = k_src_head_ptr[i];
v_dst_head_ptr[i] = v_src_head_ptr[i];
}
}
} // namespace op::paged_caching::cuda
#endif // __PAGED_CACHING_KERNEL_CUH__
#ifndef __PAGED_CACHING_INFO_H__
#define __PAGED_CACHING_INFO_H__
#include "../../../utils.h"
#include "../../tensor.h"
#include <optional>
#include <vector>
namespace op::paged_caching {
class PagedCachingInfo {
PagedCachingInfo() = default;
public:
// --- Data Type ---
infiniDtype_t dtype;
// --- Shape Dimensions ---
size_t num_tokens;
size_t num_kv_heads;
size_t head_size;
size_t block_size;
// --- Strides for Memory Layout ---
ptrdiff_t k_src_stride;
ptrdiff_t v_src_stride;
ptrdiff_t k_cache_block_stride;
ptrdiff_t v_cache_block_stride;
static utils::Result<PagedCachingInfo> create(
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t slot_mapping_desc) {
auto dtype = k_desc->dtype();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32);
if (v_desc->dtype() != dtype || k_cache_desc->dtype() != dtype || v_cache_desc->dtype() != dtype) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (slot_mapping_desc->dtype() != INFINI_DTYPE_I64) {
printf("slot_mapping must be int64_t.\n");
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (k_desc->ndim() != 3 || v_desc->ndim() != 3 || k_cache_desc->ndim() < 4 || v_cache_desc->ndim() < 4 || slot_mapping_desc->ndim() != 1) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
// PagedCachingInfo info;
// --- Extract shape dimensions ---
auto k_shape = k_desc->shape();
auto k_cache_shape = k_cache_desc->shape();
size_t num_tokens = slot_mapping_desc->shape()[0];
size_t num_kv_heads = k_shape[1];
size_t head_size = k_shape[2];
size_t block_size = k_cache_shape[2]; // Assuming [num_blocks, num_heads, block_size, head_size] layout
// --- Extract strides for memory access ---
ptrdiff_t k_src_stride = k_desc->stride(0);
ptrdiff_t v_src_stride = v_desc->stride(0);
ptrdiff_t k_cache_block_stride = k_cache_desc->stride(0);
ptrdiff_t v_cache_block_stride = v_cache_desc->stride(0);
return utils::Result<PagedCachingInfo>(PagedCachingInfo{
dtype,
num_tokens,
num_kv_heads,
head_size,
block_size,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride});
}
};
} // namespace op::paged_caching
#endif // __PAGED_CACHING_INFO_H__
#include "../../../devices/nvidia/nvidia_common.cuh"
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
#include "../cuda/kernel.cuh"
#include "paged_caching_nvidia.cuh"
template <typename Tdata, int NUM_THREADS>
INFINIOP_CUDA_KERNEL pagedCaching(
Tdata *k_cache, Tdata *v_cache,
const Tdata *k, const Tdata *v,
const int64_t *slot_mapping,
const size_t head_size, const size_t block_size,
const ptrdiff_t k_src_stride, const ptrdiff_t v_src_stride,
const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride) {
op::paged_caching::cuda::pagedCachingKernel<Tdata, NUM_THREADS>(
k_cache, v_cache, k, v, slot_mapping, head_size,
block_size, k_src_stride, v_src_stride, k_cache_block_stride, v_cache_block_stride);
}
namespace op::paged_caching::nvidia {
// PIMPL struct definition
struct Descriptor::Opaque {
std::shared_ptr<device::nvidia::Handle::Internal> internal;
};
// Destructor implementation
Descriptor::~Descriptor() {
delete _opaque;
}
// Static factory method implementation
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t slot_mapping_desc) {
auto info = PagedCachingInfo::create(k_desc, v_desc, k_cache_desc, v_cache_desc, slot_mapping_desc);
CHECK_RESULT(info);
// Create and return the Descriptor instance.
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::nvidia::Handle *>(handle)->internal()},
info.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
// The launchKernel function is a templated helper to encapsulate the CUDA kernel launch.
// It sets up grid/block dimensions and calls the device-side kernel.
template <int NUM_THREADS>
infiniStatus_t launchKernel(const PagedCachingInfo &info,
void *k_cache, void *v_cache,
infiniDtype_t dtype,
const void *k, const void *v,
const void *slot_mapping,
size_t num_tokens, size_t num_kv_heads, size_t head_size, size_t block_size,
ptrdiff_t k_src_stride, ptrdiff_t v_src_stride,
ptrdiff_t k_cache_block_stride, ptrdiff_t v_cache_block_stride,
cudaStream_t stream) {
// Grid dimension is 1D, with one block per token, as we decided.
dim3 grid(uint64_t(num_kv_heads), uint64_t(num_tokens), 1);
// Block dimension is 1D, using the number of threads specified at compile time.
dim3 block(NUM_THREADS);
// This kernel does not require dynamic shared memory.
size_t shared_mem_size = 0;
// Launch the device-side CUDA kernel.
if (dtype == INFINI_DTYPE_F16) {
pagedCaching<half, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
(half *)k_cache,
(half *)v_cache,
(const half *)k,
(const half *)v,
(const int64_t *)slot_mapping,
head_size,
block_size,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
} else if (dtype == INFINI_DTYPE_BF16) {
pagedCaching<__nv_bfloat16, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
(__nv_bfloat16 *)k_cache,
(__nv_bfloat16 *)v_cache,
(const __nv_bfloat16 *)k,
(const __nv_bfloat16 *)v,
(const int64_t *)slot_mapping,
head_size,
block_size,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
} else if (dtype == INFINI_DTYPE_F32) {
pagedCaching<float, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
(float *)k_cache,
(float *)v_cache,
(const float *)k,
(const float *)v,
(const int64_t *)slot_mapping,
head_size,
block_size,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
// Execution method implementation
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
const void *k, const void *v,
void *k_cache, void *v_cache,
const void *slot_mapping,
void *stream_) const {
cudaStream_t stream = (cudaStream_t)stream_;
// Dispatch logic based on the GPU's maximum threads per block.
// This allows selecting the largest, most efficient block size the hardware supports.
if (_opaque->internal->maxThreadsPerBlock() >= CUDA_BLOCK_SIZE_1024) {
// Dispatch based on data type for a 1024-thread block.
launchKernel<CUDA_BLOCK_SIZE_1024>(
_info, k_cache, v_cache, _info.dtype, k, v, slot_mapping,
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride,
stream);
} else if (_opaque->internal->maxThreadsPerBlock() >= CUDA_BLOCK_SIZE_512) {
launchKernel<CUDA_BLOCK_SIZE_512>(
_info, k_cache, v_cache, _info.dtype, k, v, slot_mapping,
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride,
stream);
} else if (_opaque->internal->maxThreadsPerBlock() >= CUDA_BLOCK_SIZE_4096) {
launchKernel<CUDA_BLOCK_SIZE_4096>(
_info, k_cache, v_cache, _info.dtype, k, v, slot_mapping,
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride,
stream);
} else {
// If the GPU is older and supports fewer threads, return an error.
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::paged_caching::nvidia
#ifndef __PAGED_CACHING_NVIDIA_H__
#define __PAGED_CACHING_NVIDIA_H__
#include "../paged_caching.h"
DESCRIPTOR(nvidia)
#endif // __PAGED_CACHING_NVIDIA_H__
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/paged_caching.h"
#ifdef ENABLE_NVIDIA_API
#include "nvidia/paged_caching_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
#include "metax/paged_caching_metax.h"
#endif
__C infiniStatus_t infiniopCreatePagedCachingDescriptor(
infiniopHandle_t handle,
infiniopPagedCachingDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t slot_mapping_desc) {
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::paged_caching::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::paged_caching::NAMESPACE::Descriptor **>(desc_ptr), \
k_desc, v_desc, k_cache_desc, v_cache_desc, slot_mapping_desc);
switch (handle->device) {
#ifdef ENABLE_NVIDIA_API
CREATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax)
#endif
}
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopGetPagedCachingWorkspaceSize(
infiniopPagedCachingDescriptor_t desc,
size_t *size) {
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<op::paged_caching::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) {
#ifdef ENABLE_NVIDIA_API
GET(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax)
#endif
}
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopPagedCaching(
infiniopPagedCachingDescriptor_t desc,
void *workspace, size_t workspace_size,
const void *k, const void *v,
void *k_cache, void *v_cache,
const void *slot_mapping,
void *stream) {
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<op::paged_caching::NAMESPACE::Descriptor *>(desc)->calculate( \
workspace, workspace_size, k, v, k_cache, v_cache, slot_mapping, stream);
switch (desc->device_type) {
#ifdef ENABLE_NVIDIA_API
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax)
#endif
}
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopDestroyPagedCachingDescriptor(
infiniopPagedCachingDescriptor_t desc) {
#define DESTROY(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<op::paged_caching::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) {
#ifdef ENABLE_NVIDIA_API
DESTROY(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_METAX_API
DESTROY(INFINI_DEVICE_METAX, metax)
#endif
}
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#ifndef PAGED_CACHING_H
#define PAGED_CACHING_H
#include "../../operator.h"
#include "info.h"
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::paged_caching::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
PagedCachingInfo _info; \
size_t _workspace_size; \
\
Descriptor( \
Opaque *opaque, \
PagedCachingInfo info, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t k_desc, \
infiniopTensorDescriptor_t v_desc, \
infiniopTensorDescriptor_t k_cache_desc, \
infiniopTensorDescriptor_t v_cache_desc, \
infiniopTensorDescriptor_t slot_mapping_desc); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
const void *k, const void *v, \
void *k_cache, void *v_cache, \
const void *slot_mapping, \
void *stream) const; \
}; \
}
#endif // PAGED_CACHING_H
......@@ -977,3 +977,87 @@ def scaled_mm_int8_(lib):
lib.infiniopDestroyI8GemmDescriptor.argtypes = [
infiniopOperatorDescriptor_t,
]
@OpRegister.operator
def paged_attention_(lib):
lib.infiniopCreatePagedAttentionDescriptor.restype = c_int32
lib.infiniopCreatePagedAttentionDescriptor.argtypes = [
infiniopHandle_t,
POINTER(infiniopOperatorDescriptor_t),
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
c_void_p,
c_float,
]
lib.infiniopGetPagedAttentionWorkspaceSize.restype = c_int32
lib.infiniopGetPagedAttentionWorkspaceSize.argtypes = [
infiniopOperatorDescriptor_t,
POINTER(c_size_t),
]
lib.infiniopPagedAttention.restype = c_int32
lib.infiniopPagedAttention.argtypes = [
infiniopOperatorDescriptor_t,
c_void_p,
c_size_t,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
]
lib.infiniopDestroyPagedAttentionDescriptor.restype = c_int32
lib.infiniopDestroyPagedAttentionDescriptor.argtypes = [
infiniopOperatorDescriptor_t,
]
@OpRegister.operator
def paged_caching_(lib):
lib.infiniopCreatePagedCachingDescriptor.restype = c_int32
lib.infiniopCreatePagedCachingDescriptor.argtypes = [
infiniopHandle_t,
POINTER(infiniopOperatorDescriptor_t),
infiniopTensorDescriptor_t, # k_desc
infiniopTensorDescriptor_t, # v_desc
infiniopTensorDescriptor_t, # k_cache_desc
infiniopTensorDescriptor_t, # v_cache_desc
infiniopTensorDescriptor_t, # slot_mapping_desc
]
# infiniopGetPagedCachingWorkspaceSize
lib.infiniopGetPagedCachingWorkspaceSize.restype = c_int32
lib.infiniopGetPagedCachingWorkspaceSize.argtypes = [
infiniopOperatorDescriptor_t,
POINTER(c_size_t),
]
# infiniopPagedCaching
lib.infiniopPagedCaching.restype = c_int32
lib.infiniopPagedCaching.argtypes = [
infiniopOperatorDescriptor_t,
c_void_p, # workspace
c_size_t, # workspace_size
c_void_p, # k
c_void_p, # v
c_void_p, # k_cache
c_void_p, # v_cache
c_void_p, # slot_mapping
c_void_p, # stream
]
# infiniopDestroyPagedCachingDescriptor
lib.infiniopDestroyPagedCachingDescriptor.restype = c_int32
lib.infiniopDestroyPagedCachingDescriptor.argtypes = [
infiniopOperatorDescriptor_t,
]
import torch
import ctypes
from ctypes import c_uint64
import math
from libinfiniop import (
LIBINFINIOP,
TestTensor,
get_test_devices,
check_error,
test_operator,
get_args,
debug,
get_tolerance,
profile_operation,
InfiniDtype,
InfiniDtypeNames,
InfiniDeviceNames,
infiniopOperatorDescriptor_t,
TestWorkspace,
)
# ==============================================================================
# Reference Implementation
# ==============================================================================
def get_alibi_slopes(n):
# 简化版的ALiBi斜率计算方法
# 参考: https://github.com/ofirpress/attention_with_linear_biases/blob/master/fairseq/models/transformer.py#L742
closest_power_of_2 = 2 ** math.floor(math.log2(n))
base = 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3)))
powers = [base**i for i in range(1, closest_power_of_2 + 1)]
if n > closest_power_of_2:
extra = [base ** (i * 2) for i in range(1, 2 * (n - closest_power_of_2) + 1, 2)]
powers += extra
return powers[:n]
def ref_masked_attention(query, key, value, scale, attn_mask=None):
# Reference implementation for a single masked attention head.
attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
if attn_mask is not None:
attn_weights = attn_weights + attn_mask.float()
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1).to(value.dtype)
out = torch.einsum("hqk,khd->qhd", attn_weights, value)
return out
def ref_single_query_cached_kv_attention(
query, key_cache, value_cache, block_tables, seq_lens, scale, alibi_slopes
):
# Reference implementation for paged attention, iterating through each sequence.
output = torch.empty_like(query)
num_query_heads, num_kv_heads = query.shape[1], value_cache.shape[1]
num_queries_per_kv = num_query_heads // num_kv_heads
head_size, block_size = value_cache.shape[3], value_cache.shape[2]
num_seqs = query.shape[0]
for i in range(num_seqs):
q = query[i].unsqueeze(0)
seq_len = seq_lens[i].item()
block_table = block_tables[i]
keys_lst, values_lst = [], []
for j in range(seq_len):
block_num = block_table[j // block_size].item()
block_off = j % block_size
k = key_cache[block_num, :, block_off, :]
v = value_cache[block_num, :, block_off, :]
keys_lst.append(k)
values_lst.append(v)
keys = torch.stack(keys_lst, dim=0)
values = torch.stack(values_lst, dim=0)
if num_queries_per_kv > 1:
keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)
alibi_bias = None
if alibi_slopes is not None:
pos = torch.arange(seq_len, device=query.device).int()
alibi_bias = (pos - seq_len + 1).float()
alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(1, 1, -1)
out = ref_masked_attention(q, keys, values, scale, alibi_bias)
output[i] = out.view(num_query_heads, head_size)
return output
# ==============================================================================
# Test Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES_ = [
# (num_seqs, num_heads, num_kv_heads, head_size, block_size, max_seq_len, use_alibi)
(1, 1, 1, 128, 16, 1024, False),
(4, 40, 40, 128, 16, 1024, False),
(6, 40, 40, 128, 16, 1024, False),
(3, 8, 8, 128, 16, 1024, False),
(8, 64, 8, 128, 16, 2048, False),
]
# Data types for testing
_TENSOR_DTYPES = [InfiniDtype.BF16, InfiniDtype.F16, InfiniDtype.F32]
# Tolerance map for different data types
_TOLERANCE_MAP = {
InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-2},
InfiniDtype.BF16: {"atol": 5e-3, "rtol": 5e-2},
InfiniDtype.F32: {"atol": 1e-5, "rtol": 1e-5},
}
# Global flags for controlling test behavior
DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
def test(
handle,
device,
num_seqs,
num_heads,
num_kv_heads,
head_size,
block_size,
max_seq_len,
use_alibi,
dtype=InfiniDtype.F16,
sync=None,
):
print(
f"Testing PagedAttention on {InfiniDeviceNames[device]} with "
f"num_seqs={num_seqs}, num_heads={num_heads}, head_size={head_size}, "
f"block_size={block_size}, dtype={InfiniDtypeNames[dtype]}, use_alibi={use_alibi}"
)
scale = 1.0 / (head_size**0.5)
max_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
num_blocks = num_seqs * max_blocks_per_seq # A reasonable number for testing
# Create input tensors
q = TestTensor((num_seqs, num_heads, head_size), None, dtype, device)
out = TestTensor((num_seqs, num_heads, head_size), None, dtype, device)
k_cache = TestTensor(
(num_blocks, num_kv_heads, block_size, head_size), None, dtype, device
)
v_cache = TestTensor(
(num_blocks, num_kv_heads, block_size, head_size), None, dtype, device
)
seq_lens_direct = 1023
seq_lens_torch = torch.randint(
1, seq_lens_direct + 1, (num_seqs,), dtype=torch.int64
)
seq_lens = TestTensor.from_torch(seq_lens_torch, InfiniDtype.I64, device)
block_tables_py = torch.arange(
0, num_seqs * max_blocks_per_seq, dtype=torch.int64
).view(num_seqs, max_blocks_per_seq)
block_tables = TestTensor.from_torch(block_tables_py, InfiniDtype.I64, device)
alibi_slopes_desc = ctypes.c_void_p(0)
alibi_slopes_data = ctypes.c_void_p(0)
alibi_slopes_torch = None
if use_alibi:
alibi_slopes = TestTensor((num_heads,), None, InfiniDtype.F32, device)
alibi_slopes_desc = alibi_slopes.descriptor
alibi_slopes_data = alibi_slopes.data()
alibi_slopes_torch = alibi_slopes.torch_tensor()
# Run reference implementation
ans = ref_single_query_cached_kv_attention(
q.torch_tensor(),
k_cache.torch_tensor(),
v_cache.torch_tensor(),
block_tables.torch_tensor(),
seq_lens.torch_tensor(),
scale,
alibi_slopes_torch,
)
if sync:
sync()
scale = 1.0 / (head_size**0.5)
# Create operator descriptor
descriptor = infiniopOperatorDescriptor_t()
check_error(
LIBINFINIOP.infiniopCreatePagedAttentionDescriptor(
handle,
ctypes.byref(descriptor),
out.descriptor,
q.descriptor,
k_cache.descriptor,
v_cache.descriptor,
block_tables.descriptor,
seq_lens.descriptor,
alibi_slopes_desc,
scale,
)
)
# Get workspace size and allocate memory
workspace_size = c_uint64(0)
check_error(
LIBINFINIOP.infiniopGetPagedAttentionWorkspaceSize(
descriptor, ctypes.byref(workspace_size)
)
)
workspace = TestWorkspace(workspace_size.value, q.device)
# Invalidate descriptors to ensure kernel does not rely on them
q.destroy_desc()
out.destroy_desc()
k_cache.destroy_desc()
v_cache.destroy_desc()
block_tables.destroy_desc()
seq_lens.destroy_desc()
if use_alibi:
alibi_slopes.destroy_desc()
# Define the library call as a lambda for profiling
def lib_paged_attention():
check_error(
LIBINFINIOP.infiniopPagedAttention(
descriptor,
workspace.data(),
workspace_size.value,
out.data(),
q.data(),
k_cache.data(),
v_cache.data(),
block_tables.data(),
seq_lens.data(),
alibi_slopes_data,
None,
)
)
# Execute the custom operator
lib_paged_attention()
if sync:
sync()
# Verify correctness
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG:
debug(out.actual_tensor(), ans, atol=atol, rtol=rtol)
assert torch.allclose(out.actual_tensor(), ans, atol=atol, rtol=rtol)
# Profiling workflow
if PROFILE:
# fmt: off
profile_operation("PyTorch", lambda: ref_single_query_cached_kv_attention(
q.torch_tensor(), k_cache.torch_tensor(), v_cache.torch_tensor(),
block_tables.torch_tensor(), seq_lens.torch_tensor(),
scale, alibi_slopes_torch),
device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lib_paged_attention, device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on
# Clean up resources
check_error(LIBINFINIOP.infiniopDestroyPagedAttentionDescriptor(descriptor))
if __name__ == "__main__":
args = get_args()
# Configure testing options from command line arguments
DEBUG = args.debug
PROFILE = args.profile
NUM_PRERUN = args.num_prerun
NUM_ITERATIONS = args.num_iterations
for device in get_test_devices(args):
test_operator(device, test, _TEST_CASES_, _TENSOR_DTYPES)
print("\033[92mTest passed!\033[0m")
import torch
import ctypes
from ctypes import c_uint64
from libinfiniop import (
LIBINFINIOP,
TestTensor,
get_test_devices,
check_error,
test_operator,
get_args,
debug,
get_tolerance,
profile_operation,
InfiniDtype,
InfiniDtypeNames,
InfiniDeviceNames,
infiniopOperatorDescriptor_t,
TestWorkspace,
)
# ==============================================================================
# Reference Implementation
# ==============================================================================
def ref_paged_caching(key, value, key_cache_pool, value_cache_pool, slot_mapping):
"""
Reference implementation for paged_caching operator.
Args:
key (torch.Tensor): Keys, shape [ntok, nkvh, dh]
value (torch.Tensor): Values, shape [ntok, nkvh, dh]
key_cache_pool (torch.Tensor): K cache pool, shape [num_blocks, nkvh, block_size, dh]
value_cache_pool (torch.Tensor): V cache pool, shape [num_blocks, nkvh, block_size, dh]
slot_mapping (torch.Tensor): Slot mapping, shape [ntok]
"""
ntok = key.shape[0]
block_size = key_cache_pool.shape[2]
# This reference implementation operates on a cloned cache to avoid modifying the original input tensor,
# mimicking the behavior where the custom operator writes to its output tensor.
k_cache_ref = key_cache_pool.clone()
v_cache_ref = value_cache_pool.clone()
for i in range(ntok):
slot = slot_mapping[i].item()
block_idx = slot // block_size
block_offset = slot % block_size
key_token = key[i]
value_token = value[i]
k_cache_ref[block_idx, :, block_offset, :] = key_token
v_cache_ref[block_idx, :, block_offset, :] = value_token
return k_cache_ref, v_cache_ref
# ==============================================================================
# Test Configuration (Internal Use Only)
# ==============================================================================
_TEST_CASES_ = [
# (num_seqs, max_seq_len, num_kv_heads, head_size, block_size)
(1, 128, 8, 128, 16),
(5, 512, 40, 128, 16),
(16, 1024, 8, 64, 32),
(10, 1024, 40, 64, 32),
]
# Data types for testing
_TENSOR_DTYPES = [InfiniDtype.BF16, InfiniDtype.F16, InfiniDtype.F32]
# Tolerance map for different data types
_TOLERANCE_MAP = {
InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-2},
InfiniDtype.BF16: {"atol": 5e-3, "rtol": 5e-2},
InfiniDtype.F32: {"atol": 1e-5, "rtol": 1e-5},
}
# Global flags for controlling test behavior
DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 100
def test(
handle,
device,
num_seqs, # nreq
max_seq_len,
num_kv_heads, # nkvh
head_size, # dh
block_size,
dtype=InfiniDtype.F16,
sync=None,
):
print(
f"Testing PagedCaching on {InfiniDeviceNames[device]} with "
f"num_seqs={num_seqs}, max_seq_len={max_seq_len}, num_kv_heads={num_kv_heads}, "
f"head_size={head_size}, block_size={block_size}, dtype={InfiniDtypeNames[dtype]}"
)
num_blocks = 4096 # A reasonably large cache pool for testing
# Create metadata: variable context lengths for each sequence in the batch
context_lens_torch = torch.randint(
1, max_seq_len + 1, (num_seqs,), dtype=torch.int64
)
ntok = torch.sum(context_lens_torch).item()
# If ntok is 0 (all sequences have length 0), skip the test
if ntok == 0:
print("Skipping test case with ntok=0")
return
# Simulate the scheduler's behavior to create the slot_mapping
slot_mapping_list = []
current_slot = 0
for length in context_lens_torch:
# Find a contiguous chunk of 'length' slots
start_slot = current_slot
slot_mapping_list.extend(range(start_slot, start_slot + length.item()))
current_slot += length.item()
# Ensure we don't exceed the total number of slots in the cache
assert (
current_slot <= num_blocks * block_size
), "Not enough blocks in the cache pool for this test case"
slot_mapping_torch = torch.tensor(slot_mapping_list, dtype=torch.int64)
# Create input tensors based on the calculated total tokens (ntok)
k = TestTensor((ntok, num_kv_heads, head_size), None, dtype, device)
v = TestTensor((ntok, num_kv_heads, head_size), None, dtype, device)
slot_mapping = TestTensor.from_torch(slot_mapping_torch, InfiniDtype.I64, device)
# The cache pools are the "output" tensors for this operator
k_cache_pool = TestTensor(
(num_blocks, num_kv_heads, block_size, head_size), None, dtype, device
)
v_cache_pool = TestTensor(
(num_blocks, num_kv_heads, block_size, head_size), None, dtype, device
)
# Run reference implementation
k_cache_ref, v_cache_ref = ref_paged_caching(
k.torch_tensor(),
v.torch_tensor(),
k_cache_pool.torch_tensor(),
v_cache_pool.torch_tensor(),
slot_mapping.torch_tensor(),
)
if sync:
sync()
# Create operator descriptor
descriptor = infiniopOperatorDescriptor_t()
check_error(
LIBINFINIOP.infiniopCreatePagedCachingDescriptor(
handle,
ctypes.byref(descriptor),
k.descriptor,
v.descriptor,
k_cache_pool.descriptor,
v_cache_pool.descriptor,
slot_mapping.descriptor,
)
)
# Get workspace size (likely 0 for this operator, but good practice to include)
workspace_size = c_uint64(0)
check_error(
LIBINFINIOP.infiniopGetPagedCachingWorkspaceSize(
descriptor, ctypes.byref(workspace_size)
)
)
workspace = TestWorkspace(workspace_size.value, device)
# Invalidate descriptors to ensure kernel does not rely on them
k.destroy_desc()
v.destroy_desc()
k_cache_pool.destroy_desc()
v_cache_pool.destroy_desc()
slot_mapping.destroy_desc()
# Define the library call as a lambda for profiling
def lib_paged_caching():
check_error(
LIBINFINIOP.infiniopPagedCaching(
descriptor,
workspace.data(),
workspace_size.value,
k.data(),
v.data(),
k_cache_pool.data(),
v_cache_pool.data(),
slot_mapping.data(),
None,
)
)
# Execute the custom operator
lib_paged_caching()
if sync:
sync()
# Verify correctness
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG:
print("Verifying K cache...")
debug(k_cache_pool.actual_tensor(), k_cache_ref, atol=atol, rtol=rtol)
print("Verifying V cache...")
debug(v_cache_pool.actual_tensor(), v_cache_ref, atol=atol, rtol=rtol)
assert torch.allclose(
k_cache_pool.actual_tensor(), k_cache_ref, atol=atol, rtol=rtol
)
assert torch.allclose(
v_cache_pool.actual_tensor(), v_cache_ref, atol=atol, rtol=rtol
)
# Profiling workflow
if PROFILE:
# fmt: off
profile_operation("PyTorch", lambda: ref_paged_caching(
k.torch_tensor(), v.torch_tensor(),
k_cache_pool.torch_tensor(), v_cache_pool.torch_tensor(),
slot_mapping.torch_tensor()),
device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lib_paged_caching, device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on
# Clean up resources
check_error(LIBINFINIOP.infiniopDestroyPagedCachingDescriptor(descriptor))
if __name__ == "__main__":
args = get_args()
# Configure testing options from command line arguments
DEBUG = args.debug
PROFILE = args.profile
NUM_PRERUN = args.num_prerun
NUM_ITERATIONS = args.num_iterations
for device in get_test_devices(args):
test_operator(device, test, _TEST_CASES_, _TENSOR_DTYPES)
print("\033[92mTest passed!\033[0m")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment