Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
a0cbae66
Commit
a0cbae66
authored
Mar 05, 2026
by
wooway777
Browse files
issue/1035 - support both int32 and int64 in kv caching
parent
a9503148
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
114 additions
and
83 deletions
+114
-83
src/infiniop/ops/kv_caching/cuda/kernel.cuh
src/infiniop/ops/kv_caching/cuda/kernel.cuh
+3
-3
src/infiniop/ops/kv_caching/info.h
src/infiniop/ops/kv_caching/info.h
+4
-1
src/infiniop/ops/kv_caching/metax/kv_caching_metax.maca
src/infiniop/ops/kv_caching/metax/kv_caching_metax.maca
+41
-28
src/infiniop/ops/kv_caching/nvidia/kv_caching_nvidia.cu
src/infiniop/ops/kv_caching/nvidia/kv_caching_nvidia.cu
+40
-27
test/infinicore/ops/kv_caching.py
test/infinicore/ops/kv_caching.py
+26
-24
No files found.
src/infiniop/ops/kv_caching/cuda/kernel.cuh
View file @
a0cbae66
#ifndef __KV_CACHING_KERNEL_CUH__
#ifndef __KV_CACHING_KERNEL_CUH__
#define __KV_CACHING_KERNEL_CUH__
#define __KV_CACHING_KERNEL_CUH__
template
<
typename
Tdata
>
template
<
typename
Tdata
,
typename
Tidx
>
__device__
void
kvCachingKernel
(
__device__
void
kvCachingKernel
(
Tdata
*
__restrict__
k_cache
,
Tdata
*
__restrict__
k_cache
,
Tdata
*
__restrict__
v_cache
,
Tdata
*
__restrict__
v_cache
,
const
Tdata
*
__restrict__
k
,
const
Tdata
*
__restrict__
k
,
const
Tdata
*
__restrict__
v
,
const
Tdata
*
__restrict__
v
,
const
int64_t
*
__restrict__
past_kv_lengths
,
const
Tidx
*
__restrict__
past_kv_lengths
,
int
batch_size
,
int
batch_size
,
int
num_kv_heads
,
int
num_kv_heads
,
int
max_seq_len
,
int
max_seq_len
,
...
@@ -47,7 +47,7 @@ __device__ void kvCachingKernel(
...
@@ -47,7 +47,7 @@ __device__ void kvCachingKernel(
int
h
=
idx
%
num_kv_heads
;
int
h
=
idx
%
num_kv_heads
;
int
b
=
idx
/
num_kv_heads
;
int
b
=
idx
/
num_kv_heads
;
int
past_len
=
static_cast
<
int
32_t
>
(
past_kv_lengths
[
b
]);
int
past_len
=
static_cast
<
int
>
(
past_kv_lengths
[
b
]);
// Cast to int for both types
// write position
// write position
int
cache_s
=
past_len
+
s
;
int
cache_s
=
past_len
+
s
;
int
k_cache_offset
=
d
*
(
int
)
k_cache_strides_3
+
cache_s
*
(
int
)
k_cache_strides_2
+
h
*
(
int
)
k_cache_strides_1
+
b
*
(
int
)
k_cache_strides_0
;
int
k_cache_offset
=
d
*
(
int
)
k_cache_strides_3
+
cache_s
*
(
int
)
k_cache_strides_2
+
h
*
(
int
)
k_cache_strides_1
+
b
*
(
int
)
k_cache_strides_0
;
...
...
src/infiniop/ops/kv_caching/info.h
View file @
a0cbae66
...
@@ -13,6 +13,7 @@ private:
...
@@ -13,6 +13,7 @@ private:
public:
public:
infiniDtype_t
dtype
;
infiniDtype_t
dtype
;
infiniDtype_t
past_len_dtype
;
size_t
batch_size
,
num_kv_heads
,
max_seq_len
,
seq_len
,
hidden_dim
;
size_t
batch_size
,
num_kv_heads
,
max_seq_len
,
seq_len
,
hidden_dim
;
ptrdiff_t
k_cache_strides_0
,
k_cache_strides_1
,
k_cache_strides_2
,
k_cache_strides_3
;
ptrdiff_t
k_cache_strides_0
,
k_cache_strides_1
,
k_cache_strides_2
,
k_cache_strides_3
;
ptrdiff_t
v_cache_strides_0
,
v_cache_strides_1
,
v_cache_strides_2
,
v_cache_strides_3
;
ptrdiff_t
v_cache_strides_0
,
v_cache_strides_1
,
v_cache_strides_2
,
v_cache_strides_3
;
...
@@ -32,7 +33,8 @@ public:
...
@@ -32,7 +33,8 @@ public:
const
infiniDtype_t
dtype
=
k_cache
->
dtype
();
const
infiniDtype_t
dtype
=
k_cache
->
dtype
();
CHECK_DTYPE
(
dtype
,
INFINI_DTYPE_F16
,
INFINI_DTYPE_BF16
,
INFINI_DTYPE_F32
);
CHECK_DTYPE
(
dtype
,
INFINI_DTYPE_F16
,
INFINI_DTYPE_BF16
,
INFINI_DTYPE_F32
);
CHECK_DTYPE
(
past_kv_lengths
->
dtype
(),
INFINI_DTYPE_I64
);
const
infiniDtype_t
past_len_dtype
=
past_kv_lengths
->
dtype
();
CHECK_DTYPE
(
past_len_dtype
,
INFINI_DTYPE_I32
,
INFINI_DTYPE_I64
);
CHECK_OR_RETURN
(
k_cache
->
ndim
()
==
4
CHECK_OR_RETURN
(
k_cache
->
ndim
()
==
4
&&
v_cache
->
ndim
()
==
4
&&
v_cache
->
ndim
()
==
4
...
@@ -78,6 +80,7 @@ public:
...
@@ -78,6 +80,7 @@ public:
return
utils
::
Result
<
KVCachingInfo
>
(
KVCachingInfo
{
return
utils
::
Result
<
KVCachingInfo
>
(
KVCachingInfo
{
dtype
,
dtype
,
past_len_dtype
,
batch_size
,
batch_size
,
num_kv_heads
,
num_kv_heads
,
max_seq_len
,
max_seq_len
,
...
...
src/infiniop/ops/kv_caching/metax/kv_caching_metax.maca
View file @
a0cbae66
...
@@ -8,13 +8,13 @@
...
@@ -8,13 +8,13 @@
#include "../cuda/kernel.cuh"
#include "../cuda/kernel.cuh"
template <typename Tdata>
template <typename Tdata
, typename Tidx
>
INFINIOP_METAX_KERNEL kvCaching(
INFINIOP_METAX_KERNEL kvCaching(
Tdata *k_cache,
Tdata *k_cache,
Tdata *v_cache,
Tdata *v_cache,
const Tdata *k,
const Tdata *k,
const Tdata *v,
const Tdata *v,
const
int64_t
*past_kv_lengths,
const
Tidx
*past_kv_lengths,
int batch_size,
int batch_size,
int num_kv_heads,
int num_kv_heads,
int max_seq_len,
int max_seq_len,
...
@@ -36,12 +36,12 @@ INFINIOP_METAX_KERNEL kvCaching(
...
@@ -36,12 +36,12 @@ INFINIOP_METAX_KERNEL kvCaching(
ptrdiff_t v_strides_1,
ptrdiff_t v_strides_1,
ptrdiff_t v_strides_2,
ptrdiff_t v_strides_2,
ptrdiff_t v_strides_3) {
ptrdiff_t v_strides_3) {
kvCachingKernel<Tdata>(k_cache, v_cache, k, v, past_kv_lengths,
kvCachingKernel<Tdata
, Tidx
>(k_cache, v_cache, k, v, past_kv_lengths,
batch_size, num_kv_heads, max_seq_len, seq_len, hidden_dim,
batch_size, num_kv_heads, max_seq_len, seq_len, hidden_dim,
k_cache_strides_0, k_cache_strides_1, k_cache_strides_2, k_cache_strides_3,
k_cache_strides_0, k_cache_strides_1, k_cache_strides_2, k_cache_strides_3,
v_cache_strides_0, v_cache_strides_1, v_cache_strides_2, v_cache_strides_3,
v_cache_strides_0, v_cache_strides_1, v_cache_strides_2, v_cache_strides_3,
k_strides_0, k_strides_1, k_strides_2, k_strides_3,
k_strides_0, k_strides_1, k_strides_2, k_strides_3,
v_strides_0, v_strides_1, v_strides_2, v_strides_3);
v_strides_0, v_strides_1, v_strides_2, v_strides_3);
}
}
namespace op::kv_caching::metax {
namespace op::kv_caching::metax {
...
@@ -71,13 +71,13 @@ infiniStatus_t Descriptor::create(
...
@@ -71,13 +71,13 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_SUCCESS;
return INFINI_STATUS_SUCCESS;
}
}
template <unsigned int BLOCK_SIZE, typename Tdata>
template <unsigned int BLOCK_SIZE, typename Tdata
, typename Tidx
>
infiniStatus_t launchKernel(const KVCachingInfo &info,
infiniStatus_t launchKernel(const KVCachingInfo &info,
Tdata *k_cache,
Tdata *k_cache,
Tdata *v_cache,
Tdata *v_cache,
const Tdata *k,
const Tdata *k,
const Tdata *v,
const Tdata *v,
const
int64_t
*past_kv_lengths,
const
Tidx
*past_kv_lengths,
hcStream_t stream, void *workspace) {
hcStream_t stream, void *workspace) {
int batch_size = static_cast<int>(info.batch_size);
int batch_size = static_cast<int>(info.batch_size);
...
@@ -111,7 +111,7 @@ infiniStatus_t launchKernel(const KVCachingInfo &info,
...
@@ -111,7 +111,7 @@ infiniStatus_t launchKernel(const KVCachingInfo &info,
int num_blocks = (total + BLOCK_SIZE - 1) / BLOCK_SIZE;
int num_blocks = (total + BLOCK_SIZE - 1) / BLOCK_SIZE;
kvCaching<Tdata>
kvCaching<Tdata
, Tidx
>
<<<num_blocks, BLOCK_SIZE, 0, stream>>>(k_cache, v_cache, k, v, past_kv_lengths,
<<<num_blocks, BLOCK_SIZE, 0, stream>>>(k_cache, v_cache, k, v, past_kv_lengths,
batch_size, num_kv_heads, max_seq_len, seq_len, hidden_dim,
batch_size, num_kv_heads, max_seq_len, seq_len, hidden_dim,
k_cache_strides_0, k_cache_strides_1, k_cache_strides_2, k_cache_strides_3,
k_cache_strides_0, k_cache_strides_1, k_cache_strides_2, k_cache_strides_3,
...
@@ -129,28 +129,41 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
...
@@ -129,28 +129,41 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
const void *past_kv_lengths,
const void *past_kv_lengths,
void *stream_) const {
void *stream_) const {
hcStream_t stream = (hcStream_t)stream_;
hcStream_t stream = (hcStream_t)stream_;
#define CALCULATE_KV_CACHING(BLOCK_SIZE, TDATA) \
launchKernel<BLOCK_SIZE, TDATA>(_info, (TDATA *)k_cache, (TDATA *)v_cache, (const TDATA *)k, (const TDATA *)v, (const int64_t *)past_kv_lengths, stream, workspace)
#define LAUNCH_KERNEL_WITH_BLOCK_SIZE_AND_DTYPE(BLOCK_SIZE, TDATA, TIDX) \
#define CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(BLOCK_SIZE) \
launchKernel<BLOCK_SIZE, TDATA, TIDX>(_info, (TDATA *)k_cache, (TDATA *)v_cache, \
{ \
(const TDATA *)k, (const TDATA *)v, \
if (_info.dtype == INFINI_DTYPE_F16) \
(const TIDX *)past_kv_lengths, stream, workspace)
return CALCULATE_KV_CACHING(BLOCK_SIZE, half); \
else if (_info.dtype == INFINI_DTYPE_F32) \
#define LAUNCH_KERNEL_WITH_BLOCK_SIZE(BLOCK_SIZE, TDATA) \
return CALCULATE_KV_CACHING(BLOCK_SIZE, float); \
if (_info.past_len_dtype == INFINI_DTYPE_I32) { \
else if (_info.dtype == INFINI_DTYPE_BF16) \
return LAUNCH_KERNEL_WITH_BLOCK_SIZE_AND_DTYPE(BLOCK_SIZE, TDATA, int32_t); \
return CALCULATE_KV_CACHING(BLOCK_SIZE, __hpcc_bfloat16); \
} else { /* INFINI_DTYPE_I64 */ \
else \
return LAUNCH_KERNEL_WITH_BLOCK_SIZE_AND_DTYPE(BLOCK_SIZE, TDATA, int64_t); \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
}
}
#define LAUNCH_KERNEL_WITH_BLOCK_SIZE_AND_DATA_TYPE(BLOCK_SIZE) \
{ \
if (_info.dtype == INFINI_DTYPE_F16) { \
LAUNCH_KERNEL_WITH_BLOCK_SIZE(BLOCK_SIZE, half) \
} else if (_info.dtype == INFINI_DTYPE_F32) { \
LAUNCH_KERNEL_WITH_BLOCK_SIZE(BLOCK_SIZE, float) \
} else if (_info.dtype == INFINI_DTYPE_BF16) { \
LAUNCH_KERNEL_WITH_BLOCK_SIZE(BLOCK_SIZE, __hpcc_bfloat16) \
} else { \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
} \
}
// Choose block size based on device capabilities
if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_1024) {
if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_1024) {
CALCULATE_KV_CACHING
_WITH_BLOCK_SIZE(METAX_BLOCK_SIZE_1024)
LAUNCH_KERNEL
_WITH_BLOCK_SIZE
_AND_DATA_TYPE
(METAX_BLOCK_SIZE_1024)
} else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_512) {
} else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_512) {
CALCULATE_KV_CACHING
_WITH_BLOCK_SIZE(METAX_BLOCK_SIZE_512)
LAUNCH_KERNEL
_WITH_BLOCK_SIZE
_AND_DATA_TYPE
(METAX_BLOCK_SIZE_512)
} else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_2048) {
} else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_2048) {
CALCULATE_KV_CACHING
_WITH_BLOCK_SIZE(METAX_BLOCK_SIZE_2048)
LAUNCH_KERNEL
_WITH_BLOCK_SIZE
_AND_DATA_TYPE
(METAX_BLOCK_SIZE_2048)
} else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_4096) {
} else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_4096) {
CALCULATE_KV_CACHING
_WITH_BLOCK_SIZE(METAX_BLOCK_SIZE_4096)
LAUNCH_KERNEL
_WITH_BLOCK_SIZE
_AND_DATA_TYPE
(METAX_BLOCK_SIZE_4096)
} else {
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
}
...
...
src/infiniop/ops/kv_caching/nvidia/kv_caching_nvidia.cu
View file @
a0cbae66
...
@@ -8,13 +8,13 @@
...
@@ -8,13 +8,13 @@
#include "../cuda/kernel.cuh"
#include "../cuda/kernel.cuh"
template
<
typename
Tdata
>
template
<
typename
Tdata
,
typename
Tidx
>
INFINIOP_CUDA_KERNEL
kvCaching
(
INFINIOP_CUDA_KERNEL
kvCaching
(
Tdata
*
k_cache
,
Tdata
*
k_cache
,
Tdata
*
v_cache
,
Tdata
*
v_cache
,
const
Tdata
*
k
,
const
Tdata
*
k
,
const
Tdata
*
v
,
const
Tdata
*
v
,
const
int64_t
*
past_kv_lengths
,
const
Tidx
*
past_kv_lengths
,
int
batch_size
,
int
batch_size
,
int
num_kv_heads
,
int
num_kv_heads
,
int
max_seq_len
,
int
max_seq_len
,
...
@@ -36,12 +36,12 @@ INFINIOP_CUDA_KERNEL kvCaching(
...
@@ -36,12 +36,12 @@ INFINIOP_CUDA_KERNEL kvCaching(
ptrdiff_t
v_strides_1
,
ptrdiff_t
v_strides_1
,
ptrdiff_t
v_strides_2
,
ptrdiff_t
v_strides_2
,
ptrdiff_t
v_strides_3
)
{
ptrdiff_t
v_strides_3
)
{
kvCachingKernel
<
Tdata
>
(
k_cache
,
v_cache
,
k
,
v
,
past_kv_lengths
,
kvCachingKernel
<
Tdata
,
Tidx
>
(
k_cache
,
v_cache
,
k
,
v
,
past_kv_lengths
,
batch_size
,
num_kv_heads
,
max_seq_len
,
seq_len
,
hidden_dim
,
batch_size
,
num_kv_heads
,
max_seq_len
,
seq_len
,
hidden_dim
,
k_cache_strides_0
,
k_cache_strides_1
,
k_cache_strides_2
,
k_cache_strides_3
,
k_cache_strides_0
,
k_cache_strides_1
,
k_cache_strides_2
,
k_cache_strides_3
,
v_cache_strides_0
,
v_cache_strides_1
,
v_cache_strides_2
,
v_cache_strides_3
,
v_cache_strides_0
,
v_cache_strides_1
,
v_cache_strides_2
,
v_cache_strides_3
,
k_strides_0
,
k_strides_1
,
k_strides_2
,
k_strides_3
,
k_strides_0
,
k_strides_1
,
k_strides_2
,
k_strides_3
,
v_strides_0
,
v_strides_1
,
v_strides_2
,
v_strides_3
);
v_strides_0
,
v_strides_1
,
v_strides_2
,
v_strides_3
);
}
}
namespace
op
::
kv_caching
::
nvidia
{
namespace
op
::
kv_caching
::
nvidia
{
...
@@ -71,13 +71,13 @@ infiniStatus_t Descriptor::create(
...
@@ -71,13 +71,13 @@ infiniStatus_t Descriptor::create(
return
INFINI_STATUS_SUCCESS
;
return
INFINI_STATUS_SUCCESS
;
}
}
template
<
unsigned
int
BLOCK_SIZE
,
typename
Tdata
>
template
<
unsigned
int
BLOCK_SIZE
,
typename
Tdata
,
typename
Tidx
>
infiniStatus_t
launchKernel
(
const
KVCachingInfo
&
info
,
infiniStatus_t
launchKernel
(
const
KVCachingInfo
&
info
,
Tdata
*
k_cache
,
Tdata
*
k_cache
,
Tdata
*
v_cache
,
Tdata
*
v_cache
,
const
Tdata
*
k
,
const
Tdata
*
k
,
const
Tdata
*
v
,
const
Tdata
*
v
,
const
int64_t
*
past_kv_lengths
,
const
Tidx
*
past_kv_lengths
,
cudaStream_t
stream
,
void
*
workspace
)
{
cudaStream_t
stream
,
void
*
workspace
)
{
int
batch_size
=
static_cast
<
int
>
(
info
.
batch_size
);
int
batch_size
=
static_cast
<
int
>
(
info
.
batch_size
);
...
@@ -111,7 +111,7 @@ infiniStatus_t launchKernel(const KVCachingInfo &info,
...
@@ -111,7 +111,7 @@ infiniStatus_t launchKernel(const KVCachingInfo &info,
int
num_blocks
=
(
total
+
BLOCK_SIZE
-
1
)
/
BLOCK_SIZE
;
int
num_blocks
=
(
total
+
BLOCK_SIZE
-
1
)
/
BLOCK_SIZE
;
kvCaching
<
Tdata
>
kvCaching
<
Tdata
,
Tidx
>
<<<
num_blocks
,
BLOCK_SIZE
,
0
,
stream
>>>
(
k_cache
,
v_cache
,
k
,
v
,
past_kv_lengths
,
<<<
num_blocks
,
BLOCK_SIZE
,
0
,
stream
>>>
(
k_cache
,
v_cache
,
k
,
v
,
past_kv_lengths
,
batch_size
,
num_kv_heads
,
max_seq_len
,
seq_len
,
hidden_dim
,
batch_size
,
num_kv_heads
,
max_seq_len
,
seq_len
,
hidden_dim
,
k_cache_strides_0
,
k_cache_strides_1
,
k_cache_strides_2
,
k_cache_strides_3
,
k_cache_strides_0
,
k_cache_strides_1
,
k_cache_strides_2
,
k_cache_strides_3
,
...
@@ -129,27 +129,40 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
...
@@ -129,27 +129,40 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
const
void
*
past_kv_lengths
,
const
void
*
past_kv_lengths
,
void
*
stream_
)
const
{
void
*
stream_
)
const
{
cudaStream_t
stream
=
(
cudaStream_t
)
stream_
;
cudaStream_t
stream
=
(
cudaStream_t
)
stream_
;
#define CALCULATE_KV_CACHING(BLOCK_SIZE, TDATA) \
launchKernel<BLOCK_SIZE, TDATA>(_info, (TDATA *)k_cache, (TDATA *)v_cache, (const TDATA *)k, (const TDATA *)v, (const int64_t *)past_kv_lengths, stream, workspace)
#define LAUNCH_KERNEL_WITH_BLOCK_SIZE_AND_DTYPE(BLOCK_SIZE, TDATA, TIDX) \
#define CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(BLOCK_SIZE) \
launchKernel<BLOCK_SIZE, TDATA, TIDX>(_info, (TDATA *)k_cache, (TDATA *)v_cache, \
{ \
(const TDATA *)k, (const TDATA *)v, \
if (_info.dtype == INFINI_DTYPE_F16) \
(const TIDX *)past_kv_lengths, stream, workspace)
return CALCULATE_KV_CACHING(BLOCK_SIZE, half); \
else if (_info.dtype == INFINI_DTYPE_F32) \
#define LAUNCH_KERNEL_WITH_BLOCK_SIZE(BLOCK_SIZE, TDATA) \
return CALCULATE_KV_CACHING(BLOCK_SIZE, float); \
if (_info.past_len_dtype == INFINI_DTYPE_I32) { \
else if (_info.dtype == INFINI_DTYPE_BF16) \
return LAUNCH_KERNEL_WITH_BLOCK_SIZE_AND_DTYPE(BLOCK_SIZE, TDATA, int32_t); \
return CALCULATE_KV_CACHING(BLOCK_SIZE, __nv_bfloat16); \
} else { \
else \
return LAUNCH_KERNEL_WITH_BLOCK_SIZE_AND_DTYPE(BLOCK_SIZE, TDATA, int64_t); \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
}
}
#define LAUNCH_KERNEL_WITH_BLOCK_SIZE_AND_DATA_TYPE(BLOCK_SIZE) \
{ \
if (_info.dtype == INFINI_DTYPE_F16) { \
LAUNCH_KERNEL_WITH_BLOCK_SIZE(BLOCK_SIZE, half) \
} else if (_info.dtype == INFINI_DTYPE_F32) { \
LAUNCH_KERNEL_WITH_BLOCK_SIZE(BLOCK_SIZE, float) \
} else if (_info.dtype == INFINI_DTYPE_BF16) { \
LAUNCH_KERNEL_WITH_BLOCK_SIZE(BLOCK_SIZE, __nv_bfloat16) \
} else { \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
} \
}
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_1024
)
{
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_1024
)
{
CALCULATE_KV_CACHING
_WITH_BLOCK_SIZE
(
CUDA_BLOCK_SIZE_1024
)
LAUNCH_KERNEL
_WITH_BLOCK_SIZE
_AND_DATA_TYPE
(
CUDA_BLOCK_SIZE_1024
)
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_512
)
{
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_512
)
{
CALCULATE_KV_CACHING
_WITH_BLOCK_SIZE
(
CUDA_BLOCK_SIZE_512
)
LAUNCH_KERNEL
_WITH_BLOCK_SIZE
_AND_DATA_TYPE
(
CUDA_BLOCK_SIZE_512
)
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_2048
)
{
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_2048
)
{
CALCULATE_KV_CACHING
_WITH_BLOCK_SIZE
(
CUDA_BLOCK_SIZE_2048
)
LAUNCH_KERNEL
_WITH_BLOCK_SIZE
_AND_DATA_TYPE
(
CUDA_BLOCK_SIZE_2048
)
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_4096
)
{
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_4096
)
{
CALCULATE_KV_CACHING
_WITH_BLOCK_SIZE
(
CUDA_BLOCK_SIZE_4096
)
LAUNCH_KERNEL
_WITH_BLOCK_SIZE
_AND_DATA_TYPE
(
CUDA_BLOCK_SIZE_4096
)
}
else
{
}
else
{
return
INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED
;
}
}
...
...
test/infinicore/ops/kv_caching.py
View file @
a0cbae66
...
@@ -37,6 +37,7 @@ _TOLERANCE_MAP = {
...
@@ -37,6 +37,7 @@ _TOLERANCE_MAP = {
# Data types to test
# Data types to test
_TENSOR_DTYPES
=
[
infinicore
.
float16
,
infinicore
.
bfloat16
,
infinicore
.
float32
]
_TENSOR_DTYPES
=
[
infinicore
.
float16
,
infinicore
.
bfloat16
,
infinicore
.
float32
]
_PAST_LEN_DTYPES
=
[
infinicore
.
int32
,
infinicore
.
int64
]
def
parse_test_cases
():
def
parse_test_cases
():
...
@@ -64,31 +65,32 @@ def parse_test_cases():
...
@@ -64,31 +65,32 @@ def parse_test_cases():
cache_spec
=
TensorSpec
.
from_tensor
(
cache_shape
,
strides
,
dtype
)
cache_spec
=
TensorSpec
.
from_tensor
(
cache_shape
,
strides
,
dtype
)
kv_spec
=
TensorSpec
.
from_tensor
(
kv_shape
,
None
,
dtype
)
kv_spec
=
TensorSpec
.
from_tensor
(
kv_shape
,
None
,
dtype
)
past_kv_lengths_spec
=
TensorSpec
.
from_tensor
(
for
past_len_dtype
in
_PAST_LEN_DTYPES
:
past_shape
,
past_kv_lengths_spec
=
TensorSpec
.
from_tensor
(
None
,
past_shape
,
infinicore
.
int64
,
None
,
init_mode
=
TensorInitializer
.
RANDINT
,
past_len_dtype
,
low
=
past_length
,
init_mode
=
TensorInitializer
.
RANDINT
,
high
=
past_length
+
1
,
low
=
past_length
,
)
high
=
past_length
+
1
,
)
test_cases
.
append
(
TestCase
(
test_cases
.
append
(
inputs
=
[
TestCase
(
cache_spec
,
inputs
=
[
cache_spec
,
cache_spec
,
kv_spec
,
cache_spec
,
kv_spec
,
kv_spec
,
past_kv_lengths_spec
,
kv_spec
,
],
past_kv_lengths_spec
,
kwargs
=
{},
],
output_spec
=
None
,
kwargs
=
{},
comparison_target
=
[
0
,
1
],
output_spec
=
None
,
tolerance
=
tolerance
,
comparison_target
=
[
0
,
1
],
description
=
f
"KV Caching"
,
tolerance
=
tolerance
,
description
=
f
"KV Caching"
,
)
)
)
)
return
test_cases
return
test_cases
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment