Commit af394a32 authored by wooway777's avatar wooway777
Browse files

issue/1035 - kv caching on ali, ilu, hygon, qy, metax

parent c70805c9
...@@ -29,14 +29,14 @@ __device__ void kvCachingKernel( ...@@ -29,14 +29,14 @@ __device__ void kvCachingKernel(
ptrdiff_t v_strides_1, ptrdiff_t v_strides_1,
ptrdiff_t v_strides_2, ptrdiff_t v_strides_2,
ptrdiff_t v_strides_3) { ptrdiff_t v_strides_3) {
// 总元素数 = B * H * seq_len * D // num of ele = B * H * seq_len * D
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
int total = batch_size * num_kv_heads * seq_len * hidden_dim; int total = batch_size * num_kv_heads * seq_len * hidden_dim;
const int grid_size = blockDim.x * gridDim.x; const int grid_size = blockDim.x * gridDim.x;
for (int idx = tid; idx < total; idx += grid_size) { for (int idx = tid; idx < total; idx += grid_size) {
// 反解 index // unravel index
int d = idx % hidden_dim; int d = idx % hidden_dim;
idx /= hidden_dim; idx /= hidden_dim;
...@@ -48,7 +48,7 @@ __device__ void kvCachingKernel( ...@@ -48,7 +48,7 @@ __device__ void kvCachingKernel(
int b = idx / num_kv_heads; int b = idx / num_kv_heads;
int past_len = static_cast<int32_t>(past_kv_lengths[b]); int past_len = static_cast<int32_t>(past_kv_lengths[b]);
// 写入位置 // write position
int cache_s = past_len + s; int cache_s = past_len + s;
int k_cache_offset = d * (int)k_cache_strides_3 + cache_s * (int)k_cache_strides_2 + h * (int)k_cache_strides_1 + b * (int)k_cache_strides_0; int k_cache_offset = d * (int)k_cache_strides_3 + cache_s * (int)k_cache_strides_2 + h * (int)k_cache_strides_1 + b * (int)k_cache_strides_0;
int v_cache_offset = d * (int)v_cache_strides_3 + cache_s * (int)v_cache_strides_2 + h * (int)v_cache_strides_1 + b * (int)v_cache_strides_0; int v_cache_offset = d * (int)v_cache_strides_3 + cache_s * (int)v_cache_strides_2 + h * (int)v_cache_strides_1 + b * (int)v_cache_strides_0;
......
...@@ -32,6 +32,7 @@ public: ...@@ -32,6 +32,7 @@ public:
const infiniDtype_t dtype = k_cache->dtype(); const infiniDtype_t dtype = k_cache->dtype();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32); CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32);
CHECK_DTYPE(past_kv_lengths->dtype(), INFINI_DTYPE_I64);
CHECK_OR_RETURN(k_cache->ndim() == 4 CHECK_OR_RETURN(k_cache->ndim() == 4
&& v_cache->ndim() == 4 && v_cache->ndim() == 4
......
#ifndef __KV_CACHING_METAX_API_H__
#define __KV_CACHING_METAX_API_H__
#include "../kv_caching.h"
DESCRIPTOR(metax)
#endif // __KV_CACHING_METAX_API_H__
#include "../../../devices/metax/metax_common.h"
#include "kv_caching_metax.h"
#include "../../../devices/metax/metax_kernel_common.h"
#include <cub/block/block_reduce.cuh>
#include "../../../reduce/cuda/reduce.cuh"
#include "../cuda/kernel.cuh"
template <typename Tdata>
INFINIOP_METAX_KERNEL kvCaching(
Tdata *k_cache,
Tdata *v_cache,
const Tdata *k,
const Tdata *v,
const int64_t *past_kv_lengths,
int batch_size,
int num_kv_heads,
int max_seq_len,
int seq_len,
int hidden_dim,
ptrdiff_t k_cache_strides_0,
ptrdiff_t k_cache_strides_1,
ptrdiff_t k_cache_strides_2,
ptrdiff_t k_cache_strides_3,
ptrdiff_t v_cache_strides_0,
ptrdiff_t v_cache_strides_1,
ptrdiff_t v_cache_strides_2,
ptrdiff_t v_cache_strides_3,
ptrdiff_t k_strides_0,
ptrdiff_t k_strides_1,
ptrdiff_t k_strides_2,
ptrdiff_t k_strides_3,
ptrdiff_t v_strides_0,
ptrdiff_t v_strides_1,
ptrdiff_t v_strides_2,
ptrdiff_t v_strides_3) {
kvCachingKernel<Tdata>(k_cache, v_cache, k, v, past_kv_lengths,
batch_size, num_kv_heads, max_seq_len, seq_len, hidden_dim,
k_cache_strides_0, k_cache_strides_1, k_cache_strides_2, k_cache_strides_3,
v_cache_strides_0, v_cache_strides_1, v_cache_strides_2, v_cache_strides_3,
k_strides_0, k_strides_1, k_strides_2, k_strides_3,
v_strides_0, v_strides_1, v_strides_2, v_strides_3);
}
namespace op::kv_caching::metax {
struct Descriptor::Opaque {
std::shared_ptr<device::metax::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t k_cache,
infiniopTensorDescriptor_t v_cache,
infiniopTensorDescriptor_t k,
infiniopTensorDescriptor_t v,
infiniopTensorDescriptor_t past_kv_lengths) {
auto info = KVCachingInfo::createKVCachingInfo(k_cache, v_cache, k, v, past_kv_lengths);
CHECK_RESULT(info);
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
info.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <unsigned int BLOCK_SIZE, typename Tdata>
infiniStatus_t launchKernel(const KVCachingInfo &info,
Tdata *k_cache,
Tdata *v_cache,
const Tdata *k,
const Tdata *v,
const int64_t *past_kv_lengths,
hcStream_t stream, void *workspace) {
int batch_size = static_cast<int>(info.batch_size);
int num_kv_heads = static_cast<int>(info.num_kv_heads);
int max_seq_len = static_cast<int>(info.max_seq_len);
int hidden_dim = static_cast<int>(info.hidden_dim);
int seq_len = static_cast<int>(info.seq_len);
int total = batch_size * num_kv_heads * seq_len * hidden_dim;
ptrdiff_t k_cache_strides_0 = info.k_cache_strides_0;
ptrdiff_t k_cache_strides_1 = info.k_cache_strides_1;
ptrdiff_t k_cache_strides_2 = info.k_cache_strides_2;
ptrdiff_t k_cache_strides_3 = info.k_cache_strides_3;
ptrdiff_t v_cache_strides_0 = info.v_cache_strides_0;
ptrdiff_t v_cache_strides_1 = info.v_cache_strides_1;
ptrdiff_t v_cache_strides_2 = info.v_cache_strides_2;
ptrdiff_t v_cache_strides_3 = info.v_cache_strides_3;
ptrdiff_t k_strides_0 = info.k_strides_0;
ptrdiff_t k_strides_1 = info.k_strides_1;
ptrdiff_t k_strides_2 = info.k_strides_2;
ptrdiff_t k_strides_3 = info.k_strides_3;
ptrdiff_t v_strides_0 = info.v_strides_0;
ptrdiff_t v_strides_1 = info.v_strides_1;
ptrdiff_t v_strides_2 = info.v_strides_2;
ptrdiff_t v_strides_3 = info.v_strides_3;
int num_blocks = (total + BLOCK_SIZE - 1) / BLOCK_SIZE;
kvCaching<Tdata>
<<<num_blocks, BLOCK_SIZE, 0, stream>>>(k_cache, v_cache, k, v, past_kv_lengths,
batch_size, num_kv_heads, max_seq_len, seq_len, hidden_dim,
k_cache_strides_0, k_cache_strides_1, k_cache_strides_2, k_cache_strides_3,
v_cache_strides_0, v_cache_strides_1, v_cache_strides_2, v_cache_strides_3,
k_strides_0, k_strides_1, k_strides_2, k_strides_3,
v_strides_0, v_strides_1, v_strides_2, v_strides_3);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
void *k_cache,
void *v_cache,
const void *k,
const void *v,
const void *past_kv_lengths,
void *stream_) const {
hcStream_t stream = (hcStream_t)stream_;
#define CALCULATE_KV_CACHING(BLOCK_SIZE, TDATA) \
launchKernel<BLOCK_SIZE, TDATA>(_info, (TDATA *)k_cache, (TDATA *)v_cache, (const TDATA *)k, (const TDATA *)v, (const int64_t *)past_kv_lengths, stream, workspace)
#define CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(BLOCK_SIZE) \
{ \
if (_info.dtype == INFINI_DTYPE_F16) \
return CALCULATE_KV_CACHING(BLOCK_SIZE, half); \
else if (_info.dtype == INFINI_DTYPE_F32) \
return CALCULATE_KV_CACHING(BLOCK_SIZE, float); \
else if (_info.dtype == INFINI_DTYPE_BF16) \
return CALCULATE_KV_CACHING(BLOCK_SIZE, __hpcc_bfloat16); \
else \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
}
if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_1024) {
CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(METAX_BLOCK_SIZE_1024)
} else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_512) {
CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(METAX_BLOCK_SIZE_512)
} else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_2048) {
CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(METAX_BLOCK_SIZE_2048)
} else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_4096) {
CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(METAX_BLOCK_SIZE_4096)
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::kv_caching::metax
...@@ -2,15 +2,12 @@ ...@@ -2,15 +2,12 @@
#include "../../handle.h" #include "../../handle.h"
#include "infiniop/ops/kv_caching.h" #include "infiniop/ops/kv_caching.h"
#if defined(ENABLE_NINETOOTHED) #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_ALI_API) || defined(ENABLE_HYGON_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_METAX_API) || defined(ENABLE_MOORE_API)
#include "ninetoothed/kv_caching.h"
#endif
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API)
#include "nvidia/kv_caching_nvidia.cuh" #include "nvidia/kv_caching_nvidia.cuh"
#endif #endif
#if defined(ENABLE_METAX_API)
#include "metax/kv_caching_metax.h"
#endif
__C infiniStatus_t infiniopCreateKVCachingDescriptor( __C infiniStatus_t infiniopCreateKVCachingDescriptor(
infiniopHandle_t handle, infiniopHandle_t handle,
...@@ -34,24 +31,24 @@ __C infiniStatus_t infiniopCreateKVCachingDescriptor( ...@@ -34,24 +31,24 @@ __C infiniStatus_t infiniopCreateKVCachingDescriptor(
switch (handle->device) { switch (handle->device) {
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
CREATE(INFINI_DEVICE_NVIDIA, ninetoothed);
#endif
#if defined(ENABLE_ILUVATAR_API)
CREATE(INFINI_DEVICE_ILUVATAR, ninetoothed);
#endif
#if defined(ENABLE_METAX_API)
CREATE(INFINI_DEVICE_METAX, ninetoothed);
#endif
#endif
#ifdef ENABLE_NVIDIA_API #ifdef ENABLE_NVIDIA_API
CREATE(INFINI_DEVICE_NVIDIA, nvidia); CREATE(INFINI_DEVICE_NVIDIA, nvidia);
#endif #endif
#ifdef ENABLE_QY_API #ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia); CREATE(INFINI_DEVICE_QY, nvidia);
#endif #endif
#ifdef ENABLE_ILUVATAR_API
CREATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_ALI_API
CREATE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_HYGON_API
CREATE(INFINI_DEVICE_HYGON, nvidia);
#endif
#if defined(ENABLE_METAX_API)
CREATE(INFINI_DEVICE_METAX, metax);
#endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
...@@ -72,24 +69,25 @@ __C infiniStatus_t infiniopGetKVCachingWorkspaceSize( ...@@ -72,24 +69,25 @@ __C infiniStatus_t infiniopGetKVCachingWorkspaceSize(
switch (desc->device_type) { switch (desc->device_type) {
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
GET_SIZE(INFINI_DEVICE_NVIDIA, ninetoothed);
#endif
#if defined(ENABLE_ILUVATAR_API)
GET_SIZE(INFINI_DEVICE_ILUVATAR, ninetoothed);
#endif
#if defined(ENABLE_METAX_API)
GET_SIZE(INFINI_DEVICE_METAX, ninetoothed);
#endif
#endif
#ifdef ENABLE_NVIDIA_API #ifdef ENABLE_NVIDIA_API
GET_SIZE(INFINI_DEVICE_NVIDIA, nvidia); GET_SIZE(INFINI_DEVICE_NVIDIA, nvidia);
#endif #endif
#ifdef ENABLE_QY_API #ifdef ENABLE_QY_API
GET_SIZE(INFINI_DEVICE_QY, nvidia); GET_SIZE(INFINI_DEVICE_QY, nvidia);
#endif #endif
#ifdef ENABLE_ILUVATAR_API
GET_SIZE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_ALI_API
GET_SIZE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_HYGON_API
GET_SIZE(INFINI_DEVICE_HYGON, nvidia);
#endif
#if defined(ENABLE_METAX_API)
GET_SIZE(INFINI_DEVICE_METAX, metax);
#endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
...@@ -115,24 +113,25 @@ __C infiniStatus_t infiniopKVCaching( ...@@ -115,24 +113,25 @@ __C infiniStatus_t infiniopKVCaching(
switch (desc->device_type) { switch (desc->device_type) {
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
CALCULATE(INFINI_DEVICE_NVIDIA, ninetoothed);
#endif
#if defined(ENABLE_ILUVATAR_API)
CALCULATE(INFINI_DEVICE_ILUVATAR, ninetoothed);
#endif
#if defined(ENABLE_METAX_API)
CALCULATE(INFINI_DEVICE_METAX, ninetoothed);
#endif
#endif
#ifdef ENABLE_NVIDIA_API #ifdef ENABLE_NVIDIA_API
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); CALCULATE(INFINI_DEVICE_NVIDIA, nvidia);
#endif #endif
#ifdef ENABLE_QY_API #ifdef ENABLE_QY_API
CALCULATE(INFINI_DEVICE_QY, nvidia); CALCULATE(INFINI_DEVICE_QY, nvidia);
#endif #endif
#ifdef ENABLE_ILUVATAR_API
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_ALI_API
CALCULATE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_HYGON_API
CALCULATE(INFINI_DEVICE_HYGON, nvidia);
#endif
#if defined(ENABLE_METAX_API)
CALCULATE(INFINI_DEVICE_METAX, metax);
#endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
...@@ -150,26 +149,28 @@ __C infiniStatus_t infiniopDestroyKVCachingDescriptor( ...@@ -150,26 +149,28 @@ __C infiniStatus_t infiniopDestroyKVCachingDescriptor(
switch (desc->device_type) { switch (desc->device_type) {
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
DELETE(INFINI_DEVICE_NVIDIA, ninetoothed);
#endif
#if defined(ENABLE_ILUVATAR_API)
DELETE(INFINI_DEVICE_ILUVATAR, ninetoothed);
#endif
#if defined(ENABLE_METAX_API)
DELETE(INFINI_DEVICE_METAX, ninetoothed);
#endif
#endif
#ifdef ENABLE_NVIDIA_API #ifdef ENABLE_NVIDIA_API
DELETE(INFINI_DEVICE_NVIDIA, nvidia); DELETE(INFINI_DEVICE_NVIDIA, nvidia);
#endif #endif
#ifdef ENABLE_QY_API #ifdef ENABLE_QY_API
DELETE(INFINI_DEVICE_QY, nvidia); DELETE(INFINI_DEVICE_QY, nvidia);
#endif #endif
#ifdef ENABLE_ILUVATAR_API
DELETE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_ALI_API
DELETE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_HYGON_API
DELETE(INFINI_DEVICE_HYGON, nvidia);
#endif
#if defined(ENABLE_METAX_API)
DELETE(INFINI_DEVICE_METAX, metax);
#endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
#undef DELETE #undef DELETE
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment