Commit a0cbae66 authored by wooway777's avatar wooway777
Browse files

issue/1035 - support both int32 and int64 in kv caching

parent a9503148
#ifndef __KV_CACHING_KERNEL_CUH__ #ifndef __KV_CACHING_KERNEL_CUH__
#define __KV_CACHING_KERNEL_CUH__ #define __KV_CACHING_KERNEL_CUH__
template <typename Tdata> template <typename Tdata, typename Tidx>
__device__ void kvCachingKernel( __device__ void kvCachingKernel(
Tdata *__restrict__ k_cache, Tdata *__restrict__ k_cache,
Tdata *__restrict__ v_cache, Tdata *__restrict__ v_cache,
const Tdata *__restrict__ k, const Tdata *__restrict__ k,
const Tdata *__restrict__ v, const Tdata *__restrict__ v,
const int64_t *__restrict__ past_kv_lengths, const Tidx *__restrict__ past_kv_lengths,
int batch_size, int batch_size,
int num_kv_heads, int num_kv_heads,
int max_seq_len, int max_seq_len,
...@@ -47,7 +47,7 @@ __device__ void kvCachingKernel( ...@@ -47,7 +47,7 @@ __device__ void kvCachingKernel(
int h = idx % num_kv_heads; int h = idx % num_kv_heads;
int b = idx / num_kv_heads; int b = idx / num_kv_heads;
int past_len = static_cast<int32_t>(past_kv_lengths[b]); int past_len = static_cast<int>(past_kv_lengths[b]); // Cast to int for both types
// write position // write position
int cache_s = past_len + s; int cache_s = past_len + s;
int k_cache_offset = d * (int)k_cache_strides_3 + cache_s * (int)k_cache_strides_2 + h * (int)k_cache_strides_1 + b * (int)k_cache_strides_0; int k_cache_offset = d * (int)k_cache_strides_3 + cache_s * (int)k_cache_strides_2 + h * (int)k_cache_strides_1 + b * (int)k_cache_strides_0;
......
...@@ -13,6 +13,7 @@ private: ...@@ -13,6 +13,7 @@ private:
public: public:
infiniDtype_t dtype; infiniDtype_t dtype;
infiniDtype_t past_len_dtype;
size_t batch_size, num_kv_heads, max_seq_len, seq_len, hidden_dim; size_t batch_size, num_kv_heads, max_seq_len, seq_len, hidden_dim;
ptrdiff_t k_cache_strides_0, k_cache_strides_1, k_cache_strides_2, k_cache_strides_3; ptrdiff_t k_cache_strides_0, k_cache_strides_1, k_cache_strides_2, k_cache_strides_3;
ptrdiff_t v_cache_strides_0, v_cache_strides_1, v_cache_strides_2, v_cache_strides_3; ptrdiff_t v_cache_strides_0, v_cache_strides_1, v_cache_strides_2, v_cache_strides_3;
...@@ -32,7 +33,8 @@ public: ...@@ -32,7 +33,8 @@ public:
const infiniDtype_t dtype = k_cache->dtype(); const infiniDtype_t dtype = k_cache->dtype();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32); CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32);
CHECK_DTYPE(past_kv_lengths->dtype(), INFINI_DTYPE_I64); const infiniDtype_t past_len_dtype = past_kv_lengths->dtype();
CHECK_DTYPE(past_len_dtype, INFINI_DTYPE_I32, INFINI_DTYPE_I64);
CHECK_OR_RETURN(k_cache->ndim() == 4 CHECK_OR_RETURN(k_cache->ndim() == 4
&& v_cache->ndim() == 4 && v_cache->ndim() == 4
...@@ -78,6 +80,7 @@ public: ...@@ -78,6 +80,7 @@ public:
return utils::Result<KVCachingInfo>(KVCachingInfo{ return utils::Result<KVCachingInfo>(KVCachingInfo{
dtype, dtype,
past_len_dtype,
batch_size, batch_size,
num_kv_heads, num_kv_heads,
max_seq_len, max_seq_len,
......
...@@ -8,13 +8,13 @@ ...@@ -8,13 +8,13 @@
#include "../cuda/kernel.cuh" #include "../cuda/kernel.cuh"
template <typename Tdata> template <typename Tdata, typename Tidx>
INFINIOP_METAX_KERNEL kvCaching( INFINIOP_METAX_KERNEL kvCaching(
Tdata *k_cache, Tdata *k_cache,
Tdata *v_cache, Tdata *v_cache,
const Tdata *k, const Tdata *k,
const Tdata *v, const Tdata *v,
const int64_t *past_kv_lengths, const Tidx *past_kv_lengths,
int batch_size, int batch_size,
int num_kv_heads, int num_kv_heads,
int max_seq_len, int max_seq_len,
...@@ -36,12 +36,12 @@ INFINIOP_METAX_KERNEL kvCaching( ...@@ -36,12 +36,12 @@ INFINIOP_METAX_KERNEL kvCaching(
ptrdiff_t v_strides_1, ptrdiff_t v_strides_1,
ptrdiff_t v_strides_2, ptrdiff_t v_strides_2,
ptrdiff_t v_strides_3) { ptrdiff_t v_strides_3) {
kvCachingKernel<Tdata>(k_cache, v_cache, k, v, past_kv_lengths, kvCachingKernel<Tdata, Tidx>(k_cache, v_cache, k, v, past_kv_lengths,
batch_size, num_kv_heads, max_seq_len, seq_len, hidden_dim, batch_size, num_kv_heads, max_seq_len, seq_len, hidden_dim,
k_cache_strides_0, k_cache_strides_1, k_cache_strides_2, k_cache_strides_3, k_cache_strides_0, k_cache_strides_1, k_cache_strides_2, k_cache_strides_3,
v_cache_strides_0, v_cache_strides_1, v_cache_strides_2, v_cache_strides_3, v_cache_strides_0, v_cache_strides_1, v_cache_strides_2, v_cache_strides_3,
k_strides_0, k_strides_1, k_strides_2, k_strides_3, k_strides_0, k_strides_1, k_strides_2, k_strides_3,
v_strides_0, v_strides_1, v_strides_2, v_strides_3); v_strides_0, v_strides_1, v_strides_2, v_strides_3);
} }
namespace op::kv_caching::metax { namespace op::kv_caching::metax {
...@@ -71,13 +71,13 @@ infiniStatus_t Descriptor::create( ...@@ -71,13 +71,13 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
template <unsigned int BLOCK_SIZE, typename Tdata> template <unsigned int BLOCK_SIZE, typename Tdata, typename Tidx>
infiniStatus_t launchKernel(const KVCachingInfo &info, infiniStatus_t launchKernel(const KVCachingInfo &info,
Tdata *k_cache, Tdata *k_cache,
Tdata *v_cache, Tdata *v_cache,
const Tdata *k, const Tdata *k,
const Tdata *v, const Tdata *v,
const int64_t *past_kv_lengths, const Tidx *past_kv_lengths,
hcStream_t stream, void *workspace) { hcStream_t stream, void *workspace) {
int batch_size = static_cast<int>(info.batch_size); int batch_size = static_cast<int>(info.batch_size);
...@@ -111,7 +111,7 @@ infiniStatus_t launchKernel(const KVCachingInfo &info, ...@@ -111,7 +111,7 @@ infiniStatus_t launchKernel(const KVCachingInfo &info,
int num_blocks = (total + BLOCK_SIZE - 1) / BLOCK_SIZE; int num_blocks = (total + BLOCK_SIZE - 1) / BLOCK_SIZE;
kvCaching<Tdata> kvCaching<Tdata, Tidx>
<<<num_blocks, BLOCK_SIZE, 0, stream>>>(k_cache, v_cache, k, v, past_kv_lengths, <<<num_blocks, BLOCK_SIZE, 0, stream>>>(k_cache, v_cache, k, v, past_kv_lengths,
batch_size, num_kv_heads, max_seq_len, seq_len, hidden_dim, batch_size, num_kv_heads, max_seq_len, seq_len, hidden_dim,
k_cache_strides_0, k_cache_strides_1, k_cache_strides_2, k_cache_strides_3, k_cache_strides_0, k_cache_strides_1, k_cache_strides_2, k_cache_strides_3,
...@@ -129,28 +129,41 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, ...@@ -129,28 +129,41 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
const void *past_kv_lengths, const void *past_kv_lengths,
void *stream_) const { void *stream_) const {
hcStream_t stream = (hcStream_t)stream_; hcStream_t stream = (hcStream_t)stream_;
#define CALCULATE_KV_CACHING(BLOCK_SIZE, TDATA) \
launchKernel<BLOCK_SIZE, TDATA>(_info, (TDATA *)k_cache, (TDATA *)v_cache, (const TDATA *)k, (const TDATA *)v, (const int64_t *)past_kv_lengths, stream, workspace) #define LAUNCH_KERNEL_WITH_BLOCK_SIZE_AND_DTYPE(BLOCK_SIZE, TDATA, TIDX) \
#define CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(BLOCK_SIZE) \ launchKernel<BLOCK_SIZE, TDATA, TIDX>(_info, (TDATA *)k_cache, (TDATA *)v_cache, \
{ \ (const TDATA *)k, (const TDATA *)v, \
if (_info.dtype == INFINI_DTYPE_F16) \ (const TIDX *)past_kv_lengths, stream, workspace)
return CALCULATE_KV_CACHING(BLOCK_SIZE, half); \
else if (_info.dtype == INFINI_DTYPE_F32) \ #define LAUNCH_KERNEL_WITH_BLOCK_SIZE(BLOCK_SIZE, TDATA) \
return CALCULATE_KV_CACHING(BLOCK_SIZE, float); \ if (_info.past_len_dtype == INFINI_DTYPE_I32) { \
else if (_info.dtype == INFINI_DTYPE_BF16) \ return LAUNCH_KERNEL_WITH_BLOCK_SIZE_AND_DTYPE(BLOCK_SIZE, TDATA, int32_t); \
return CALCULATE_KV_CACHING(BLOCK_SIZE, __hpcc_bfloat16); \ } else { /* INFINI_DTYPE_I64 */ \
else \ return LAUNCH_KERNEL_WITH_BLOCK_SIZE_AND_DTYPE(BLOCK_SIZE, TDATA, int64_t); \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
} }
#define LAUNCH_KERNEL_WITH_BLOCK_SIZE_AND_DATA_TYPE(BLOCK_SIZE) \
{ \
if (_info.dtype == INFINI_DTYPE_F16) { \
LAUNCH_KERNEL_WITH_BLOCK_SIZE(BLOCK_SIZE, half) \
} else if (_info.dtype == INFINI_DTYPE_F32) { \
LAUNCH_KERNEL_WITH_BLOCK_SIZE(BLOCK_SIZE, float) \
} else if (_info.dtype == INFINI_DTYPE_BF16) { \
LAUNCH_KERNEL_WITH_BLOCK_SIZE(BLOCK_SIZE, __hpcc_bfloat16) \
} else { \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
} \
}
// Choose block size based on device capabilities
if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_1024) { if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_1024) {
CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(METAX_BLOCK_SIZE_1024) LAUNCH_KERNEL_WITH_BLOCK_SIZE_AND_DATA_TYPE(METAX_BLOCK_SIZE_1024)
} else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_512) { } else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_512) {
CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(METAX_BLOCK_SIZE_512) LAUNCH_KERNEL_WITH_BLOCK_SIZE_AND_DATA_TYPE(METAX_BLOCK_SIZE_512)
} else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_2048) { } else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_2048) {
CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(METAX_BLOCK_SIZE_2048) LAUNCH_KERNEL_WITH_BLOCK_SIZE_AND_DATA_TYPE(METAX_BLOCK_SIZE_2048)
} else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_4096) { } else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_4096) {
CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(METAX_BLOCK_SIZE_4096) LAUNCH_KERNEL_WITH_BLOCK_SIZE_AND_DATA_TYPE(METAX_BLOCK_SIZE_4096)
} else { } else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
} }
......
...@@ -8,13 +8,13 @@ ...@@ -8,13 +8,13 @@
#include "../cuda/kernel.cuh" #include "../cuda/kernel.cuh"
template <typename Tdata> template <typename Tdata, typename Tidx>
INFINIOP_CUDA_KERNEL kvCaching( INFINIOP_CUDA_KERNEL kvCaching(
Tdata *k_cache, Tdata *k_cache,
Tdata *v_cache, Tdata *v_cache,
const Tdata *k, const Tdata *k,
const Tdata *v, const Tdata *v,
const int64_t *past_kv_lengths, const Tidx *past_kv_lengths,
int batch_size, int batch_size,
int num_kv_heads, int num_kv_heads,
int max_seq_len, int max_seq_len,
...@@ -36,12 +36,12 @@ INFINIOP_CUDA_KERNEL kvCaching( ...@@ -36,12 +36,12 @@ INFINIOP_CUDA_KERNEL kvCaching(
ptrdiff_t v_strides_1, ptrdiff_t v_strides_1,
ptrdiff_t v_strides_2, ptrdiff_t v_strides_2,
ptrdiff_t v_strides_3) { ptrdiff_t v_strides_3) {
kvCachingKernel<Tdata>(k_cache, v_cache, k, v, past_kv_lengths, kvCachingKernel<Tdata, Tidx>(k_cache, v_cache, k, v, past_kv_lengths,
batch_size, num_kv_heads, max_seq_len, seq_len, hidden_dim, batch_size, num_kv_heads, max_seq_len, seq_len, hidden_dim,
k_cache_strides_0, k_cache_strides_1, k_cache_strides_2, k_cache_strides_3, k_cache_strides_0, k_cache_strides_1, k_cache_strides_2, k_cache_strides_3,
v_cache_strides_0, v_cache_strides_1, v_cache_strides_2, v_cache_strides_3, v_cache_strides_0, v_cache_strides_1, v_cache_strides_2, v_cache_strides_3,
k_strides_0, k_strides_1, k_strides_2, k_strides_3, k_strides_0, k_strides_1, k_strides_2, k_strides_3,
v_strides_0, v_strides_1, v_strides_2, v_strides_3); v_strides_0, v_strides_1, v_strides_2, v_strides_3);
} }
namespace op::kv_caching::nvidia { namespace op::kv_caching::nvidia {
...@@ -71,13 +71,13 @@ infiniStatus_t Descriptor::create( ...@@ -71,13 +71,13 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
template <unsigned int BLOCK_SIZE, typename Tdata> template <unsigned int BLOCK_SIZE, typename Tdata, typename Tidx>
infiniStatus_t launchKernel(const KVCachingInfo &info, infiniStatus_t launchKernel(const KVCachingInfo &info,
Tdata *k_cache, Tdata *k_cache,
Tdata *v_cache, Tdata *v_cache,
const Tdata *k, const Tdata *k,
const Tdata *v, const Tdata *v,
const int64_t *past_kv_lengths, const Tidx *past_kv_lengths,
cudaStream_t stream, void *workspace) { cudaStream_t stream, void *workspace) {
int batch_size = static_cast<int>(info.batch_size); int batch_size = static_cast<int>(info.batch_size);
...@@ -111,7 +111,7 @@ infiniStatus_t launchKernel(const KVCachingInfo &info, ...@@ -111,7 +111,7 @@ infiniStatus_t launchKernel(const KVCachingInfo &info,
int num_blocks = (total + BLOCK_SIZE - 1) / BLOCK_SIZE; int num_blocks = (total + BLOCK_SIZE - 1) / BLOCK_SIZE;
kvCaching<Tdata> kvCaching<Tdata, Tidx>
<<<num_blocks, BLOCK_SIZE, 0, stream>>>(k_cache, v_cache, k, v, past_kv_lengths, <<<num_blocks, BLOCK_SIZE, 0, stream>>>(k_cache, v_cache, k, v, past_kv_lengths,
batch_size, num_kv_heads, max_seq_len, seq_len, hidden_dim, batch_size, num_kv_heads, max_seq_len, seq_len, hidden_dim,
k_cache_strides_0, k_cache_strides_1, k_cache_strides_2, k_cache_strides_3, k_cache_strides_0, k_cache_strides_1, k_cache_strides_2, k_cache_strides_3,
...@@ -129,27 +129,40 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, ...@@ -129,27 +129,40 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
const void *past_kv_lengths, const void *past_kv_lengths,
void *stream_) const { void *stream_) const {
cudaStream_t stream = (cudaStream_t)stream_; cudaStream_t stream = (cudaStream_t)stream_;
#define CALCULATE_KV_CACHING(BLOCK_SIZE, TDATA) \
launchKernel<BLOCK_SIZE, TDATA>(_info, (TDATA *)k_cache, (TDATA *)v_cache, (const TDATA *)k, (const TDATA *)v, (const int64_t *)past_kv_lengths, stream, workspace) #define LAUNCH_KERNEL_WITH_BLOCK_SIZE_AND_DTYPE(BLOCK_SIZE, TDATA, TIDX) \
#define CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(BLOCK_SIZE) \ launchKernel<BLOCK_SIZE, TDATA, TIDX>(_info, (TDATA *)k_cache, (TDATA *)v_cache, \
{ \ (const TDATA *)k, (const TDATA *)v, \
if (_info.dtype == INFINI_DTYPE_F16) \ (const TIDX *)past_kv_lengths, stream, workspace)
return CALCULATE_KV_CACHING(BLOCK_SIZE, half); \
else if (_info.dtype == INFINI_DTYPE_F32) \ #define LAUNCH_KERNEL_WITH_BLOCK_SIZE(BLOCK_SIZE, TDATA) \
return CALCULATE_KV_CACHING(BLOCK_SIZE, float); \ if (_info.past_len_dtype == INFINI_DTYPE_I32) { \
else if (_info.dtype == INFINI_DTYPE_BF16) \ return LAUNCH_KERNEL_WITH_BLOCK_SIZE_AND_DTYPE(BLOCK_SIZE, TDATA, int32_t); \
return CALCULATE_KV_CACHING(BLOCK_SIZE, __nv_bfloat16); \ } else { \
else \ return LAUNCH_KERNEL_WITH_BLOCK_SIZE_AND_DTYPE(BLOCK_SIZE, TDATA, int64_t); \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
} }
#define LAUNCH_KERNEL_WITH_BLOCK_SIZE_AND_DATA_TYPE(BLOCK_SIZE) \
{ \
if (_info.dtype == INFINI_DTYPE_F16) { \
LAUNCH_KERNEL_WITH_BLOCK_SIZE(BLOCK_SIZE, half) \
} else if (_info.dtype == INFINI_DTYPE_F32) { \
LAUNCH_KERNEL_WITH_BLOCK_SIZE(BLOCK_SIZE, float) \
} else if (_info.dtype == INFINI_DTYPE_BF16) { \
LAUNCH_KERNEL_WITH_BLOCK_SIZE(BLOCK_SIZE, __nv_bfloat16) \
} else { \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
} \
}
if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) { if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) {
CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_1024) LAUNCH_KERNEL_WITH_BLOCK_SIZE_AND_DATA_TYPE(CUDA_BLOCK_SIZE_1024)
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_512) { } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_512) {
CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_512) LAUNCH_KERNEL_WITH_BLOCK_SIZE_AND_DATA_TYPE(CUDA_BLOCK_SIZE_512)
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_2048) { } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_2048) {
CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_2048) LAUNCH_KERNEL_WITH_BLOCK_SIZE_AND_DATA_TYPE(CUDA_BLOCK_SIZE_2048)
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) { } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) {
CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_4096) LAUNCH_KERNEL_WITH_BLOCK_SIZE_AND_DATA_TYPE(CUDA_BLOCK_SIZE_4096)
} else { } else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
} }
......
...@@ -37,6 +37,7 @@ _TOLERANCE_MAP = { ...@@ -37,6 +37,7 @@ _TOLERANCE_MAP = {
# Data types to test # Data types to test
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32] _TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
_PAST_LEN_DTYPES = [infinicore.int32, infinicore.int64]
def parse_test_cases(): def parse_test_cases():
...@@ -64,31 +65,32 @@ def parse_test_cases(): ...@@ -64,31 +65,32 @@ def parse_test_cases():
cache_spec = TensorSpec.from_tensor(cache_shape, strides, dtype) cache_spec = TensorSpec.from_tensor(cache_shape, strides, dtype)
kv_spec = TensorSpec.from_tensor(kv_shape, None, dtype) kv_spec = TensorSpec.from_tensor(kv_shape, None, dtype)
past_kv_lengths_spec = TensorSpec.from_tensor( for past_len_dtype in _PAST_LEN_DTYPES:
past_shape, past_kv_lengths_spec = TensorSpec.from_tensor(
None, past_shape,
infinicore.int64, None,
init_mode=TensorInitializer.RANDINT, past_len_dtype,
low=past_length, init_mode=TensorInitializer.RANDINT,
high=past_length + 1, low=past_length,
) high=past_length + 1,
)
test_cases.append(
TestCase( test_cases.append(
inputs=[ TestCase(
cache_spec, inputs=[
cache_spec, cache_spec,
kv_spec, cache_spec,
kv_spec, kv_spec,
past_kv_lengths_spec, kv_spec,
], past_kv_lengths_spec,
kwargs={}, ],
output_spec=None, kwargs={},
comparison_target=[0, 1], output_spec=None,
tolerance=tolerance, comparison_target=[0, 1],
description=f"KV Caching", tolerance=tolerance,
description=f"KV Caching",
)
) )
)
return test_cases return test_cases
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment