Commit 67425576 authored by PanZezhong's avatar PanZezhong
Browse files

issue/1036 paged caching support strides

parent d8176086
...@@ -38,7 +38,11 @@ __device__ void pagedCachingKernel( ...@@ -38,7 +38,11 @@ __device__ void pagedCachingKernel(
const ptrdiff_t k_src_stride, // Stride between tokens in the source K tensor const ptrdiff_t k_src_stride, // Stride between tokens in the source K tensor
const ptrdiff_t v_src_stride, // Stride between tokens in the source V tensor const ptrdiff_t v_src_stride, // Stride between tokens in the source V tensor
const ptrdiff_t k_cache_block_stride, // Stride between blocks in the K cache pool const ptrdiff_t k_cache_block_stride, // Stride between blocks in the K cache pool
const ptrdiff_t v_cache_block_stride // Stride between blocks in the V cache pool const ptrdiff_t v_cache_block_stride, // Stride between blocks in the V cache pool
const ptrdiff_t k_cache_head_stride, // Stride between heads in the K cache pool
const ptrdiff_t v_cache_head_stride, // Stride between heads in the V cache pool
const ptrdiff_t k_cache_slot_stride, // Stride between block slots in the K cache pool
const ptrdiff_t v_cache_slot_stride // Stride between block slots in the V cache pool
) { ) {
//================================================================================ //================================================================================
// 1. Identify Work Unit & Calculate Addresses // 1. Identify Work Unit & Calculate Addresses
...@@ -66,13 +70,11 @@ __device__ void pagedCachingKernel( ...@@ -66,13 +70,11 @@ __device__ void pagedCachingKernel(
// Destination pointer calculation assumes a [num_blocks, block_size, num_heads, head_size] layout. // Destination pointer calculation assumes a [num_blocks, block_size, num_heads, head_size] layout.
// We point to the beginning of the memory region for this token's slot. // We point to the beginning of the memory region for this token's slot.
const ptrdiff_t cache_head_stride = block_size * head_size;
Tdata *k_cache_block_base_ptr = k_cache_ptr + physical_block_idx * k_cache_block_stride; Tdata *k_cache_block_base_ptr = k_cache_ptr + physical_block_idx * k_cache_block_stride;
Tdata *k_dst_head_ptr = k_cache_block_base_ptr + head_idx * cache_head_stride + block_offset * head_size; Tdata *k_dst_head_ptr = k_cache_block_base_ptr + head_idx * k_cache_head_stride + block_offset * k_cache_slot_stride;
Tdata *v_cache_block_base_ptr = v_cache_ptr + physical_block_idx * v_cache_block_stride; Tdata *v_cache_block_base_ptr = v_cache_ptr + physical_block_idx * v_cache_block_stride;
Tdata *v_dst_head_ptr = v_cache_block_base_ptr + head_idx * cache_head_stride + block_offset * head_size; Tdata *v_dst_head_ptr = v_cache_block_base_ptr + head_idx * v_cache_head_stride + block_offset * v_cache_slot_stride;
//================================================================================ //================================================================================
// 2. Perform Element-wise Data Copy (Safe, Non-Vectorized) // 2. Perform Element-wise Data Copy (Safe, Non-Vectorized)
......
...@@ -26,6 +26,10 @@ public: ...@@ -26,6 +26,10 @@ public:
ptrdiff_t v_src_stride; ptrdiff_t v_src_stride;
ptrdiff_t k_cache_block_stride; ptrdiff_t k_cache_block_stride;
ptrdiff_t v_cache_block_stride; ptrdiff_t v_cache_block_stride;
ptrdiff_t k_cache_head_stride;
ptrdiff_t v_cache_head_stride;
ptrdiff_t k_cache_slot_stride;
ptrdiff_t v_cache_slot_stride;
static utils::Result<PagedCachingInfo> create( static utils::Result<PagedCachingInfo> create(
infiniopTensorDescriptor_t k_cache_desc, infiniopTensorDescriptor_t k_cache_desc,
...@@ -63,6 +67,10 @@ public: ...@@ -63,6 +67,10 @@ public:
ptrdiff_t v_src_stride = v_desc->stride(0); ptrdiff_t v_src_stride = v_desc->stride(0);
ptrdiff_t k_cache_block_stride = k_cache_desc->stride(0); ptrdiff_t k_cache_block_stride = k_cache_desc->stride(0);
ptrdiff_t v_cache_block_stride = v_cache_desc->stride(0); ptrdiff_t v_cache_block_stride = v_cache_desc->stride(0);
ptrdiff_t k_cache_head_stride = k_cache_desc->stride(1);
ptrdiff_t v_cache_head_stride = v_cache_desc->stride(1);
ptrdiff_t k_cache_slot_stride = k_cache_desc->stride(2);
ptrdiff_t v_cache_slot_stride = v_cache_desc->stride(2);
return utils::Result<PagedCachingInfo>(PagedCachingInfo{ return utils::Result<PagedCachingInfo>(PagedCachingInfo{
dtype, dtype,
...@@ -73,7 +81,11 @@ public: ...@@ -73,7 +81,11 @@ public:
k_src_stride, k_src_stride,
v_src_stride, v_src_stride,
k_cache_block_stride, k_cache_block_stride,
v_cache_block_stride}); v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride});
} }
}; };
......
...@@ -10,10 +10,13 @@ INFINIOP_METAX_KERNEL pagedCaching( ...@@ -10,10 +10,13 @@ INFINIOP_METAX_KERNEL pagedCaching(
const int64_t *slot_mapping, const int64_t *slot_mapping,
const size_t head_size, const size_t block_size, const size_t head_size, const size_t block_size,
const ptrdiff_t k_src_stride, const ptrdiff_t v_src_stride, const ptrdiff_t k_src_stride, const ptrdiff_t v_src_stride,
const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride) { const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride,
const ptrdiff_t k_cache_head_stride, const ptrdiff_t v_cache_head_stride,
const ptrdiff_t k_cache_slot_stride, const ptrdiff_t v_cache_slot_strid) {
op::paged_caching::cuda::pagedCachingKernel<Tdata, NUM_THREADS>( op::paged_caching::cuda::pagedCachingKernel<Tdata, NUM_THREADS>(
k_cache, v_cache, k, v, slot_mapping, head_size, k_cache, v_cache, k, v, slot_mapping, head_size,
block_size, k_src_stride, v_src_stride, k_cache_block_stride, v_cache_block_stride); block_size, k_src_stride, v_src_stride,
k_cache_block_stride, v_cache_block_stride, k_cache_head_stride, v_cache_head_stride, k_cache_slot_stride, v_cache_slot_stride);
} }
namespace op::paged_caching::metax { namespace op::paged_caching::metax {
...@@ -59,6 +62,8 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info, ...@@ -59,6 +62,8 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
size_t num_tokens, size_t num_kv_heads, size_t head_size, size_t block_size, size_t num_tokens, size_t num_kv_heads, size_t head_size, size_t block_size,
ptrdiff_t k_src_stride, ptrdiff_t v_src_stride, ptrdiff_t k_src_stride, ptrdiff_t v_src_stride,
ptrdiff_t k_cache_block_stride, ptrdiff_t v_cache_block_stride, ptrdiff_t k_cache_block_stride, ptrdiff_t v_cache_block_stride,
ptrdiff_t k_cache_head_stride, ptrdiff_t v_cache_head_stride,
ptrdiff_t k_cache_slot_stride, ptrdiff_t v_cache_slot_stride,
hcStream_t stream) { hcStream_t stream) {
// Grid dimension is 1D, with one block per token, as we decided. // Grid dimension is 1D, with one block per token, as we decided.
...@@ -83,7 +88,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info, ...@@ -83,7 +88,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride, k_src_stride,
v_src_stride, v_src_stride,
k_cache_block_stride, k_cache_block_stride,
v_cache_block_stride); v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride);
} else if (dtype == INFINI_DTYPE_BF16) { } else if (dtype == INFINI_DTYPE_BF16) {
pagedCaching<cuda_bfloat16, NUM_THREADS> pagedCaching<cuda_bfloat16, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>( <<<grid, block, shared_mem_size, stream>>>(
...@@ -97,7 +106,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info, ...@@ -97,7 +106,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride, k_src_stride,
v_src_stride, v_src_stride,
k_cache_block_stride, k_cache_block_stride,
v_cache_block_stride); v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride);
} else if (dtype == INFINI_DTYPE_F32) { } else if (dtype == INFINI_DTYPE_F32) {
pagedCaching<float, NUM_THREADS> pagedCaching<float, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>( <<<grid, block, shared_mem_size, stream>>>(
...@@ -111,7 +124,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info, ...@@ -111,7 +124,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride, k_src_stride,
v_src_stride, v_src_stride,
k_cache_block_stride, k_cache_block_stride,
v_cache_block_stride); v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride);
} else { } else {
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
...@@ -138,6 +155,8 @@ infiniStatus_t Descriptor::calculate( ...@@ -138,6 +155,8 @@ infiniStatus_t Descriptor::calculate(
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size, _info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride, _info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride, _info.k_cache_block_stride, _info.v_cache_block_stride,
_info.k_cache_head_stride, _info.v_cache_head_stride,
_info.k_cache_slot_stride, _info.v_cache_slot_stride,
stream); stream);
} else if (max_threads >= METAX_BLOCK_SIZE_512) { } else if (max_threads >= METAX_BLOCK_SIZE_512) {
launchKernel<METAX_BLOCK_SIZE_512>( launchKernel<METAX_BLOCK_SIZE_512>(
...@@ -145,6 +164,8 @@ infiniStatus_t Descriptor::calculate( ...@@ -145,6 +164,8 @@ infiniStatus_t Descriptor::calculate(
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size, _info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride, _info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride, _info.k_cache_block_stride, _info.v_cache_block_stride,
_info.k_cache_head_stride, _info.v_cache_head_stride,
_info.k_cache_slot_stride, _info.v_cache_slot_stride,
stream); stream);
} else { } else {
// If the device supports fewer threads, return an error. // If the device supports fewer threads, return an error.
......
...@@ -10,10 +10,13 @@ INFINIOP_MOORE_KERNEL pagedCaching( ...@@ -10,10 +10,13 @@ INFINIOP_MOORE_KERNEL pagedCaching(
const int64_t *slot_mapping, const int64_t *slot_mapping,
const size_t head_size, const size_t block_size, const size_t head_size, const size_t block_size,
const ptrdiff_t k_src_stride, const ptrdiff_t v_src_stride, const ptrdiff_t k_src_stride, const ptrdiff_t v_src_stride,
const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride) { const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride,
const ptrdiff_t k_cache_head_stride, const ptrdiff_t v_cache_head_stride,
const ptrdiff_t k_cache_slot_stride, const ptrdiff_t v_cache_slot_stride) {
op::paged_caching::cuda::pagedCachingKernel<Tdata, NUM_THREADS>( op::paged_caching::cuda::pagedCachingKernel<Tdata, NUM_THREADS>(
k_cache, v_cache, k, v, slot_mapping, head_size, k_cache, v_cache, k, v, slot_mapping, head_size,
block_size, k_src_stride, v_src_stride, k_cache_block_stride, v_cache_block_stride); block_size, k_src_stride, v_src_stride,
k_cache_block_stride, v_cache_block_stride, k_cache_head_stride, v_cache_head_stride, k_cache_slot_stride, v_cache_slot_stride);
} }
namespace op::paged_caching::moore { namespace op::paged_caching::moore {
...@@ -59,6 +62,8 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info, ...@@ -59,6 +62,8 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
size_t num_tokens, size_t num_kv_heads, size_t head_size, size_t block_size, size_t num_tokens, size_t num_kv_heads, size_t head_size, size_t block_size,
ptrdiff_t k_src_stride, ptrdiff_t v_src_stride, ptrdiff_t k_src_stride, ptrdiff_t v_src_stride,
ptrdiff_t k_cache_block_stride, ptrdiff_t v_cache_block_stride, ptrdiff_t k_cache_block_stride, ptrdiff_t v_cache_block_stride,
ptrdiff_t k_cache_head_stride, ptrdiff_t v_cache_head_stride,
ptrdiff_t k_cache_slot_stride, ptrdiff_t v_cache_slot_stride,
musaStream_t stream) { musaStream_t stream) {
// Grid dimension is 1D, with one block per token, as we decided. // Grid dimension is 1D, with one block per token, as we decided.
...@@ -83,7 +88,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info, ...@@ -83,7 +88,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride, k_src_stride,
v_src_stride, v_src_stride,
k_cache_block_stride, k_cache_block_stride,
v_cache_block_stride); v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride);
} else if (dtype == INFINI_DTYPE_BF16) { } else if (dtype == INFINI_DTYPE_BF16) {
pagedCaching<__mt_bfloat16, NUM_THREADS> pagedCaching<__mt_bfloat16, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>( <<<grid, block, shared_mem_size, stream>>>(
...@@ -97,7 +106,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info, ...@@ -97,7 +106,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride, k_src_stride,
v_src_stride, v_src_stride,
k_cache_block_stride, k_cache_block_stride,
v_cache_block_stride); v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride);
} else if (dtype == INFINI_DTYPE_F32) { } else if (dtype == INFINI_DTYPE_F32) {
pagedCaching<float, NUM_THREADS> pagedCaching<float, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>( <<<grid, block, shared_mem_size, stream>>>(
...@@ -111,7 +124,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info, ...@@ -111,7 +124,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride, k_src_stride,
v_src_stride, v_src_stride,
k_cache_block_stride, k_cache_block_stride,
v_cache_block_stride); v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride);
} else { } else {
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
...@@ -137,6 +154,8 @@ infiniStatus_t Descriptor::calculate( ...@@ -137,6 +154,8 @@ infiniStatus_t Descriptor::calculate(
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size, _info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride, _info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride, _info.k_cache_block_stride, _info.v_cache_block_stride,
_info.k_cache_head_stride, _info.v_cache_head_stride,
_info.k_cache_slot_stride, _info.v_cache_slot_stride,
stream); stream);
} else if (_opaque->internal->maxThreadsPerBlock() >= MOORE_BLOCK_SIZE_512) { } else if (_opaque->internal->maxThreadsPerBlock() >= MOORE_BLOCK_SIZE_512) {
launchKernel<MOORE_BLOCK_SIZE_512>( launchKernel<MOORE_BLOCK_SIZE_512>(
...@@ -144,6 +163,8 @@ infiniStatus_t Descriptor::calculate( ...@@ -144,6 +163,8 @@ infiniStatus_t Descriptor::calculate(
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size, _info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride, _info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride, _info.k_cache_block_stride, _info.v_cache_block_stride,
_info.k_cache_head_stride, _info.v_cache_head_stride,
_info.k_cache_slot_stride, _info.v_cache_slot_stride,
stream); stream);
} else { } else {
// If the GPU is older and supports fewer threads, return an error. // If the GPU is older and supports fewer threads, return an error.
......
...@@ -10,10 +10,13 @@ INFINIOP_CUDA_KERNEL pagedCaching( ...@@ -10,10 +10,13 @@ INFINIOP_CUDA_KERNEL pagedCaching(
const int64_t *slot_mapping, const int64_t *slot_mapping,
const size_t head_size, const size_t block_size, const size_t head_size, const size_t block_size,
const ptrdiff_t k_src_stride, const ptrdiff_t v_src_stride, const ptrdiff_t k_src_stride, const ptrdiff_t v_src_stride,
const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride) { const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride,
const ptrdiff_t k_cache_head_stride, const ptrdiff_t v_cache_head_stride,
const ptrdiff_t k_cache_slot_stride, const ptrdiff_t v_cache_slot_stride) {
op::paged_caching::cuda::pagedCachingKernel<Tdata, NUM_THREADS>( op::paged_caching::cuda::pagedCachingKernel<Tdata, NUM_THREADS>(
k_cache, v_cache, k, v, slot_mapping, head_size, k_cache, v_cache, k, v, slot_mapping, head_size,
block_size, k_src_stride, v_src_stride, k_cache_block_stride, v_cache_block_stride); block_size, k_src_stride, v_src_stride,
k_cache_block_stride, v_cache_block_stride, k_cache_head_stride, v_cache_head_stride, k_cache_slot_stride, v_cache_slot_stride);
} }
namespace op::paged_caching::nvidia { namespace op::paged_caching::nvidia {
...@@ -59,6 +62,8 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info, ...@@ -59,6 +62,8 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
size_t num_tokens, size_t num_kv_heads, size_t head_size, size_t block_size, size_t num_tokens, size_t num_kv_heads, size_t head_size, size_t block_size,
ptrdiff_t k_src_stride, ptrdiff_t v_src_stride, ptrdiff_t k_src_stride, ptrdiff_t v_src_stride,
ptrdiff_t k_cache_block_stride, ptrdiff_t v_cache_block_stride, ptrdiff_t k_cache_block_stride, ptrdiff_t v_cache_block_stride,
ptrdiff_t k_cache_head_stride, ptrdiff_t v_cache_head_stride,
ptrdiff_t k_cache_slot_stride, ptrdiff_t v_cache_slot_stride,
cudaStream_t stream) { cudaStream_t stream) {
// Grid dimension is 1D, with one block per token, as we decided. // Grid dimension is 1D, with one block per token, as we decided.
...@@ -83,7 +88,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info, ...@@ -83,7 +88,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride, k_src_stride,
v_src_stride, v_src_stride,
k_cache_block_stride, k_cache_block_stride,
v_cache_block_stride); v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride);
} else if (dtype == INFINI_DTYPE_BF16) { } else if (dtype == INFINI_DTYPE_BF16) {
pagedCaching<__nv_bfloat16, NUM_THREADS> pagedCaching<__nv_bfloat16, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>( <<<grid, block, shared_mem_size, stream>>>(
...@@ -97,7 +106,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info, ...@@ -97,7 +106,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride, k_src_stride,
v_src_stride, v_src_stride,
k_cache_block_stride, k_cache_block_stride,
v_cache_block_stride); v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride);
} else if (dtype == INFINI_DTYPE_F32) { } else if (dtype == INFINI_DTYPE_F32) {
pagedCaching<float, NUM_THREADS> pagedCaching<float, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>( <<<grid, block, shared_mem_size, stream>>>(
...@@ -111,7 +124,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info, ...@@ -111,7 +124,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride, k_src_stride,
v_src_stride, v_src_stride,
k_cache_block_stride, k_cache_block_stride,
v_cache_block_stride); v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride);
} else { } else {
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
...@@ -137,6 +154,8 @@ infiniStatus_t Descriptor::calculate( ...@@ -137,6 +154,8 @@ infiniStatus_t Descriptor::calculate(
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size, _info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride, _info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride, _info.k_cache_block_stride, _info.v_cache_block_stride,
_info.k_cache_head_stride, _info.v_cache_head_stride,
_info.k_cache_slot_stride, _info.v_cache_slot_stride,
stream); stream);
} else if (_opaque->internal->maxThreadsPerBlock() >= CUDA_BLOCK_SIZE_512) { } else if (_opaque->internal->maxThreadsPerBlock() >= CUDA_BLOCK_SIZE_512) {
launchKernel<CUDA_BLOCK_SIZE_512>( launchKernel<CUDA_BLOCK_SIZE_512>(
...@@ -144,6 +163,8 @@ infiniStatus_t Descriptor::calculate( ...@@ -144,6 +163,8 @@ infiniStatus_t Descriptor::calculate(
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size, _info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride, _info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride, _info.k_cache_block_stride, _info.v_cache_block_stride,
_info.k_cache_head_stride, _info.v_cache_head_stride,
_info.k_cache_slot_stride, _info.v_cache_slot_stride,
stream); stream);
} else if (_opaque->internal->maxThreadsPerBlock() >= CUDA_BLOCK_SIZE_4096) { } else if (_opaque->internal->maxThreadsPerBlock() >= CUDA_BLOCK_SIZE_4096) {
launchKernel<CUDA_BLOCK_SIZE_4096>( launchKernel<CUDA_BLOCK_SIZE_4096>(
...@@ -151,6 +172,8 @@ infiniStatus_t Descriptor::calculate( ...@@ -151,6 +172,8 @@ infiniStatus_t Descriptor::calculate(
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size, _info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride, _info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride, _info.k_cache_block_stride, _info.v_cache_block_stride,
_info.k_cache_head_stride, _info.v_cache_head_stride,
_info.k_cache_slot_stride, _info.v_cache_slot_stride,
stream); stream);
} else { } else {
// If the GPU is older and supports fewer threads, return an error. // If the GPU is older and supports fewer threads, return an error.
......
...@@ -18,12 +18,16 @@ from framework import ( ...@@ -18,12 +18,16 @@ from framework import (
# Operator-specific configuration # Operator-specific configuration
# ============================================================================== # ==============================================================================
# Test cases format: (num_seqs, max_seq_len, num_kv_heads, head_size, block_size) # Test cases format: (num_seqs, max_seq_len, num_kv_heads, head_size, block_size, permute_dim_1_2)
_TEST_CASES_DATA = [ _TEST_CASES_DATA = [
(1, 128, 8, 128, 16), (1, 128, 8, 128, 16, False),
(5, 512, 40, 128, 16), (1, 128, 8, 128, 16, True),
(16, 1024, 8, 64, 32), (5, 512, 40, 256, 16, False),
(10, 1024, 40, 64, 32), (5, 512, 40, 256, 16, True),
(16, 1024, 8, 64, 32, False),
(16, 1024, 8, 64, 32, True),
(10, 1024, 40, 64, 32, False),
(10, 1024, 40, 64, 32, True),
] ]
# Tolerance configuration # Tolerance configuration
...@@ -40,7 +44,9 @@ _TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32] ...@@ -40,7 +44,9 @@ _TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
# ============================================================================== # ==============================================================================
# Reference Implementation # Reference Implementation
# ============================================================================== # ==============================================================================
def ref_paged_caching(key_cache_pool, value_cache_pool, key, value, slot_mapping): def ref_paged_caching(
key_cache_pool, value_cache_pool, key, value, slot_mapping, permute_dim_1_2
):
""" """
Reference implementation for paged_caching operator. Reference implementation for paged_caching operator.
...@@ -52,7 +58,7 @@ def ref_paged_caching(key_cache_pool, value_cache_pool, key, value, slot_mapping ...@@ -52,7 +58,7 @@ def ref_paged_caching(key_cache_pool, value_cache_pool, key, value, slot_mapping
slot_mapping (torch.Tensor): Slot mapping, shape [ntok] slot_mapping (torch.Tensor): Slot mapping, shape [ntok]
""" """
ntok = key.shape[0] ntok = key.shape[0]
block_size = key_cache_pool.shape[2] block_size = key_cache_pool.shape[1] if permute_dim_1_2 else key_cache_pool.shape[2]
# This reference implementation operates on a cloned cache to avoid modifying the original input tensor, # This reference implementation operates on a cloned cache to avoid modifying the original input tensor,
# mimicking the behavior where the custom operator writes to its output tensor. # mimicking the behavior where the custom operator writes to its output tensor.
...@@ -67,8 +73,12 @@ def ref_paged_caching(key_cache_pool, value_cache_pool, key, value, slot_mapping ...@@ -67,8 +73,12 @@ def ref_paged_caching(key_cache_pool, value_cache_pool, key, value, slot_mapping
key_token = key[i] key_token = key[i]
value_token = value[i] value_token = value[i]
k_cache_ref[block_idx, :, block_offset, :] = key_token if permute_dim_1_2:
v_cache_ref[block_idx, :, block_offset, :] = value_token k_cache_ref[block_idx, block_offset, :, :] = key_token
v_cache_ref[block_idx, block_offset, :, :] = value_token
else:
k_cache_ref[block_idx, :, block_offset, :] = key_token
v_cache_ref[block_idx, :, block_offset, :] = value_token
return k_cache_ref, v_cache_ref return k_cache_ref, v_cache_ref
...@@ -79,7 +89,14 @@ def parse_test_cases(): ...@@ -79,7 +89,14 @@ def parse_test_cases():
Each test case contains all necessary information for execution and validation. Each test case contains all necessary information for execution and validation.
""" """
test_cases = [] test_cases = []
for num_seqs, max_seq_len, num_kv_heads, head_size, block_size in _TEST_CASES_DATA: for (
num_seqs,
max_seq_len,
num_kv_heads,
head_size,
block_size,
permute_dim_1_2,
) in _TEST_CASES_DATA:
num_blocks = 4096 # A reasonably large cache pool for testing num_blocks = 4096 # A reasonably large cache pool for testing
# Create metadata: variable context lengths for each sequence in the batch # Create metadata: variable context lengths for each sequence in the batch
...@@ -111,6 +128,9 @@ def parse_test_cases(): ...@@ -111,6 +128,9 @@ def parse_test_cases():
v_shape = (ntok, num_kv_heads, head_size) v_shape = (ntok, num_kv_heads, head_size)
k_cache_shape = (num_blocks, num_kv_heads, block_size, head_size) k_cache_shape = (num_blocks, num_kv_heads, block_size, head_size)
v_cache_shape = (num_blocks, num_kv_heads, block_size, head_size) v_cache_shape = (num_blocks, num_kv_heads, block_size, head_size)
if permute_dim_1_2:
k_cache_shape = (num_blocks, block_size, num_kv_heads, head_size)
v_cache_shape = (num_blocks, block_size, num_kv_heads, head_size)
# Generate test cases for all data types # Generate test cases for all data types
for dtype in _TENSOR_DTYPES: for dtype in _TENSOR_DTYPES:
...@@ -142,7 +162,7 @@ def parse_test_cases(): ...@@ -142,7 +162,7 @@ def parse_test_cases():
v_spec, v_spec,
slot_mapping_spec, slot_mapping_spec,
], ],
kwargs=None, kwargs={"permute_dim_1_2": permute_dim_1_2},
output_spec=None, output_spec=None,
comparison_target=0, # Only compare k_cache comparison_target=0, # Only compare k_cache
tolerance=tolerance, tolerance=tolerance,
...@@ -162,13 +182,22 @@ class OpTest(BaseOperatorTest): ...@@ -162,13 +182,22 @@ class OpTest(BaseOperatorTest):
def get_test_cases(self): def get_test_cases(self):
return parse_test_cases() return parse_test_cases()
def torch_operator(self, *args, **kwargs): def torch_operator(
self, k_cache, v_cache, key, value, slot_mapping, permute_dim_1_2=False
):
"""PyTorch paged_caching implementation""" """PyTorch paged_caching implementation"""
return ref_paged_caching(*args, **kwargs) return ref_paged_caching(
k_cache, v_cache, key, value, slot_mapping, permute_dim_1_2
)
def infinicore_operator(self, *args, **kwargs): def infinicore_operator(
self, k_cache, v_cache, key, value, slot_mapping, permute_dim_1_2=False
):
"""InfiniCore paged_caching implementation""" """InfiniCore paged_caching implementation"""
return infinicore.paged_caching(*args, **kwargs) if permute_dim_1_2:
k_cache = k_cache.permute([0, 2, 1, 3])
v_cache = v_cache.permute([0, 2, 1, 3])
return infinicore.paged_caching(k_cache, v_cache, key, value, slot_mapping)
def main(): def main():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment