Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
90cb1b54
Unverified
Commit
90cb1b54
authored
Mar 02, 2026
by
thatPepe
Committed by
GitHub
Mar 02, 2026
Browse files
Merge pull request #1037 from InfiniTensor/issue/1036
Issue/1036 paged caching support strides
parents
d8176086
3e1ef507
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
222 additions
and
86 deletions
+222
-86
src/infiniop/ops/causal_softmax/operator.cc
src/infiniop/ops/causal_softmax/operator.cc
+12
-4
src/infiniop/ops/clip/operator.cc
src/infiniop/ops/clip/operator.cc
+2
-2
src/infiniop/ops/logsoftmax/operator.cc
src/infiniop/ops/logsoftmax/operator.cc
+8
-4
src/infiniop/ops/paged_caching/cuda/kernel.cuh
src/infiniop/ops/paged_caching/cuda/kernel.cuh
+7
-5
src/infiniop/ops/paged_caching/info.h
src/infiniop/ops/paged_caching/info.h
+13
-1
src/infiniop/ops/paged_caching/metax/paged_caching_metax.maca
...infiniop/ops/paged_caching/metax/paged_caching_metax.maca
+26
-5
src/infiniop/ops/paged_caching/moore/paged_caching_moore.mu
src/infiniop/ops/paged_caching/moore/paged_caching_moore.mu
+26
-5
src/infiniop/ops/paged_caching/nvidia/paged_caching_nvidia.cu
...infiniop/ops/paged_caching/nvidia/paged_caching_nvidia.cu
+28
-5
src/infiniop/ops/rms_norm/operator.cc
src/infiniop/ops/rms_norm/operator.cc
+8
-8
src/infiniop/ops/rope/operator.cc
src/infiniop/ops/rope/operator.cc
+8
-8
src/infiniop/ops/silu_and_mul/operator.cc
src/infiniop/ops/silu_and_mul/operator.cc
+16
-4
src/infiniop/ops/softmax/operator.cc
src/infiniop/ops/softmax/operator.cc
+8
-4
src/infiniop/ops/topkrouter/operator.cc
src/infiniop/ops/topkrouter/operator.cc
+8
-8
src/infiniop/ops/topksoftmax/operator.cc
src/infiniop/ops/topksoftmax/operator.cc
+8
-8
test/infinicore/ops/paged_caching.py
test/infinicore/ops/paged_caching.py
+44
-15
No files found.
src/infiniop/ops/causal_softmax/operator.cc
View file @
90cb1b54
...
...
@@ -72,8 +72,10 @@ __C infiniStatus_t infiniopCreateCausalSoftmaxDescriptor(
#ifdef ENABLE_MOORE_API
CREATE
(
INFINI_DEVICE_MOORE
,
moore
)
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
#undef CREATE
}
__C
infiniStatus_t
infiniopGetCausalSoftmaxWorkspaceSize
(
infiniopCausalSoftmaxDescriptor_t
desc
,
size_t
*
size
)
{
...
...
@@ -117,8 +119,10 @@ __C infiniStatus_t infiniopGetCausalSoftmaxWorkspaceSize(infiniopCausalSoftmaxDe
#ifdef ENABLE_MOORE_API
GET
(
INFINI_DEVICE_MOORE
,
moore
)
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
#undef GET
}
__C
infiniStatus_t
infiniopCausalSoftmax
(
...
...
@@ -167,8 +171,10 @@ __C infiniStatus_t infiniopCausalSoftmax(
#ifdef ENABLE_MOORE_API
CALCULATE
(
INFINI_DEVICE_MOORE
,
moore
)
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
#undef CALCULATE
}
__C
infiniStatus_t
infiniopDestroyCausalSoftmaxDescriptor
(
infiniopCausalSoftmaxDescriptor_t
desc
)
{
...
...
@@ -212,6 +218,8 @@ __C infiniStatus_t infiniopDestroyCausalSoftmaxDescriptor(infiniopCausalSoftmaxD
#ifdef ENABLE_MOORE_API
DESTROY
(
INFINI_DEVICE_MOORE
,
moore
)
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
#undef DESTROY
}
src/infiniop/ops/clip/operator.cc
View file @
90cb1b54
...
...
@@ -91,11 +91,11 @@ __C infiniStatus_t infiniopGetClipWorkspaceSize(infiniopClipDescriptor_t desc, s
#ifdef ENABLE_KUNLUN_API
GET
(
INFINI_DEVICE_KUNLUN
,
kunlun
)
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef GET
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
__C
infiniStatus_t
infiniopClip
(
...
...
src/infiniop/ops/logsoftmax/operator.cc
View file @
90cb1b54
...
...
@@ -51,8 +51,9 @@ __C infiniStatus_t infiniopCreateLogSoftmaxDescriptor(
#ifdef ENABLE_ASCEND_API
// CREATE(INFINI_DEVICE_ASCEND, ascend)
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
__C
infiniStatus_t
infiniopGetLogSoftmaxWorkspaceSize
(
infiniopLogSoftmaxDescriptor_t
desc
,
size_t
*
size
)
{
...
...
@@ -84,8 +85,9 @@ __C infiniStatus_t infiniopGetLogSoftmaxWorkspaceSize(infiniopLogSoftmaxDescript
#ifdef ENABLE_ASCEND_API
// GET(INFINI_DEVICE_ASCEND, ascend)
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
__C
infiniStatus_t
infiniopLogSoftmax
(
...
...
@@ -122,8 +124,9 @@ __C infiniStatus_t infiniopLogSoftmax(
#ifdef ENABLE_ASCEND_API
// CALCULATE(INFINI_DEVICE_ASCEND, ascend)
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
__C
infiniStatus_t
infiniopDestroyLogSoftmaxDescriptor
(
infiniopLogSoftmaxDescriptor_t
desc
)
{
...
...
@@ -155,6 +158,7 @@ __C infiniStatus_t infiniopDestroyLogSoftmaxDescriptor(infiniopLogSoftmaxDescrip
#ifdef ENABLE_ASCEND_API
// DESTROY(INFINI_DEVICE_ASCEND, ascend)
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
src/infiniop/ops/paged_caching/cuda/kernel.cuh
View file @
90cb1b54
...
...
@@ -38,7 +38,11 @@ __device__ void pagedCachingKernel(
const
ptrdiff_t
k_src_stride
,
// Stride between tokens in the source K tensor
const
ptrdiff_t
v_src_stride
,
// Stride between tokens in the source V tensor
const
ptrdiff_t
k_cache_block_stride
,
// Stride between blocks in the K cache pool
const
ptrdiff_t
v_cache_block_stride
// Stride between blocks in the V cache pool
const
ptrdiff_t
v_cache_block_stride
,
// Stride between blocks in the V cache pool
const
ptrdiff_t
k_cache_head_stride
,
// Stride between heads in the K cache pool
const
ptrdiff_t
v_cache_head_stride
,
// Stride between heads in the V cache pool
const
ptrdiff_t
k_cache_slot_stride
,
// Stride between block slots in the K cache pool
const
ptrdiff_t
v_cache_slot_stride
// Stride between block slots in the V cache pool
)
{
//================================================================================
// 1. Identify Work Unit & Calculate Addresses
...
...
@@ -66,13 +70,11 @@ __device__ void pagedCachingKernel(
// Destination pointer calculation assumes a [num_blocks, block_size, num_heads, head_size] layout.
// We point to the beginning of the memory region for this token's slot.
const
ptrdiff_t
cache_head_stride
=
block_size
*
head_size
;
Tdata
*
k_cache_block_base_ptr
=
k_cache_ptr
+
physical_block_idx
*
k_cache_block_stride
;
Tdata
*
k_dst_head_ptr
=
k_cache_block_base_ptr
+
head_idx
*
cache_head_stride
+
block_offset
*
head_siz
e
;
Tdata
*
k_dst_head_ptr
=
k_cache_block_base_ptr
+
head_idx
*
k_
cache_head_stride
+
block_offset
*
k_cache_slot_strid
e
;
Tdata
*
v_cache_block_base_ptr
=
v_cache_ptr
+
physical_block_idx
*
v_cache_block_stride
;
Tdata
*
v_dst_head_ptr
=
v_cache_block_base_ptr
+
head_idx
*
cache_head_stride
+
block_offset
*
head_siz
e
;
Tdata
*
v_dst_head_ptr
=
v_cache_block_base_ptr
+
head_idx
*
v_
cache_head_stride
+
block_offset
*
v_cache_slot_strid
e
;
//================================================================================
// 2. Perform Element-wise Data Copy (Safe, Non-Vectorized)
...
...
src/infiniop/ops/paged_caching/info.h
View file @
90cb1b54
...
...
@@ -26,6 +26,10 @@ public:
ptrdiff_t
v_src_stride
;
ptrdiff_t
k_cache_block_stride
;
ptrdiff_t
v_cache_block_stride
;
ptrdiff_t
k_cache_head_stride
;
ptrdiff_t
v_cache_head_stride
;
ptrdiff_t
k_cache_slot_stride
;
ptrdiff_t
v_cache_slot_stride
;
static
utils
::
Result
<
PagedCachingInfo
>
create
(
infiniopTensorDescriptor_t
k_cache_desc
,
...
...
@@ -63,6 +67,10 @@ public:
ptrdiff_t
v_src_stride
=
v_desc
->
stride
(
0
);
ptrdiff_t
k_cache_block_stride
=
k_cache_desc
->
stride
(
0
);
ptrdiff_t
v_cache_block_stride
=
v_cache_desc
->
stride
(
0
);
ptrdiff_t
k_cache_head_stride
=
k_cache_desc
->
stride
(
1
);
ptrdiff_t
v_cache_head_stride
=
v_cache_desc
->
stride
(
1
);
ptrdiff_t
k_cache_slot_stride
=
k_cache_desc
->
stride
(
2
);
ptrdiff_t
v_cache_slot_stride
=
v_cache_desc
->
stride
(
2
);
return
utils
::
Result
<
PagedCachingInfo
>
(
PagedCachingInfo
{
dtype
,
...
...
@@ -73,7 +81,11 @@ public:
k_src_stride
,
v_src_stride
,
k_cache_block_stride
,
v_cache_block_stride
});
v_cache_block_stride
,
k_cache_head_stride
,
v_cache_head_stride
,
k_cache_slot_stride
,
v_cache_slot_stride
});
}
};
...
...
src/infiniop/ops/paged_caching/metax/paged_caching_metax.maca
View file @
90cb1b54
...
...
@@ -10,10 +10,13 @@ INFINIOP_METAX_KERNEL pagedCaching(
const int64_t *slot_mapping,
const size_t head_size, const size_t block_size,
const ptrdiff_t k_src_stride, const ptrdiff_t v_src_stride,
const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride) {
const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride,
const ptrdiff_t k_cache_head_stride, const ptrdiff_t v_cache_head_stride,
const ptrdiff_t k_cache_slot_stride, const ptrdiff_t v_cache_slot_strid) {
op::paged_caching::cuda::pagedCachingKernel<Tdata, NUM_THREADS>(
k_cache, v_cache, k, v, slot_mapping, head_size,
block_size, k_src_stride, v_src_stride, k_cache_block_stride, v_cache_block_stride);
block_size, k_src_stride, v_src_stride,
k_cache_block_stride, v_cache_block_stride, k_cache_head_stride, v_cache_head_stride, k_cache_slot_stride, v_cache_slot_stride);
}
namespace op::paged_caching::metax {
...
...
@@ -59,6 +62,8 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
size_t num_tokens, size_t num_kv_heads, size_t head_size, size_t block_size,
ptrdiff_t k_src_stride, ptrdiff_t v_src_stride,
ptrdiff_t k_cache_block_stride, ptrdiff_t v_cache_block_stride,
ptrdiff_t k_cache_head_stride, ptrdiff_t v_cache_head_stride,
ptrdiff_t k_cache_slot_stride, ptrdiff_t v_cache_slot_stride,
hcStream_t stream) {
// Grid dimension is 1D, with one block per token, as we decided.
...
...
@@ -83,7 +88,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride);
} else if (dtype == INFINI_DTYPE_BF16) {
pagedCaching<cuda_bfloat16, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
...
...
@@ -97,7 +106,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride);
} else if (dtype == INFINI_DTYPE_F32) {
pagedCaching<float, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
...
...
@@ -111,7 +124,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
...
...
@@ -138,6 +155,8 @@ infiniStatus_t Descriptor::calculate(
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride,
_info.k_cache_head_stride, _info.v_cache_head_stride,
_info.k_cache_slot_stride, _info.v_cache_slot_stride,
stream);
} else if (max_threads >= METAX_BLOCK_SIZE_512) {
launchKernel<METAX_BLOCK_SIZE_512>(
...
...
@@ -145,6 +164,8 @@ infiniStatus_t Descriptor::calculate(
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride,
_info.k_cache_head_stride, _info.v_cache_head_stride,
_info.k_cache_slot_stride, _info.v_cache_slot_stride,
stream);
} else {
// If the device supports fewer threads, return an error.
...
...
src/infiniop/ops/paged_caching/moore/paged_caching_moore.mu
View file @
90cb1b54
...
...
@@ -10,10 +10,13 @@ INFINIOP_MOORE_KERNEL pagedCaching(
const int64_t *slot_mapping,
const size_t head_size, const size_t block_size,
const ptrdiff_t k_src_stride, const ptrdiff_t v_src_stride,
const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride) {
const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride,
const ptrdiff_t k_cache_head_stride, const ptrdiff_t v_cache_head_stride,
const ptrdiff_t k_cache_slot_stride, const ptrdiff_t v_cache_slot_stride) {
op::paged_caching::cuda::pagedCachingKernel<Tdata, NUM_THREADS>(
k_cache, v_cache, k, v, slot_mapping, head_size,
block_size, k_src_stride, v_src_stride, k_cache_block_stride, v_cache_block_stride);
block_size, k_src_stride, v_src_stride,
k_cache_block_stride, v_cache_block_stride, k_cache_head_stride, v_cache_head_stride, k_cache_slot_stride, v_cache_slot_stride);
}
namespace op::paged_caching::moore {
...
...
@@ -59,6 +62,8 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
size_t num_tokens, size_t num_kv_heads, size_t head_size, size_t block_size,
ptrdiff_t k_src_stride, ptrdiff_t v_src_stride,
ptrdiff_t k_cache_block_stride, ptrdiff_t v_cache_block_stride,
ptrdiff_t k_cache_head_stride, ptrdiff_t v_cache_head_stride,
ptrdiff_t k_cache_slot_stride, ptrdiff_t v_cache_slot_stride,
musaStream_t stream) {
// Grid dimension is 1D, with one block per token, as we decided.
...
...
@@ -83,7 +88,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride);
} else if (dtype == INFINI_DTYPE_BF16) {
pagedCaching<__mt_bfloat16, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
...
...
@@ -97,7 +106,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride);
} else if (dtype == INFINI_DTYPE_F32) {
pagedCaching<float, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
...
...
@@ -111,7 +124,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
...
...
@@ -137,6 +154,8 @@ infiniStatus_t Descriptor::calculate(
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride,
_info.k_cache_head_stride, _info.v_cache_head_stride,
_info.k_cache_slot_stride, _info.v_cache_slot_stride,
stream);
} else if (_opaque->internal->maxThreadsPerBlock() >= MOORE_BLOCK_SIZE_512) {
launchKernel<MOORE_BLOCK_SIZE_512>(
...
...
@@ -144,6 +163,8 @@ infiniStatus_t Descriptor::calculate(
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride,
_info.k_cache_head_stride, _info.v_cache_head_stride,
_info.k_cache_slot_stride, _info.v_cache_slot_stride,
stream);
} else {
// If the GPU is older and supports fewer threads, return an error.
...
...
src/infiniop/ops/paged_caching/nvidia/paged_caching_nvidia.cu
View file @
90cb1b54
...
...
@@ -10,10 +10,13 @@ INFINIOP_CUDA_KERNEL pagedCaching(
const
int64_t
*
slot_mapping
,
const
size_t
head_size
,
const
size_t
block_size
,
const
ptrdiff_t
k_src_stride
,
const
ptrdiff_t
v_src_stride
,
const
ptrdiff_t
k_cache_block_stride
,
const
ptrdiff_t
v_cache_block_stride
)
{
const
ptrdiff_t
k_cache_block_stride
,
const
ptrdiff_t
v_cache_block_stride
,
const
ptrdiff_t
k_cache_head_stride
,
const
ptrdiff_t
v_cache_head_stride
,
const
ptrdiff_t
k_cache_slot_stride
,
const
ptrdiff_t
v_cache_slot_stride
)
{
op
::
paged_caching
::
cuda
::
pagedCachingKernel
<
Tdata
,
NUM_THREADS
>
(
k_cache
,
v_cache
,
k
,
v
,
slot_mapping
,
head_size
,
block_size
,
k_src_stride
,
v_src_stride
,
k_cache_block_stride
,
v_cache_block_stride
);
block_size
,
k_src_stride
,
v_src_stride
,
k_cache_block_stride
,
v_cache_block_stride
,
k_cache_head_stride
,
v_cache_head_stride
,
k_cache_slot_stride
,
v_cache_slot_stride
);
}
namespace
op
::
paged_caching
::
nvidia
{
...
...
@@ -59,6 +62,8 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
size_t
num_tokens
,
size_t
num_kv_heads
,
size_t
head_size
,
size_t
block_size
,
ptrdiff_t
k_src_stride
,
ptrdiff_t
v_src_stride
,
ptrdiff_t
k_cache_block_stride
,
ptrdiff_t
v_cache_block_stride
,
ptrdiff_t
k_cache_head_stride
,
ptrdiff_t
v_cache_head_stride
,
ptrdiff_t
k_cache_slot_stride
,
ptrdiff_t
v_cache_slot_stride
,
cudaStream_t
stream
)
{
// Grid dimension is 1D, with one block per token, as we decided.
...
...
@@ -83,7 +88,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride
,
v_src_stride
,
k_cache_block_stride
,
v_cache_block_stride
);
v_cache_block_stride
,
k_cache_head_stride
,
v_cache_head_stride
,
k_cache_slot_stride
,
v_cache_slot_stride
);
}
else
if
(
dtype
==
INFINI_DTYPE_BF16
)
{
pagedCaching
<
__nv_bfloat16
,
NUM_THREADS
>
<<<
grid
,
block
,
shared_mem_size
,
stream
>>>
(
...
...
@@ -97,7 +106,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride
,
v_src_stride
,
k_cache_block_stride
,
v_cache_block_stride
);
v_cache_block_stride
,
k_cache_head_stride
,
v_cache_head_stride
,
k_cache_slot_stride
,
v_cache_slot_stride
);
}
else
if
(
dtype
==
INFINI_DTYPE_F32
)
{
pagedCaching
<
float
,
NUM_THREADS
>
<<<
grid
,
block
,
shared_mem_size
,
stream
>>>
(
...
...
@@ -111,7 +124,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride
,
v_src_stride
,
k_cache_block_stride
,
v_cache_block_stride
);
v_cache_block_stride
,
k_cache_head_stride
,
v_cache_head_stride
,
k_cache_slot_stride
,
v_cache_slot_stride
);
}
else
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
...
...
@@ -137,6 +154,8 @@ infiniStatus_t Descriptor::calculate(
_info
.
num_tokens
,
_info
.
num_kv_heads
,
_info
.
head_size
,
_info
.
block_size
,
_info
.
k_src_stride
,
_info
.
v_src_stride
,
_info
.
k_cache_block_stride
,
_info
.
v_cache_block_stride
,
_info
.
k_cache_head_stride
,
_info
.
v_cache_head_stride
,
_info
.
k_cache_slot_stride
,
_info
.
v_cache_slot_stride
,
stream
);
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
>=
CUDA_BLOCK_SIZE_512
)
{
launchKernel
<
CUDA_BLOCK_SIZE_512
>
(
...
...
@@ -144,6 +163,8 @@ infiniStatus_t Descriptor::calculate(
_info
.
num_tokens
,
_info
.
num_kv_heads
,
_info
.
head_size
,
_info
.
block_size
,
_info
.
k_src_stride
,
_info
.
v_src_stride
,
_info
.
k_cache_block_stride
,
_info
.
v_cache_block_stride
,
_info
.
k_cache_head_stride
,
_info
.
v_cache_head_stride
,
_info
.
k_cache_slot_stride
,
_info
.
v_cache_slot_stride
,
stream
);
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
>=
CUDA_BLOCK_SIZE_4096
)
{
launchKernel
<
CUDA_BLOCK_SIZE_4096
>
(
...
...
@@ -151,6 +172,8 @@ infiniStatus_t Descriptor::calculate(
_info
.
num_tokens
,
_info
.
num_kv_heads
,
_info
.
head_size
,
_info
.
block_size
,
_info
.
k_src_stride
,
_info
.
v_src_stride
,
_info
.
k_cache_block_stride
,
_info
.
v_cache_block_stride
,
_info
.
k_cache_head_stride
,
_info
.
v_cache_head_stride
,
_info
.
k_cache_slot_stride
,
_info
.
v_cache_slot_stride
,
stream
);
}
else
{
// If the GPU is older and supports fewer threads, return an error.
...
...
src/infiniop/ops/rms_norm/operator.cc
View file @
90cb1b54
...
...
@@ -76,11 +76,11 @@ __C infiniStatus_t infiniopCreateRMSNormDescriptor(
#ifdef ENABLE_MOORE_API
CREATE
(
INFINI_DEVICE_MOORE
,
moore
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef CREATE
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
__C
infiniStatus_t
infiniopGetRMSNormWorkspaceSize
(
infiniopRMSNormDescriptor_t
desc
,
size_t
*
size
)
{
...
...
@@ -124,11 +124,11 @@ __C infiniStatus_t infiniopGetRMSNormWorkspaceSize(infiniopRMSNormDescriptor_t d
#ifdef ENABLE_MOORE_API
GET
(
INFINI_DEVICE_MOORE
,
moore
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef GET
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
__C
infiniStatus_t
infiniopRMSNorm
(
infiniopRMSNormDescriptor_t
desc
,
void
*
workspace
,
size_t
workspace_size
,
...
...
@@ -173,11 +173,11 @@ __C infiniStatus_t infiniopRMSNorm(infiniopRMSNormDescriptor_t desc, void *works
#ifdef ENABLE_MOORE_API
CALCULATE
(
INFINI_DEVICE_MOORE
,
moore
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef CALCULATE
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
__C
infiniStatus_t
infiniopDestroyRMSNormDescriptor
(
infiniopRMSNormDescriptor_t
desc
)
{
...
...
@@ -221,9 +221,9 @@ __C infiniStatus_t infiniopDestroyRMSNormDescriptor(infiniopRMSNormDescriptor_t
#ifdef ENABLE_MOORE_API
DESTROY
(
INFINI_DEVICE_MOORE
,
moore
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef DESTROY
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
src/infiniop/ops/rope/operator.cc
View file @
90cb1b54
...
...
@@ -80,11 +80,11 @@ __C infiniStatus_t infiniopCreateRoPEDescriptor(
#ifdef ENABLE_CAMBRICON_API
CREATE
(
INFINI_DEVICE_CAMBRICON
,
bang
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef CREATE
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
__C
infiniStatus_t
infiniopGetRoPEWorkspaceSize
(
infiniopRoPEDescriptor_t
desc
,
...
...
@@ -128,11 +128,11 @@ __C infiniStatus_t infiniopGetRoPEWorkspaceSize(infiniopRoPEDescriptor_t desc,
#ifdef ENABLE_ASCEND_API
GET
(
INFINI_DEVICE_ASCEND
,
ascend
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef GET
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
__C
infiniStatus_t
infiniopRoPE
(
...
...
@@ -185,11 +185,11 @@ __C infiniStatus_t infiniopRoPE(
#ifdef ENABLE_ASCEND_API
CALCULATE
(
INFINI_DEVICE_ASCEND
,
ascend
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef CALCULATE
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
__C
infiniStatus_t
...
...
@@ -234,9 +234,9 @@ infiniopDestroyRoPEDescriptor(infiniopRoPEDescriptor_t desc) {
#ifdef ENABLE_ASCEND_API
DELETE
(
INFINI_DEVICE_ASCEND
,
ascend
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef DELETE
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
src/infiniop/ops/silu_and_mul/operator.cc
View file @
90cb1b54
...
...
@@ -24,8 +24,11 @@ __C infiniStatus_t infiniopCreateSiluAndMulDescriptor(
#ifdef ENABLE_MOORE_API
CREATE
(
INFINI_DEVICE_MOORE
,
moore
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
#undef CREATE
}
__C
infiniStatus_t
infiniopGetSiluAndMulWorkspaceSize
(
infiniopSiluAndMulDescriptor_t
desc
,
size_t
*
size
)
{
...
...
@@ -39,8 +42,11 @@ __C infiniStatus_t infiniopGetSiluAndMulWorkspaceSize(infiniopSiluAndMulDescript
#ifdef ENABLE_MOORE_API
GET
(
INFINI_DEVICE_MOORE
,
moore
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
#undef GET
}
__C
infiniStatus_t
infiniopSiluAndMul
(
...
...
@@ -59,8 +65,11 @@ __C infiniStatus_t infiniopSiluAndMul(
#ifdef ENABLE_MOORE_API
CALCULATE
(
INFINI_DEVICE_MOORE
,
moore
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
#undef CALCULATE
}
__C
infiniStatus_t
infiniopDestroySiluAndMulDescriptor
(
infiniopSiluAndMulDescriptor_t
desc
)
{
...
...
@@ -74,6 +83,9 @@ __C infiniStatus_t infiniopDestroySiluAndMulDescriptor(infiniopSiluAndMulDescrip
#ifdef ENABLE_MOORE_API
DESTROY
(
INFINI_DEVICE_MOORE
,
moore
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
#undef DESTROY
}
src/infiniop/ops/softmax/operator.cc
View file @
90cb1b54
...
...
@@ -37,8 +37,9 @@ __C infiniStatus_t infiniopCreateSoftmaxDescriptor(
#ifdef ENABLE_ALI_API
CREATE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
__C
infiniStatus_t
infiniopGetSoftmaxWorkspaceSize
(
infiniopSoftmaxDescriptor_t
desc
,
size_t
*
size
)
{
...
...
@@ -64,8 +65,9 @@ __C infiniStatus_t infiniopGetSoftmaxWorkspaceSize(infiniopSoftmaxDescriptor_t d
#ifdef ENABLE_ALI_API
GET
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
__C
infiniStatus_t
infiniopSoftmax
(
...
...
@@ -96,8 +98,9 @@ __C infiniStatus_t infiniopSoftmax(
#ifdef ENABLE_ALI_API
CALCULATE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
__C
infiniStatus_t
infiniopDestroySoftmaxDescriptor
(
infiniopSoftmaxDescriptor_t
desc
)
{
...
...
@@ -123,6 +126,7 @@ __C infiniStatus_t infiniopDestroySoftmaxDescriptor(infiniopSoftmaxDescriptor_t
#ifdef ENABLE_ALI_API
DESTROY
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
src/infiniop/ops/topkrouter/operator.cc
View file @
90cb1b54
...
...
@@ -42,11 +42,11 @@ __C infiniStatus_t infiniopCreateTopkrouterDescriptor(infiniopHandle_t handle, i
#ifdef ENABLE_ALI_API
CREATE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef CREATE
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
__C
infiniStatus_t
infiniopGetTopkrouterWorkspaceSize
(
infiniopTopkrouterDescriptor_t
desc
,
size_t
*
size
)
{
...
...
@@ -74,11 +74,11 @@ __C infiniStatus_t infiniopGetTopkrouterWorkspaceSize(infiniopTopkrouterDescript
#ifdef ENABLE_ALI_API
GET
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef GET
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
__C
infiniStatus_t
infiniopTopkrouter
(
infiniopTopkrouterDescriptor_t
desc
,
void
*
workspace
,
size_t
workspace_size
,
...
...
@@ -109,11 +109,11 @@ __C infiniStatus_t infiniopTopkrouter(infiniopTopkrouterDescriptor_t desc, void
#ifdef ENABLE_ALI_API
CALCULATE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef CALCULATE
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
__C
infiniStatus_t
infiniopDestroyTopkrouterDescriptor
(
infiniopTopkrouterDescriptor_t
desc
)
{
...
...
@@ -141,9 +141,9 @@ __C infiniStatus_t infiniopDestroyTopkrouterDescriptor(infiniopTopkrouterDescrip
#ifdef ENABLE_ALI_API
DESTROY
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef DESTROY
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
src/infiniop/ops/topksoftmax/operator.cc
View file @
90cb1b54
...
...
@@ -40,11 +40,11 @@ __C infiniStatus_t infiniopCreateTopksoftmaxDescriptor(infiniopHandle_t handle,
#ifdef ENABLE_ILUVATAR_API
CREATE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef CREATE
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
__C
infiniStatus_t
infiniopGetTopksoftmaxWorkspaceSize
(
infiniopTopksoftmaxDescriptor_t
desc
,
size_t
*
size
)
{
...
...
@@ -73,11 +73,11 @@ __C infiniStatus_t infiniopGetTopksoftmaxWorkspaceSize(infiniopTopksoftmaxDescri
#ifdef ENABLE_ILUVATAR_API
GET
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef GET
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
__C
infiniStatus_t
infiniopTopksoftmax
(
infiniopTopksoftmaxDescriptor_t
desc
,
void
*
workspace
,
size_t
workspace_size
,
...
...
@@ -111,11 +111,11 @@ __C infiniStatus_t infiniopTopksoftmax(infiniopTopksoftmaxDescriptor_t desc, voi
#ifdef ENABLE_ILUVATAR_API
CALCULATE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef CALCULATE
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
__C
infiniStatus_t
infiniopDestroyTopksoftmaxDescriptor
(
infiniopTopksoftmaxDescriptor_t
desc
)
{
...
...
@@ -144,9 +144,9 @@ __C infiniStatus_t infiniopDestroyTopksoftmaxDescriptor(infiniopTopksoftmaxDescr
#ifdef ENABLE_ILUVATAR_API
DESTROY
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef DESTROY
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
test/infinicore/ops/paged_caching.py
View file @
90cb1b54
...
...
@@ -18,12 +18,16 @@ from framework import (
# Operator-specific configuration
# ==============================================================================
# Test cases format: (num_seqs, max_seq_len, num_kv_heads, head_size, block_size)
# Test cases format: (num_seqs, max_seq_len, num_kv_heads, head_size, block_size
, permute_dim_1_2
)
_TEST_CASES_DATA
=
[
(
1
,
128
,
8
,
128
,
16
),
(
5
,
512
,
40
,
128
,
16
),
(
16
,
1024
,
8
,
64
,
32
),
(
10
,
1024
,
40
,
64
,
32
),
(
1
,
128
,
8
,
128
,
16
,
False
),
(
1
,
128
,
8
,
128
,
16
,
True
),
(
5
,
512
,
40
,
256
,
16
,
False
),
(
5
,
512
,
40
,
256
,
16
,
True
),
(
16
,
1024
,
8
,
64
,
32
,
False
),
(
16
,
1024
,
8
,
64
,
32
,
True
),
(
10
,
1024
,
40
,
64
,
32
,
False
),
(
10
,
1024
,
40
,
64
,
32
,
True
),
]
# Tolerance configuration
...
...
@@ -40,7 +44,9 @@ _TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
# ==============================================================================
# Reference Implementation
# ==============================================================================
def
ref_paged_caching
(
key_cache_pool
,
value_cache_pool
,
key
,
value
,
slot_mapping
):
def
ref_paged_caching
(
key_cache_pool
,
value_cache_pool
,
key
,
value
,
slot_mapping
,
permute_dim_1_2
):
"""
Reference implementation for paged_caching operator.
...
...
@@ -52,7 +58,7 @@ def ref_paged_caching(key_cache_pool, value_cache_pool, key, value, slot_mapping
slot_mapping (torch.Tensor): Slot mapping, shape [ntok]
"""
ntok
=
key
.
shape
[
0
]
block_size
=
key_cache_pool
.
shape
[
2
]
block_size
=
key_cache_pool
.
shape
[
1
]
if
permute_dim_1_2
else
key_cache_pool
.
shape
[
2
]
# This reference implementation operates on a cloned cache to avoid modifying the original input tensor,
# mimicking the behavior where the custom operator writes to its output tensor.
...
...
@@ -67,8 +73,12 @@ def ref_paged_caching(key_cache_pool, value_cache_pool, key, value, slot_mapping
key_token
=
key
[
i
]
value_token
=
value
[
i
]
k_cache_ref
[
block_idx
,
:,
block_offset
,
:]
=
key_token
v_cache_ref
[
block_idx
,
:,
block_offset
,
:]
=
value_token
if
permute_dim_1_2
:
k_cache_ref
[
block_idx
,
block_offset
,
:,
:]
=
key_token
v_cache_ref
[
block_idx
,
block_offset
,
:,
:]
=
value_token
else
:
k_cache_ref
[
block_idx
,
:,
block_offset
,
:]
=
key_token
v_cache_ref
[
block_idx
,
:,
block_offset
,
:]
=
value_token
return
k_cache_ref
,
v_cache_ref
...
...
@@ -79,7 +89,14 @@ def parse_test_cases():
Each test case contains all necessary information for execution and validation.
"""
test_cases
=
[]
for
num_seqs
,
max_seq_len
,
num_kv_heads
,
head_size
,
block_size
in
_TEST_CASES_DATA
:
for
(
num_seqs
,
max_seq_len
,
num_kv_heads
,
head_size
,
block_size
,
permute_dim_1_2
,
)
in
_TEST_CASES_DATA
:
num_blocks
=
4096
# A reasonably large cache pool for testing
# Create metadata: variable context lengths for each sequence in the batch
...
...
@@ -111,6 +128,9 @@ def parse_test_cases():
v_shape
=
(
ntok
,
num_kv_heads
,
head_size
)
k_cache_shape
=
(
num_blocks
,
num_kv_heads
,
block_size
,
head_size
)
v_cache_shape
=
(
num_blocks
,
num_kv_heads
,
block_size
,
head_size
)
if
permute_dim_1_2
:
k_cache_shape
=
(
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
v_cache_shape
=
(
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
# Generate test cases for all data types
for
dtype
in
_TENSOR_DTYPES
:
...
...
@@ -142,7 +162,7 @@ def parse_test_cases():
v_spec
,
slot_mapping_spec
,
],
kwargs
=
None
,
kwargs
=
{
"permute_dim_1_2"
:
permute_dim_1_2
}
,
output_spec
=
None
,
comparison_target
=
0
,
# Only compare k_cache
tolerance
=
tolerance
,
...
...
@@ -162,13 +182,22 @@ class OpTest(BaseOperatorTest):
def
get_test_cases
(
self
):
return
parse_test_cases
()
def
torch_operator
(
self
,
*
args
,
**
kwargs
):
def
torch_operator
(
self
,
k_cache
,
v_cache
,
key
,
value
,
slot_mapping
,
permute_dim_1_2
=
False
):
"""PyTorch paged_caching implementation"""
return
ref_paged_caching
(
*
args
,
**
kwargs
)
return
ref_paged_caching
(
k_cache
,
v_cache
,
key
,
value
,
slot_mapping
,
permute_dim_1_2
)
def
infinicore_operator
(
self
,
*
args
,
**
kwargs
):
def
infinicore_operator
(
self
,
k_cache
,
v_cache
,
key
,
value
,
slot_mapping
,
permute_dim_1_2
=
False
):
"""InfiniCore paged_caching implementation"""
return
infinicore
.
paged_caching
(
*
args
,
**
kwargs
)
if
permute_dim_1_2
:
k_cache
=
k_cache
.
permute
([
0
,
2
,
1
,
3
])
v_cache
=
v_cache
.
permute
([
0
,
2
,
1
,
3
])
return
infinicore
.
paged_caching
(
k_cache
,
v_cache
,
key
,
value
,
slot_mapping
)
def
main
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment