Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
96551cb7
"git@developer.sourcefind.cn:yangql/googletest.git" did not exist on "aa43220fe5d34f7283f3a55e11dda5e3f25e5632"
Commit
96551cb7
authored
Jan 12, 2026
by
PanZezhong
Browse files
issue/867 fix page caching api, paged attn support more head dims
parent
01a4a0c8
Changes
19
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
171 additions
and
149 deletions
+171
-149
include/infinicore/ops/paged_caching.hpp
include/infinicore/ops/paged_caching.hpp
+2
-2
include/infiniop/ops/paged_caching.h
include/infiniop/ops/paged_caching.h
+8
-8
python/infinicore/ops/paged_caching.py
python/infinicore/ops/paged_caching.py
+4
-4
src/infinicore/ops/paged_caching/paged_caching.cc
src/infinicore/ops/paged_caching/paged_caching.cc
+6
-6
src/infinicore/ops/paged_caching/paged_caching_infiniop.cc
src/infinicore/ops/paged_caching/paged_caching_infiniop.cc
+4
-4
src/infinicore/pybind11/ops/paged_caching.hpp
src/infinicore/pybind11/ops/paged_caching.hpp
+2
-2
src/infiniop/ops/paged_attention/info.h
src/infiniop/ops/paged_attention/info.h
+2
-4
src/infiniop/ops/paged_attention/nvidia/paged_attention_nvidia.cu
...niop/ops/paged_attention/nvidia/paged_attention_nvidia.cu
+36
-24
src/infiniop/ops/paged_attention/operator.cc
src/infiniop/ops/paged_attention/operator.cc
+23
-19
src/infiniop/ops/paged_caching/info.h
src/infiniop/ops/paged_caching/info.h
+2
-2
src/infiniop/ops/paged_caching/nvidia/paged_caching_nvidia.cu
...infiniop/ops/paged_caching/nvidia/paged_caching_nvidia.cu
+4
-4
src/infiniop/ops/paged_caching/operator.cc
src/infiniop/ops/paged_caching/operator.cc
+28
-24
src/infiniop/ops/paged_caching/paged_caching.h
src/infiniop/ops/paged_caching/paged_caching.h
+3
-3
test/infinicore/ops/paged_attention.py
test/infinicore/ops/paged_attention.py
+1
-2
test/infinicore/ops/paged_caching.py
test/infinicore/ops/paged_caching.py
+19
-15
test/infiniop/libinfiniop/op_register.py
test/infiniop/libinfiniop/op_register.py
+4
-4
test/infiniop/paged_attention.py
test/infiniop/paged_attention.py
+1
-0
test/infiniop/paged_caching.py
test/infiniop/paged_caching.py
+15
-15
test/infiniop/paged_caching_prefill.py
test/infiniop/paged_caching_prefill.py
+7
-7
No files found.
include/infinicore/ops/paged_caching.hpp
View file @
96551cb7
...
@@ -8,10 +8,10 @@ namespace infinicore::op {
...
@@ -8,10 +8,10 @@ namespace infinicore::op {
class
PagedCaching
{
class
PagedCaching
{
public:
public:
using
schema
=
void
(
*
)(
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
);
using
schema
=
void
(
*
)(
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
);
static
void
execute
(
Tensor
k
,
Tensor
v
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
slot_mapping
);
static
void
execute
(
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
k
,
Tensor
v
,
Tensor
slot_mapping
);
static
common
::
OpDispatcher
<
schema
>
&
dispatcher
();
static
common
::
OpDispatcher
<
schema
>
&
dispatcher
();
};
};
void
paged_caching_
(
Tensor
k
,
Tensor
v
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
slot_mapping
);
void
paged_caching_
(
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
k
,
Tensor
v
,
Tensor
slot_mapping
);
}
// namespace infinicore::op
}
// namespace infinicore::op
include/infiniop/ops/paged_caching.h
View file @
96551cb7
...
@@ -14,20 +14,20 @@ typedef struct InfiniopDescriptor *infiniopPagedCachingDescriptor_t;
...
@@ -14,20 +14,20 @@ typedef struct InfiniopDescriptor *infiniopPagedCachingDescriptor_t;
*
*
* @param handle The handle to the InfiniOP library context.
* @param handle The handle to the InfiniOP library context.
* @param desc_ptr A pointer to store the created descriptor.
* @param desc_ptr A pointer to store the created descriptor.
* @param k_desc Descriptor for the source key tensor.
* @param v_desc Descriptor for the source value tensor.
* @param k_cache_desc Descriptor for the key cache pool tensor.
* @param k_cache_desc Descriptor for the key cache pool tensor.
* @param v_cache_desc Descriptor for the value cache pool tensor.
* @param v_cache_desc Descriptor for the value cache pool tensor.
* @param k_desc Descriptor for the source key tensor.
* @param v_desc Descriptor for the source value tensor.
* @param slot_mapping_desc Descriptor for the slot mapping tensor.
* @param slot_mapping_desc Descriptor for the slot mapping tensor.
* @return infiniStatus_t Status code of the operation.
* @return infiniStatus_t Status code of the operation.
*/
*/
__C
__export
infiniStatus_t
infiniopCreatePagedCachingDescriptor
(
__C
__export
infiniStatus_t
infiniopCreatePagedCachingDescriptor
(
infiniopHandle_t
handle
,
infiniopHandle_t
handle
,
infiniopPagedCachingDescriptor_t
*
desc_ptr
,
infiniopPagedCachingDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
k_desc
,
infiniopTensorDescriptor_t
v_desc
,
infiniopTensorDescriptor_t
k_cache_desc
,
infiniopTensorDescriptor_t
k_cache_desc
,
infiniopTensorDescriptor_t
v_cache_desc
,
infiniopTensorDescriptor_t
v_cache_desc
,
infiniopTensorDescriptor_t
k_desc
,
infiniopTensorDescriptor_t
v_desc
,
infiniopTensorDescriptor_t
slot_mapping_desc
);
infiniopTensorDescriptor_t
slot_mapping_desc
);
/**
/**
...
@@ -46,10 +46,10 @@ __C __export infiniStatus_t infiniopGetPagedCachingWorkspaceSize(
...
@@ -46,10 +46,10 @@ __C __export infiniStatus_t infiniopGetPagedCachingWorkspaceSize(
* @param desc The Paged Caching descriptor.
* @param desc The Paged Caching descriptor.
* @param workspace Pointer to the workspace memory.
* @param workspace Pointer to the workspace memory.
* @param workspace_size The size of the workspace.
* @param workspace_size The size of the workspace.
* @param k Pointer to the source key tensor data.
* @param v Pointer to the source value tensor data.
* @param k_cache Pointer to the key cache pool data.
* @param k_cache Pointer to the key cache pool data.
* @param v_cache Pointer to the value cache pool data.
* @param v_cache Pointer to the value cache pool data.
* @param k Pointer to the source key tensor data.
* @param v Pointer to the source value tensor data.
* @param slot_mapping Pointer to the slot mapping data.
* @param slot_mapping Pointer to the slot mapping data.
* @param stream The CUDA stream for the operation. Can be NULL.
* @param stream The CUDA stream for the operation. Can be NULL.
* @return infiniStatus_t Status code of the operation.
* @return infiniStatus_t Status code of the operation.
...
@@ -58,10 +58,10 @@ __C __export infiniStatus_t infiniopPagedCaching(
...
@@ -58,10 +58,10 @@ __C __export infiniStatus_t infiniopPagedCaching(
infiniopPagedCachingDescriptor_t
desc
,
infiniopPagedCachingDescriptor_t
desc
,
void
*
workspace
,
void
*
workspace
,
size_t
workspace_size
,
size_t
workspace_size
,
const
void
*
k
,
const
void
*
v
,
void
*
k_cache
,
void
*
k_cache
,
void
*
v_cache
,
void
*
v_cache
,
const
void
*
k
,
const
void
*
v
,
const
void
*
slot_mapping
,
const
void
*
slot_mapping
,
void
*
stream
);
void
*
stream
);
...
...
python/infinicore/ops/paged_caching.py
View file @
96551cb7
...
@@ -3,18 +3,18 @@ from infinicore.tensor import Tensor
...
@@ -3,18 +3,18 @@ from infinicore.tensor import Tensor
def
paged_caching
(
def
paged_caching
(
k
:
Tensor
,
v
:
Tensor
,
k_cache
:
Tensor
,
k_cache
:
Tensor
,
v_cache
:
Tensor
,
v_cache
:
Tensor
,
k
:
Tensor
,
v
:
Tensor
,
slot_mapping
:
Tensor
,
slot_mapping
:
Tensor
,
):
):
Tensor
(
Tensor
(
_infinicore
.
paged_caching_
(
_infinicore
.
paged_caching_
(
k
.
_underlying
,
v
.
_underlying
,
k_cache
.
_underlying
,
k_cache
.
_underlying
,
v_cache
.
_underlying
,
v_cache
.
_underlying
,
k
.
_underlying
,
v
.
_underlying
,
slot_mapping
.
_underlying
,
slot_mapping
.
_underlying
,
)
)
)
)
...
...
src/infinicore/ops/paged_caching/paged_caching.cc
View file @
96551cb7
...
@@ -9,14 +9,14 @@ common::OpDispatcher<PagedCaching::schema> &PagedCaching::dispatcher() {
...
@@ -9,14 +9,14 @@ common::OpDispatcher<PagedCaching::schema> &PagedCaching::dispatcher() {
return
dispatcher_
;
return
dispatcher_
;
};
};
void
PagedCaching
::
execute
(
Tensor
k
,
Tensor
v
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
slot_mapping
)
{
void
PagedCaching
::
execute
(
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
k
,
Tensor
v
,
Tensor
slot_mapping
)
{
INFINICORE_ASSERT_TENSORS_SAME_DEVICE
(
k
,
v
,
k_cache
,
v_cache
,
slot_mapping
);
INFINICORE_ASSERT_TENSORS_SAME_DEVICE
(
k_cache
,
v_cache
,
k
,
v
,
slot_mapping
);
infinicore
::
context
::
setDevice
(
k
->
device
());
infinicore
::
context
::
setDevice
(
k
_cache
->
device
());
dispatcher
().
lookup
(
k
->
device
().
getType
())(
k
,
v
,
k_cache
,
v_cache
,
slot_mapping
);
dispatcher
().
lookup
(
k
_cache
->
device
().
getType
())(
k_cache
,
v_cache
,
k
,
v
,
slot_mapping
);
}
}
void
paged_caching_
(
Tensor
k
,
Tensor
v
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
slot_mapping
)
{
void
paged_caching_
(
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
k
,
Tensor
v
,
Tensor
slot_mapping
)
{
PagedCaching
::
execute
(
k
,
v
,
k_cache
,
v_cache
,
slot_mapping
);
PagedCaching
::
execute
(
k_cache
,
v_cache
,
k
,
v
,
slot_mapping
);
}
}
}
// namespace infinicore::op
}
// namespace infinicore::op
src/infinicore/ops/paged_caching/paged_caching_infiniop.cc
View file @
96551cb7
...
@@ -15,8 +15,8 @@ thread_local common::OpCache<size_t, infiniopPagedCachingDescriptor_t> caches(
...
@@ -15,8 +15,8 @@ thread_local common::OpCache<size_t, infiniopPagedCachingDescriptor_t> caches(
}
}
});
});
void
calculate
(
Tensor
k
,
Tensor
v
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
slot_mapping
)
{
void
calculate
(
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
k
,
Tensor
v
,
Tensor
slot_mapping
)
{
size_t
seed
=
hash_combine
(
k
,
v
,
k_cache
,
v_cache
,
slot_mapping
);
size_t
seed
=
hash_combine
(
k_cache
,
v_cache
,
k
,
v
,
slot_mapping
);
auto
device
=
context
::
getDevice
();
auto
device
=
context
::
getDevice
();
auto
&
cache
=
caches
.
getCache
(
device
);
auto
&
cache
=
caches
.
getCache
(
device
);
...
@@ -27,7 +27,7 @@ void calculate(Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, Tensor slot_m
...
@@ -27,7 +27,7 @@ void calculate(Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, Tensor slot_m
if
(
!
desc_opt
)
{
if
(
!
desc_opt
)
{
INFINICORE_CHECK_ERROR
(
infiniopCreatePagedCachingDescriptor
(
INFINICORE_CHECK_ERROR
(
infiniopCreatePagedCachingDescriptor
(
context
::
getInfiniopHandle
(
device
),
&
desc
,
context
::
getInfiniopHandle
(
device
),
&
desc
,
k
->
desc
(),
v
->
desc
(),
k_cache
->
desc
(),
v_cache
->
desc
(),
slot_mapping
->
desc
()));
k_cache
->
desc
(),
v_cache
->
desc
(),
k
->
desc
(),
v
->
desc
(),
slot_mapping
->
desc
()));
cache
.
put
(
seed
,
desc
);
cache
.
put
(
seed
,
desc
);
}
else
{
}
else
{
desc
=
*
desc_opt
;
desc
=
*
desc_opt
;
...
@@ -39,7 +39,7 @@ void calculate(Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, Tensor slot_m
...
@@ -39,7 +39,7 @@ void calculate(Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, Tensor slot_m
INFINICORE_CHECK_ERROR
(
infiniopPagedCaching
(
INFINICORE_CHECK_ERROR
(
infiniopPagedCaching
(
desc
,
workspace
->
data
(),
workspace_size
,
desc
,
workspace
->
data
(),
workspace_size
,
k
->
data
(),
v
->
data
(),
k_cache
->
data
(),
v_cache
->
data
(),
slot_mapping
->
data
(),
context
::
getStream
()));
k_cache
->
data
(),
v_cache
->
data
(),
k
->
data
(),
v
->
data
(),
slot_mapping
->
data
(),
context
::
getStream
()));
}
}
static
bool
registered
=
[]()
{
static
bool
registered
=
[]()
{
...
...
src/infinicore/pybind11/ops/paged_caching.hpp
View file @
96551cb7
...
@@ -11,10 +11,10 @@ namespace infinicore::ops {
...
@@ -11,10 +11,10 @@ namespace infinicore::ops {
inline
void
bind_paged_caching
(
py
::
module
&
m
)
{
inline
void
bind_paged_caching
(
py
::
module
&
m
)
{
m
.
def
(
"paged_caching_"
,
m
.
def
(
"paged_caching_"
,
&
op
::
paged_caching_
,
&
op
::
paged_caching_
,
py
::
arg
(
"k"
),
py
::
arg
(
"v"
),
py
::
arg
(
"k_cache"
),
py
::
arg
(
"k_cache"
),
py
::
arg
(
"v_cache"
),
py
::
arg
(
"v_cache"
),
py
::
arg
(
"k"
),
py
::
arg
(
"v"
),
py
::
arg
(
"slot_mapping"
),
py
::
arg
(
"slot_mapping"
),
R"doc(Paged caching of key and value tensors.)doc"
);
R"doc(Paged caching of key and value tensors.)doc"
);
}
}
...
...
src/infiniop/ops/paged_attention/info.h
View file @
96551cb7
...
@@ -67,11 +67,9 @@ public:
...
@@ -67,11 +67,9 @@ public:
size_t
num_heads
=
q_shape
[
1
];
size_t
num_heads
=
q_shape
[
1
];
size_t
head_size
=
q_shape
[
2
];
size_t
head_size
=
q_shape
[
2
];
if
(
head_size
!=
128
)
{
if
(
head_size
!=
16
&&
head_size
!=
32
&&
head_size
!=
64
&&
head_size
!=
128
&&
head_size
!=
256
)
{
// 输出具体的错误原因和当前的参数值
std
::
cerr
<<
"[Error] Now only supports head_size = 16/32/64/128/256, but got "
std
::
cerr
<<
"[Error] Now only supports head_size = 128, but got "
<<
head_size
<<
"."
<<
std
::
endl
;
<<
head_size
<<
"."
<<
std
::
endl
;
// 建议返回 SHAPE 相关的错误码
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
}
...
...
src/infiniop/ops/paged_attention/nvidia/paged_attention_nvidia.cu
View file @
96551cb7
...
@@ -98,37 +98,49 @@ infiniStatus_t Descriptor::calculate(
...
@@ -98,37 +98,49 @@ infiniStatus_t Descriptor::calculate(
const
void
*
block_tables
,
const
void
*
seq_lens
,
const
void
*
alibi_slopes
,
const
void
*
block_tables
,
const
void
*
seq_lens
,
const
void
*
alibi_slopes
,
void
*
stream_
)
const
{
void
*
stream_
)
const
{
cudaStream_t
stream
=
(
cudaStream_t
)
stream_
;
cudaStream_t
stream
=
(
cudaStream_t
)
stream_
;
#define LAUNCH_HEADSIZE_BLOCKSIZE(__H_SIZE, __B_SIZE) \
launchKernel<__H_SIZE, __B_SIZE>( \
out, q, k_cache, v_cache, _info.dtype, block_tables, seq_lens, alibi_slopes, \
_info.num_heads, _info.num_seqs, \
_info.num_kv_heads, _info.scale, _info.max_num_blocks_per_seq, _info.block_size, \
_info.q_stride, _info.kv_block_stride, _info.kv_head_stride, _info.o_stride, \
stream);
#define SWITCH_HEAD_SIZE(__B_SIZE) \
switch (_info.head_size) { \
case 16: \
LAUNCH_HEADSIZE_BLOCKSIZE(16, __B_SIZE) \
break; \
case 32: \
LAUNCH_HEADSIZE_BLOCKSIZE(32, __B_SIZE) \
break; \
case 64: \
LAUNCH_HEADSIZE_BLOCKSIZE(64, __B_SIZE) \
break; \
case 128: \
LAUNCH_HEADSIZE_BLOCKSIZE(128, __B_SIZE) \
break; \
case 256: \
LAUNCH_HEADSIZE_BLOCKSIZE(256, __B_SIZE) \
break; \
default: \
return INFINI_STATUS_BAD_TENSOR_SHAPE; \
}
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_1024
)
{
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_1024
)
{
if
(
_info
.
head_size
==
128
)
{
SWITCH_HEAD_SIZE
(
CUDA_BLOCK_SIZE_1024
)
launchKernel
<
128
,
CUDA_BLOCK_SIZE_1024
>
(
out
,
q
,
k_cache
,
v_cache
,
_info
.
dtype
,
block_tables
,
seq_lens
,
alibi_slopes
,
_info
.
num_heads
,
_info
.
num_seqs
,
_info
.
num_kv_heads
,
_info
.
scale
,
_info
.
max_num_blocks_per_seq
,
_info
.
block_size
,
_info
.
q_stride
,
_info
.
kv_block_stride
,
_info
.
kv_head_stride
,
_info
.
o_stride
,
stream
);
}
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_512
)
{
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_512
)
{
if
(
_info
.
head_size
==
128
)
{
SWITCH_HEAD_SIZE
(
CUDA_BLOCK_SIZE_512
)
launchKernel
<
128
,
CUDA_BLOCK_SIZE_512
>
(
out
,
q
,
k_cache
,
v_cache
,
_info
.
dtype
,
block_tables
,
seq_lens
,
alibi_slopes
,
_info
.
num_heads
,
_info
.
num_seqs
,
_info
.
num_kv_heads
,
_info
.
scale
,
_info
.
max_num_blocks_per_seq
,
_info
.
block_size
,
_info
.
q_stride
,
_info
.
kv_block_stride
,
_info
.
kv_head_stride
,
_info
.
o_stride
,
stream
);
}
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_4096
)
{
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_4096
)
{
if
(
_info
.
head_size
==
128
)
{
SWITCH_HEAD_SIZE
(
CUDA_BLOCK_SIZE_4096
)
launchKernel
<
128
,
CUDA_BLOCK_SIZE_4096
>
(
out
,
q
,
k_cache
,
v_cache
,
_info
.
dtype
,
block_tables
,
seq_lens
,
alibi_slopes
,
_info
.
num_heads
,
_info
.
num_seqs
,
_info
.
num_kv_heads
,
_info
.
scale
,
_info
.
max_num_blocks_per_seq
,
_info
.
block_size
,
_info
.
q_stride
,
_info
.
kv_block_stride
,
_info
.
kv_head_stride
,
_info
.
o_stride
,
stream
);
}
}
else
{
}
else
{
return
INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED
;
}
}
#undef LAUNCH_HEADSIZE_BLOCKSIZE
#undef SWITCH_HEAD_SIZE
return
INFINI_STATUS_SUCCESS
;
return
INFINI_STATUS_SUCCESS
;
}
}
...
...
src/infiniop/ops/paged_attention/operator.cc
View file @
96551cb7
...
@@ -5,9 +5,9 @@
...
@@ -5,9 +5,9 @@
#ifdef ENABLE_NVIDIA_API
#ifdef ENABLE_NVIDIA_API
#include "nvidia/paged_attention_nvidia.cuh"
#include "nvidia/paged_attention_nvidia.cuh"
#endif
#endif
#ifdef ENABLE_METAX_API
//
#ifdef ENABLE_METAX_API
#include "metax/paged_attention_metax.h"
//
#include "metax/paged_attention_metax.h"
#endif
//
#endif
__C
infiniStatus_t
infiniopCreatePagedAttentionDescriptor
(
__C
infiniStatus_t
infiniopCreatePagedAttentionDescriptor
(
infiniopHandle_t
handle
,
infiniopHandle_t
handle
,
...
@@ -34,11 +34,12 @@ __C infiniStatus_t infiniopCreatePagedAttentionDescriptor(
...
@@ -34,11 +34,12 @@ __C infiniStatus_t infiniopCreatePagedAttentionDescriptor(
#ifdef ENABLE_NVIDIA_API
#ifdef ENABLE_NVIDIA_API
CREATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
CREATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
#endif
#ifdef ENABLE_METAX_API
// #ifdef ENABLE_METAX_API
CREATE
(
INFINI_DEVICE_METAX
,
metax
)
// CREATE(INFINI_DEVICE_METAX, metax)
#endif
// #endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
}
__C
infiniStatus_t
infiniopGetPagedAttentionWorkspaceSize
(
__C
infiniStatus_t
infiniopGetPagedAttentionWorkspaceSize
(
...
@@ -54,11 +55,12 @@ __C infiniStatus_t infiniopGetPagedAttentionWorkspaceSize(
...
@@ -54,11 +55,12 @@ __C infiniStatus_t infiniopGetPagedAttentionWorkspaceSize(
#ifdef ENABLE_NVIDIA_API
#ifdef ENABLE_NVIDIA_API
GET
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
GET
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
#endif
#ifdef ENABLE_METAX_API
// #ifdef ENABLE_METAX_API
GET
(
INFINI_DEVICE_METAX
,
metax
)
// GET(INFINI_DEVICE_METAX, metax)
#endif
// #endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
}
__C
infiniStatus_t
infiniopPagedAttention
(
__C
infiniStatus_t
infiniopPagedAttention
(
...
@@ -78,11 +80,12 @@ __C infiniStatus_t infiniopPagedAttention(
...
@@ -78,11 +80,12 @@ __C infiniStatus_t infiniopPagedAttention(
#ifdef ENABLE_NVIDIA_API
#ifdef ENABLE_NVIDIA_API
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
#endif
#ifdef ENABLE_METAX_API
// #ifdef ENABLE_METAX_API
CALCULATE
(
INFINI_DEVICE_METAX
,
metax
)
// CALCULATE(INFINI_DEVICE_METAX, metax)
#endif
// #endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
}
__C
infiniStatus_t
infiniopDestroyPagedAttentionDescriptor
(
__C
infiniStatus_t
infiniopDestroyPagedAttentionDescriptor
(
...
@@ -97,9 +100,10 @@ __C infiniStatus_t infiniopDestroyPagedAttentionDescriptor(
...
@@ -97,9 +100,10 @@ __C infiniStatus_t infiniopDestroyPagedAttentionDescriptor(
#ifdef ENABLE_NVIDIA_API
#ifdef ENABLE_NVIDIA_API
DESTROY
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
DESTROY
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
#endif
#ifdef ENABLE_METAX_API
// #ifdef ENABLE_METAX_API
DESTROY
(
INFINI_DEVICE_METAX
,
metax
)
// DESTROY(INFINI_DEVICE_METAX, metax)
#endif
// #endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
}
src/infiniop/ops/paged_caching/info.h
View file @
96551cb7
...
@@ -28,10 +28,10 @@ public:
...
@@ -28,10 +28,10 @@ public:
ptrdiff_t
v_cache_block_stride
;
ptrdiff_t
v_cache_block_stride
;
static
utils
::
Result
<
PagedCachingInfo
>
create
(
static
utils
::
Result
<
PagedCachingInfo
>
create
(
infiniopTensorDescriptor_t
k_desc
,
infiniopTensorDescriptor_t
v_desc
,
infiniopTensorDescriptor_t
k_cache_desc
,
infiniopTensorDescriptor_t
k_cache_desc
,
infiniopTensorDescriptor_t
v_cache_desc
,
infiniopTensorDescriptor_t
v_cache_desc
,
infiniopTensorDescriptor_t
k_desc
,
infiniopTensorDescriptor_t
v_desc
,
infiniopTensorDescriptor_t
slot_mapping_desc
)
{
infiniopTensorDescriptor_t
slot_mapping_desc
)
{
auto
dtype
=
k_desc
->
dtype
();
auto
dtype
=
k_desc
->
dtype
();
...
...
src/infiniop/ops/paged_caching/nvidia/paged_caching_nvidia.cu
View file @
96551cb7
...
@@ -31,13 +31,13 @@ Descriptor::~Descriptor() {
...
@@ -31,13 +31,13 @@ Descriptor::~Descriptor() {
infiniStatus_t
Descriptor
::
create
(
infiniStatus_t
Descriptor
::
create
(
infiniopHandle_t
handle
,
infiniopHandle_t
handle
,
Descriptor
**
desc_ptr
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
k_desc
,
infiniopTensorDescriptor_t
v_desc
,
infiniopTensorDescriptor_t
k_cache_desc
,
infiniopTensorDescriptor_t
k_cache_desc
,
infiniopTensorDescriptor_t
v_cache_desc
,
infiniopTensorDescriptor_t
v_cache_desc
,
infiniopTensorDescriptor_t
k_desc
,
infiniopTensorDescriptor_t
v_desc
,
infiniopTensorDescriptor_t
slot_mapping_desc
)
{
infiniopTensorDescriptor_t
slot_mapping_desc
)
{
auto
info
=
PagedCachingInfo
::
create
(
k_desc
,
v_desc
,
k_cache_desc
,
v_cache_desc
,
slot_mapping_desc
);
auto
info
=
PagedCachingInfo
::
create
(
k_cache_desc
,
v_cache_desc
,
k_desc
,
v_desc
,
slot_mapping_desc
);
CHECK_RESULT
(
info
);
CHECK_RESULT
(
info
);
// Create and return the Descriptor instance.
// Create and return the Descriptor instance.
...
@@ -121,8 +121,8 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
...
@@ -121,8 +121,8 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
// Execution method implementation
// Execution method implementation
infiniStatus_t
Descriptor
::
calculate
(
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
workspace
,
size_t
workspace_size
,
const
void
*
k
,
const
void
*
v
,
void
*
k_cache
,
void
*
v_cache
,
void
*
k_cache
,
void
*
v_cache
,
const
void
*
k
,
const
void
*
v
,
const
void
*
slot_mapping
,
const
void
*
slot_mapping
,
void
*
stream_
)
const
{
void
*
stream_
)
const
{
...
...
src/infiniop/ops/paged_caching/operator.cc
View file @
96551cb7
...
@@ -5,17 +5,17 @@
...
@@ -5,17 +5,17 @@
#ifdef ENABLE_NVIDIA_API
#ifdef ENABLE_NVIDIA_API
#include "nvidia/paged_caching_nvidia.cuh"
#include "nvidia/paged_caching_nvidia.cuh"
#endif
#endif
#ifdef ENABLE_METAX_API
//
#ifdef ENABLE_METAX_API
#include "metax/paged_caching_metax.h"
//
#include "metax/paged_caching_metax.h"
#endif
//
#endif
__C
infiniStatus_t
infiniopCreatePagedCachingDescriptor
(
__C
infiniStatus_t
infiniopCreatePagedCachingDescriptor
(
infiniopHandle_t
handle
,
infiniopHandle_t
handle
,
infiniopPagedCachingDescriptor_t
*
desc_ptr
,
infiniopPagedCachingDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
k_desc
,
infiniopTensorDescriptor_t
v_desc
,
infiniopTensorDescriptor_t
k_cache_desc
,
infiniopTensorDescriptor_t
k_cache_desc
,
infiniopTensorDescriptor_t
v_cache_desc
,
infiniopTensorDescriptor_t
v_cache_desc
,
infiniopTensorDescriptor_t
k_desc
,
infiniopTensorDescriptor_t
v_desc
,
infiniopTensorDescriptor_t
slot_mapping_desc
)
{
infiniopTensorDescriptor_t
slot_mapping_desc
)
{
#define CREATE(CASE, NAMESPACE) \
#define CREATE(CASE, NAMESPACE) \
...
@@ -23,17 +23,18 @@ __C infiniStatus_t infiniopCreatePagedCachingDescriptor(
...
@@ -23,17 +23,18 @@ __C infiniStatus_t infiniopCreatePagedCachingDescriptor(
return op::paged_caching::NAMESPACE::Descriptor::create( \
return op::paged_caching::NAMESPACE::Descriptor::create( \
handle, \
handle, \
reinterpret_cast<op::paged_caching::NAMESPACE::Descriptor **>(desc_ptr), \
reinterpret_cast<op::paged_caching::NAMESPACE::Descriptor **>(desc_ptr), \
k_desc, v_desc,
k_cache_desc, v_cache_desc, slot_mapping_desc);
k_cache_desc, v_cache_desc,
k_desc, v_desc,
slot_mapping_desc);
switch
(
handle
->
device
)
{
switch
(
handle
->
device
)
{
#ifdef ENABLE_NVIDIA_API
#ifdef ENABLE_NVIDIA_API
CREATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
CREATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
#endif
#ifdef ENABLE_METAX_API
// #ifdef ENABLE_METAX_API
CREATE
(
INFINI_DEVICE_METAX
,
metax
)
// CREATE(INFINI_DEVICE_METAX, metax)
#endif
// #endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
}
__C
infiniStatus_t
infiniopGetPagedCachingWorkspaceSize
(
__C
infiniStatus_t
infiniopGetPagedCachingWorkspaceSize
(
...
@@ -49,35 +50,37 @@ __C infiniStatus_t infiniopGetPagedCachingWorkspaceSize(
...
@@ -49,35 +50,37 @@ __C infiniStatus_t infiniopGetPagedCachingWorkspaceSize(
#ifdef ENABLE_NVIDIA_API
#ifdef ENABLE_NVIDIA_API
GET
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
GET
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
#endif
#ifdef ENABLE_METAX_API
// #ifdef ENABLE_METAX_API
GET
(
INFINI_DEVICE_METAX
,
metax
)
// GET(INFINI_DEVICE_METAX, metax)
#endif
// #endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
}
__C
infiniStatus_t
infiniopPagedCaching
(
__C
infiniStatus_t
infiniopPagedCaching
(
infiniopPagedCachingDescriptor_t
desc
,
infiniopPagedCachingDescriptor_t
desc
,
void
*
workspace
,
size_t
workspace_size
,
void
*
workspace
,
size_t
workspace_size
,
const
void
*
k
,
const
void
*
v
,
void
*
k_cache
,
void
*
v_cache
,
void
*
k_cache
,
void
*
v_cache
,
const
void
*
k
,
const
void
*
v
,
const
void
*
slot_mapping
,
const
void
*
slot_mapping
,
void
*
stream
)
{
void
*
stream
)
{
#define CALCULATE(CASE, NAMESPACE) \
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
case CASE: \
return reinterpret_cast<op::paged_caching::NAMESPACE::Descriptor *>(desc)->calculate( \
return reinterpret_cast<op::paged_caching::NAMESPACE::Descriptor *>(desc)->calculate( \
workspace, workspace_size,
k, v,
k_cache, v_cache, slot_mapping, stream);
workspace, workspace_size, k_cache, v_cache,
k, v,
slot_mapping, stream);
switch
(
desc
->
device_type
)
{
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_NVIDIA_API
#ifdef ENABLE_NVIDIA_API
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
#endif
#ifdef ENABLE_METAX_API
// #ifdef ENABLE_METAX_API
CALCULATE
(
INFINI_DEVICE_METAX
,
metax
)
// CALCULATE(INFINI_DEVICE_METAX, metax)
#endif
// #endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
}
__C
infiniStatus_t
infiniopDestroyPagedCachingDescriptor
(
__C
infiniStatus_t
infiniopDestroyPagedCachingDescriptor
(
...
@@ -92,9 +95,10 @@ __C infiniStatus_t infiniopDestroyPagedCachingDescriptor(
...
@@ -92,9 +95,10 @@ __C infiniStatus_t infiniopDestroyPagedCachingDescriptor(
#ifdef ENABLE_NVIDIA_API
#ifdef ENABLE_NVIDIA_API
DESTROY
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
DESTROY
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
#endif
#ifdef ENABLE_METAX_API
// #ifdef ENABLE_METAX_API
DESTROY
(
INFINI_DEVICE_METAX
,
metax
)
// DESTROY(INFINI_DEVICE_METAX, metax)
#endif
// #endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
}
src/infiniop/ops/paged_caching/paged_caching.h
View file @
96551cb7
...
@@ -32,16 +32,16 @@
...
@@ -32,16 +32,16 @@
static infiniStatus_t create( \
static infiniStatus_t create( \
infiniopHandle_t handle, \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t k_desc, \
infiniopTensorDescriptor_t v_desc, \
infiniopTensorDescriptor_t k_cache_desc, \
infiniopTensorDescriptor_t k_cache_desc, \
infiniopTensorDescriptor_t v_cache_desc, \
infiniopTensorDescriptor_t v_cache_desc, \
infiniopTensorDescriptor_t k_desc, \
infiniopTensorDescriptor_t v_desc, \
infiniopTensorDescriptor_t slot_mapping_desc); \
infiniopTensorDescriptor_t slot_mapping_desc); \
\
\
infiniStatus_t calculate( \
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *workspace, size_t workspace_size, \
const void *k, const void *v, \
void *k_cache, void *v_cache, \
void *k_cache, void *v_cache, \
const void *k, const void *v, \
const void *slot_mapping, \
const void *slot_mapping, \
void *stream) const; \
void *stream) const; \
}; \
}; \
...
...
test/infinicore/ops/paged_attention.py
View file @
96551cb7
...
@@ -25,6 +25,7 @@ _TEST_CASES_DATA = [
...
@@ -25,6 +25,7 @@ _TEST_CASES_DATA = [
(
4
,
40
,
40
,
128
,
16
,
1024
,
False
),
(
4
,
40
,
40
,
128
,
16
,
1024
,
False
),
(
6
,
40
,
40
,
128
,
16
,
1024
,
False
),
(
6
,
40
,
40
,
128
,
16
,
1024
,
False
),
(
3
,
8
,
8
,
128
,
16
,
1024
,
False
),
(
3
,
8
,
8
,
128
,
16
,
1024
,
False
),
(
3
,
8
,
8
,
64
,
16
,
1024
,
False
),
(
8
,
64
,
8
,
128
,
16
,
2048
,
False
),
(
8
,
64
,
8
,
128
,
16
,
2048
,
False
),
]
]
...
@@ -68,8 +69,6 @@ def parse_test_cases():
...
@@ -68,8 +69,6 @@ def parse_test_cases():
0
,
num_seqs
*
max_blocks_per_seq
,
dtype
=
torch
.
int64
0
,
num_seqs
*
max_blocks_per_seq
,
dtype
=
torch
.
int64
).
view
(
num_seqs
,
max_blocks_per_seq
)
).
view
(
num_seqs
,
max_blocks_per_seq
)
print
(
"block_tables.shape"
,
block_tables
.
shape
,
block_tables
)
q_shape
=
(
num_seqs
,
num_heads
,
head_size
)
q_shape
=
(
num_seqs
,
num_heads
,
head_size
)
k_cache_shape
=
(
num_blocks
,
num_kv_heads
,
block_size
,
head_size
)
k_cache_shape
=
(
num_blocks
,
num_kv_heads
,
block_size
,
head_size
)
v_cache_shape
=
(
num_blocks
,
num_kv_heads
,
block_size
,
head_size
)
v_cache_shape
=
(
num_blocks
,
num_kv_heads
,
block_size
,
head_size
)
...
...
test/infinicore/ops/paged_caching.py
View file @
96551cb7
...
@@ -28,9 +28,9 @@ _TEST_CASES_DATA = [
...
@@ -28,9 +28,9 @@ _TEST_CASES_DATA = [
# Tolerance configuration
# Tolerance configuration
_TOLERANCE_MAP
=
{
_TOLERANCE_MAP
=
{
infinicore
.
float16
:
{
"atol"
:
0
,
"rtol"
:
1e-
2
},
infinicore
.
float16
:
{
"atol"
:
0
,
"rtol"
:
1e-
5
},
infinicore
.
float32
:
{
"atol"
:
1e-4
,
"rtol"
:
1e-
3
},
infinicore
.
float32
:
{
"atol"
:
0
,
"rtol"
:
1e-
5
},
infinicore
.
bfloat16
:
{
"atol"
:
0
,
"rtol"
:
5
e-
2
},
infinicore
.
bfloat16
:
{
"atol"
:
0
,
"rtol"
:
1
e-
5
},
}
}
# Data types to test
# Data types to test
...
@@ -40,15 +40,15 @@ _TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
...
@@ -40,15 +40,15 @@ _TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
# ==============================================================================
# ==============================================================================
# Reference Implementation
# Reference Implementation
# ==============================================================================
# ==============================================================================
def
ref_paged_caching
(
key
,
value
,
key_cache_pool
,
value_cache_pool
,
slot_mapping
):
def
ref_paged_caching
(
key_cache_pool
,
value_cache_pool
,
key
,
value
,
slot_mapping
):
"""
"""
Reference implementation for paged_caching operator.
Reference implementation for paged_caching operator.
Args:
Args:
key (torch.Tensor): Keys, shape [ntok, nkvh, dh]
value (torch.Tensor): Values, shape [ntok, nkvh, dh]
key_cache_pool (torch.Tensor): K cache pool, shape [num_blocks, nkvh, block_size, dh]
key_cache_pool (torch.Tensor): K cache pool, shape [num_blocks, nkvh, block_size, dh]
value_cache_pool (torch.Tensor): V cache pool, shape [num_blocks, nkvh, block_size, dh]
value_cache_pool (torch.Tensor): V cache pool, shape [num_blocks, nkvh, block_size, dh]
key (torch.Tensor): Keys, shape [ntok, nkvh, dh]
value (torch.Tensor): Values, shape [ntok, nkvh, dh]
slot_mapping (torch.Tensor): Slot mapping, shape [ntok]
slot_mapping (torch.Tensor): Slot mapping, shape [ntok]
"""
"""
ntok
=
key
.
shape
[
0
]
ntok
=
key
.
shape
[
0
]
...
@@ -56,8 +56,8 @@ def ref_paged_caching(key, value, key_cache_pool, value_cache_pool, slot_mapping
...
@@ -56,8 +56,8 @@ def ref_paged_caching(key, value, key_cache_pool, value_cache_pool, slot_mapping
# This reference implementation operates on a cloned cache to avoid modifying the original input tensor,
# This reference implementation operates on a cloned cache to avoid modifying the original input tensor,
# mimicking the behavior where the custom operator writes to its output tensor.
# mimicking the behavior where the custom operator writes to its output tensor.
k_cache_ref
=
key_cache_pool
.
clone
()
k_cache_ref
=
key_cache_pool
v_cache_ref
=
value_cache_pool
.
clone
()
v_cache_ref
=
value_cache_pool
for
i
in
range
(
ntok
):
for
i
in
range
(
ntok
):
slot
=
slot_mapping
[
i
].
item
()
slot
=
slot_mapping
[
i
].
item
()
...
@@ -98,9 +98,9 @@ def parse_test_cases():
...
@@ -98,9 +98,9 @@ def parse_test_cases():
current_slot
+=
length
.
item
()
current_slot
+=
length
.
item
()
# Ensure we don't exceed the total number of slots in the cache
# Ensure we don't exceed the total number of slots in the cache
assert
(
assert
current_slot
<=
num_blocks
*
block_size
,
(
current_slot
<=
num_blocks
*
block_size
"Not enough blocks in the cache pool for this test case"
)
,
"Not enough blocks in the cache pool for this test case"
)
slot_mapping
=
torch
.
tensor
(
slot_mapping_list
,
dtype
=
torch
.
int64
)
slot_mapping
=
torch
.
tensor
(
slot_mapping_list
,
dtype
=
torch
.
int64
)
...
@@ -119,8 +119,12 @@ def parse_test_cases():
...
@@ -119,8 +119,12 @@ def parse_test_cases():
# Create typed tensor specs
# Create typed tensor specs
k_spec
=
TensorSpec
.
from_tensor
(
k_shape
,
None
,
dtype
)
k_spec
=
TensorSpec
.
from_tensor
(
k_shape
,
None
,
dtype
)
v_spec
=
TensorSpec
.
from_tensor
(
v_shape
,
None
,
dtype
)
v_spec
=
TensorSpec
.
from_tensor
(
v_shape
,
None
,
dtype
)
k_cache_spec
=
TensorSpec
.
from_tensor
(
k_cache_shape
,
None
,
dtype
)
k_cache_spec
=
TensorSpec
.
from_tensor
(
v_cache_spec
=
TensorSpec
.
from_tensor
(
v_cache_shape
,
None
,
dtype
)
k_cache_shape
,
None
,
dtype
,
init_mode
=
TensorInitializer
.
ZEROS
)
v_cache_spec
=
TensorSpec
.
from_tensor
(
v_cache_shape
,
None
,
dtype
,
init_mode
=
TensorInitializer
.
ZEROS
)
slot_mapping_spec
=
TensorSpec
.
from_tensor
(
slot_mapping_spec
=
TensorSpec
.
from_tensor
(
slot_mapping_shape
,
slot_mapping_shape
,
init_mode
=
TensorInitializer
.
MANUAL
,
init_mode
=
TensorInitializer
.
MANUAL
,
...
@@ -132,10 +136,10 @@ def parse_test_cases():
...
@@ -132,10 +136,10 @@ def parse_test_cases():
test_cases
.
append
(
test_cases
.
append
(
TestCase
(
TestCase
(
inputs
=
[
inputs
=
[
k_spec
,
v_spec
,
k_cache_spec
,
k_cache_spec
,
v_cache_spec
,
v_cache_spec
,
k_spec
,
v_spec
,
slot_mapping_spec
,
slot_mapping_spec
,
],
],
kwargs
=
None
,
kwargs
=
None
,
...
...
test/infiniop/libinfiniop/op_register.py
View file @
96551cb7
...
@@ -1066,10 +1066,10 @@ def paged_caching_(lib):
...
@@ -1066,10 +1066,10 @@ def paged_caching_(lib):
lib
.
infiniopCreatePagedCachingDescriptor
.
argtypes
=
[
lib
.
infiniopCreatePagedCachingDescriptor
.
argtypes
=
[
infiniopHandle_t
,
infiniopHandle_t
,
POINTER
(
infiniopOperatorDescriptor_t
),
POINTER
(
infiniopOperatorDescriptor_t
),
infiniopTensorDescriptor_t
,
# k_desc
infiniopTensorDescriptor_t
,
# v_desc
infiniopTensorDescriptor_t
,
# k_cache_desc
infiniopTensorDescriptor_t
,
# k_cache_desc
infiniopTensorDescriptor_t
,
# v_cache_desc
infiniopTensorDescriptor_t
,
# v_cache_desc
infiniopTensorDescriptor_t
,
# k_desc
infiniopTensorDescriptor_t
,
# v_desc
infiniopTensorDescriptor_t
,
# slot_mapping_desc
infiniopTensorDescriptor_t
,
# slot_mapping_desc
]
]
...
@@ -1086,10 +1086,10 @@ def paged_caching_(lib):
...
@@ -1086,10 +1086,10 @@ def paged_caching_(lib):
infiniopOperatorDescriptor_t
,
infiniopOperatorDescriptor_t
,
c_void_p
,
# workspace
c_void_p
,
# workspace
c_size_t
,
# workspace_size
c_size_t
,
# workspace_size
c_void_p
,
# k
c_void_p
,
# v
c_void_p
,
# k_cache
c_void_p
,
# k_cache
c_void_p
,
# v_cache
c_void_p
,
# v_cache
c_void_p
,
# k
c_void_p
,
# v
c_void_p
,
# slot_mapping
c_void_p
,
# slot_mapping
c_void_p
,
# stream
c_void_p
,
# stream
]
]
...
...
test/infiniop/paged_attention.py
View file @
96551cb7
...
@@ -95,6 +95,7 @@ _TEST_CASES_ = [
...
@@ -95,6 +95,7 @@ _TEST_CASES_ = [
(
4
,
40
,
40
,
128
,
16
,
1024
,
False
),
(
4
,
40
,
40
,
128
,
16
,
1024
,
False
),
(
6
,
40
,
40
,
128
,
16
,
1024
,
False
),
(
6
,
40
,
40
,
128
,
16
,
1024
,
False
),
(
3
,
8
,
8
,
128
,
16
,
1024
,
False
),
(
3
,
8
,
8
,
128
,
16
,
1024
,
False
),
(
3
,
8
,
8
,
64
,
16
,
1024
,
False
),
(
8
,
64
,
8
,
128
,
16
,
2048
,
False
),
(
8
,
64
,
8
,
128
,
16
,
2048
,
False
),
]
]
...
...
test/infiniop/paged_caching.py
View file @
96551cb7
...
@@ -22,15 +22,15 @@ from libinfiniop import (
...
@@ -22,15 +22,15 @@ from libinfiniop import (
# ==============================================================================
# ==============================================================================
# Reference Implementation
# Reference Implementation
# ==============================================================================
# ==============================================================================
def
ref_paged_caching
(
key
,
value
,
key_cache_pool
,
value_cache_pool
,
slot_mapping
):
def
ref_paged_caching
(
key_cache_pool
,
value_cache_pool
,
key
,
value
,
slot_mapping
):
"""
"""
Reference implementation for paged_caching operator.
Reference implementation for paged_caching operator.
Args:
Args:
key (torch.Tensor): Keys, shape [ntok, nkvh, dh]
value (torch.Tensor): Values, shape [ntok, nkvh, dh]
key_cache_pool (torch.Tensor): K cache pool, shape [num_blocks, nkvh, block_size, dh]
key_cache_pool (torch.Tensor): K cache pool, shape [num_blocks, nkvh, block_size, dh]
value_cache_pool (torch.Tensor): V cache pool, shape [num_blocks, nkvh, block_size, dh]
value_cache_pool (torch.Tensor): V cache pool, shape [num_blocks, nkvh, block_size, dh]
key (torch.Tensor): Keys, shape [ntok, nkvh, dh]
value (torch.Tensor): Values, shape [ntok, nkvh, dh]
slot_mapping (torch.Tensor): Slot mapping, shape [ntok]
slot_mapping (torch.Tensor): Slot mapping, shape [ntok]
"""
"""
ntok
=
key
.
shape
[
0
]
ntok
=
key
.
shape
[
0
]
...
@@ -71,9 +71,9 @@ _TENSOR_DTYPES = [InfiniDtype.BF16, InfiniDtype.F16, InfiniDtype.F32]
...
@@ -71,9 +71,9 @@ _TENSOR_DTYPES = [InfiniDtype.BF16, InfiniDtype.F16, InfiniDtype.F32]
# Tolerance map for different data types
# Tolerance map for different data types
_TOLERANCE_MAP
=
{
_TOLERANCE_MAP
=
{
InfiniDtype
.
F16
:
{
"atol"
:
1e-3
,
"rtol"
:
1e-
2
},
InfiniDtype
.
F16
:
{
"atol"
:
0
,
"rtol"
:
1e-
5
},
InfiniDtype
.
BF16
:
{
"atol"
:
5e-3
,
"rtol"
:
5
e-
2
},
InfiniDtype
.
BF16
:
{
"atol"
:
0
,
"rtol"
:
1
e-
5
},
InfiniDtype
.
F32
:
{
"atol"
:
1e-5
,
"rtol"
:
1e-5
},
InfiniDtype
.
F32
:
{
"atol"
:
0
,
"rtol"
:
1e-5
},
}
}
# Global flags for controlling test behavior
# Global flags for controlling test behavior
...
@@ -123,9 +123,9 @@ def test(
...
@@ -123,9 +123,9 @@ def test(
current_slot
+=
length
.
item
()
current_slot
+=
length
.
item
()
# Ensure we don't exceed the total number of slots in the cache
# Ensure we don't exceed the total number of slots in the cache
assert
(
assert
current_slot
<=
num_blocks
*
block_size
,
(
current_slot
<=
num_blocks
*
block_size
"Not enough blocks in the cache pool for this test case"
)
,
"Not enough blocks in the cache pool for this test case"
)
slot_mapping_torch
=
torch
.
tensor
(
slot_mapping_list
,
dtype
=
torch
.
int64
)
slot_mapping_torch
=
torch
.
tensor
(
slot_mapping_list
,
dtype
=
torch
.
int64
)
...
@@ -144,10 +144,10 @@ def test(
...
@@ -144,10 +144,10 @@ def test(
# Run reference implementation
# Run reference implementation
k_cache_ref
,
v_cache_ref
=
ref_paged_caching
(
k_cache_ref
,
v_cache_ref
=
ref_paged_caching
(
k
.
torch_tensor
(),
v
.
torch_tensor
(),
k_cache_pool
.
torch_tensor
(),
k_cache_pool
.
torch_tensor
(),
v_cache_pool
.
torch_tensor
(),
v_cache_pool
.
torch_tensor
(),
k
.
torch_tensor
(),
v
.
torch_tensor
(),
slot_mapping
.
torch_tensor
(),
slot_mapping
.
torch_tensor
(),
)
)
...
@@ -160,10 +160,10 @@ def test(
...
@@ -160,10 +160,10 @@ def test(
LIBINFINIOP
.
infiniopCreatePagedCachingDescriptor
(
LIBINFINIOP
.
infiniopCreatePagedCachingDescriptor
(
handle
,
handle
,
ctypes
.
byref
(
descriptor
),
ctypes
.
byref
(
descriptor
),
k
.
descriptor
,
v
.
descriptor
,
k_cache_pool
.
descriptor
,
k_cache_pool
.
descriptor
,
v_cache_pool
.
descriptor
,
v_cache_pool
.
descriptor
,
k
.
descriptor
,
v
.
descriptor
,
slot_mapping
.
descriptor
,
slot_mapping
.
descriptor
,
)
)
)
)
...
@@ -191,10 +191,10 @@ def test(
...
@@ -191,10 +191,10 @@ def test(
descriptor
,
descriptor
,
workspace
.
data
(),
workspace
.
data
(),
workspace_size
.
value
,
workspace_size
.
value
,
k
.
data
(),
v
.
data
(),
k_cache_pool
.
data
(),
k_cache_pool
.
data
(),
v_cache_pool
.
data
(),
v_cache_pool
.
data
(),
k
.
data
(),
v
.
data
(),
slot_mapping
.
data
(),
slot_mapping
.
data
(),
None
,
None
,
)
)
...
...
test/infiniop/paged_caching_prefill.py
View file @
96551cb7
...
@@ -80,7 +80,7 @@ class SimpleCacheManager:
...
@@ -80,7 +80,7 @@ class SimpleCacheManager:
return
torch
.
tensor
(
slots
,
dtype
=
torch
.
int32
)
return
torch
.
tensor
(
slots
,
dtype
=
torch
.
int32
)
def
ref_paged_caching
(
k_
new
,
v_new
,
k_pool
,
v_pool
,
slots
,
block_size
):
def
ref_paged_caching
(
k_
pool
,
v_pool
,
k_new
,
v_new
,
slots
,
block_size
):
"""Reference implementation for incremental caching."""
"""Reference implementation for incremental caching."""
for
i
in
range
(
k_new
.
shape
[
0
]):
for
i
in
range
(
k_new
.
shape
[
0
]):
slot
=
slots
[
i
].
item
()
slot
=
slots
[
i
].
item
()
...
@@ -152,10 +152,10 @@ def test(
...
@@ -152,10 +152,10 @@ def test(
def
torch_caching
():
def
torch_caching
():
nonlocal
k_pool_ref
,
v_pool_ref
nonlocal
k_pool_ref
,
v_pool_ref
return
ref_paged_caching
(
return
ref_paged_caching
(
k_in
.
torch_tensor
(),
v_in
.
torch_tensor
(),
k_pool_ref
,
k_pool_ref
,
v_pool_ref
,
v_pool_ref
,
k_in
.
torch_tensor
(),
v_in
.
torch_tensor
(),
slots_torch
,
slots_torch
,
block_size
,
block_size
,
)
)
...
@@ -168,10 +168,10 @@ def test(
...
@@ -168,10 +168,10 @@ def test(
LIBINFINIOP
.
infiniopCreatePagedCachingDescriptor
(
LIBINFINIOP
.
infiniopCreatePagedCachingDescriptor
(
handle
,
handle
,
ctypes
.
byref
(
descriptor
),
ctypes
.
byref
(
descriptor
),
k_in
.
descriptor
,
v_in
.
descriptor
,
k_cache_pool
.
descriptor
,
k_cache_pool
.
descriptor
,
v_cache_pool
.
descriptor
,
v_cache_pool
.
descriptor
,
k_in
.
descriptor
,
v_in
.
descriptor
,
slot_mapping
.
descriptor
,
slot_mapping
.
descriptor
,
)
)
)
)
...
@@ -190,10 +190,10 @@ def test(
...
@@ -190,10 +190,10 @@ def test(
descriptor
,
descriptor
,
workspace
.
data
(),
workspace
.
data
(),
workspace_size
.
value
,
workspace_size
.
value
,
k_in
.
data
(),
v_in
.
data
(),
k_cache_pool
.
data
(),
k_cache_pool
.
data
(),
v_cache_pool
.
data
(),
v_cache_pool
.
data
(),
k_in
.
data
(),
v_in
.
data
(),
slot_mapping
.
data
(),
slot_mapping
.
data
(),
None
,
None
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment