Commit a0cbae66 authored by wooway777's avatar wooway777
Browse files

issue/1035 - support both int32 and int64 in kv caching

parent a9503148
#ifndef __KV_CACHING_KERNEL_CUH__
#define __KV_CACHING_KERNEL_CUH__
template <typename Tdata>
template <typename Tdata, typename Tidx>
__device__ void kvCachingKernel(
Tdata *__restrict__ k_cache,
Tdata *__restrict__ v_cache,
const Tdata *__restrict__ k,
const Tdata *__restrict__ v,
const int64_t *__restrict__ past_kv_lengths,
const Tidx *__restrict__ past_kv_lengths,
int batch_size,
int num_kv_heads,
int max_seq_len,
......@@ -47,7 +47,7 @@ __device__ void kvCachingKernel(
int h = idx % num_kv_heads;
int b = idx / num_kv_heads;
int past_len = static_cast<int32_t>(past_kv_lengths[b]);
int past_len = static_cast<int>(past_kv_lengths[b]); // Cast to int for both types
// write position
int cache_s = past_len + s;
int k_cache_offset = d * (int)k_cache_strides_3 + cache_s * (int)k_cache_strides_2 + h * (int)k_cache_strides_1 + b * (int)k_cache_strides_0;
......
......@@ -13,6 +13,7 @@ private:
public:
infiniDtype_t dtype;
infiniDtype_t past_len_dtype;
size_t batch_size, num_kv_heads, max_seq_len, seq_len, hidden_dim;
ptrdiff_t k_cache_strides_0, k_cache_strides_1, k_cache_strides_2, k_cache_strides_3;
ptrdiff_t v_cache_strides_0, v_cache_strides_1, v_cache_strides_2, v_cache_strides_3;
......@@ -32,7 +33,8 @@ public:
const infiniDtype_t dtype = k_cache->dtype();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32);
CHECK_DTYPE(past_kv_lengths->dtype(), INFINI_DTYPE_I64);
const infiniDtype_t past_len_dtype = past_kv_lengths->dtype();
CHECK_DTYPE(past_len_dtype, INFINI_DTYPE_I32, INFINI_DTYPE_I64);
CHECK_OR_RETURN(k_cache->ndim() == 4
&& v_cache->ndim() == 4
......@@ -78,6 +80,7 @@ public:
return utils::Result<KVCachingInfo>(KVCachingInfo{
dtype,
past_len_dtype,
batch_size,
num_kv_heads,
max_seq_len,
......
......@@ -8,13 +8,13 @@
#include "../cuda/kernel.cuh"
template <typename Tdata>
template <typename Tdata, typename Tidx>
INFINIOP_METAX_KERNEL kvCaching(
Tdata *k_cache,
Tdata *v_cache,
const Tdata *k,
const Tdata *v,
const int64_t *past_kv_lengths,
const Tidx *past_kv_lengths,
int batch_size,
int num_kv_heads,
int max_seq_len,
......@@ -36,7 +36,7 @@ INFINIOP_METAX_KERNEL kvCaching(
ptrdiff_t v_strides_1,
ptrdiff_t v_strides_2,
ptrdiff_t v_strides_3) {
kvCachingKernel<Tdata>(k_cache, v_cache, k, v, past_kv_lengths,
kvCachingKernel<Tdata, Tidx>(k_cache, v_cache, k, v, past_kv_lengths,
batch_size, num_kv_heads, max_seq_len, seq_len, hidden_dim,
k_cache_strides_0, k_cache_strides_1, k_cache_strides_2, k_cache_strides_3,
v_cache_strides_0, v_cache_strides_1, v_cache_strides_2, v_cache_strides_3,
......@@ -71,13 +71,13 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_SUCCESS;
}
template <unsigned int BLOCK_SIZE, typename Tdata>
template <unsigned int BLOCK_SIZE, typename Tdata, typename Tidx>
infiniStatus_t launchKernel(const KVCachingInfo &info,
Tdata *k_cache,
Tdata *v_cache,
const Tdata *k,
const Tdata *v,
const int64_t *past_kv_lengths,
const Tidx *past_kv_lengths,
hcStream_t stream, void *workspace) {
int batch_size = static_cast<int>(info.batch_size);
......@@ -111,7 +111,7 @@ infiniStatus_t launchKernel(const KVCachingInfo &info,
int num_blocks = (total + BLOCK_SIZE - 1) / BLOCK_SIZE;
kvCaching<Tdata>
kvCaching<Tdata, Tidx>
<<<num_blocks, BLOCK_SIZE, 0, stream>>>(k_cache, v_cache, k, v, past_kv_lengths,
batch_size, num_kv_heads, max_seq_len, seq_len, hidden_dim,
k_cache_strides_0, k_cache_strides_1, k_cache_strides_2, k_cache_strides_3,
......@@ -129,28 +129,41 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
const void *past_kv_lengths,
void *stream_) const {
hcStream_t stream = (hcStream_t)stream_;
#define CALCULATE_KV_CACHING(BLOCK_SIZE, TDATA) \
launchKernel<BLOCK_SIZE, TDATA>(_info, (TDATA *)k_cache, (TDATA *)v_cache, (const TDATA *)k, (const TDATA *)v, (const int64_t *)past_kv_lengths, stream, workspace)
#define CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(BLOCK_SIZE) \
#define LAUNCH_KERNEL_WITH_BLOCK_SIZE_AND_DTYPE(BLOCK_SIZE, TDATA, TIDX) \
launchKernel<BLOCK_SIZE, TDATA, TIDX>(_info, (TDATA *)k_cache, (TDATA *)v_cache, \
(const TDATA *)k, (const TDATA *)v, \
(const TIDX *)past_kv_lengths, stream, workspace)
#define LAUNCH_KERNEL_WITH_BLOCK_SIZE(BLOCK_SIZE, TDATA) \
if (_info.past_len_dtype == INFINI_DTYPE_I32) { \
return LAUNCH_KERNEL_WITH_BLOCK_SIZE_AND_DTYPE(BLOCK_SIZE, TDATA, int32_t); \
} else { /* INFINI_DTYPE_I64 */ \
return LAUNCH_KERNEL_WITH_BLOCK_SIZE_AND_DTYPE(BLOCK_SIZE, TDATA, int64_t); \
}
#define LAUNCH_KERNEL_WITH_BLOCK_SIZE_AND_DATA_TYPE(BLOCK_SIZE) \
{ \
if (_info.dtype == INFINI_DTYPE_F16) \
return CALCULATE_KV_CACHING(BLOCK_SIZE, half); \
else if (_info.dtype == INFINI_DTYPE_F32) \
return CALCULATE_KV_CACHING(BLOCK_SIZE, float); \
else if (_info.dtype == INFINI_DTYPE_BF16) \
return CALCULATE_KV_CACHING(BLOCK_SIZE, __hpcc_bfloat16); \
else \
if (_info.dtype == INFINI_DTYPE_F16) { \
LAUNCH_KERNEL_WITH_BLOCK_SIZE(BLOCK_SIZE, half) \
} else if (_info.dtype == INFINI_DTYPE_F32) { \
LAUNCH_KERNEL_WITH_BLOCK_SIZE(BLOCK_SIZE, float) \
} else if (_info.dtype == INFINI_DTYPE_BF16) { \
LAUNCH_KERNEL_WITH_BLOCK_SIZE(BLOCK_SIZE, __hpcc_bfloat16) \
} else { \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
} \
}
// Choose block size based on device capabilities
if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_1024) {
CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(METAX_BLOCK_SIZE_1024)
LAUNCH_KERNEL_WITH_BLOCK_SIZE_AND_DATA_TYPE(METAX_BLOCK_SIZE_1024)
} else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_512) {
CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(METAX_BLOCK_SIZE_512)
LAUNCH_KERNEL_WITH_BLOCK_SIZE_AND_DATA_TYPE(METAX_BLOCK_SIZE_512)
} else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_2048) {
CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(METAX_BLOCK_SIZE_2048)
LAUNCH_KERNEL_WITH_BLOCK_SIZE_AND_DATA_TYPE(METAX_BLOCK_SIZE_2048)
} else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_4096) {
CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(METAX_BLOCK_SIZE_4096)
LAUNCH_KERNEL_WITH_BLOCK_SIZE_AND_DATA_TYPE(METAX_BLOCK_SIZE_4096)
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
......
......@@ -8,13 +8,13 @@
#include "../cuda/kernel.cuh"
template <typename Tdata>
template <typename Tdata, typename Tidx>
INFINIOP_CUDA_KERNEL kvCaching(
Tdata *k_cache,
Tdata *v_cache,
const Tdata *k,
const Tdata *v,
const int64_t *past_kv_lengths,
const Tidx *past_kv_lengths,
int batch_size,
int num_kv_heads,
int max_seq_len,
......@@ -36,7 +36,7 @@ INFINIOP_CUDA_KERNEL kvCaching(
ptrdiff_t v_strides_1,
ptrdiff_t v_strides_2,
ptrdiff_t v_strides_3) {
kvCachingKernel<Tdata>(k_cache, v_cache, k, v, past_kv_lengths,
kvCachingKernel<Tdata, Tidx>(k_cache, v_cache, k, v, past_kv_lengths,
batch_size, num_kv_heads, max_seq_len, seq_len, hidden_dim,
k_cache_strides_0, k_cache_strides_1, k_cache_strides_2, k_cache_strides_3,
v_cache_strides_0, v_cache_strides_1, v_cache_strides_2, v_cache_strides_3,
......@@ -71,13 +71,13 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_SUCCESS;
}
template <unsigned int BLOCK_SIZE, typename Tdata>
template <unsigned int BLOCK_SIZE, typename Tdata, typename Tidx>
infiniStatus_t launchKernel(const KVCachingInfo &info,
Tdata *k_cache,
Tdata *v_cache,
const Tdata *k,
const Tdata *v,
const int64_t *past_kv_lengths,
const Tidx *past_kv_lengths,
cudaStream_t stream, void *workspace) {
int batch_size = static_cast<int>(info.batch_size);
......@@ -111,7 +111,7 @@ infiniStatus_t launchKernel(const KVCachingInfo &info,
int num_blocks = (total + BLOCK_SIZE - 1) / BLOCK_SIZE;
kvCaching<Tdata>
kvCaching<Tdata, Tidx>
<<<num_blocks, BLOCK_SIZE, 0, stream>>>(k_cache, v_cache, k, v, past_kv_lengths,
batch_size, num_kv_heads, max_seq_len, seq_len, hidden_dim,
k_cache_strides_0, k_cache_strides_1, k_cache_strides_2, k_cache_strides_3,
......@@ -129,27 +129,40 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
const void *past_kv_lengths,
void *stream_) const {
cudaStream_t stream = (cudaStream_t)stream_;
#define CALCULATE_KV_CACHING(BLOCK_SIZE, TDATA) \
launchKernel<BLOCK_SIZE, TDATA>(_info, (TDATA *)k_cache, (TDATA *)v_cache, (const TDATA *)k, (const TDATA *)v, (const int64_t *)past_kv_lengths, stream, workspace)
#define CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(BLOCK_SIZE) \
#define LAUNCH_KERNEL_WITH_BLOCK_SIZE_AND_DTYPE(BLOCK_SIZE, TDATA, TIDX) \
launchKernel<BLOCK_SIZE, TDATA, TIDX>(_info, (TDATA *)k_cache, (TDATA *)v_cache, \
(const TDATA *)k, (const TDATA *)v, \
(const TIDX *)past_kv_lengths, stream, workspace)
#define LAUNCH_KERNEL_WITH_BLOCK_SIZE(BLOCK_SIZE, TDATA) \
if (_info.past_len_dtype == INFINI_DTYPE_I32) { \
return LAUNCH_KERNEL_WITH_BLOCK_SIZE_AND_DTYPE(BLOCK_SIZE, TDATA, int32_t); \
} else { \
return LAUNCH_KERNEL_WITH_BLOCK_SIZE_AND_DTYPE(BLOCK_SIZE, TDATA, int64_t); \
}
#define LAUNCH_KERNEL_WITH_BLOCK_SIZE_AND_DATA_TYPE(BLOCK_SIZE) \
{ \
if (_info.dtype == INFINI_DTYPE_F16) \
return CALCULATE_KV_CACHING(BLOCK_SIZE, half); \
else if (_info.dtype == INFINI_DTYPE_F32) \
return CALCULATE_KV_CACHING(BLOCK_SIZE, float); \
else if (_info.dtype == INFINI_DTYPE_BF16) \
return CALCULATE_KV_CACHING(BLOCK_SIZE, __nv_bfloat16); \
else \
if (_info.dtype == INFINI_DTYPE_F16) { \
LAUNCH_KERNEL_WITH_BLOCK_SIZE(BLOCK_SIZE, half) \
} else if (_info.dtype == INFINI_DTYPE_F32) { \
LAUNCH_KERNEL_WITH_BLOCK_SIZE(BLOCK_SIZE, float) \
} else if (_info.dtype == INFINI_DTYPE_BF16) { \
LAUNCH_KERNEL_WITH_BLOCK_SIZE(BLOCK_SIZE, __nv_bfloat16) \
} else { \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
} \
}
if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) {
CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_1024)
LAUNCH_KERNEL_WITH_BLOCK_SIZE_AND_DATA_TYPE(CUDA_BLOCK_SIZE_1024)
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_512) {
CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_512)
LAUNCH_KERNEL_WITH_BLOCK_SIZE_AND_DATA_TYPE(CUDA_BLOCK_SIZE_512)
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_2048) {
CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_2048)
LAUNCH_KERNEL_WITH_BLOCK_SIZE_AND_DATA_TYPE(CUDA_BLOCK_SIZE_2048)
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) {
CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_4096)
LAUNCH_KERNEL_WITH_BLOCK_SIZE_AND_DATA_TYPE(CUDA_BLOCK_SIZE_4096)
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
......
......@@ -37,6 +37,7 @@ _TOLERANCE_MAP = {
# Data types to test
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
_PAST_LEN_DTYPES = [infinicore.int32, infinicore.int64]
def parse_test_cases():
......@@ -64,10 +65,11 @@ def parse_test_cases():
cache_spec = TensorSpec.from_tensor(cache_shape, strides, dtype)
kv_spec = TensorSpec.from_tensor(kv_shape, None, dtype)
for past_len_dtype in _PAST_LEN_DTYPES:
past_kv_lengths_spec = TensorSpec.from_tensor(
past_shape,
None,
infinicore.int64,
past_len_dtype,
init_mode=TensorInitializer.RANDINT,
low=past_length,
high=past_length + 1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment