Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
83e28ec3
Unverified
Commit
83e28ec3
authored
Mar 04, 2026
by
thatPepe
Committed by
GitHub
Mar 04, 2026
Browse files
Merge pull request #1039 from InfiniTensor/issue/1035
issue/1035: 添加NVIDIA平台上的kv_caching算子
parents
811ffab3
af394a32
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
862 additions
and
31 deletions
+862
-31
src/infiniop/ops/kv_caching/cuda/kernel.cuh
src/infiniop/ops/kv_caching/cuda/kernel.cuh
+63
-0
src/infiniop/ops/kv_caching/info.h
src/infiniop/ops/kv_caching/info.h
+106
-0
src/infiniop/ops/kv_caching/kv_caching.h
src/infiniop/ops/kv_caching/kv_caching.h
+49
-0
src/infiniop/ops/kv_caching/metax/kv_caching_metax.h
src/infiniop/ops/kv_caching/metax/kv_caching_metax.h
+7
-0
src/infiniop/ops/kv_caching/metax/kv_caching_metax.maca
src/infiniop/ops/kv_caching/metax/kv_caching_metax.maca
+160
-0
src/infiniop/ops/kv_caching/nvidia/kv_caching_nvidia.cu
src/infiniop/ops/kv_caching/nvidia/kv_caching_nvidia.cu
+159
-0
src/infiniop/ops/kv_caching/nvidia/kv_caching_nvidia.cuh
src/infiniop/ops/kv_caching/nvidia/kv_caching_nvidia.cuh
+7
-0
src/infiniop/ops/kv_caching/operator.cc
src/infiniop/ops/kv_caching/operator.cc
+64
-31
test/infiniop/kv_caching.py
test/infiniop/kv_caching.py
+205
-0
test/infiniop/libinfiniop/op_register.py
test/infiniop/libinfiniop/op_register.py
+42
-0
No files found.
src/infiniop/ops/kv_caching/cuda/kernel.cuh
0 → 100644
View file @
83e28ec3
#ifndef __KV_CACHING_KERNEL_CUH__
#define __KV_CACHING_KERNEL_CUH__
template
<
typename
Tdata
>
__device__
void
kvCachingKernel
(
Tdata
*
__restrict__
k_cache
,
Tdata
*
__restrict__
v_cache
,
const
Tdata
*
__restrict__
k
,
const
Tdata
*
__restrict__
v
,
const
int64_t
*
__restrict__
past_kv_lengths
,
int
batch_size
,
int
num_kv_heads
,
int
max_seq_len
,
int
seq_len
,
int
hidden_dim
,
ptrdiff_t
k_cache_strides_0
,
ptrdiff_t
k_cache_strides_1
,
ptrdiff_t
k_cache_strides_2
,
ptrdiff_t
k_cache_strides_3
,
ptrdiff_t
v_cache_strides_0
,
ptrdiff_t
v_cache_strides_1
,
ptrdiff_t
v_cache_strides_2
,
ptrdiff_t
v_cache_strides_3
,
ptrdiff_t
k_strides_0
,
ptrdiff_t
k_strides_1
,
ptrdiff_t
k_strides_2
,
ptrdiff_t
k_strides_3
,
ptrdiff_t
v_strides_0
,
ptrdiff_t
v_strides_1
,
ptrdiff_t
v_strides_2
,
ptrdiff_t
v_strides_3
)
{
// num of ele = B * H * seq_len * D
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
total
=
batch_size
*
num_kv_heads
*
seq_len
*
hidden_dim
;
const
int
grid_size
=
blockDim
.
x
*
gridDim
.
x
;
for
(
int
idx
=
tid
;
idx
<
total
;
idx
+=
grid_size
)
{
// unravel index
int
d
=
idx
%
hidden_dim
;
idx
/=
hidden_dim
;
int
s
=
idx
%
seq_len
;
idx
/=
seq_len
;
int
h
=
idx
%
num_kv_heads
;
int
b
=
idx
/
num_kv_heads
;
int
past_len
=
static_cast
<
int32_t
>
(
past_kv_lengths
[
b
]);
// write position
int
cache_s
=
past_len
+
s
;
int
k_cache_offset
=
d
*
(
int
)
k_cache_strides_3
+
cache_s
*
(
int
)
k_cache_strides_2
+
h
*
(
int
)
k_cache_strides_1
+
b
*
(
int
)
k_cache_strides_0
;
int
v_cache_offset
=
d
*
(
int
)
v_cache_strides_3
+
cache_s
*
(
int
)
v_cache_strides_2
+
h
*
(
int
)
v_cache_strides_1
+
b
*
(
int
)
v_cache_strides_0
;
int
k_src_offset
=
d
*
(
int
)
k_strides_3
+
s
*
(
int
)
k_strides_2
+
h
*
(
int
)
k_strides_1
+
b
*
(
int
)
k_strides_0
;
int
v_src_offset
=
d
*
(
int
)
v_strides_3
+
s
*
(
int
)
v_strides_2
+
h
*
(
int
)
v_strides_1
+
b
*
(
int
)
v_strides_0
;
k_cache
[
k_cache_offset
]
=
k
[
k_src_offset
];
v_cache
[
v_cache_offset
]
=
v
[
v_src_offset
];
}
}
#endif // __KV_CACHING_KERNEL_CUH__
src/infiniop/ops/kv_caching/info.h
0 → 100644
View file @
83e28ec3
#ifndef __KV_CACHING_INFO_H__
#define __KV_CACHING_INFO_H__
#include "../../../utils.h"
#include "../../operator.h"
#include "../../tensor.h"
namespace
op
::
kv_caching
{
class
KVCachingInfo
{
private:
KVCachingInfo
()
=
default
;
public:
infiniDtype_t
dtype
;
size_t
batch_size
,
num_kv_heads
,
max_seq_len
,
seq_len
,
hidden_dim
;
ptrdiff_t
k_cache_strides_0
,
k_cache_strides_1
,
k_cache_strides_2
,
k_cache_strides_3
;
ptrdiff_t
v_cache_strides_0
,
v_cache_strides_1
,
v_cache_strides_2
,
v_cache_strides_3
;
ptrdiff_t
k_strides_0
,
k_strides_1
,
k_strides_2
,
k_strides_3
;
ptrdiff_t
v_strides_0
,
v_strides_1
,
v_strides_2
,
v_strides_3
;
static
utils
::
Result
<
KVCachingInfo
>
createKVCachingInfo
(
infiniopTensorDescriptor_t
k_cache
,
infiniopTensorDescriptor_t
v_cache
,
infiniopTensorDescriptor_t
k
,
infiniopTensorDescriptor_t
v
,
infiniopTensorDescriptor_t
past_kv_lengths
)
{
CHECK_OR_RETURN
(
k_cache
!=
nullptr
&&
v_cache
!=
nullptr
&&
k
!=
nullptr
&&
v
!=
nullptr
&&
past_kv_lengths
!=
nullptr
,
INFINI_STATUS_NULL_POINTER
);
const
infiniDtype_t
dtype
=
k_cache
->
dtype
();
CHECK_DTYPE
(
dtype
,
INFINI_DTYPE_F16
,
INFINI_DTYPE_BF16
,
INFINI_DTYPE_F32
);
CHECK_DTYPE
(
past_kv_lengths
->
dtype
(),
INFINI_DTYPE_I64
);
CHECK_OR_RETURN
(
k_cache
->
ndim
()
==
4
&&
v_cache
->
ndim
()
==
4
&&
k
->
ndim
()
==
4
&&
v
->
ndim
()
==
4
,
INFINI_STATUS_BAD_TENSOR_SHAPE
);
auto
shape
=
k_cache
->
shape
();
CHECK_SAME_SHAPE
(
shape
,
v_cache
->
shape
());
CHECK_SAME_SHAPE
(
k
->
shape
(),
v
->
shape
());
size_t
batch_size
=
shape
[
0
];
size_t
num_kv_heads
=
shape
[
1
];
size_t
max_seq_len
=
shape
[
2
];
size_t
hidden_dim
=
shape
[
3
];
size_t
seq_len
=
k
->
shape
()[
2
];
CHECK_OR_RETURN
(
batch_size
==
k
->
dim
(
0
)
||
num_kv_heads
==
k
->
dim
(
1
)
||
hidden_dim
==
k
->
dim
(
3
),
INFINI_STATUS_BAD_TENSOR_SHAPE
);
ptrdiff_t
k_cache_strides_0
=
k_cache
->
strides
()[
0
];
ptrdiff_t
k_cache_strides_1
=
k_cache
->
strides
()[
1
];
ptrdiff_t
k_cache_strides_2
=
k_cache
->
strides
()[
2
];
ptrdiff_t
k_cache_strides_3
=
k_cache
->
strides
()[
3
];
ptrdiff_t
v_cache_strides_0
=
v_cache
->
strides
()[
0
];
ptrdiff_t
v_cache_strides_1
=
v_cache
->
strides
()[
1
];
ptrdiff_t
v_cache_strides_2
=
v_cache
->
strides
()[
2
];
ptrdiff_t
v_cache_strides_3
=
v_cache
->
strides
()[
3
];
ptrdiff_t
k_strides_0
=
k
->
strides
()[
0
];
ptrdiff_t
k_strides_1
=
k
->
strides
()[
1
];
ptrdiff_t
k_strides_2
=
k
->
strides
()[
2
];
ptrdiff_t
k_strides_3
=
k
->
strides
()[
3
];
ptrdiff_t
v_strides_0
=
v
->
strides
()[
0
];
ptrdiff_t
v_strides_1
=
v
->
strides
()[
1
];
ptrdiff_t
v_strides_2
=
v
->
strides
()[
2
];
ptrdiff_t
v_strides_3
=
v
->
strides
()[
3
];
return
utils
::
Result
<
KVCachingInfo
>
(
KVCachingInfo
{
dtype
,
batch_size
,
num_kv_heads
,
max_seq_len
,
seq_len
,
hidden_dim
,
k_cache_strides_0
,
k_cache_strides_1
,
k_cache_strides_2
,
k_cache_strides_3
,
v_cache_strides_0
,
v_cache_strides_1
,
v_cache_strides_2
,
v_cache_strides_3
,
k_strides_0
,
k_strides_1
,
k_strides_2
,
k_strides_3
,
v_strides_0
,
v_strides_1
,
v_strides_2
,
v_strides_3
});
}
};
}
// namespace op::kv_caching
#endif // __KV_CACHING_INFO_H__
src/infiniop/ops/kv_caching/kv_caching.h
0 → 100644
View file @
83e28ec3
#ifndef KV_CACHING_H
#define KV_CACHING_H
#include "../../operator.h"
#include "info.h"
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::kv_caching::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
KVCachingInfo _info; \
size_t _workspace_size; \
\
Descriptor( \
Opaque *opaque, \
KVCachingInfo info, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t get_workspace_size() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t k_cache, \
infiniopTensorDescriptor_t v_cache, \
infiniopTensorDescriptor_t k, \
infiniopTensorDescriptor_t v, \
infiniopTensorDescriptor_t past_kv_lengths); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *k_cache, void *v_cache, \
const void *k, const void *v, const void *past_kv_lengths, \
void *stream) const; \
}; \
}
#endif // KV_CACHING_H
src/infiniop/ops/kv_caching/metax/kv_caching_metax.h
0 → 100644
View file @
83e28ec3
#ifndef __KV_CACHING_METAX_API_H__
#define __KV_CACHING_METAX_API_H__
#include "../kv_caching.h"
DESCRIPTOR
(
metax
)
#endif // __KV_CACHING_METAX_API_H__
src/infiniop/ops/kv_caching/metax/kv_caching_metax.maca
0 → 100644
View file @
83e28ec3
#include "../../../devices/metax/metax_common.h"
#include "kv_caching_metax.h"
#include "../../../devices/metax/metax_kernel_common.h"
#include <cub/block/block_reduce.cuh>
#include "../../../reduce/cuda/reduce.cuh"
#include "../cuda/kernel.cuh"
template <typename Tdata>
INFINIOP_METAX_KERNEL kvCaching(
Tdata *k_cache,
Tdata *v_cache,
const Tdata *k,
const Tdata *v,
const int64_t *past_kv_lengths,
int batch_size,
int num_kv_heads,
int max_seq_len,
int seq_len,
int hidden_dim,
ptrdiff_t k_cache_strides_0,
ptrdiff_t k_cache_strides_1,
ptrdiff_t k_cache_strides_2,
ptrdiff_t k_cache_strides_3,
ptrdiff_t v_cache_strides_0,
ptrdiff_t v_cache_strides_1,
ptrdiff_t v_cache_strides_2,
ptrdiff_t v_cache_strides_3,
ptrdiff_t k_strides_0,
ptrdiff_t k_strides_1,
ptrdiff_t k_strides_2,
ptrdiff_t k_strides_3,
ptrdiff_t v_strides_0,
ptrdiff_t v_strides_1,
ptrdiff_t v_strides_2,
ptrdiff_t v_strides_3) {
kvCachingKernel<Tdata>(k_cache, v_cache, k, v, past_kv_lengths,
batch_size, num_kv_heads, max_seq_len, seq_len, hidden_dim,
k_cache_strides_0, k_cache_strides_1, k_cache_strides_2, k_cache_strides_3,
v_cache_strides_0, v_cache_strides_1, v_cache_strides_2, v_cache_strides_3,
k_strides_0, k_strides_1, k_strides_2, k_strides_3,
v_strides_0, v_strides_1, v_strides_2, v_strides_3);
}
namespace op::kv_caching::metax {
struct Descriptor::Opaque {
std::shared_ptr<device::metax::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t k_cache,
infiniopTensorDescriptor_t v_cache,
infiniopTensorDescriptor_t k,
infiniopTensorDescriptor_t v,
infiniopTensorDescriptor_t past_kv_lengths) {
auto info = KVCachingInfo::createKVCachingInfo(k_cache, v_cache, k, v, past_kv_lengths);
CHECK_RESULT(info);
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
info.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <unsigned int BLOCK_SIZE, typename Tdata>
infiniStatus_t launchKernel(const KVCachingInfo &info,
Tdata *k_cache,
Tdata *v_cache,
const Tdata *k,
const Tdata *v,
const int64_t *past_kv_lengths,
hcStream_t stream, void *workspace) {
int batch_size = static_cast<int>(info.batch_size);
int num_kv_heads = static_cast<int>(info.num_kv_heads);
int max_seq_len = static_cast<int>(info.max_seq_len);
int hidden_dim = static_cast<int>(info.hidden_dim);
int seq_len = static_cast<int>(info.seq_len);
int total = batch_size * num_kv_heads * seq_len * hidden_dim;
ptrdiff_t k_cache_strides_0 = info.k_cache_strides_0;
ptrdiff_t k_cache_strides_1 = info.k_cache_strides_1;
ptrdiff_t k_cache_strides_2 = info.k_cache_strides_2;
ptrdiff_t k_cache_strides_3 = info.k_cache_strides_3;
ptrdiff_t v_cache_strides_0 = info.v_cache_strides_0;
ptrdiff_t v_cache_strides_1 = info.v_cache_strides_1;
ptrdiff_t v_cache_strides_2 = info.v_cache_strides_2;
ptrdiff_t v_cache_strides_3 = info.v_cache_strides_3;
ptrdiff_t k_strides_0 = info.k_strides_0;
ptrdiff_t k_strides_1 = info.k_strides_1;
ptrdiff_t k_strides_2 = info.k_strides_2;
ptrdiff_t k_strides_3 = info.k_strides_3;
ptrdiff_t v_strides_0 = info.v_strides_0;
ptrdiff_t v_strides_1 = info.v_strides_1;
ptrdiff_t v_strides_2 = info.v_strides_2;
ptrdiff_t v_strides_3 = info.v_strides_3;
int num_blocks = (total + BLOCK_SIZE - 1) / BLOCK_SIZE;
kvCaching<Tdata>
<<<num_blocks, BLOCK_SIZE, 0, stream>>>(k_cache, v_cache, k, v, past_kv_lengths,
batch_size, num_kv_heads, max_seq_len, seq_len, hidden_dim,
k_cache_strides_0, k_cache_strides_1, k_cache_strides_2, k_cache_strides_3,
v_cache_strides_0, v_cache_strides_1, v_cache_strides_2, v_cache_strides_3,
k_strides_0, k_strides_1, k_strides_2, k_strides_3,
v_strides_0, v_strides_1, v_strides_2, v_strides_3);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
void *k_cache,
void *v_cache,
const void *k,
const void *v,
const void *past_kv_lengths,
void *stream_) const {
hcStream_t stream = (hcStream_t)stream_;
#define CALCULATE_KV_CACHING(BLOCK_SIZE, TDATA) \
launchKernel<BLOCK_SIZE, TDATA>(_info, (TDATA *)k_cache, (TDATA *)v_cache, (const TDATA *)k, (const TDATA *)v, (const int64_t *)past_kv_lengths, stream, workspace)
#define CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(BLOCK_SIZE) \
{ \
if (_info.dtype == INFINI_DTYPE_F16) \
return CALCULATE_KV_CACHING(BLOCK_SIZE, half); \
else if (_info.dtype == INFINI_DTYPE_F32) \
return CALCULATE_KV_CACHING(BLOCK_SIZE, float); \
else if (_info.dtype == INFINI_DTYPE_BF16) \
return CALCULATE_KV_CACHING(BLOCK_SIZE, __hpcc_bfloat16); \
else \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
}
if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_1024) {
CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(METAX_BLOCK_SIZE_1024)
} else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_512) {
CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(METAX_BLOCK_SIZE_512)
} else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_2048) {
CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(METAX_BLOCK_SIZE_2048)
} else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_4096) {
CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(METAX_BLOCK_SIZE_4096)
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::kv_caching::metax
src/infiniop/ops/kv_caching/nvidia/kv_caching_nvidia.cu
0 → 100644
View file @
83e28ec3
#include "../../../devices/nvidia/nvidia_common.cuh"
#include "kv_caching_nvidia.cuh"
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
#include <cub/block/block_reduce.cuh>
#include "../../../reduce/cuda/reduce.cuh"
#include "../cuda/kernel.cuh"
template
<
typename
Tdata
>
INFINIOP_CUDA_KERNEL
kvCaching
(
Tdata
*
k_cache
,
Tdata
*
v_cache
,
const
Tdata
*
k
,
const
Tdata
*
v
,
const
int64_t
*
past_kv_lengths
,
int
batch_size
,
int
num_kv_heads
,
int
max_seq_len
,
int
seq_len
,
int
hidden_dim
,
ptrdiff_t
k_cache_strides_0
,
ptrdiff_t
k_cache_strides_1
,
ptrdiff_t
k_cache_strides_2
,
ptrdiff_t
k_cache_strides_3
,
ptrdiff_t
v_cache_strides_0
,
ptrdiff_t
v_cache_strides_1
,
ptrdiff_t
v_cache_strides_2
,
ptrdiff_t
v_cache_strides_3
,
ptrdiff_t
k_strides_0
,
ptrdiff_t
k_strides_1
,
ptrdiff_t
k_strides_2
,
ptrdiff_t
k_strides_3
,
ptrdiff_t
v_strides_0
,
ptrdiff_t
v_strides_1
,
ptrdiff_t
v_strides_2
,
ptrdiff_t
v_strides_3
)
{
kvCachingKernel
<
Tdata
>
(
k_cache
,
v_cache
,
k
,
v
,
past_kv_lengths
,
batch_size
,
num_kv_heads
,
max_seq_len
,
seq_len
,
hidden_dim
,
k_cache_strides_0
,
k_cache_strides_1
,
k_cache_strides_2
,
k_cache_strides_3
,
v_cache_strides_0
,
v_cache_strides_1
,
v_cache_strides_2
,
v_cache_strides_3
,
k_strides_0
,
k_strides_1
,
k_strides_2
,
k_strides_3
,
v_strides_0
,
v_strides_1
,
v_strides_2
,
v_strides_3
);
}
namespace
op
::
kv_caching
::
nvidia
{
struct
Descriptor
::
Opaque
{
std
::
shared_ptr
<
device
::
nvidia
::
Handle
::
Internal
>
internal
;
};
Descriptor
::~
Descriptor
()
{
delete
_opaque
;
}
infiniStatus_t
Descriptor
::
create
(
infiniopHandle_t
handle
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
k_cache
,
infiniopTensorDescriptor_t
v_cache
,
infiniopTensorDescriptor_t
k
,
infiniopTensorDescriptor_t
v
,
infiniopTensorDescriptor_t
past_kv_lengths
)
{
auto
info
=
KVCachingInfo
::
createKVCachingInfo
(
k_cache
,
v_cache
,
k
,
v
,
past_kv_lengths
);
CHECK_RESULT
(
info
);
*
desc_ptr
=
new
Descriptor
(
new
Opaque
{
reinterpret_cast
<
device
::
nvidia
::
Handle
*>
(
handle
)
->
internal
()},
info
.
take
(),
0
,
handle
->
device
,
handle
->
device_id
);
return
INFINI_STATUS_SUCCESS
;
}
template
<
unsigned
int
BLOCK_SIZE
,
typename
Tdata
>
infiniStatus_t
launchKernel
(
const
KVCachingInfo
&
info
,
Tdata
*
k_cache
,
Tdata
*
v_cache
,
const
Tdata
*
k
,
const
Tdata
*
v
,
const
int64_t
*
past_kv_lengths
,
cudaStream_t
stream
,
void
*
workspace
)
{
int
batch_size
=
static_cast
<
int
>
(
info
.
batch_size
);
int
num_kv_heads
=
static_cast
<
int
>
(
info
.
num_kv_heads
);
int
max_seq_len
=
static_cast
<
int
>
(
info
.
max_seq_len
);
int
hidden_dim
=
static_cast
<
int
>
(
info
.
hidden_dim
);
int
seq_len
=
static_cast
<
int
>
(
info
.
seq_len
);
int
total
=
batch_size
*
num_kv_heads
*
seq_len
*
hidden_dim
;
ptrdiff_t
k_cache_strides_0
=
info
.
k_cache_strides_0
;
ptrdiff_t
k_cache_strides_1
=
info
.
k_cache_strides_1
;
ptrdiff_t
k_cache_strides_2
=
info
.
k_cache_strides_2
;
ptrdiff_t
k_cache_strides_3
=
info
.
k_cache_strides_3
;
ptrdiff_t
v_cache_strides_0
=
info
.
v_cache_strides_0
;
ptrdiff_t
v_cache_strides_1
=
info
.
v_cache_strides_1
;
ptrdiff_t
v_cache_strides_2
=
info
.
v_cache_strides_2
;
ptrdiff_t
v_cache_strides_3
=
info
.
v_cache_strides_3
;
ptrdiff_t
k_strides_0
=
info
.
k_strides_0
;
ptrdiff_t
k_strides_1
=
info
.
k_strides_1
;
ptrdiff_t
k_strides_2
=
info
.
k_strides_2
;
ptrdiff_t
k_strides_3
=
info
.
k_strides_3
;
ptrdiff_t
v_strides_0
=
info
.
v_strides_0
;
ptrdiff_t
v_strides_1
=
info
.
v_strides_1
;
ptrdiff_t
v_strides_2
=
info
.
v_strides_2
;
ptrdiff_t
v_strides_3
=
info
.
v_strides_3
;
int
num_blocks
=
(
total
+
BLOCK_SIZE
-
1
)
/
BLOCK_SIZE
;
kvCaching
<
Tdata
>
<<<
num_blocks
,
BLOCK_SIZE
,
0
,
stream
>>>
(
k_cache
,
v_cache
,
k
,
v
,
past_kv_lengths
,
batch_size
,
num_kv_heads
,
max_seq_len
,
seq_len
,
hidden_dim
,
k_cache_strides_0
,
k_cache_strides_1
,
k_cache_strides_2
,
k_cache_strides_3
,
v_cache_strides_0
,
v_cache_strides_1
,
v_cache_strides_2
,
v_cache_strides_3
,
k_strides_0
,
k_strides_1
,
k_strides_2
,
k_strides_3
,
v_strides_0
,
v_strides_1
,
v_strides_2
,
v_strides_3
);
return
INFINI_STATUS_SUCCESS
;
}
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
k_cache
,
void
*
v_cache
,
const
void
*
k
,
const
void
*
v
,
const
void
*
past_kv_lengths
,
void
*
stream_
)
const
{
cudaStream_t
stream
=
(
cudaStream_t
)
stream_
;
#define CALCULATE_KV_CACHING(BLOCK_SIZE, TDATA) \
launchKernel<BLOCK_SIZE, TDATA>(_info, (TDATA *)k_cache, (TDATA *)v_cache, (const TDATA *)k, (const TDATA *)v, (const int64_t *)past_kv_lengths, stream, workspace)
#define CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(BLOCK_SIZE) \
{ \
if (_info.dtype == INFINI_DTYPE_F16) \
return CALCULATE_KV_CACHING(BLOCK_SIZE, half); \
else if (_info.dtype == INFINI_DTYPE_F32) \
return CALCULATE_KV_CACHING(BLOCK_SIZE, float); \
else if (_info.dtype == INFINI_DTYPE_BF16) \
return CALCULATE_KV_CACHING(BLOCK_SIZE, __nv_bfloat16); \
else \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
}
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_1024
)
{
CALCULATE_KV_CACHING_WITH_BLOCK_SIZE
(
CUDA_BLOCK_SIZE_1024
)
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_512
)
{
CALCULATE_KV_CACHING_WITH_BLOCK_SIZE
(
CUDA_BLOCK_SIZE_512
)
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_2048
)
{
CALCULATE_KV_CACHING_WITH_BLOCK_SIZE
(
CUDA_BLOCK_SIZE_2048
)
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_4096
)
{
CALCULATE_KV_CACHING_WITH_BLOCK_SIZE
(
CUDA_BLOCK_SIZE_4096
)
}
else
{
return
INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED
;
}
return
INFINI_STATUS_SUCCESS
;
}
}
// namespace op::kv_caching::nvidia
src/infiniop/ops/kv_caching/nvidia/kv_caching_nvidia.cuh
0 → 100644
View file @
83e28ec3
#ifndef __KV_CACHING_NVIDIA_API_H__
#define __KV_CACHING_NVIDIA_API_H__
#include "../kv_caching.h"
DESCRIPTOR
(
nvidia
)
#endif // __KV_CACHING_NVIDIA_API_H__
src/infiniop/ops/kv_caching/operator.cc
View file @
83e28ec3
...
...
@@ -2,10 +2,11 @@
#include "../../handle.h"
#include "infiniop/ops/kv_caching.h"
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_METAX_API) || defined(ENABLE_MOORE_API)
#include "ninetoothed/kv_caching.h"
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_ALI_API) || defined(ENABLE_HYGON_API)
#include "nvidia/kv_caching_nvidia.cuh"
#endif
#if defined(ENABLE_METAX_API)
#include "metax/kv_caching_metax.h"
#endif
__C
infiniStatus_t
infiniopCreateKVCachingDescriptor
(
...
...
@@ -30,16 +31,23 @@ __C infiniStatus_t infiniopCreateKVCachingDescriptor(
switch
(
handle
->
device
)
{
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
CREATE
(
INFINI_DEVICE_NVIDIA
,
ninetoothed
);
#ifdef ENABLE_NVIDIA_API
CREATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
);
#endif
#if
def
ined(
ENABLE_
ILUVATAR
_API
)
CREATE
(
INFINI_DEVICE_
ILUVATAR
,
ninetoothed
);
#ifdef
ENABLE_
QY
_API
CREATE
(
INFINI_DEVICE_
QY
,
nvidia
);
#endif
#if defined(ENABLE_METAX_API)
CREATE
(
INFINI_DEVICE_METAX
,
ninetoothed
);
#ifdef ENABLE_ILUVATAR_API
CREATE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#ifdef ENABLE_ALI_API
CREATE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
#ifdef ENABLE_HYGON_API
CREATE
(
INFINI_DEVICE_HYGON
,
nvidia
);
#endif
#if defined(ENABLE_METAX_API)
CREATE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
default:
...
...
@@ -61,17 +69,25 @@ __C infiniStatus_t infiniopGetKVCachingWorkspaceSize(
switch
(
desc
->
device_type
)
{
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
GET_SIZE
(
INFINI_DEVICE_NVIDIA
,
ninetoothed
);
#ifdef ENABLE_NVIDIA_API
GET_SIZE
(
INFINI_DEVICE_NVIDIA
,
nvidia
);
#endif
#if
def
ined(
ENABLE_
ILUVATAR
_API
)
GET_SIZE
(
INFINI_DEVICE_
ILUVATAR
,
ninetoothed
);
#ifdef
ENABLE_
QY
_API
GET_SIZE
(
INFINI_DEVICE_
QY
,
nvidia
);
#endif
#if
def
ined(
ENABLE_
ME
TA
X
_API
)
GET_SIZE
(
INFINI_DEVICE_
ME
TA
X
,
n
inetoothed
);
#ifdef
ENABLE_
ILUVA
TA
R
_API
GET_SIZE
(
INFINI_DEVICE_
ILUVA
TA
R
,
n
vidia
);
#endif
#ifdef ENABLE_ALI_API
GET_SIZE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
#ifdef ENABLE_HYGON_API
GET_SIZE
(
INFINI_DEVICE_HYGON
,
nvidia
);
#endif
#if defined(ENABLE_METAX_API)
GET_SIZE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
...
...
@@ -97,17 +113,25 @@ __C infiniStatus_t infiniopKVCaching(
switch
(
desc
->
device_type
)
{
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
ninetoothed
);
#ifdef ENABLE_NVIDIA_API
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
);
#endif
#if
def
ined(
ENABLE_
ILUVATAR
_API
)
CALCULATE
(
INFINI_DEVICE_
ILUVATAR
,
ninetoothed
);
#ifdef
ENABLE_
QY
_API
CALCULATE
(
INFINI_DEVICE_
QY
,
nvidia
);
#endif
#if
def
ined(
ENABLE_
ME
TA
X
_API
)
CALCULATE
(
INFINI_DEVICE_
ME
TA
X
,
n
inetoothed
);
#ifdef
ENABLE_
ILUVA
TA
R
_API
CALCULATE
(
INFINI_DEVICE_
ILUVA
TA
R
,
n
vidia
);
#endif
#ifdef ENABLE_ALI_API
CALCULATE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
#ifdef ENABLE_HYGON_API
CALCULATE
(
INFINI_DEVICE_HYGON
,
nvidia
);
#endif
#if defined(ENABLE_METAX_API)
CALCULATE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
...
...
@@ -125,19 +149,28 @@ __C infiniStatus_t infiniopDestroyKVCachingDescriptor(
switch
(
desc
->
device_type
)
{
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
DELETE
(
INFINI_DEVICE_NVIDIA
,
ninetoothed
);
#ifdef ENABLE_NVIDIA_API
DELETE
(
INFINI_DEVICE_NVIDIA
,
nvidia
);
#endif
#if
def
ined(
ENABLE_
ILUVATAR
_API
)
DELETE
(
INFINI_DEVICE_
ILUVATAR
,
ninetoothed
);
#ifdef
ENABLE_
QY
_API
DELETE
(
INFINI_DEVICE_
QY
,
nvidia
);
#endif
#if
def
ined(
ENABLE_
ME
TA
X
_API
)
DELETE
(
INFINI_DEVICE_
ME
TA
X
,
n
inetoothed
);
#ifdef
ENABLE_
ILUVA
TA
R
_API
DELETE
(
INFINI_DEVICE_
ILUVA
TA
R
,
n
vidia
);
#endif
#ifdef ENABLE_ALI_API
DELETE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
#ifdef ENABLE_HYGON_API
DELETE
(
INFINI_DEVICE_HYGON
,
nvidia
);
#endif
#if defined(ENABLE_METAX_API)
DELETE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef DELETE
}
test/infiniop/kv_caching.py
0 → 100644
View file @
83e28ec3
import
torch
import
ctypes
from
ctypes
import
c_uint64
from
libinfiniop
import
(
LIBINFINIOP
,
TestTensor
,
get_test_devices
,
check_error
,
test_operator
,
get_args
,
debug
,
get_tolerance
,
profile_operation
,
InfiniDtype
,
InfiniDtypeNames
,
InfiniDeviceNames
,
infiniopOperatorDescriptor_t
,
TestWorkspace
,
)
# ==============================================================================
# Reference Implementation
# ==============================================================================
def
torch_kv_caching
(
k_cache
,
v_cache
,
k
,
v
,
past_kv_lengths
):
#k_cache.shape=[batch_size, num_kv_heads, max_seq_len, hidden_dim]
#v_cache.shape=[batch_size, num_kv_heads, max_seq_len, hidden_dim]
#k.shape=[batch_size, num_kv_heads, seq_len, hidden_dim]
#v.shape=[batch_size, num_kv_heads, seq_len, hidden_dim]
#past_kv_lengths.shape = [batch_size]
batch_size
,
num_kv_heads
,
_
,
head_dim
=
k_cache
.
shape
seq_len
=
k
.
shape
[
2
]
for
b
in
range
(
batch_size
):
past_len
=
past_kv_lengths
[
b
].
item
()
for
h
in
range
(
num_kv_heads
):
k_cache
[
b
,
h
,
past_len
:
past_len
+
seq_len
,
:]
=
k
[
b
,
h
,
:,
:]
v_cache
[
b
,
h
,
past_len
:
past_len
+
seq_len
,
:]
=
v
[
b
,
h
,
:,
:]
return
k_cache
,
v_cache
# ==============================================================================
# Test Configuration (Internal Use Only)
# ==============================================================================
_TEST_CASES_
=
[
# (num_seqs, num_kv_heads, max_seq_len, hidden_dim), strides
((
1
,
1
,
8
,
1
),
None
),
((
1
,
8
,
32
,
32
),
None
),
((
8
,
8
,
64
,
32
),
None
),
((
1
,
32
,
8
,
64
),
(
32768
,
1024
,
64
,
1
)),
((
4
,
8
,
32
,
16
),
(
65536
,
8192
,
256
,
16
)),
((
8
,
16
,
64
,
128
),
(
8388608
,
524288
,
8192
,
1
)),
((
1
,
2
,
2304
,
128
),
(
589824
,
294912
,
128
,
1
)),
]
# Data types for testing
_TENSOR_DTYPES
=
[
InfiniDtype
.
BF16
,
InfiniDtype
.
F16
,
InfiniDtype
.
F32
]
# Tolerance map for different data types
_TOLERANCE_MAP
=
{
InfiniDtype
.
F16
:
{
"atol"
:
0
,
"rtol"
:
0
},
InfiniDtype
.
BF16
:
{
"atol"
:
0
,
"rtol"
:
0
},
InfiniDtype
.
F32
:
{
"atol"
:
0
,
"rtol"
:
0
},
}
# Global flags for controlling test behavior
DEBUG
=
False
PROFILE
=
False
NUM_PRERUN
=
10
NUM_ITERATIONS
=
100
def
test
(
handle
,
device
,
cache_shape
,
strides
,
dtype
=
InfiniDtype
.
F16
,
sync
=
None
,
):
print
(
f
"Testing KVCaching on
{
InfiniDeviceNames
[
device
]
}
with cache_shape:
{
cache_shape
}
, strides:
{
strides
}
, dtype=
{
InfiniDtypeNames
[
dtype
]
}
"
)
import
random
kv_shape
=
(
cache_shape
[
0
],
cache_shape
[
1
],
random
.
randrange
(
1
,
cache_shape
[
2
]),
cache_shape
[
3
],
)
past_shape
=
(
cache_shape
[
0
],)
k_cache
=
TestTensor
(
cache_shape
,
strides
,
dtype
,
device
)
v_cache
=
TestTensor
(
cache_shape
,
strides
,
dtype
,
device
)
k
=
TestTensor
(
kv_shape
,
None
,
dtype
,
device
)
v
=
TestTensor
(
kv_shape
,
None
,
dtype
,
device
)
past_kv_lengths
=
TestTensor
(
past_shape
,
None
,
InfiniDtype
.
I64
,
device
,
randint_low
=
0
,
randint_high
=
cache_shape
[
2
]
-
kv_shape
[
2
])
# Run reference implementation
k_cache_ref
,
v_cache_ref
=
torch_kv_caching
(
k_cache
.
torch_tensor
(),
v_cache
.
torch_tensor
(),
k
.
torch_tensor
(),
v
.
torch_tensor
(),
past_kv_lengths
.
torch_tensor
())
if
sync
:
sync
()
# Create operator descriptor
descriptor
=
infiniopOperatorDescriptor_t
()
check_error
(
LIBINFINIOP
.
infiniopCreateKVCachingDescriptor
(
handle
,
ctypes
.
byref
(
descriptor
),
k_cache
.
descriptor
,
v_cache
.
descriptor
,
k
.
descriptor
,
v
.
descriptor
,
past_kv_lengths
.
descriptor
,
)
)
# Get workspace size (likely 0 for this operator, but good practice to include)
workspace_size
=
c_uint64
(
0
)
check_error
(
LIBINFINIOP
.
infiniopGetKVCachingWorkspaceSize
(
descriptor
,
ctypes
.
byref
(
workspace_size
)
)
)
workspace
=
TestWorkspace
(
workspace_size
.
value
,
device
)
# Invalidate descriptors to ensure kernel does not rely on them
k
.
destroy_desc
()
v
.
destroy_desc
()
k_cache
.
destroy_desc
()
v_cache
.
destroy_desc
()
past_kv_lengths
.
destroy_desc
()
# Define the library call as a lambda for profiling
def
lib_kv_caching
():
check_error
(
LIBINFINIOP
.
infiniopKVCaching
(
descriptor
,
workspace
.
data
(),
workspace_size
.
value
,
k_cache
.
data
(),
v_cache
.
data
(),
k
.
data
(),
v
.
data
(),
past_kv_lengths
.
data
(),
None
,
)
)
# Execute the custom operator
lib_kv_caching
()
if
sync
:
sync
()
# Verify correctness
atol
,
rtol
=
get_tolerance
(
_TOLERANCE_MAP
,
dtype
)
if
DEBUG
:
print
(
"Verifying K cache..."
)
debug
(
k_cache
.
actual_tensor
(),
k_cache_ref
,
atol
=
atol
,
rtol
=
rtol
)
print
(
"Verifying V cache..."
)
debug
(
v_cache
.
actual_tensor
(),
v_cache_ref
,
atol
=
atol
,
rtol
=
rtol
)
assert
torch
.
allclose
(
k_cache
.
actual_tensor
(),
k_cache_ref
,
atol
=
atol
,
rtol
=
rtol
)
assert
torch
.
allclose
(
v_cache
.
actual_tensor
(),
v_cache_ref
,
atol
=
atol
,
rtol
=
rtol
)
# Profiling workflow
if
PROFILE
:
# fmt: off
profile_operation
(
"PyTorch"
,
lambda
:
torch_kv_caching
(
k_cache
.
torch_tensor
(),
v_cache
.
torch_tensor
(),
k
.
torch_tensor
(),
v
.
torch_tensor
(),
past_kv_lengths
.
torch_tensor
()),
device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
profile_operation
(
" lib"
,
lib_kv_caching
,
device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
# fmt: on
# Clean up resources
check_error
(
LIBINFINIOP
.
infiniopDestroyKVCachingDescriptor
(
descriptor
))
if
__name__
==
"__main__"
:
args
=
get_args
()
# Configure testing options from command line arguments
DEBUG
=
args
.
debug
PROFILE
=
args
.
profile
NUM_PRERUN
=
args
.
num_prerun
NUM_ITERATIONS
=
args
.
num_iterations
for
device
in
get_test_devices
(
args
):
test_operator
(
device
,
test
,
_TEST_CASES_
,
_TENSOR_DTYPES
)
print
(
"
\033
[92mTest passed!
\033
[0m"
)
test/infiniop/libinfiniop/op_register.py
View file @
83e28ec3
...
...
@@ -1054,6 +1054,48 @@ def scaled_mm_int8_(lib):
]
@
OpRegister
.
operator
def
kv_caching_
(
lib
):
lib
.
infiniopCreateKVCachingDescriptor
.
restype
=
c_int32
lib
.
infiniopCreateKVCachingDescriptor
.
argtypes
=
[
infiniopHandle_t
,
POINTER
(
infiniopOperatorDescriptor_t
),
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
]
lib
.
infiniopGetKVCachingWorkspaceSize
.
restype
=
c_int32
lib
.
infiniopGetKVCachingWorkspaceSize
.
argtypes
=
[
infiniopOperatorDescriptor_t
,
POINTER
(
c_size_t
),
]
lib
.
infiniopKVCaching
.
restype
=
c_int32
lib
.
infiniopKVCaching
.
argtypes
=
[
infiniopOperatorDescriptor_t
,
c_void_p
,
c_size_t
,
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
]
lib
.
infiniopDestroyKVCachingDescriptor
.
restype
=
c_int32
lib
.
infiniopDestroyKVCachingDescriptor
.
argtypes
=
[
infiniopOperatorDescriptor_t
,
]
@
OpRegister
.
operator
def
paged_attention_
(
lib
):
lib
.
infiniopCreatePagedAttentionDescriptor
.
restype
=
c_int32
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment