Commit c70805c9 authored by xgqdut2016's avatar xgqdut2016 Committed by wooway777
Browse files

issue/1035: kv caching on nvidia

parent abd45713
#ifndef __KV_CACHING_KERNEL_CUH__
#define __KV_CACHING_KERNEL_CUH__
template <typename Tdata>
__device__ void kvCachingKernel(
Tdata *__restrict__ k_cache,
Tdata *__restrict__ v_cache,
const Tdata *__restrict__ k,
const Tdata *__restrict__ v,
const int64_t *__restrict__ past_kv_lengths,
int batch_size,
int num_kv_heads,
int max_seq_len,
int seq_len,
int hidden_dim,
ptrdiff_t k_cache_strides_0,
ptrdiff_t k_cache_strides_1,
ptrdiff_t k_cache_strides_2,
ptrdiff_t k_cache_strides_3,
ptrdiff_t v_cache_strides_0,
ptrdiff_t v_cache_strides_1,
ptrdiff_t v_cache_strides_2,
ptrdiff_t v_cache_strides_3,
ptrdiff_t k_strides_0,
ptrdiff_t k_strides_1,
ptrdiff_t k_strides_2,
ptrdiff_t k_strides_3,
ptrdiff_t v_strides_0,
ptrdiff_t v_strides_1,
ptrdiff_t v_strides_2,
ptrdiff_t v_strides_3) {
// 总元素数 = B * H * seq_len * D
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int total = batch_size * num_kv_heads * seq_len * hidden_dim;
const int grid_size = blockDim.x * gridDim.x;
for (int idx = tid; idx < total; idx += grid_size) {
// 反解 index
int d = idx % hidden_dim;
idx /= hidden_dim;
int s = idx % seq_len;
idx /= seq_len;
int h = idx % num_kv_heads;
int b = idx / num_kv_heads;
int past_len = static_cast<int32_t>(past_kv_lengths[b]);
// 写入位置
int cache_s = past_len + s;
int k_cache_offset = d * (int)k_cache_strides_3 + cache_s * (int)k_cache_strides_2 + h * (int)k_cache_strides_1 + b * (int)k_cache_strides_0;
int v_cache_offset = d * (int)v_cache_strides_3 + cache_s * (int)v_cache_strides_2 + h * (int)v_cache_strides_1 + b * (int)v_cache_strides_0;
int k_src_offset = d * (int)k_strides_3 + s * (int)k_strides_2 + h * (int)k_strides_1 + b * (int)k_strides_0;
int v_src_offset = d * (int)v_strides_3 + s * (int)v_strides_2 + h * (int)v_strides_1 + b * (int)v_strides_0;
k_cache[k_cache_offset] = k[k_src_offset];
v_cache[v_cache_offset] = v[v_src_offset];
}
}
#endif // __KV_CACHING_KERNEL_CUH__
#ifndef __KV_CACHING_INFO_H__
#define __KV_CACHING_INFO_H__
#include "../../../utils.h"
#include "../../operator.h"
#include "../../tensor.h"
namespace op::kv_caching {
class KVCachingInfo {
private:
KVCachingInfo() = default;
public:
infiniDtype_t dtype;
size_t batch_size, num_kv_heads, max_seq_len, seq_len, hidden_dim;
ptrdiff_t k_cache_strides_0, k_cache_strides_1, k_cache_strides_2, k_cache_strides_3;
ptrdiff_t v_cache_strides_0, v_cache_strides_1, v_cache_strides_2, v_cache_strides_3;
ptrdiff_t k_strides_0, k_strides_1, k_strides_2, k_strides_3;
ptrdiff_t v_strides_0, v_strides_1, v_strides_2, v_strides_3;
static utils::Result<KVCachingInfo> createKVCachingInfo(
infiniopTensorDescriptor_t k_cache,
infiniopTensorDescriptor_t v_cache,
infiniopTensorDescriptor_t k,
infiniopTensorDescriptor_t v,
infiniopTensorDescriptor_t past_kv_lengths) {
CHECK_OR_RETURN(
k_cache != nullptr && v_cache != nullptr && k != nullptr && v != nullptr && past_kv_lengths != nullptr,
INFINI_STATUS_NULL_POINTER);
const infiniDtype_t dtype = k_cache->dtype();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32);
CHECK_OR_RETURN(k_cache->ndim() == 4
&& v_cache->ndim() == 4
&& k->ndim() == 4
&& v->ndim() == 4,
INFINI_STATUS_BAD_TENSOR_SHAPE);
auto shape = k_cache->shape();
CHECK_SAME_SHAPE(shape, v_cache->shape());
CHECK_SAME_SHAPE(k->shape(), v->shape());
size_t batch_size = shape[0];
size_t num_kv_heads = shape[1];
size_t max_seq_len = shape[2];
size_t hidden_dim = shape[3];
size_t seq_len = k->shape()[2];
CHECK_OR_RETURN(batch_size == k->dim(0)
|| num_kv_heads == k->dim(1)
|| hidden_dim == k->dim(3),
INFINI_STATUS_BAD_TENSOR_SHAPE);
ptrdiff_t k_cache_strides_0 = k_cache->strides()[0];
ptrdiff_t k_cache_strides_1 = k_cache->strides()[1];
ptrdiff_t k_cache_strides_2 = k_cache->strides()[2];
ptrdiff_t k_cache_strides_3 = k_cache->strides()[3];
ptrdiff_t v_cache_strides_0 = v_cache->strides()[0];
ptrdiff_t v_cache_strides_1 = v_cache->strides()[1];
ptrdiff_t v_cache_strides_2 = v_cache->strides()[2];
ptrdiff_t v_cache_strides_3 = v_cache->strides()[3];
ptrdiff_t k_strides_0 = k->strides()[0];
ptrdiff_t k_strides_1 = k->strides()[1];
ptrdiff_t k_strides_2 = k->strides()[2];
ptrdiff_t k_strides_3 = k->strides()[3];
ptrdiff_t v_strides_0 = v->strides()[0];
ptrdiff_t v_strides_1 = v->strides()[1];
ptrdiff_t v_strides_2 = v->strides()[2];
ptrdiff_t v_strides_3 = v->strides()[3];
return utils::Result<KVCachingInfo>(KVCachingInfo{
dtype,
batch_size,
num_kv_heads,
max_seq_len,
seq_len,
hidden_dim,
k_cache_strides_0,
k_cache_strides_1,
k_cache_strides_2,
k_cache_strides_3,
v_cache_strides_0,
v_cache_strides_1,
v_cache_strides_2,
v_cache_strides_3,
k_strides_0,
k_strides_1,
k_strides_2,
k_strides_3,
v_strides_0,
v_strides_1,
v_strides_2,
v_strides_3});
}
};
} // namespace op::kv_caching
#endif // __KV_CACHING_INFO_H__
#ifndef KV_CACHING_H
#define KV_CACHING_H
#include "../../operator.h"
#include "info.h"
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::kv_caching::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
KVCachingInfo _info; \
size_t _workspace_size; \
\
Descriptor( \
Opaque *opaque, \
KVCachingInfo info, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t get_workspace_size() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t k_cache, \
infiniopTensorDescriptor_t v_cache, \
infiniopTensorDescriptor_t k, \
infiniopTensorDescriptor_t v, \
infiniopTensorDescriptor_t past_kv_lengths); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *k_cache, void *v_cache, \
const void *k, const void *v, const void *past_kv_lengths, \
void *stream) const; \
}; \
}
#endif // KV_CACHING_H
#include "../../../devices/nvidia/nvidia_common.cuh"
#include "kv_caching_nvidia.cuh"
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
#include <cub/block/block_reduce.cuh>
#include "../../../reduce/cuda/reduce.cuh"
#include "../cuda/kernel.cuh"
template <typename Tdata>
INFINIOP_CUDA_KERNEL kvCaching(
Tdata *k_cache,
Tdata *v_cache,
const Tdata *k,
const Tdata *v,
const int64_t *past_kv_lengths,
int batch_size,
int num_kv_heads,
int max_seq_len,
int seq_len,
int hidden_dim,
ptrdiff_t k_cache_strides_0,
ptrdiff_t k_cache_strides_1,
ptrdiff_t k_cache_strides_2,
ptrdiff_t k_cache_strides_3,
ptrdiff_t v_cache_strides_0,
ptrdiff_t v_cache_strides_1,
ptrdiff_t v_cache_strides_2,
ptrdiff_t v_cache_strides_3,
ptrdiff_t k_strides_0,
ptrdiff_t k_strides_1,
ptrdiff_t k_strides_2,
ptrdiff_t k_strides_3,
ptrdiff_t v_strides_0,
ptrdiff_t v_strides_1,
ptrdiff_t v_strides_2,
ptrdiff_t v_strides_3) {
kvCachingKernel<Tdata>(k_cache, v_cache, k, v, past_kv_lengths,
batch_size, num_kv_heads, max_seq_len, seq_len, hidden_dim,
k_cache_strides_0, k_cache_strides_1, k_cache_strides_2, k_cache_strides_3,
v_cache_strides_0, v_cache_strides_1, v_cache_strides_2, v_cache_strides_3,
k_strides_0, k_strides_1, k_strides_2, k_strides_3,
v_strides_0, v_strides_1, v_strides_2, v_strides_3);
}
namespace op::kv_caching::nvidia {
struct Descriptor::Opaque {
std::shared_ptr<device::nvidia::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t k_cache,
infiniopTensorDescriptor_t v_cache,
infiniopTensorDescriptor_t k,
infiniopTensorDescriptor_t v,
infiniopTensorDescriptor_t past_kv_lengths) {
auto info = KVCachingInfo::createKVCachingInfo(k_cache, v_cache, k, v, past_kv_lengths);
CHECK_RESULT(info);
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::nvidia::Handle *>(handle)->internal()},
info.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <unsigned int BLOCK_SIZE, typename Tdata>
infiniStatus_t launchKernel(const KVCachingInfo &info,
Tdata *k_cache,
Tdata *v_cache,
const Tdata *k,
const Tdata *v,
const int64_t *past_kv_lengths,
cudaStream_t stream, void *workspace) {
int batch_size = static_cast<int>(info.batch_size);
int num_kv_heads = static_cast<int>(info.num_kv_heads);
int max_seq_len = static_cast<int>(info.max_seq_len);
int hidden_dim = static_cast<int>(info.hidden_dim);
int seq_len = static_cast<int>(info.seq_len);
int total = batch_size * num_kv_heads * seq_len * hidden_dim;
ptrdiff_t k_cache_strides_0 = info.k_cache_strides_0;
ptrdiff_t k_cache_strides_1 = info.k_cache_strides_1;
ptrdiff_t k_cache_strides_2 = info.k_cache_strides_2;
ptrdiff_t k_cache_strides_3 = info.k_cache_strides_3;
ptrdiff_t v_cache_strides_0 = info.v_cache_strides_0;
ptrdiff_t v_cache_strides_1 = info.v_cache_strides_1;
ptrdiff_t v_cache_strides_2 = info.v_cache_strides_2;
ptrdiff_t v_cache_strides_3 = info.v_cache_strides_3;
ptrdiff_t k_strides_0 = info.k_strides_0;
ptrdiff_t k_strides_1 = info.k_strides_1;
ptrdiff_t k_strides_2 = info.k_strides_2;
ptrdiff_t k_strides_3 = info.k_strides_3;
ptrdiff_t v_strides_0 = info.v_strides_0;
ptrdiff_t v_strides_1 = info.v_strides_1;
ptrdiff_t v_strides_2 = info.v_strides_2;
ptrdiff_t v_strides_3 = info.v_strides_3;
int num_blocks = (total + BLOCK_SIZE - 1) / BLOCK_SIZE;
kvCaching<Tdata>
<<<num_blocks, BLOCK_SIZE, 0, stream>>>(k_cache, v_cache, k, v, past_kv_lengths,
batch_size, num_kv_heads, max_seq_len, seq_len, hidden_dim,
k_cache_strides_0, k_cache_strides_1, k_cache_strides_2, k_cache_strides_3,
v_cache_strides_0, v_cache_strides_1, v_cache_strides_2, v_cache_strides_3,
k_strides_0, k_strides_1, k_strides_2, k_strides_3,
v_strides_0, v_strides_1, v_strides_2, v_strides_3);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
void *k_cache,
void *v_cache,
const void *k,
const void *v,
const void *past_kv_lengths,
void *stream_) const {
cudaStream_t stream = (cudaStream_t)stream_;
#define CALCULATE_KV_CACHING(BLOCK_SIZE, TDATA) \
launchKernel<BLOCK_SIZE, TDATA>(_info, (TDATA *)k_cache, (TDATA *)v_cache, (const TDATA *)k, (const TDATA *)v, (const int64_t *)past_kv_lengths, stream, workspace)
#define CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(BLOCK_SIZE) \
{ \
if (_info.dtype == INFINI_DTYPE_F16) \
return CALCULATE_KV_CACHING(BLOCK_SIZE, half); \
else if (_info.dtype == INFINI_DTYPE_F32) \
return CALCULATE_KV_CACHING(BLOCK_SIZE, float); \
else if (_info.dtype == INFINI_DTYPE_BF16) \
return CALCULATE_KV_CACHING(BLOCK_SIZE, __nv_bfloat16); \
else \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
}
if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) {
CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_1024)
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_512) {
CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_512)
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_2048) {
CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_2048)
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) {
CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_4096)
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::kv_caching::nvidia
#ifndef __KV_CACHING_NVIDIA_API_H__
#define __KV_CACHING_NVIDIA_API_H__
#include "../kv_caching.h"
DESCRIPTOR(nvidia)
#endif // __KV_CACHING_NVIDIA_API_H__
...@@ -8,6 +8,10 @@ ...@@ -8,6 +8,10 @@
#endif #endif
#endif #endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API)
#include "nvidia/kv_caching_nvidia.cuh"
#endif
__C infiniStatus_t infiniopCreateKVCachingDescriptor( __C infiniStatus_t infiniopCreateKVCachingDescriptor(
infiniopHandle_t handle, infiniopHandle_t handle,
infiniopKVCachingDescriptor_t *desc_ptr, infiniopKVCachingDescriptor_t *desc_ptr,
...@@ -42,6 +46,13 @@ __C infiniStatus_t infiniopCreateKVCachingDescriptor( ...@@ -42,6 +46,13 @@ __C infiniStatus_t infiniopCreateKVCachingDescriptor(
#endif #endif
#endif #endif
#ifdef ENABLE_NVIDIA_API
CREATE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia);
#endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
...@@ -71,6 +82,13 @@ __C infiniStatus_t infiniopGetKVCachingWorkspaceSize( ...@@ -71,6 +82,13 @@ __C infiniStatus_t infiniopGetKVCachingWorkspaceSize(
#if defined(ENABLE_METAX_API) #if defined(ENABLE_METAX_API)
GET_SIZE(INFINI_DEVICE_METAX, ninetoothed); GET_SIZE(INFINI_DEVICE_METAX, ninetoothed);
#endif #endif
#endif
#ifdef ENABLE_NVIDIA_API
GET_SIZE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_QY_API
GET_SIZE(INFINI_DEVICE_QY, nvidia);
#endif #endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
...@@ -107,6 +125,13 @@ __C infiniStatus_t infiniopKVCaching( ...@@ -107,6 +125,13 @@ __C infiniStatus_t infiniopKVCaching(
#if defined(ENABLE_METAX_API) #if defined(ENABLE_METAX_API)
CALCULATE(INFINI_DEVICE_METAX, ninetoothed); CALCULATE(INFINI_DEVICE_METAX, ninetoothed);
#endif #endif
#endif
#ifdef ENABLE_NVIDIA_API
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_QY_API
CALCULATE(INFINI_DEVICE_QY, nvidia);
#endif #endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
...@@ -135,6 +160,13 @@ __C infiniStatus_t infiniopDestroyKVCachingDescriptor( ...@@ -135,6 +160,13 @@ __C infiniStatus_t infiniopDestroyKVCachingDescriptor(
#if defined(ENABLE_METAX_API) #if defined(ENABLE_METAX_API)
DELETE(INFINI_DEVICE_METAX, ninetoothed); DELETE(INFINI_DEVICE_METAX, ninetoothed);
#endif #endif
#endif
#ifdef ENABLE_NVIDIA_API
DELETE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_QY_API
DELETE(INFINI_DEVICE_QY, nvidia);
#endif #endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......
import torch
import ctypes
from ctypes import c_uint64
from libinfiniop import (
LIBINFINIOP,
TestTensor,
get_test_devices,
check_error,
test_operator,
get_args,
debug,
get_tolerance,
profile_operation,
InfiniDtype,
InfiniDtypeNames,
InfiniDeviceNames,
infiniopOperatorDescriptor_t,
TestWorkspace,
)
# ==============================================================================
# Reference Implementation
# ==============================================================================
def torch_kv_caching(k_cache, v_cache, k, v, past_kv_lengths):
#k_cache.shape=[batch_size, num_kv_heads, max_seq_len, hidden_dim]
#v_cache.shape=[batch_size, num_kv_heads, max_seq_len, hidden_dim]
#k.shape=[batch_size, num_kv_heads, seq_len, hidden_dim]
#v.shape=[batch_size, num_kv_heads, seq_len, hidden_dim]
#past_kv_lengths.shape = [batch_size]
batch_size, num_kv_heads, _, head_dim = k_cache.shape
seq_len = k.shape[2]
for b in range(batch_size):
past_len = past_kv_lengths[b].item()
for h in range(num_kv_heads):
k_cache[b, h, past_len : past_len + seq_len, :] = k[b, h, :, :]
v_cache[b, h, past_len : past_len + seq_len, :] = v[b, h, :, :]
return k_cache, v_cache
# ==============================================================================
# Test Configuration (Internal Use Only)
# ==============================================================================
_TEST_CASES_ = [
# (num_seqs, num_kv_heads, max_seq_len, hidden_dim), strides
((1, 1, 8, 1), None),
((1, 8, 32, 32), None),
((8, 8, 64, 32), None),
((1, 32, 8, 64), (32768, 1024, 64, 1)),
((4, 8, 32, 16), (65536, 8192, 256, 16)),
((8, 16, 64, 128), (8388608, 524288, 8192, 1)),
((1, 2, 2304, 128), (589824, 294912, 128, 1)),
]
# Data types for testing
_TENSOR_DTYPES = [InfiniDtype.BF16, InfiniDtype.F16, InfiniDtype.F32]
# Tolerance map for different data types
_TOLERANCE_MAP = {
InfiniDtype.F16: {"atol": 0, "rtol": 0},
InfiniDtype.BF16: {"atol": 0, "rtol": 0},
InfiniDtype.F32: {"atol": 0, "rtol": 0},
}
# Global flags for controlling test behavior
DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 100
def test(
handle,
device,
cache_shape,
strides,
dtype=InfiniDtype.F16,
sync=None,
):
print(
f"Testing KVCaching on {InfiniDeviceNames[device]} with cache_shape:{cache_shape}, strides:{strides}, dtype={InfiniDtypeNames[dtype]}"
)
import random
kv_shape = (
cache_shape[0],
cache_shape[1],
random.randrange(1, cache_shape[2]),
cache_shape[3],
)
past_shape = (cache_shape[0],)
k_cache = TestTensor(cache_shape, strides, dtype, device)
v_cache = TestTensor(cache_shape, strides, dtype, device)
k = TestTensor(kv_shape, None, dtype, device)
v = TestTensor(kv_shape, None, dtype, device)
past_kv_lengths = TestTensor(past_shape, None, InfiniDtype.I64, device, randint_low=0, randint_high=cache_shape[2] - kv_shape[2])
# Run reference implementation
k_cache_ref, v_cache_ref = torch_kv_caching(
k_cache.torch_tensor(),
v_cache.torch_tensor(),
k.torch_tensor(),
v.torch_tensor(),
past_kv_lengths.torch_tensor())
if sync:
sync()
# Create operator descriptor
descriptor = infiniopOperatorDescriptor_t()
check_error(
LIBINFINIOP.infiniopCreateKVCachingDescriptor(
handle,
ctypes.byref(descriptor),
k_cache.descriptor,
v_cache.descriptor,
k.descriptor,
v.descriptor,
past_kv_lengths.descriptor,
)
)
# Get workspace size (likely 0 for this operator, but good practice to include)
workspace_size = c_uint64(0)
check_error(
LIBINFINIOP.infiniopGetKVCachingWorkspaceSize(
descriptor, ctypes.byref(workspace_size)
)
)
workspace = TestWorkspace(workspace_size.value, device)
# Invalidate descriptors to ensure kernel does not rely on them
k.destroy_desc()
v.destroy_desc()
k_cache.destroy_desc()
v_cache.destroy_desc()
past_kv_lengths.destroy_desc()
# Define the library call as a lambda for profiling
def lib_kv_caching():
check_error(
LIBINFINIOP.infiniopKVCaching(
descriptor,
workspace.data(),
workspace_size.value,
k_cache.data(),
v_cache.data(),
k.data(),
v.data(),
past_kv_lengths.data(),
None,
)
)
# Execute the custom operator
lib_kv_caching()
if sync:
sync()
# Verify correctness
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG:
print("Verifying K cache...")
debug(k_cache.actual_tensor(), k_cache_ref, atol=atol, rtol=rtol)
print("Verifying V cache...")
debug(v_cache.actual_tensor(), v_cache_ref, atol=atol, rtol=rtol)
assert torch.allclose(
k_cache.actual_tensor(), k_cache_ref, atol=atol, rtol=rtol
)
assert torch.allclose(
v_cache.actual_tensor(), v_cache_ref, atol=atol, rtol=rtol
)
# Profiling workflow
if PROFILE:
# fmt: off
profile_operation("PyTorch", lambda: torch_kv_caching(k_cache.torch_tensor(), v_cache.torch_tensor(), k.torch_tensor(), v.torch_tensor(), past_kv_lengths.torch_tensor()), device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lib_kv_caching, device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on
# Clean up resources
check_error(LIBINFINIOP.infiniopDestroyKVCachingDescriptor(descriptor))
if __name__ == "__main__":
args = get_args()
# Configure testing options from command line arguments
DEBUG = args.debug
PROFILE = args.profile
NUM_PRERUN = args.num_prerun
NUM_ITERATIONS = args.num_iterations
for device in get_test_devices(args):
test_operator(device, test, _TEST_CASES_, _TENSOR_DTYPES)
print("\033[92mTest passed!\033[0m")
...@@ -1054,6 +1054,48 @@ def scaled_mm_int8_(lib): ...@@ -1054,6 +1054,48 @@ def scaled_mm_int8_(lib):
] ]
@OpRegister.operator
def kv_caching_(lib):
lib.infiniopCreateKVCachingDescriptor.restype = c_int32
lib.infiniopCreateKVCachingDescriptor.argtypes = [
infiniopHandle_t,
POINTER(infiniopOperatorDescriptor_t),
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
]
lib.infiniopGetKVCachingWorkspaceSize.restype = c_int32
lib.infiniopGetKVCachingWorkspaceSize.argtypes = [
infiniopOperatorDescriptor_t,
POINTER(c_size_t),
]
lib.infiniopKVCaching.restype = c_int32
lib.infiniopKVCaching.argtypes = [
infiniopOperatorDescriptor_t,
c_void_p,
c_size_t,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
]
lib.infiniopDestroyKVCachingDescriptor.restype = c_int32
lib.infiniopDestroyKVCachingDescriptor.argtypes = [
infiniopOperatorDescriptor_t,
]
@OpRegister.operator @OpRegister.operator
def paged_attention_(lib): def paged_attention_(lib):
lib.infiniopCreatePagedAttentionDescriptor.restype = c_int32 lib.infiniopCreatePagedAttentionDescriptor.restype = c_int32
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment