Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
c70805c9
Commit
c70805c9
authored
Mar 02, 2026
by
xgqdut2016
Committed by
wooway777
Mar 03, 2026
Browse files
issue/1035: kv caching on nvidia
parent
abd45713
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
662 additions
and
0 deletions
+662
-0
src/infiniop/ops/kv_caching/cuda/kernel.cuh
src/infiniop/ops/kv_caching/cuda/kernel.cuh
+63
-0
src/infiniop/ops/kv_caching/info.h
src/infiniop/ops/kv_caching/info.h
+105
-0
src/infiniop/ops/kv_caching/kv_caching.h
src/infiniop/ops/kv_caching/kv_caching.h
+49
-0
src/infiniop/ops/kv_caching/nvidia/kv_caching_nvidia.cu
src/infiniop/ops/kv_caching/nvidia/kv_caching_nvidia.cu
+159
-0
src/infiniop/ops/kv_caching/nvidia/kv_caching_nvidia.cuh
src/infiniop/ops/kv_caching/nvidia/kv_caching_nvidia.cuh
+7
-0
src/infiniop/ops/kv_caching/operator.cc
src/infiniop/ops/kv_caching/operator.cc
+32
-0
test/infiniop/kv_caching.py
test/infiniop/kv_caching.py
+205
-0
test/infiniop/libinfiniop/op_register.py
test/infiniop/libinfiniop/op_register.py
+42
-0
No files found.
src/infiniop/ops/kv_caching/cuda/kernel.cuh
0 → 100644
View file @
c70805c9
#ifndef __KV_CACHING_KERNEL_CUH__
#define __KV_CACHING_KERNEL_CUH__
template
<
typename
Tdata
>
__device__
void
kvCachingKernel
(
Tdata
*
__restrict__
k_cache
,
Tdata
*
__restrict__
v_cache
,
const
Tdata
*
__restrict__
k
,
const
Tdata
*
__restrict__
v
,
const
int64_t
*
__restrict__
past_kv_lengths
,
int
batch_size
,
int
num_kv_heads
,
int
max_seq_len
,
int
seq_len
,
int
hidden_dim
,
ptrdiff_t
k_cache_strides_0
,
ptrdiff_t
k_cache_strides_1
,
ptrdiff_t
k_cache_strides_2
,
ptrdiff_t
k_cache_strides_3
,
ptrdiff_t
v_cache_strides_0
,
ptrdiff_t
v_cache_strides_1
,
ptrdiff_t
v_cache_strides_2
,
ptrdiff_t
v_cache_strides_3
,
ptrdiff_t
k_strides_0
,
ptrdiff_t
k_strides_1
,
ptrdiff_t
k_strides_2
,
ptrdiff_t
k_strides_3
,
ptrdiff_t
v_strides_0
,
ptrdiff_t
v_strides_1
,
ptrdiff_t
v_strides_2
,
ptrdiff_t
v_strides_3
)
{
// 总元素数 = B * H * seq_len * D
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
total
=
batch_size
*
num_kv_heads
*
seq_len
*
hidden_dim
;
const
int
grid_size
=
blockDim
.
x
*
gridDim
.
x
;
for
(
int
idx
=
tid
;
idx
<
total
;
idx
+=
grid_size
)
{
// 反解 index
int
d
=
idx
%
hidden_dim
;
idx
/=
hidden_dim
;
int
s
=
idx
%
seq_len
;
idx
/=
seq_len
;
int
h
=
idx
%
num_kv_heads
;
int
b
=
idx
/
num_kv_heads
;
int
past_len
=
static_cast
<
int32_t
>
(
past_kv_lengths
[
b
]);
// 写入位置
int
cache_s
=
past_len
+
s
;
int
k_cache_offset
=
d
*
(
int
)
k_cache_strides_3
+
cache_s
*
(
int
)
k_cache_strides_2
+
h
*
(
int
)
k_cache_strides_1
+
b
*
(
int
)
k_cache_strides_0
;
int
v_cache_offset
=
d
*
(
int
)
v_cache_strides_3
+
cache_s
*
(
int
)
v_cache_strides_2
+
h
*
(
int
)
v_cache_strides_1
+
b
*
(
int
)
v_cache_strides_0
;
int
k_src_offset
=
d
*
(
int
)
k_strides_3
+
s
*
(
int
)
k_strides_2
+
h
*
(
int
)
k_strides_1
+
b
*
(
int
)
k_strides_0
;
int
v_src_offset
=
d
*
(
int
)
v_strides_3
+
s
*
(
int
)
v_strides_2
+
h
*
(
int
)
v_strides_1
+
b
*
(
int
)
v_strides_0
;
k_cache
[
k_cache_offset
]
=
k
[
k_src_offset
];
v_cache
[
v_cache_offset
]
=
v
[
v_src_offset
];
}
}
#endif // __KV_CACHING_KERNEL_CUH__
src/infiniop/ops/kv_caching/info.h
0 → 100644
View file @
c70805c9
#ifndef __KV_CACHING_INFO_H__
#define __KV_CACHING_INFO_H__
#include "../../../utils.h"
#include "../../operator.h"
#include "../../tensor.h"
namespace
op
::
kv_caching
{
class
KVCachingInfo
{
private:
KVCachingInfo
()
=
default
;
public:
infiniDtype_t
dtype
;
size_t
batch_size
,
num_kv_heads
,
max_seq_len
,
seq_len
,
hidden_dim
;
ptrdiff_t
k_cache_strides_0
,
k_cache_strides_1
,
k_cache_strides_2
,
k_cache_strides_3
;
ptrdiff_t
v_cache_strides_0
,
v_cache_strides_1
,
v_cache_strides_2
,
v_cache_strides_3
;
ptrdiff_t
k_strides_0
,
k_strides_1
,
k_strides_2
,
k_strides_3
;
ptrdiff_t
v_strides_0
,
v_strides_1
,
v_strides_2
,
v_strides_3
;
static
utils
::
Result
<
KVCachingInfo
>
createKVCachingInfo
(
infiniopTensorDescriptor_t
k_cache
,
infiniopTensorDescriptor_t
v_cache
,
infiniopTensorDescriptor_t
k
,
infiniopTensorDescriptor_t
v
,
infiniopTensorDescriptor_t
past_kv_lengths
)
{
CHECK_OR_RETURN
(
k_cache
!=
nullptr
&&
v_cache
!=
nullptr
&&
k
!=
nullptr
&&
v
!=
nullptr
&&
past_kv_lengths
!=
nullptr
,
INFINI_STATUS_NULL_POINTER
);
const
infiniDtype_t
dtype
=
k_cache
->
dtype
();
CHECK_DTYPE
(
dtype
,
INFINI_DTYPE_F16
,
INFINI_DTYPE_BF16
,
INFINI_DTYPE_F32
);
CHECK_OR_RETURN
(
k_cache
->
ndim
()
==
4
&&
v_cache
->
ndim
()
==
4
&&
k
->
ndim
()
==
4
&&
v
->
ndim
()
==
4
,
INFINI_STATUS_BAD_TENSOR_SHAPE
);
auto
shape
=
k_cache
->
shape
();
CHECK_SAME_SHAPE
(
shape
,
v_cache
->
shape
());
CHECK_SAME_SHAPE
(
k
->
shape
(),
v
->
shape
());
size_t
batch_size
=
shape
[
0
];
size_t
num_kv_heads
=
shape
[
1
];
size_t
max_seq_len
=
shape
[
2
];
size_t
hidden_dim
=
shape
[
3
];
size_t
seq_len
=
k
->
shape
()[
2
];
CHECK_OR_RETURN
(
batch_size
==
k
->
dim
(
0
)
||
num_kv_heads
==
k
->
dim
(
1
)
||
hidden_dim
==
k
->
dim
(
3
),
INFINI_STATUS_BAD_TENSOR_SHAPE
);
ptrdiff_t
k_cache_strides_0
=
k_cache
->
strides
()[
0
];
ptrdiff_t
k_cache_strides_1
=
k_cache
->
strides
()[
1
];
ptrdiff_t
k_cache_strides_2
=
k_cache
->
strides
()[
2
];
ptrdiff_t
k_cache_strides_3
=
k_cache
->
strides
()[
3
];
ptrdiff_t
v_cache_strides_0
=
v_cache
->
strides
()[
0
];
ptrdiff_t
v_cache_strides_1
=
v_cache
->
strides
()[
1
];
ptrdiff_t
v_cache_strides_2
=
v_cache
->
strides
()[
2
];
ptrdiff_t
v_cache_strides_3
=
v_cache
->
strides
()[
3
];
ptrdiff_t
k_strides_0
=
k
->
strides
()[
0
];
ptrdiff_t
k_strides_1
=
k
->
strides
()[
1
];
ptrdiff_t
k_strides_2
=
k
->
strides
()[
2
];
ptrdiff_t
k_strides_3
=
k
->
strides
()[
3
];
ptrdiff_t
v_strides_0
=
v
->
strides
()[
0
];
ptrdiff_t
v_strides_1
=
v
->
strides
()[
1
];
ptrdiff_t
v_strides_2
=
v
->
strides
()[
2
];
ptrdiff_t
v_strides_3
=
v
->
strides
()[
3
];
return
utils
::
Result
<
KVCachingInfo
>
(
KVCachingInfo
{
dtype
,
batch_size
,
num_kv_heads
,
max_seq_len
,
seq_len
,
hidden_dim
,
k_cache_strides_0
,
k_cache_strides_1
,
k_cache_strides_2
,
k_cache_strides_3
,
v_cache_strides_0
,
v_cache_strides_1
,
v_cache_strides_2
,
v_cache_strides_3
,
k_strides_0
,
k_strides_1
,
k_strides_2
,
k_strides_3
,
v_strides_0
,
v_strides_1
,
v_strides_2
,
v_strides_3
});
}
};
}
// namespace op::kv_caching
#endif // __KV_CACHING_INFO_H__
src/infiniop/ops/kv_caching/kv_caching.h
0 → 100644
View file @
c70805c9
#ifndef KV_CACHING_H
#define KV_CACHING_H
#include "../../operator.h"
#include "info.h"
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::kv_caching::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
KVCachingInfo _info; \
size_t _workspace_size; \
\
Descriptor( \
Opaque *opaque, \
KVCachingInfo info, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t get_workspace_size() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t k_cache, \
infiniopTensorDescriptor_t v_cache, \
infiniopTensorDescriptor_t k, \
infiniopTensorDescriptor_t v, \
infiniopTensorDescriptor_t past_kv_lengths); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *k_cache, void *v_cache, \
const void *k, const void *v, const void *past_kv_lengths, \
void *stream) const; \
}; \
}
#endif // KV_CACHING_H
src/infiniop/ops/kv_caching/nvidia/kv_caching_nvidia.cu
0 → 100644
View file @
c70805c9
#include "../../../devices/nvidia/nvidia_common.cuh"
#include "kv_caching_nvidia.cuh"
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
#include <cub/block/block_reduce.cuh>
#include "../../../reduce/cuda/reduce.cuh"
#include "../cuda/kernel.cuh"
template
<
typename
Tdata
>
INFINIOP_CUDA_KERNEL
kvCaching
(
Tdata
*
k_cache
,
Tdata
*
v_cache
,
const
Tdata
*
k
,
const
Tdata
*
v
,
const
int64_t
*
past_kv_lengths
,
int
batch_size
,
int
num_kv_heads
,
int
max_seq_len
,
int
seq_len
,
int
hidden_dim
,
ptrdiff_t
k_cache_strides_0
,
ptrdiff_t
k_cache_strides_1
,
ptrdiff_t
k_cache_strides_2
,
ptrdiff_t
k_cache_strides_3
,
ptrdiff_t
v_cache_strides_0
,
ptrdiff_t
v_cache_strides_1
,
ptrdiff_t
v_cache_strides_2
,
ptrdiff_t
v_cache_strides_3
,
ptrdiff_t
k_strides_0
,
ptrdiff_t
k_strides_1
,
ptrdiff_t
k_strides_2
,
ptrdiff_t
k_strides_3
,
ptrdiff_t
v_strides_0
,
ptrdiff_t
v_strides_1
,
ptrdiff_t
v_strides_2
,
ptrdiff_t
v_strides_3
)
{
kvCachingKernel
<
Tdata
>
(
k_cache
,
v_cache
,
k
,
v
,
past_kv_lengths
,
batch_size
,
num_kv_heads
,
max_seq_len
,
seq_len
,
hidden_dim
,
k_cache_strides_0
,
k_cache_strides_1
,
k_cache_strides_2
,
k_cache_strides_3
,
v_cache_strides_0
,
v_cache_strides_1
,
v_cache_strides_2
,
v_cache_strides_3
,
k_strides_0
,
k_strides_1
,
k_strides_2
,
k_strides_3
,
v_strides_0
,
v_strides_1
,
v_strides_2
,
v_strides_3
);
}
namespace
op
::
kv_caching
::
nvidia
{
struct
Descriptor
::
Opaque
{
std
::
shared_ptr
<
device
::
nvidia
::
Handle
::
Internal
>
internal
;
};
Descriptor
::~
Descriptor
()
{
delete
_opaque
;
}
infiniStatus_t
Descriptor
::
create
(
infiniopHandle_t
handle
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
k_cache
,
infiniopTensorDescriptor_t
v_cache
,
infiniopTensorDescriptor_t
k
,
infiniopTensorDescriptor_t
v
,
infiniopTensorDescriptor_t
past_kv_lengths
)
{
auto
info
=
KVCachingInfo
::
createKVCachingInfo
(
k_cache
,
v_cache
,
k
,
v
,
past_kv_lengths
);
CHECK_RESULT
(
info
);
*
desc_ptr
=
new
Descriptor
(
new
Opaque
{
reinterpret_cast
<
device
::
nvidia
::
Handle
*>
(
handle
)
->
internal
()},
info
.
take
(),
0
,
handle
->
device
,
handle
->
device_id
);
return
INFINI_STATUS_SUCCESS
;
}
template
<
unsigned
int
BLOCK_SIZE
,
typename
Tdata
>
infiniStatus_t
launchKernel
(
const
KVCachingInfo
&
info
,
Tdata
*
k_cache
,
Tdata
*
v_cache
,
const
Tdata
*
k
,
const
Tdata
*
v
,
const
int64_t
*
past_kv_lengths
,
cudaStream_t
stream
,
void
*
workspace
)
{
int
batch_size
=
static_cast
<
int
>
(
info
.
batch_size
);
int
num_kv_heads
=
static_cast
<
int
>
(
info
.
num_kv_heads
);
int
max_seq_len
=
static_cast
<
int
>
(
info
.
max_seq_len
);
int
hidden_dim
=
static_cast
<
int
>
(
info
.
hidden_dim
);
int
seq_len
=
static_cast
<
int
>
(
info
.
seq_len
);
int
total
=
batch_size
*
num_kv_heads
*
seq_len
*
hidden_dim
;
ptrdiff_t
k_cache_strides_0
=
info
.
k_cache_strides_0
;
ptrdiff_t
k_cache_strides_1
=
info
.
k_cache_strides_1
;
ptrdiff_t
k_cache_strides_2
=
info
.
k_cache_strides_2
;
ptrdiff_t
k_cache_strides_3
=
info
.
k_cache_strides_3
;
ptrdiff_t
v_cache_strides_0
=
info
.
v_cache_strides_0
;
ptrdiff_t
v_cache_strides_1
=
info
.
v_cache_strides_1
;
ptrdiff_t
v_cache_strides_2
=
info
.
v_cache_strides_2
;
ptrdiff_t
v_cache_strides_3
=
info
.
v_cache_strides_3
;
ptrdiff_t
k_strides_0
=
info
.
k_strides_0
;
ptrdiff_t
k_strides_1
=
info
.
k_strides_1
;
ptrdiff_t
k_strides_2
=
info
.
k_strides_2
;
ptrdiff_t
k_strides_3
=
info
.
k_strides_3
;
ptrdiff_t
v_strides_0
=
info
.
v_strides_0
;
ptrdiff_t
v_strides_1
=
info
.
v_strides_1
;
ptrdiff_t
v_strides_2
=
info
.
v_strides_2
;
ptrdiff_t
v_strides_3
=
info
.
v_strides_3
;
int
num_blocks
=
(
total
+
BLOCK_SIZE
-
1
)
/
BLOCK_SIZE
;
kvCaching
<
Tdata
>
<<<
num_blocks
,
BLOCK_SIZE
,
0
,
stream
>>>
(
k_cache
,
v_cache
,
k
,
v
,
past_kv_lengths
,
batch_size
,
num_kv_heads
,
max_seq_len
,
seq_len
,
hidden_dim
,
k_cache_strides_0
,
k_cache_strides_1
,
k_cache_strides_2
,
k_cache_strides_3
,
v_cache_strides_0
,
v_cache_strides_1
,
v_cache_strides_2
,
v_cache_strides_3
,
k_strides_0
,
k_strides_1
,
k_strides_2
,
k_strides_3
,
v_strides_0
,
v_strides_1
,
v_strides_2
,
v_strides_3
);
return
INFINI_STATUS_SUCCESS
;
}
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
k_cache
,
void
*
v_cache
,
const
void
*
k
,
const
void
*
v
,
const
void
*
past_kv_lengths
,
void
*
stream_
)
const
{
cudaStream_t
stream
=
(
cudaStream_t
)
stream_
;
#define CALCULATE_KV_CACHING(BLOCK_SIZE, TDATA) \
launchKernel<BLOCK_SIZE, TDATA>(_info, (TDATA *)k_cache, (TDATA *)v_cache, (const TDATA *)k, (const TDATA *)v, (const int64_t *)past_kv_lengths, stream, workspace)
#define CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(BLOCK_SIZE) \
{ \
if (_info.dtype == INFINI_DTYPE_F16) \
return CALCULATE_KV_CACHING(BLOCK_SIZE, half); \
else if (_info.dtype == INFINI_DTYPE_F32) \
return CALCULATE_KV_CACHING(BLOCK_SIZE, float); \
else if (_info.dtype == INFINI_DTYPE_BF16) \
return CALCULATE_KV_CACHING(BLOCK_SIZE, __nv_bfloat16); \
else \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
}
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_1024
)
{
CALCULATE_KV_CACHING_WITH_BLOCK_SIZE
(
CUDA_BLOCK_SIZE_1024
)
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_512
)
{
CALCULATE_KV_CACHING_WITH_BLOCK_SIZE
(
CUDA_BLOCK_SIZE_512
)
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_2048
)
{
CALCULATE_KV_CACHING_WITH_BLOCK_SIZE
(
CUDA_BLOCK_SIZE_2048
)
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_4096
)
{
CALCULATE_KV_CACHING_WITH_BLOCK_SIZE
(
CUDA_BLOCK_SIZE_4096
)
}
else
{
return
INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED
;
}
return
INFINI_STATUS_SUCCESS
;
}
}
// namespace op::kv_caching::nvidia
src/infiniop/ops/kv_caching/nvidia/kv_caching_nvidia.cuh
0 → 100644
View file @
c70805c9
#ifndef __KV_CACHING_NVIDIA_API_H__
#define __KV_CACHING_NVIDIA_API_H__
#include "../kv_caching.h"
DESCRIPTOR
(
nvidia
)
#endif // __KV_CACHING_NVIDIA_API_H__
src/infiniop/ops/kv_caching/operator.cc
View file @
c70805c9
...
...
@@ -8,6 +8,10 @@
#endif
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API)
#include "nvidia/kv_caching_nvidia.cuh"
#endif
__C
infiniStatus_t
infiniopCreateKVCachingDescriptor
(
infiniopHandle_t
handle
,
infiniopKVCachingDescriptor_t
*
desc_ptr
,
...
...
@@ -42,6 +46,13 @@ __C infiniStatus_t infiniopCreateKVCachingDescriptor(
#endif
#endif
#ifdef ENABLE_NVIDIA_API
CREATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
);
#endif
#ifdef ENABLE_QY_API
CREATE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
...
...
@@ -71,6 +82,13 @@ __C infiniStatus_t infiniopGetKVCachingWorkspaceSize(
#if defined(ENABLE_METAX_API)
GET_SIZE
(
INFINI_DEVICE_METAX
,
ninetoothed
);
#endif
#endif
#ifdef ENABLE_NVIDIA_API
GET_SIZE
(
INFINI_DEVICE_NVIDIA
,
nvidia
);
#endif
#ifdef ENABLE_QY_API
GET_SIZE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
@@ -107,6 +125,13 @@ __C infiniStatus_t infiniopKVCaching(
#if defined(ENABLE_METAX_API)
CALCULATE
(
INFINI_DEVICE_METAX
,
ninetoothed
);
#endif
#endif
#ifdef ENABLE_NVIDIA_API
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
);
#endif
#ifdef ENABLE_QY_API
CALCULATE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
@@ -135,6 +160,13 @@ __C infiniStatus_t infiniopDestroyKVCachingDescriptor(
#if defined(ENABLE_METAX_API)
DELETE
(
INFINI_DEVICE_METAX
,
ninetoothed
);
#endif
#endif
#ifdef ENABLE_NVIDIA_API
DELETE
(
INFINI_DEVICE_NVIDIA
,
nvidia
);
#endif
#ifdef ENABLE_QY_API
DELETE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
test/infiniop/kv_caching.py
0 → 100644
View file @
c70805c9
import
torch
import
ctypes
from
ctypes
import
c_uint64
from
libinfiniop
import
(
LIBINFINIOP
,
TestTensor
,
get_test_devices
,
check_error
,
test_operator
,
get_args
,
debug
,
get_tolerance
,
profile_operation
,
InfiniDtype
,
InfiniDtypeNames
,
InfiniDeviceNames
,
infiniopOperatorDescriptor_t
,
TestWorkspace
,
)
# ==============================================================================
# Reference Implementation
# ==============================================================================
def
torch_kv_caching
(
k_cache
,
v_cache
,
k
,
v
,
past_kv_lengths
):
#k_cache.shape=[batch_size, num_kv_heads, max_seq_len, hidden_dim]
#v_cache.shape=[batch_size, num_kv_heads, max_seq_len, hidden_dim]
#k.shape=[batch_size, num_kv_heads, seq_len, hidden_dim]
#v.shape=[batch_size, num_kv_heads, seq_len, hidden_dim]
#past_kv_lengths.shape = [batch_size]
batch_size
,
num_kv_heads
,
_
,
head_dim
=
k_cache
.
shape
seq_len
=
k
.
shape
[
2
]
for
b
in
range
(
batch_size
):
past_len
=
past_kv_lengths
[
b
].
item
()
for
h
in
range
(
num_kv_heads
):
k_cache
[
b
,
h
,
past_len
:
past_len
+
seq_len
,
:]
=
k
[
b
,
h
,
:,
:]
v_cache
[
b
,
h
,
past_len
:
past_len
+
seq_len
,
:]
=
v
[
b
,
h
,
:,
:]
return
k_cache
,
v_cache
# ==============================================================================
# Test Configuration (Internal Use Only)
# ==============================================================================
_TEST_CASES_
=
[
# (num_seqs, num_kv_heads, max_seq_len, hidden_dim), strides
((
1
,
1
,
8
,
1
),
None
),
((
1
,
8
,
32
,
32
),
None
),
((
8
,
8
,
64
,
32
),
None
),
((
1
,
32
,
8
,
64
),
(
32768
,
1024
,
64
,
1
)),
((
4
,
8
,
32
,
16
),
(
65536
,
8192
,
256
,
16
)),
((
8
,
16
,
64
,
128
),
(
8388608
,
524288
,
8192
,
1
)),
((
1
,
2
,
2304
,
128
),
(
589824
,
294912
,
128
,
1
)),
]
# Data types for testing
_TENSOR_DTYPES
=
[
InfiniDtype
.
BF16
,
InfiniDtype
.
F16
,
InfiniDtype
.
F32
]
# Tolerance map for different data types
_TOLERANCE_MAP
=
{
InfiniDtype
.
F16
:
{
"atol"
:
0
,
"rtol"
:
0
},
InfiniDtype
.
BF16
:
{
"atol"
:
0
,
"rtol"
:
0
},
InfiniDtype
.
F32
:
{
"atol"
:
0
,
"rtol"
:
0
},
}
# Global flags for controlling test behavior
DEBUG
=
False
PROFILE
=
False
NUM_PRERUN
=
10
NUM_ITERATIONS
=
100
def
test
(
handle
,
device
,
cache_shape
,
strides
,
dtype
=
InfiniDtype
.
F16
,
sync
=
None
,
):
print
(
f
"Testing KVCaching on
{
InfiniDeviceNames
[
device
]
}
with cache_shape:
{
cache_shape
}
, strides:
{
strides
}
, dtype=
{
InfiniDtypeNames
[
dtype
]
}
"
)
import
random
kv_shape
=
(
cache_shape
[
0
],
cache_shape
[
1
],
random
.
randrange
(
1
,
cache_shape
[
2
]),
cache_shape
[
3
],
)
past_shape
=
(
cache_shape
[
0
],)
k_cache
=
TestTensor
(
cache_shape
,
strides
,
dtype
,
device
)
v_cache
=
TestTensor
(
cache_shape
,
strides
,
dtype
,
device
)
k
=
TestTensor
(
kv_shape
,
None
,
dtype
,
device
)
v
=
TestTensor
(
kv_shape
,
None
,
dtype
,
device
)
past_kv_lengths
=
TestTensor
(
past_shape
,
None
,
InfiniDtype
.
I64
,
device
,
randint_low
=
0
,
randint_high
=
cache_shape
[
2
]
-
kv_shape
[
2
])
# Run reference implementation
k_cache_ref
,
v_cache_ref
=
torch_kv_caching
(
k_cache
.
torch_tensor
(),
v_cache
.
torch_tensor
(),
k
.
torch_tensor
(),
v
.
torch_tensor
(),
past_kv_lengths
.
torch_tensor
())
if
sync
:
sync
()
# Create operator descriptor
descriptor
=
infiniopOperatorDescriptor_t
()
check_error
(
LIBINFINIOP
.
infiniopCreateKVCachingDescriptor
(
handle
,
ctypes
.
byref
(
descriptor
),
k_cache
.
descriptor
,
v_cache
.
descriptor
,
k
.
descriptor
,
v
.
descriptor
,
past_kv_lengths
.
descriptor
,
)
)
# Get workspace size (likely 0 for this operator, but good practice to include)
workspace_size
=
c_uint64
(
0
)
check_error
(
LIBINFINIOP
.
infiniopGetKVCachingWorkspaceSize
(
descriptor
,
ctypes
.
byref
(
workspace_size
)
)
)
workspace
=
TestWorkspace
(
workspace_size
.
value
,
device
)
# Invalidate descriptors to ensure kernel does not rely on them
k
.
destroy_desc
()
v
.
destroy_desc
()
k_cache
.
destroy_desc
()
v_cache
.
destroy_desc
()
past_kv_lengths
.
destroy_desc
()
# Define the library call as a lambda for profiling
def
lib_kv_caching
():
check_error
(
LIBINFINIOP
.
infiniopKVCaching
(
descriptor
,
workspace
.
data
(),
workspace_size
.
value
,
k_cache
.
data
(),
v_cache
.
data
(),
k
.
data
(),
v
.
data
(),
past_kv_lengths
.
data
(),
None
,
)
)
# Execute the custom operator
lib_kv_caching
()
if
sync
:
sync
()
# Verify correctness
atol
,
rtol
=
get_tolerance
(
_TOLERANCE_MAP
,
dtype
)
if
DEBUG
:
print
(
"Verifying K cache..."
)
debug
(
k_cache
.
actual_tensor
(),
k_cache_ref
,
atol
=
atol
,
rtol
=
rtol
)
print
(
"Verifying V cache..."
)
debug
(
v_cache
.
actual_tensor
(),
v_cache_ref
,
atol
=
atol
,
rtol
=
rtol
)
assert
torch
.
allclose
(
k_cache
.
actual_tensor
(),
k_cache_ref
,
atol
=
atol
,
rtol
=
rtol
)
assert
torch
.
allclose
(
v_cache
.
actual_tensor
(),
v_cache_ref
,
atol
=
atol
,
rtol
=
rtol
)
# Profiling workflow
if
PROFILE
:
# fmt: off
profile_operation
(
"PyTorch"
,
lambda
:
torch_kv_caching
(
k_cache
.
torch_tensor
(),
v_cache
.
torch_tensor
(),
k
.
torch_tensor
(),
v
.
torch_tensor
(),
past_kv_lengths
.
torch_tensor
()),
device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
profile_operation
(
" lib"
,
lib_kv_caching
,
device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
# fmt: on
# Clean up resources
check_error
(
LIBINFINIOP
.
infiniopDestroyKVCachingDescriptor
(
descriptor
))
if
__name__
==
"__main__"
:
args
=
get_args
()
# Configure testing options from command line arguments
DEBUG
=
args
.
debug
PROFILE
=
args
.
profile
NUM_PRERUN
=
args
.
num_prerun
NUM_ITERATIONS
=
args
.
num_iterations
for
device
in
get_test_devices
(
args
):
test_operator
(
device
,
test
,
_TEST_CASES_
,
_TENSOR_DTYPES
)
print
(
"
\033
[92mTest passed!
\033
[0m"
)
test/infiniop/libinfiniop/op_register.py
View file @
c70805c9
...
...
@@ -1054,6 +1054,48 @@ def scaled_mm_int8_(lib):
]
@
OpRegister
.
operator
def
kv_caching_
(
lib
):
lib
.
infiniopCreateKVCachingDescriptor
.
restype
=
c_int32
lib
.
infiniopCreateKVCachingDescriptor
.
argtypes
=
[
infiniopHandle_t
,
POINTER
(
infiniopOperatorDescriptor_t
),
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
]
lib
.
infiniopGetKVCachingWorkspaceSize
.
restype
=
c_int32
lib
.
infiniopGetKVCachingWorkspaceSize
.
argtypes
=
[
infiniopOperatorDescriptor_t
,
POINTER
(
c_size_t
),
]
lib
.
infiniopKVCaching
.
restype
=
c_int32
lib
.
infiniopKVCaching
.
argtypes
=
[
infiniopOperatorDescriptor_t
,
c_void_p
,
c_size_t
,
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
]
lib
.
infiniopDestroyKVCachingDescriptor
.
restype
=
c_int32
lib
.
infiniopDestroyKVCachingDescriptor
.
argtypes
=
[
infiniopOperatorDescriptor_t
,
]
@
OpRegister
.
operator
def
paged_attention_
(
lib
):
lib
.
infiniopCreatePagedAttentionDescriptor
.
restype
=
c_int32
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment