Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
67425576
Commit
67425576
authored
Feb 28, 2026
by
PanZezhong
Browse files
issue/1036 paged caching support strides
parent
d8176086
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
144 additions
and
36 deletions
+144
-36
src/infiniop/ops/paged_caching/cuda/kernel.cuh
src/infiniop/ops/paged_caching/cuda/kernel.cuh
+7
-5
src/infiniop/ops/paged_caching/info.h
src/infiniop/ops/paged_caching/info.h
+13
-1
src/infiniop/ops/paged_caching/metax/paged_caching_metax.maca
...infiniop/ops/paged_caching/metax/paged_caching_metax.maca
+26
-5
src/infiniop/ops/paged_caching/moore/paged_caching_moore.mu
src/infiniop/ops/paged_caching/moore/paged_caching_moore.mu
+26
-5
src/infiniop/ops/paged_caching/nvidia/paged_caching_nvidia.cu
...infiniop/ops/paged_caching/nvidia/paged_caching_nvidia.cu
+28
-5
test/infinicore/ops/paged_caching.py
test/infinicore/ops/paged_caching.py
+44
-15
No files found.
src/infiniop/ops/paged_caching/cuda/kernel.cuh
View file @
67425576
...
@@ -38,7 +38,11 @@ __device__ void pagedCachingKernel(
...
@@ -38,7 +38,11 @@ __device__ void pagedCachingKernel(
const
ptrdiff_t
k_src_stride
,
// Stride between tokens in the source K tensor
const
ptrdiff_t
k_src_stride
,
// Stride between tokens in the source K tensor
const
ptrdiff_t
v_src_stride
,
// Stride between tokens in the source V tensor
const
ptrdiff_t
v_src_stride
,
// Stride between tokens in the source V tensor
const
ptrdiff_t
k_cache_block_stride
,
// Stride between blocks in the K cache pool
const
ptrdiff_t
k_cache_block_stride
,
// Stride between blocks in the K cache pool
const
ptrdiff_t
v_cache_block_stride
// Stride between blocks in the V cache pool
const
ptrdiff_t
v_cache_block_stride
,
// Stride between blocks in the V cache pool
const
ptrdiff_t
k_cache_head_stride
,
// Stride between heads in the K cache pool
const
ptrdiff_t
v_cache_head_stride
,
// Stride between heads in the V cache pool
const
ptrdiff_t
k_cache_slot_stride
,
// Stride between block slots in the K cache pool
const
ptrdiff_t
v_cache_slot_stride
// Stride between block slots in the V cache pool
)
{
)
{
//================================================================================
//================================================================================
// 1. Identify Work Unit & Calculate Addresses
// 1. Identify Work Unit & Calculate Addresses
...
@@ -66,13 +70,11 @@ __device__ void pagedCachingKernel(
...
@@ -66,13 +70,11 @@ __device__ void pagedCachingKernel(
// Destination pointer calculation assumes a [num_blocks, block_size, num_heads, head_size] layout.
// Destination pointer calculation assumes a [num_blocks, block_size, num_heads, head_size] layout.
// We point to the beginning of the memory region for this token's slot.
// We point to the beginning of the memory region for this token's slot.
const
ptrdiff_t
cache_head_stride
=
block_size
*
head_size
;
Tdata
*
k_cache_block_base_ptr
=
k_cache_ptr
+
physical_block_idx
*
k_cache_block_stride
;
Tdata
*
k_cache_block_base_ptr
=
k_cache_ptr
+
physical_block_idx
*
k_cache_block_stride
;
Tdata
*
k_dst_head_ptr
=
k_cache_block_base_ptr
+
head_idx
*
cache_head_stride
+
block_offset
*
head_siz
e
;
Tdata
*
k_dst_head_ptr
=
k_cache_block_base_ptr
+
head_idx
*
k_
cache_head_stride
+
block_offset
*
k_cache_slot_strid
e
;
Tdata
*
v_cache_block_base_ptr
=
v_cache_ptr
+
physical_block_idx
*
v_cache_block_stride
;
Tdata
*
v_cache_block_base_ptr
=
v_cache_ptr
+
physical_block_idx
*
v_cache_block_stride
;
Tdata
*
v_dst_head_ptr
=
v_cache_block_base_ptr
+
head_idx
*
cache_head_stride
+
block_offset
*
head_siz
e
;
Tdata
*
v_dst_head_ptr
=
v_cache_block_base_ptr
+
head_idx
*
v_
cache_head_stride
+
block_offset
*
v_cache_slot_strid
e
;
//================================================================================
//================================================================================
// 2. Perform Element-wise Data Copy (Safe, Non-Vectorized)
// 2. Perform Element-wise Data Copy (Safe, Non-Vectorized)
...
...
src/infiniop/ops/paged_caching/info.h
View file @
67425576
...
@@ -26,6 +26,10 @@ public:
...
@@ -26,6 +26,10 @@ public:
ptrdiff_t
v_src_stride
;
ptrdiff_t
v_src_stride
;
ptrdiff_t
k_cache_block_stride
;
ptrdiff_t
k_cache_block_stride
;
ptrdiff_t
v_cache_block_stride
;
ptrdiff_t
v_cache_block_stride
;
ptrdiff_t
k_cache_head_stride
;
ptrdiff_t
v_cache_head_stride
;
ptrdiff_t
k_cache_slot_stride
;
ptrdiff_t
v_cache_slot_stride
;
static
utils
::
Result
<
PagedCachingInfo
>
create
(
static
utils
::
Result
<
PagedCachingInfo
>
create
(
infiniopTensorDescriptor_t
k_cache_desc
,
infiniopTensorDescriptor_t
k_cache_desc
,
...
@@ -63,6 +67,10 @@ public:
...
@@ -63,6 +67,10 @@ public:
ptrdiff_t
v_src_stride
=
v_desc
->
stride
(
0
);
ptrdiff_t
v_src_stride
=
v_desc
->
stride
(
0
);
ptrdiff_t
k_cache_block_stride
=
k_cache_desc
->
stride
(
0
);
ptrdiff_t
k_cache_block_stride
=
k_cache_desc
->
stride
(
0
);
ptrdiff_t
v_cache_block_stride
=
v_cache_desc
->
stride
(
0
);
ptrdiff_t
v_cache_block_stride
=
v_cache_desc
->
stride
(
0
);
ptrdiff_t
k_cache_head_stride
=
k_cache_desc
->
stride
(
1
);
ptrdiff_t
v_cache_head_stride
=
v_cache_desc
->
stride
(
1
);
ptrdiff_t
k_cache_slot_stride
=
k_cache_desc
->
stride
(
2
);
ptrdiff_t
v_cache_slot_stride
=
v_cache_desc
->
stride
(
2
);
return
utils
::
Result
<
PagedCachingInfo
>
(
PagedCachingInfo
{
return
utils
::
Result
<
PagedCachingInfo
>
(
PagedCachingInfo
{
dtype
,
dtype
,
...
@@ -73,7 +81,11 @@ public:
...
@@ -73,7 +81,11 @@ public:
k_src_stride
,
k_src_stride
,
v_src_stride
,
v_src_stride
,
k_cache_block_stride
,
k_cache_block_stride
,
v_cache_block_stride
});
v_cache_block_stride
,
k_cache_head_stride
,
v_cache_head_stride
,
k_cache_slot_stride
,
v_cache_slot_stride
});
}
}
};
};
...
...
src/infiniop/ops/paged_caching/metax/paged_caching_metax.maca
View file @
67425576
...
@@ -10,10 +10,13 @@ INFINIOP_METAX_KERNEL pagedCaching(
...
@@ -10,10 +10,13 @@ INFINIOP_METAX_KERNEL pagedCaching(
const int64_t *slot_mapping,
const int64_t *slot_mapping,
const size_t head_size, const size_t block_size,
const size_t head_size, const size_t block_size,
const ptrdiff_t k_src_stride, const ptrdiff_t v_src_stride,
const ptrdiff_t k_src_stride, const ptrdiff_t v_src_stride,
const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride) {
const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride,
const ptrdiff_t k_cache_head_stride, const ptrdiff_t v_cache_head_stride,
const ptrdiff_t k_cache_slot_stride, const ptrdiff_t v_cache_slot_strid) {
op::paged_caching::cuda::pagedCachingKernel<Tdata, NUM_THREADS>(
op::paged_caching::cuda::pagedCachingKernel<Tdata, NUM_THREADS>(
k_cache, v_cache, k, v, slot_mapping, head_size,
k_cache, v_cache, k, v, slot_mapping, head_size,
block_size, k_src_stride, v_src_stride, k_cache_block_stride, v_cache_block_stride);
block_size, k_src_stride, v_src_stride,
k_cache_block_stride, v_cache_block_stride, k_cache_head_stride, v_cache_head_stride, k_cache_slot_stride, v_cache_slot_stride);
}
}
namespace op::paged_caching::metax {
namespace op::paged_caching::metax {
...
@@ -59,6 +62,8 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
...
@@ -59,6 +62,8 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
size_t num_tokens, size_t num_kv_heads, size_t head_size, size_t block_size,
size_t num_tokens, size_t num_kv_heads, size_t head_size, size_t block_size,
ptrdiff_t k_src_stride, ptrdiff_t v_src_stride,
ptrdiff_t k_src_stride, ptrdiff_t v_src_stride,
ptrdiff_t k_cache_block_stride, ptrdiff_t v_cache_block_stride,
ptrdiff_t k_cache_block_stride, ptrdiff_t v_cache_block_stride,
ptrdiff_t k_cache_head_stride, ptrdiff_t v_cache_head_stride,
ptrdiff_t k_cache_slot_stride, ptrdiff_t v_cache_slot_stride,
hcStream_t stream) {
hcStream_t stream) {
// Grid dimension is 1D, with one block per token, as we decided.
// Grid dimension is 1D, with one block per token, as we decided.
...
@@ -83,7 +88,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
...
@@ -83,7 +88,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride,
k_src_stride,
v_src_stride,
v_src_stride,
k_cache_block_stride,
k_cache_block_stride,
v_cache_block_stride);
v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride);
} else if (dtype == INFINI_DTYPE_BF16) {
} else if (dtype == INFINI_DTYPE_BF16) {
pagedCaching<cuda_bfloat16, NUM_THREADS>
pagedCaching<cuda_bfloat16, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
<<<grid, block, shared_mem_size, stream>>>(
...
@@ -97,7 +106,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
...
@@ -97,7 +106,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride,
k_src_stride,
v_src_stride,
v_src_stride,
k_cache_block_stride,
k_cache_block_stride,
v_cache_block_stride);
v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride);
} else if (dtype == INFINI_DTYPE_F32) {
} else if (dtype == INFINI_DTYPE_F32) {
pagedCaching<float, NUM_THREADS>
pagedCaching<float, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
<<<grid, block, shared_mem_size, stream>>>(
...
@@ -111,7 +124,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
...
@@ -111,7 +124,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride,
k_src_stride,
v_src_stride,
v_src_stride,
k_cache_block_stride,
k_cache_block_stride,
v_cache_block_stride);
v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride);
} else {
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
}
...
@@ -138,6 +155,8 @@ infiniStatus_t Descriptor::calculate(
...
@@ -138,6 +155,8 @@ infiniStatus_t Descriptor::calculate(
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride,
_info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride,
_info.k_cache_head_stride, _info.v_cache_head_stride,
_info.k_cache_slot_stride, _info.v_cache_slot_stride,
stream);
stream);
} else if (max_threads >= METAX_BLOCK_SIZE_512) {
} else if (max_threads >= METAX_BLOCK_SIZE_512) {
launchKernel<METAX_BLOCK_SIZE_512>(
launchKernel<METAX_BLOCK_SIZE_512>(
...
@@ -145,6 +164,8 @@ infiniStatus_t Descriptor::calculate(
...
@@ -145,6 +164,8 @@ infiniStatus_t Descriptor::calculate(
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride,
_info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride,
_info.k_cache_head_stride, _info.v_cache_head_stride,
_info.k_cache_slot_stride, _info.v_cache_slot_stride,
stream);
stream);
} else {
} else {
// If the device supports fewer threads, return an error.
// If the device supports fewer threads, return an error.
...
...
src/infiniop/ops/paged_caching/moore/paged_caching_moore.mu
View file @
67425576
...
@@ -10,10 +10,13 @@ INFINIOP_MOORE_KERNEL pagedCaching(
...
@@ -10,10 +10,13 @@ INFINIOP_MOORE_KERNEL pagedCaching(
const int64_t *slot_mapping,
const int64_t *slot_mapping,
const size_t head_size, const size_t block_size,
const size_t head_size, const size_t block_size,
const ptrdiff_t k_src_stride, const ptrdiff_t v_src_stride,
const ptrdiff_t k_src_stride, const ptrdiff_t v_src_stride,
const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride) {
const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride,
const ptrdiff_t k_cache_head_stride, const ptrdiff_t v_cache_head_stride,
const ptrdiff_t k_cache_slot_stride, const ptrdiff_t v_cache_slot_stride) {
op::paged_caching::cuda::pagedCachingKernel<Tdata, NUM_THREADS>(
op::paged_caching::cuda::pagedCachingKernel<Tdata, NUM_THREADS>(
k_cache, v_cache, k, v, slot_mapping, head_size,
k_cache, v_cache, k, v, slot_mapping, head_size,
block_size, k_src_stride, v_src_stride, k_cache_block_stride, v_cache_block_stride);
block_size, k_src_stride, v_src_stride,
k_cache_block_stride, v_cache_block_stride, k_cache_head_stride, v_cache_head_stride, k_cache_slot_stride, v_cache_slot_stride);
}
}
namespace op::paged_caching::moore {
namespace op::paged_caching::moore {
...
@@ -59,6 +62,8 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
...
@@ -59,6 +62,8 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
size_t num_tokens, size_t num_kv_heads, size_t head_size, size_t block_size,
size_t num_tokens, size_t num_kv_heads, size_t head_size, size_t block_size,
ptrdiff_t k_src_stride, ptrdiff_t v_src_stride,
ptrdiff_t k_src_stride, ptrdiff_t v_src_stride,
ptrdiff_t k_cache_block_stride, ptrdiff_t v_cache_block_stride,
ptrdiff_t k_cache_block_stride, ptrdiff_t v_cache_block_stride,
ptrdiff_t k_cache_head_stride, ptrdiff_t v_cache_head_stride,
ptrdiff_t k_cache_slot_stride, ptrdiff_t v_cache_slot_stride,
musaStream_t stream) {
musaStream_t stream) {
// Grid dimension is 1D, with one block per token, as we decided.
// Grid dimension is 1D, with one block per token, as we decided.
...
@@ -83,7 +88,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
...
@@ -83,7 +88,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride,
k_src_stride,
v_src_stride,
v_src_stride,
k_cache_block_stride,
k_cache_block_stride,
v_cache_block_stride);
v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride);
} else if (dtype == INFINI_DTYPE_BF16) {
} else if (dtype == INFINI_DTYPE_BF16) {
pagedCaching<__mt_bfloat16, NUM_THREADS>
pagedCaching<__mt_bfloat16, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
<<<grid, block, shared_mem_size, stream>>>(
...
@@ -97,7 +106,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
...
@@ -97,7 +106,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride,
k_src_stride,
v_src_stride,
v_src_stride,
k_cache_block_stride,
k_cache_block_stride,
v_cache_block_stride);
v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride);
} else if (dtype == INFINI_DTYPE_F32) {
} else if (dtype == INFINI_DTYPE_F32) {
pagedCaching<float, NUM_THREADS>
pagedCaching<float, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
<<<grid, block, shared_mem_size, stream>>>(
...
@@ -111,7 +124,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
...
@@ -111,7 +124,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride,
k_src_stride,
v_src_stride,
v_src_stride,
k_cache_block_stride,
k_cache_block_stride,
v_cache_block_stride);
v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride);
} else {
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
}
...
@@ -137,6 +154,8 @@ infiniStatus_t Descriptor::calculate(
...
@@ -137,6 +154,8 @@ infiniStatus_t Descriptor::calculate(
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride,
_info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride,
_info.k_cache_head_stride, _info.v_cache_head_stride,
_info.k_cache_slot_stride, _info.v_cache_slot_stride,
stream);
stream);
} else if (_opaque->internal->maxThreadsPerBlock() >= MOORE_BLOCK_SIZE_512) {
} else if (_opaque->internal->maxThreadsPerBlock() >= MOORE_BLOCK_SIZE_512) {
launchKernel<MOORE_BLOCK_SIZE_512>(
launchKernel<MOORE_BLOCK_SIZE_512>(
...
@@ -144,6 +163,8 @@ infiniStatus_t Descriptor::calculate(
...
@@ -144,6 +163,8 @@ infiniStatus_t Descriptor::calculate(
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride,
_info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride,
_info.k_cache_head_stride, _info.v_cache_head_stride,
_info.k_cache_slot_stride, _info.v_cache_slot_stride,
stream);
stream);
} else {
} else {
// If the GPU is older and supports fewer threads, return an error.
// If the GPU is older and supports fewer threads, return an error.
...
...
src/infiniop/ops/paged_caching/nvidia/paged_caching_nvidia.cu
View file @
67425576
...
@@ -10,10 +10,13 @@ INFINIOP_CUDA_KERNEL pagedCaching(
...
@@ -10,10 +10,13 @@ INFINIOP_CUDA_KERNEL pagedCaching(
const
int64_t
*
slot_mapping
,
const
int64_t
*
slot_mapping
,
const
size_t
head_size
,
const
size_t
block_size
,
const
size_t
head_size
,
const
size_t
block_size
,
const
ptrdiff_t
k_src_stride
,
const
ptrdiff_t
v_src_stride
,
const
ptrdiff_t
k_src_stride
,
const
ptrdiff_t
v_src_stride
,
const
ptrdiff_t
k_cache_block_stride
,
const
ptrdiff_t
v_cache_block_stride
)
{
const
ptrdiff_t
k_cache_block_stride
,
const
ptrdiff_t
v_cache_block_stride
,
const
ptrdiff_t
k_cache_head_stride
,
const
ptrdiff_t
v_cache_head_stride
,
const
ptrdiff_t
k_cache_slot_stride
,
const
ptrdiff_t
v_cache_slot_stride
)
{
op
::
paged_caching
::
cuda
::
pagedCachingKernel
<
Tdata
,
NUM_THREADS
>
(
op
::
paged_caching
::
cuda
::
pagedCachingKernel
<
Tdata
,
NUM_THREADS
>
(
k_cache
,
v_cache
,
k
,
v
,
slot_mapping
,
head_size
,
k_cache
,
v_cache
,
k
,
v
,
slot_mapping
,
head_size
,
block_size
,
k_src_stride
,
v_src_stride
,
k_cache_block_stride
,
v_cache_block_stride
);
block_size
,
k_src_stride
,
v_src_stride
,
k_cache_block_stride
,
v_cache_block_stride
,
k_cache_head_stride
,
v_cache_head_stride
,
k_cache_slot_stride
,
v_cache_slot_stride
);
}
}
namespace
op
::
paged_caching
::
nvidia
{
namespace
op
::
paged_caching
::
nvidia
{
...
@@ -59,6 +62,8 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
...
@@ -59,6 +62,8 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
size_t
num_tokens
,
size_t
num_kv_heads
,
size_t
head_size
,
size_t
block_size
,
size_t
num_tokens
,
size_t
num_kv_heads
,
size_t
head_size
,
size_t
block_size
,
ptrdiff_t
k_src_stride
,
ptrdiff_t
v_src_stride
,
ptrdiff_t
k_src_stride
,
ptrdiff_t
v_src_stride
,
ptrdiff_t
k_cache_block_stride
,
ptrdiff_t
v_cache_block_stride
,
ptrdiff_t
k_cache_block_stride
,
ptrdiff_t
v_cache_block_stride
,
ptrdiff_t
k_cache_head_stride
,
ptrdiff_t
v_cache_head_stride
,
ptrdiff_t
k_cache_slot_stride
,
ptrdiff_t
v_cache_slot_stride
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
// Grid dimension is 1D, with one block per token, as we decided.
// Grid dimension is 1D, with one block per token, as we decided.
...
@@ -83,7 +88,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
...
@@ -83,7 +88,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride
,
k_src_stride
,
v_src_stride
,
v_src_stride
,
k_cache_block_stride
,
k_cache_block_stride
,
v_cache_block_stride
);
v_cache_block_stride
,
k_cache_head_stride
,
v_cache_head_stride
,
k_cache_slot_stride
,
v_cache_slot_stride
);
}
else
if
(
dtype
==
INFINI_DTYPE_BF16
)
{
}
else
if
(
dtype
==
INFINI_DTYPE_BF16
)
{
pagedCaching
<
__nv_bfloat16
,
NUM_THREADS
>
pagedCaching
<
__nv_bfloat16
,
NUM_THREADS
>
<<<
grid
,
block
,
shared_mem_size
,
stream
>>>
(
<<<
grid
,
block
,
shared_mem_size
,
stream
>>>
(
...
@@ -97,7 +106,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
...
@@ -97,7 +106,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride
,
k_src_stride
,
v_src_stride
,
v_src_stride
,
k_cache_block_stride
,
k_cache_block_stride
,
v_cache_block_stride
);
v_cache_block_stride
,
k_cache_head_stride
,
v_cache_head_stride
,
k_cache_slot_stride
,
v_cache_slot_stride
);
}
else
if
(
dtype
==
INFINI_DTYPE_F32
)
{
}
else
if
(
dtype
==
INFINI_DTYPE_F32
)
{
pagedCaching
<
float
,
NUM_THREADS
>
pagedCaching
<
float
,
NUM_THREADS
>
<<<
grid
,
block
,
shared_mem_size
,
stream
>>>
(
<<<
grid
,
block
,
shared_mem_size
,
stream
>>>
(
...
@@ -111,7 +124,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
...
@@ -111,7 +124,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride
,
k_src_stride
,
v_src_stride
,
v_src_stride
,
k_cache_block_stride
,
k_cache_block_stride
,
v_cache_block_stride
);
v_cache_block_stride
,
k_cache_head_stride
,
v_cache_head_stride
,
k_cache_slot_stride
,
v_cache_slot_stride
);
}
else
{
}
else
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
}
...
@@ -137,6 +154,8 @@ infiniStatus_t Descriptor::calculate(
...
@@ -137,6 +154,8 @@ infiniStatus_t Descriptor::calculate(
_info
.
num_tokens
,
_info
.
num_kv_heads
,
_info
.
head_size
,
_info
.
block_size
,
_info
.
num_tokens
,
_info
.
num_kv_heads
,
_info
.
head_size
,
_info
.
block_size
,
_info
.
k_src_stride
,
_info
.
v_src_stride
,
_info
.
k_src_stride
,
_info
.
v_src_stride
,
_info
.
k_cache_block_stride
,
_info
.
v_cache_block_stride
,
_info
.
k_cache_block_stride
,
_info
.
v_cache_block_stride
,
_info
.
k_cache_head_stride
,
_info
.
v_cache_head_stride
,
_info
.
k_cache_slot_stride
,
_info
.
v_cache_slot_stride
,
stream
);
stream
);
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
>=
CUDA_BLOCK_SIZE_512
)
{
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
>=
CUDA_BLOCK_SIZE_512
)
{
launchKernel
<
CUDA_BLOCK_SIZE_512
>
(
launchKernel
<
CUDA_BLOCK_SIZE_512
>
(
...
@@ -144,6 +163,8 @@ infiniStatus_t Descriptor::calculate(
...
@@ -144,6 +163,8 @@ infiniStatus_t Descriptor::calculate(
_info
.
num_tokens
,
_info
.
num_kv_heads
,
_info
.
head_size
,
_info
.
block_size
,
_info
.
num_tokens
,
_info
.
num_kv_heads
,
_info
.
head_size
,
_info
.
block_size
,
_info
.
k_src_stride
,
_info
.
v_src_stride
,
_info
.
k_src_stride
,
_info
.
v_src_stride
,
_info
.
k_cache_block_stride
,
_info
.
v_cache_block_stride
,
_info
.
k_cache_block_stride
,
_info
.
v_cache_block_stride
,
_info
.
k_cache_head_stride
,
_info
.
v_cache_head_stride
,
_info
.
k_cache_slot_stride
,
_info
.
v_cache_slot_stride
,
stream
);
stream
);
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
>=
CUDA_BLOCK_SIZE_4096
)
{
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
>=
CUDA_BLOCK_SIZE_4096
)
{
launchKernel
<
CUDA_BLOCK_SIZE_4096
>
(
launchKernel
<
CUDA_BLOCK_SIZE_4096
>
(
...
@@ -151,6 +172,8 @@ infiniStatus_t Descriptor::calculate(
...
@@ -151,6 +172,8 @@ infiniStatus_t Descriptor::calculate(
_info
.
num_tokens
,
_info
.
num_kv_heads
,
_info
.
head_size
,
_info
.
block_size
,
_info
.
num_tokens
,
_info
.
num_kv_heads
,
_info
.
head_size
,
_info
.
block_size
,
_info
.
k_src_stride
,
_info
.
v_src_stride
,
_info
.
k_src_stride
,
_info
.
v_src_stride
,
_info
.
k_cache_block_stride
,
_info
.
v_cache_block_stride
,
_info
.
k_cache_block_stride
,
_info
.
v_cache_block_stride
,
_info
.
k_cache_head_stride
,
_info
.
v_cache_head_stride
,
_info
.
k_cache_slot_stride
,
_info
.
v_cache_slot_stride
,
stream
);
stream
);
}
else
{
}
else
{
// If the GPU is older and supports fewer threads, return an error.
// If the GPU is older and supports fewer threads, return an error.
...
...
test/infinicore/ops/paged_caching.py
View file @
67425576
...
@@ -18,12 +18,16 @@ from framework import (
...
@@ -18,12 +18,16 @@ from framework import (
# Operator-specific configuration
# Operator-specific configuration
# ==============================================================================
# ==============================================================================
# Test cases format: (num_seqs, max_seq_len, num_kv_heads, head_size, block_size)
# Test cases format: (num_seqs, max_seq_len, num_kv_heads, head_size, block_size
, permute_dim_1_2
)
_TEST_CASES_DATA
=
[
_TEST_CASES_DATA
=
[
(
1
,
128
,
8
,
128
,
16
),
(
1
,
128
,
8
,
128
,
16
,
False
),
(
5
,
512
,
40
,
128
,
16
),
(
1
,
128
,
8
,
128
,
16
,
True
),
(
16
,
1024
,
8
,
64
,
32
),
(
5
,
512
,
40
,
256
,
16
,
False
),
(
10
,
1024
,
40
,
64
,
32
),
(
5
,
512
,
40
,
256
,
16
,
True
),
(
16
,
1024
,
8
,
64
,
32
,
False
),
(
16
,
1024
,
8
,
64
,
32
,
True
),
(
10
,
1024
,
40
,
64
,
32
,
False
),
(
10
,
1024
,
40
,
64
,
32
,
True
),
]
]
# Tolerance configuration
# Tolerance configuration
...
@@ -40,7 +44,9 @@ _TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
...
@@ -40,7 +44,9 @@ _TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
# ==============================================================================
# ==============================================================================
# Reference Implementation
# Reference Implementation
# ==============================================================================
# ==============================================================================
def
ref_paged_caching
(
key_cache_pool
,
value_cache_pool
,
key
,
value
,
slot_mapping
):
def
ref_paged_caching
(
key_cache_pool
,
value_cache_pool
,
key
,
value
,
slot_mapping
,
permute_dim_1_2
):
"""
"""
Reference implementation for paged_caching operator.
Reference implementation for paged_caching operator.
...
@@ -52,7 +58,7 @@ def ref_paged_caching(key_cache_pool, value_cache_pool, key, value, slot_mapping
...
@@ -52,7 +58,7 @@ def ref_paged_caching(key_cache_pool, value_cache_pool, key, value, slot_mapping
slot_mapping (torch.Tensor): Slot mapping, shape [ntok]
slot_mapping (torch.Tensor): Slot mapping, shape [ntok]
"""
"""
ntok
=
key
.
shape
[
0
]
ntok
=
key
.
shape
[
0
]
block_size
=
key_cache_pool
.
shape
[
2
]
block_size
=
key_cache_pool
.
shape
[
1
]
if
permute_dim_1_2
else
key_cache_pool
.
shape
[
2
]
# This reference implementation operates on a cloned cache to avoid modifying the original input tensor,
# This reference implementation operates on a cloned cache to avoid modifying the original input tensor,
# mimicking the behavior where the custom operator writes to its output tensor.
# mimicking the behavior where the custom operator writes to its output tensor.
...
@@ -67,8 +73,12 @@ def ref_paged_caching(key_cache_pool, value_cache_pool, key, value, slot_mapping
...
@@ -67,8 +73,12 @@ def ref_paged_caching(key_cache_pool, value_cache_pool, key, value, slot_mapping
key_token
=
key
[
i
]
key_token
=
key
[
i
]
value_token
=
value
[
i
]
value_token
=
value
[
i
]
k_cache_ref
[
block_idx
,
:,
block_offset
,
:]
=
key_token
if
permute_dim_1_2
:
v_cache_ref
[
block_idx
,
:,
block_offset
,
:]
=
value_token
k_cache_ref
[
block_idx
,
block_offset
,
:,
:]
=
key_token
v_cache_ref
[
block_idx
,
block_offset
,
:,
:]
=
value_token
else
:
k_cache_ref
[
block_idx
,
:,
block_offset
,
:]
=
key_token
v_cache_ref
[
block_idx
,
:,
block_offset
,
:]
=
value_token
return
k_cache_ref
,
v_cache_ref
return
k_cache_ref
,
v_cache_ref
...
@@ -79,7 +89,14 @@ def parse_test_cases():
...
@@ -79,7 +89,14 @@ def parse_test_cases():
Each test case contains all necessary information for execution and validation.
Each test case contains all necessary information for execution and validation.
"""
"""
test_cases
=
[]
test_cases
=
[]
for
num_seqs
,
max_seq_len
,
num_kv_heads
,
head_size
,
block_size
in
_TEST_CASES_DATA
:
for
(
num_seqs
,
max_seq_len
,
num_kv_heads
,
head_size
,
block_size
,
permute_dim_1_2
,
)
in
_TEST_CASES_DATA
:
num_blocks
=
4096
# A reasonably large cache pool for testing
num_blocks
=
4096
# A reasonably large cache pool for testing
# Create metadata: variable context lengths for each sequence in the batch
# Create metadata: variable context lengths for each sequence in the batch
...
@@ -111,6 +128,9 @@ def parse_test_cases():
...
@@ -111,6 +128,9 @@ def parse_test_cases():
v_shape
=
(
ntok
,
num_kv_heads
,
head_size
)
v_shape
=
(
ntok
,
num_kv_heads
,
head_size
)
k_cache_shape
=
(
num_blocks
,
num_kv_heads
,
block_size
,
head_size
)
k_cache_shape
=
(
num_blocks
,
num_kv_heads
,
block_size
,
head_size
)
v_cache_shape
=
(
num_blocks
,
num_kv_heads
,
block_size
,
head_size
)
v_cache_shape
=
(
num_blocks
,
num_kv_heads
,
block_size
,
head_size
)
if
permute_dim_1_2
:
k_cache_shape
=
(
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
v_cache_shape
=
(
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
# Generate test cases for all data types
# Generate test cases for all data types
for
dtype
in
_TENSOR_DTYPES
:
for
dtype
in
_TENSOR_DTYPES
:
...
@@ -142,7 +162,7 @@ def parse_test_cases():
...
@@ -142,7 +162,7 @@ def parse_test_cases():
v_spec
,
v_spec
,
slot_mapping_spec
,
slot_mapping_spec
,
],
],
kwargs
=
None
,
kwargs
=
{
"permute_dim_1_2"
:
permute_dim_1_2
}
,
output_spec
=
None
,
output_spec
=
None
,
comparison_target
=
0
,
# Only compare k_cache
comparison_target
=
0
,
# Only compare k_cache
tolerance
=
tolerance
,
tolerance
=
tolerance
,
...
@@ -162,13 +182,22 @@ class OpTest(BaseOperatorTest):
...
@@ -162,13 +182,22 @@ class OpTest(BaseOperatorTest):
def
get_test_cases
(
self
):
def
get_test_cases
(
self
):
return
parse_test_cases
()
return
parse_test_cases
()
def
torch_operator
(
self
,
*
args
,
**
kwargs
):
def
torch_operator
(
self
,
k_cache
,
v_cache
,
key
,
value
,
slot_mapping
,
permute_dim_1_2
=
False
):
"""PyTorch paged_caching implementation"""
"""PyTorch paged_caching implementation"""
return
ref_paged_caching
(
*
args
,
**
kwargs
)
return
ref_paged_caching
(
k_cache
,
v_cache
,
key
,
value
,
slot_mapping
,
permute_dim_1_2
)
def
infinicore_operator
(
self
,
*
args
,
**
kwargs
):
def
infinicore_operator
(
self
,
k_cache
,
v_cache
,
key
,
value
,
slot_mapping
,
permute_dim_1_2
=
False
):
"""InfiniCore paged_caching implementation"""
"""InfiniCore paged_caching implementation"""
return
infinicore
.
paged_caching
(
*
args
,
**
kwargs
)
if
permute_dim_1_2
:
k_cache
=
k_cache
.
permute
([
0
,
2
,
1
,
3
])
v_cache
=
v_cache
.
permute
([
0
,
2
,
1
,
3
])
return
infinicore
.
paged_caching
(
k_cache
,
v_cache
,
key
,
value
,
slot_mapping
)
def
main
():
def
main
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment