Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
af394a32
Commit
af394a32
authored
Mar 03, 2026
by
wooway777
Browse files
issue/1035 - kv caching on ali, ilu, hygon, qy, metax
parent
c70805c9
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
227 additions
and
58 deletions
+227
-58
src/infiniop/ops/kv_caching/cuda/kernel.cuh
src/infiniop/ops/kv_caching/cuda/kernel.cuh
+3
-3
src/infiniop/ops/kv_caching/info.h
src/infiniop/ops/kv_caching/info.h
+1
-0
src/infiniop/ops/kv_caching/metax/kv_caching_metax.h
src/infiniop/ops/kv_caching/metax/kv_caching_metax.h
+7
-0
src/infiniop/ops/kv_caching/metax/kv_caching_metax.maca
src/infiniop/ops/kv_caching/metax/kv_caching_metax.maca
+160
-0
src/infiniop/ops/kv_caching/operator.cc
src/infiniop/ops/kv_caching/operator.cc
+56
-55
No files found.
src/infiniop/ops/kv_caching/cuda/kernel.cuh
View file @
af394a32
...
...
@@ -29,14 +29,14 @@ __device__ void kvCachingKernel(
ptrdiff_t
v_strides_1
,
ptrdiff_t
v_strides_2
,
ptrdiff_t
v_strides_3
)
{
//
总元素数
= B * H * seq_len * D
//
num of ele
= B * H * seq_len * D
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
total
=
batch_size
*
num_kv_heads
*
seq_len
*
hidden_dim
;
const
int
grid_size
=
blockDim
.
x
*
gridDim
.
x
;
for
(
int
idx
=
tid
;
idx
<
total
;
idx
+=
grid_size
)
{
//
反解
index
//
unravel
index
int
d
=
idx
%
hidden_dim
;
idx
/=
hidden_dim
;
...
...
@@ -48,7 +48,7 @@ __device__ void kvCachingKernel(
int
b
=
idx
/
num_kv_heads
;
int
past_len
=
static_cast
<
int32_t
>
(
past_kv_lengths
[
b
]);
//
写入位置
//
write position
int
cache_s
=
past_len
+
s
;
int
k_cache_offset
=
d
*
(
int
)
k_cache_strides_3
+
cache_s
*
(
int
)
k_cache_strides_2
+
h
*
(
int
)
k_cache_strides_1
+
b
*
(
int
)
k_cache_strides_0
;
int
v_cache_offset
=
d
*
(
int
)
v_cache_strides_3
+
cache_s
*
(
int
)
v_cache_strides_2
+
h
*
(
int
)
v_cache_strides_1
+
b
*
(
int
)
v_cache_strides_0
;
...
...
src/infiniop/ops/kv_caching/info.h
View file @
af394a32
...
...
@@ -32,6 +32,7 @@ public:
const
infiniDtype_t
dtype
=
k_cache
->
dtype
();
CHECK_DTYPE
(
dtype
,
INFINI_DTYPE_F16
,
INFINI_DTYPE_BF16
,
INFINI_DTYPE_F32
);
CHECK_DTYPE
(
past_kv_lengths
->
dtype
(),
INFINI_DTYPE_I64
);
CHECK_OR_RETURN
(
k_cache
->
ndim
()
==
4
&&
v_cache
->
ndim
()
==
4
...
...
src/infiniop/ops/kv_caching/metax/kv_caching_metax.h
0 → 100644
View file @
af394a32
#ifndef __KV_CACHING_METAX_API_H__
#define __KV_CACHING_METAX_API_H__
#include "../kv_caching.h"
DESCRIPTOR
(
metax
)
#endif // __KV_CACHING_METAX_API_H__
src/infiniop/ops/kv_caching/metax/kv_caching_metax.maca
0 → 100644
View file @
af394a32
#include "../../../devices/metax/metax_common.h"
#include "kv_caching_metax.h"
#include "../../../devices/metax/metax_kernel_common.h"
#include <cub/block/block_reduce.cuh>
#include "../../../reduce/cuda/reduce.cuh"
#include "../cuda/kernel.cuh"
template <typename Tdata>
INFINIOP_METAX_KERNEL kvCaching(
Tdata *k_cache,
Tdata *v_cache,
const Tdata *k,
const Tdata *v,
const int64_t *past_kv_lengths,
int batch_size,
int num_kv_heads,
int max_seq_len,
int seq_len,
int hidden_dim,
ptrdiff_t k_cache_strides_0,
ptrdiff_t k_cache_strides_1,
ptrdiff_t k_cache_strides_2,
ptrdiff_t k_cache_strides_3,
ptrdiff_t v_cache_strides_0,
ptrdiff_t v_cache_strides_1,
ptrdiff_t v_cache_strides_2,
ptrdiff_t v_cache_strides_3,
ptrdiff_t k_strides_0,
ptrdiff_t k_strides_1,
ptrdiff_t k_strides_2,
ptrdiff_t k_strides_3,
ptrdiff_t v_strides_0,
ptrdiff_t v_strides_1,
ptrdiff_t v_strides_2,
ptrdiff_t v_strides_3) {
kvCachingKernel<Tdata>(k_cache, v_cache, k, v, past_kv_lengths,
batch_size, num_kv_heads, max_seq_len, seq_len, hidden_dim,
k_cache_strides_0, k_cache_strides_1, k_cache_strides_2, k_cache_strides_3,
v_cache_strides_0, v_cache_strides_1, v_cache_strides_2, v_cache_strides_3,
k_strides_0, k_strides_1, k_strides_2, k_strides_3,
v_strides_0, v_strides_1, v_strides_2, v_strides_3);
}
namespace op::kv_caching::metax {
struct Descriptor::Opaque {
std::shared_ptr<device::metax::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t k_cache,
infiniopTensorDescriptor_t v_cache,
infiniopTensorDescriptor_t k,
infiniopTensorDescriptor_t v,
infiniopTensorDescriptor_t past_kv_lengths) {
auto info = KVCachingInfo::createKVCachingInfo(k_cache, v_cache, k, v, past_kv_lengths);
CHECK_RESULT(info);
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
info.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <unsigned int BLOCK_SIZE, typename Tdata>
infiniStatus_t launchKernel(const KVCachingInfo &info,
Tdata *k_cache,
Tdata *v_cache,
const Tdata *k,
const Tdata *v,
const int64_t *past_kv_lengths,
hcStream_t stream, void *workspace) {
int batch_size = static_cast<int>(info.batch_size);
int num_kv_heads = static_cast<int>(info.num_kv_heads);
int max_seq_len = static_cast<int>(info.max_seq_len);
int hidden_dim = static_cast<int>(info.hidden_dim);
int seq_len = static_cast<int>(info.seq_len);
int total = batch_size * num_kv_heads * seq_len * hidden_dim;
ptrdiff_t k_cache_strides_0 = info.k_cache_strides_0;
ptrdiff_t k_cache_strides_1 = info.k_cache_strides_1;
ptrdiff_t k_cache_strides_2 = info.k_cache_strides_2;
ptrdiff_t k_cache_strides_3 = info.k_cache_strides_3;
ptrdiff_t v_cache_strides_0 = info.v_cache_strides_0;
ptrdiff_t v_cache_strides_1 = info.v_cache_strides_1;
ptrdiff_t v_cache_strides_2 = info.v_cache_strides_2;
ptrdiff_t v_cache_strides_3 = info.v_cache_strides_3;
ptrdiff_t k_strides_0 = info.k_strides_0;
ptrdiff_t k_strides_1 = info.k_strides_1;
ptrdiff_t k_strides_2 = info.k_strides_2;
ptrdiff_t k_strides_3 = info.k_strides_3;
ptrdiff_t v_strides_0 = info.v_strides_0;
ptrdiff_t v_strides_1 = info.v_strides_1;
ptrdiff_t v_strides_2 = info.v_strides_2;
ptrdiff_t v_strides_3 = info.v_strides_3;
int num_blocks = (total + BLOCK_SIZE - 1) / BLOCK_SIZE;
kvCaching<Tdata>
<<<num_blocks, BLOCK_SIZE, 0, stream>>>(k_cache, v_cache, k, v, past_kv_lengths,
batch_size, num_kv_heads, max_seq_len, seq_len, hidden_dim,
k_cache_strides_0, k_cache_strides_1, k_cache_strides_2, k_cache_strides_3,
v_cache_strides_0, v_cache_strides_1, v_cache_strides_2, v_cache_strides_3,
k_strides_0, k_strides_1, k_strides_2, k_strides_3,
v_strides_0, v_strides_1, v_strides_2, v_strides_3);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
void *k_cache,
void *v_cache,
const void *k,
const void *v,
const void *past_kv_lengths,
void *stream_) const {
hcStream_t stream = (hcStream_t)stream_;
#define CALCULATE_KV_CACHING(BLOCK_SIZE, TDATA) \
launchKernel<BLOCK_SIZE, TDATA>(_info, (TDATA *)k_cache, (TDATA *)v_cache, (const TDATA *)k, (const TDATA *)v, (const int64_t *)past_kv_lengths, stream, workspace)
#define CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(BLOCK_SIZE) \
{ \
if (_info.dtype == INFINI_DTYPE_F16) \
return CALCULATE_KV_CACHING(BLOCK_SIZE, half); \
else if (_info.dtype == INFINI_DTYPE_F32) \
return CALCULATE_KV_CACHING(BLOCK_SIZE, float); \
else if (_info.dtype == INFINI_DTYPE_BF16) \
return CALCULATE_KV_CACHING(BLOCK_SIZE, __hpcc_bfloat16); \
else \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
}
if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_1024) {
CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(METAX_BLOCK_SIZE_1024)
} else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_512) {
CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(METAX_BLOCK_SIZE_512)
} else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_2048) {
CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(METAX_BLOCK_SIZE_2048)
} else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_4096) {
CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(METAX_BLOCK_SIZE_4096)
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::kv_caching::metax
src/infiniop/ops/kv_caching/operator.cc
View file @
af394a32
...
...
@@ -2,15 +2,12 @@
#include "../../handle.h"
#include "infiniop/ops/kv_caching.h"
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_METAX_API) || defined(ENABLE_MOORE_API)
#include "ninetoothed/kv_caching.h"
#endif
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_ALI_API) || defined(ENABLE_HYGON_API)
#include "nvidia/kv_caching_nvidia.cuh"
#endif
#if defined(ENABLE_METAX_API)
#include "metax/kv_caching_metax.h"
#endif
__C
infiniStatus_t
infiniopCreateKVCachingDescriptor
(
infiniopHandle_t
handle
,
...
...
@@ -34,24 +31,24 @@ __C infiniStatus_t infiniopCreateKVCachingDescriptor(
switch
(
handle
->
device
)
{
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
CREATE
(
INFINI_DEVICE_NVIDIA
,
ninetoothed
);
#endif
#if defined(ENABLE_ILUVATAR_API)
CREATE
(
INFINI_DEVICE_ILUVATAR
,
ninetoothed
);
#endif
#if defined(ENABLE_METAX_API)
CREATE
(
INFINI_DEVICE_METAX
,
ninetoothed
);
#endif
#endif
#ifdef ENABLE_NVIDIA_API
CREATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
);
#endif
#ifdef ENABLE_QY_API
CREATE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
#ifdef ENABLE_ILUVATAR_API
CREATE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#ifdef ENABLE_ALI_API
CREATE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
#ifdef ENABLE_HYGON_API
CREATE
(
INFINI_DEVICE_HYGON
,
nvidia
);
#endif
#if defined(ENABLE_METAX_API)
CREATE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
@@ -72,24 +69,25 @@ __C infiniStatus_t infiniopGetKVCachingWorkspaceSize(
switch
(
desc
->
device_type
)
{
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
GET_SIZE
(
INFINI_DEVICE_NVIDIA
,
ninetoothed
);
#endif
#if defined(ENABLE_ILUVATAR_API)
GET_SIZE
(
INFINI_DEVICE_ILUVATAR
,
ninetoothed
);
#endif
#if defined(ENABLE_METAX_API)
GET_SIZE
(
INFINI_DEVICE_METAX
,
ninetoothed
);
#endif
#endif
#ifdef ENABLE_NVIDIA_API
GET_SIZE
(
INFINI_DEVICE_NVIDIA
,
nvidia
);
#endif
#ifdef ENABLE_QY_API
GET_SIZE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
#ifdef ENABLE_ILUVATAR_API
GET_SIZE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#ifdef ENABLE_ALI_API
GET_SIZE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
#ifdef ENABLE_HYGON_API
GET_SIZE
(
INFINI_DEVICE_HYGON
,
nvidia
);
#endif
#if defined(ENABLE_METAX_API)
GET_SIZE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
...
...
@@ -115,24 +113,25 @@ __C infiniStatus_t infiniopKVCaching(
switch
(
desc
->
device_type
)
{
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
ninetoothed
);
#endif
#if defined(ENABLE_ILUVATAR_API)
CALCULATE
(
INFINI_DEVICE_ILUVATAR
,
ninetoothed
);
#endif
#if defined(ENABLE_METAX_API)
CALCULATE
(
INFINI_DEVICE_METAX
,
ninetoothed
);
#endif
#endif
#ifdef ENABLE_NVIDIA_API
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
);
#endif
#ifdef ENABLE_QY_API
CALCULATE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
#ifdef ENABLE_ILUVATAR_API
CALCULATE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#ifdef ENABLE_ALI_API
CALCULATE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
#ifdef ENABLE_HYGON_API
CALCULATE
(
INFINI_DEVICE_HYGON
,
nvidia
);
#endif
#if defined(ENABLE_METAX_API)
CALCULATE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
...
...
@@ -150,26 +149,28 @@ __C infiniStatus_t infiniopDestroyKVCachingDescriptor(
switch
(
desc
->
device_type
)
{
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
DELETE
(
INFINI_DEVICE_NVIDIA
,
ninetoothed
);
#endif
#if defined(ENABLE_ILUVATAR_API)
DELETE
(
INFINI_DEVICE_ILUVATAR
,
ninetoothed
);
#endif
#if defined(ENABLE_METAX_API)
DELETE
(
INFINI_DEVICE_METAX
,
ninetoothed
);
#endif
#endif
#ifdef ENABLE_NVIDIA_API
DELETE
(
INFINI_DEVICE_NVIDIA
,
nvidia
);
#endif
#ifdef ENABLE_QY_API
DELETE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
#ifdef ENABLE_ILUVATAR_API
DELETE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#ifdef ENABLE_ALI_API
DELETE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
#ifdef ENABLE_HYGON_API
DELETE
(
INFINI_DEVICE_HYGON
,
nvidia
);
#endif
#if defined(ENABLE_METAX_API)
DELETE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef DELETE
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment