Unverified Commit 422a4e1e authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #916 from InfiniTensor/issue/867-2

issue/867 fix page caching api, paged attn support more head dims
parents 01a4a0c8 96551cb7
......@@ -8,10 +8,10 @@ namespace infinicore::op {
class PagedCaching {
public:
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor);
static void execute(Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, Tensor slot_mapping);
static void execute(Tensor k_cache, Tensor v_cache, Tensor k, Tensor v, Tensor slot_mapping);
static common::OpDispatcher<schema> &dispatcher();
};
void paged_caching_(Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, Tensor slot_mapping);
void paged_caching_(Tensor k_cache, Tensor v_cache, Tensor k, Tensor v, Tensor slot_mapping);
} // namespace infinicore::op
......@@ -14,20 +14,20 @@ typedef struct InfiniopDescriptor *infiniopPagedCachingDescriptor_t;
*
* @param handle The handle to the InfiniOP library context.
* @param desc_ptr A pointer to store the created descriptor.
* @param k_desc Descriptor for the source key tensor.
* @param v_desc Descriptor for the source value tensor.
* @param k_cache_desc Descriptor for the key cache pool tensor.
* @param v_cache_desc Descriptor for the value cache pool tensor.
* @param k_desc Descriptor for the source key tensor.
* @param v_desc Descriptor for the source value tensor.
* @param slot_mapping_desc Descriptor for the slot mapping tensor.
* @return infiniStatus_t Status code of the operation.
*/
__C __export infiniStatus_t infiniopCreatePagedCachingDescriptor(
infiniopHandle_t handle,
infiniopPagedCachingDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t slot_mapping_desc);
/**
......@@ -46,10 +46,10 @@ __C __export infiniStatus_t infiniopGetPagedCachingWorkspaceSize(
* @param desc The Paged Caching descriptor.
* @param workspace Pointer to the workspace memory.
* @param workspace_size The size of the workspace.
* @param k Pointer to the source key tensor data.
* @param v Pointer to the source value tensor data.
* @param k_cache Pointer to the key cache pool data.
* @param v_cache Pointer to the value cache pool data.
* @param k Pointer to the source key tensor data.
* @param v Pointer to the source value tensor data.
* @param slot_mapping Pointer to the slot mapping data.
* @param stream The CUDA stream for the operation. Can be NULL.
* @return infiniStatus_t Status code of the operation.
......@@ -58,10 +58,10 @@ __C __export infiniStatus_t infiniopPagedCaching(
infiniopPagedCachingDescriptor_t desc,
void *workspace,
size_t workspace_size,
const void *k,
const void *v,
void *k_cache,
void *v_cache,
const void *k,
const void *v,
const void *slot_mapping,
void *stream);
......
......@@ -3,18 +3,18 @@ from infinicore.tensor import Tensor
def paged_caching(
k: Tensor,
v: Tensor,
k_cache: Tensor,
v_cache: Tensor,
k: Tensor,
v: Tensor,
slot_mapping: Tensor,
):
Tensor(
_infinicore.paged_caching_(
k._underlying,
v._underlying,
k_cache._underlying,
v_cache._underlying,
k._underlying,
v._underlying,
slot_mapping._underlying,
)
)
......
......@@ -9,14 +9,14 @@ common::OpDispatcher<PagedCaching::schema> &PagedCaching::dispatcher() {
return dispatcher_;
};
void PagedCaching::execute(Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, Tensor slot_mapping) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(k, v, k_cache, v_cache, slot_mapping);
infinicore::context::setDevice(k->device());
dispatcher().lookup(k->device().getType())(k, v, k_cache, v_cache, slot_mapping);
void PagedCaching::execute(Tensor k_cache, Tensor v_cache, Tensor k, Tensor v, Tensor slot_mapping) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(k_cache, v_cache, k, v, slot_mapping);
infinicore::context::setDevice(k_cache->device());
dispatcher().lookup(k_cache->device().getType())(k_cache, v_cache, k, v, slot_mapping);
}
void paged_caching_(Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, Tensor slot_mapping) {
PagedCaching::execute(k, v, k_cache, v_cache, slot_mapping);
void paged_caching_(Tensor k_cache, Tensor v_cache, Tensor k, Tensor v, Tensor slot_mapping) {
PagedCaching::execute(k_cache, v_cache, k, v, slot_mapping);
}
} // namespace infinicore::op
......@@ -15,8 +15,8 @@ thread_local common::OpCache<size_t, infiniopPagedCachingDescriptor_t> caches(
}
});
void calculate(Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, Tensor slot_mapping) {
size_t seed = hash_combine(k, v, k_cache, v_cache, slot_mapping);
void calculate(Tensor k_cache, Tensor v_cache, Tensor k, Tensor v, Tensor slot_mapping) {
size_t seed = hash_combine(k_cache, v_cache, k, v, slot_mapping);
auto device = context::getDevice();
auto &cache = caches.getCache(device);
......@@ -27,7 +27,7 @@ void calculate(Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, Tensor slot_m
if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreatePagedCachingDescriptor(
context::getInfiniopHandle(device), &desc,
k->desc(), v->desc(), k_cache->desc(), v_cache->desc(), slot_mapping->desc()));
k_cache->desc(), v_cache->desc(), k->desc(), v->desc(), slot_mapping->desc()));
cache.put(seed, desc);
} else {
desc = *desc_opt;
......@@ -39,7 +39,7 @@ void calculate(Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, Tensor slot_m
INFINICORE_CHECK_ERROR(infiniopPagedCaching(
desc, workspace->data(), workspace_size,
k->data(), v->data(), k_cache->data(), v_cache->data(), slot_mapping->data(), context::getStream()));
k_cache->data(), v_cache->data(), k->data(), v->data(), slot_mapping->data(), context::getStream()));
}
static bool registered = []() {
......
......@@ -11,10 +11,10 @@ namespace infinicore::ops {
inline void bind_paged_caching(py::module &m) {
m.def("paged_caching_",
&op::paged_caching_,
py::arg("k"),
py::arg("v"),
py::arg("k_cache"),
py::arg("v_cache"),
py::arg("k"),
py::arg("v"),
py::arg("slot_mapping"),
R"doc(Paged caching of key and value tensors.)doc");
}
......
......@@ -67,11 +67,9 @@ public:
size_t num_heads = q_shape[1];
size_t head_size = q_shape[2];
if (head_size != 128) {
// 输出具体的错误原因和当前的参数值
std::cerr << "[Error] Now only supports head_size = 128, but got "
if (head_size != 16 && head_size != 32 && head_size != 64 && head_size != 128 && head_size != 256) {
std::cerr << "[Error] Now only supports head_size = 16/32/64/128/256, but got "
<< head_size << "." << std::endl;
// 建议返回 SHAPE 相关的错误码
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
......
......@@ -98,37 +98,49 @@ infiniStatus_t Descriptor::calculate(
const void *block_tables, const void *seq_lens, const void *alibi_slopes,
void *stream_) const {
cudaStream_t stream = (cudaStream_t)stream_;
if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) {
if (_info.head_size == 128) {
launchKernel<128, CUDA_BLOCK_SIZE_1024>(
out, q, k_cache, v_cache, _info.dtype, block_tables, seq_lens, alibi_slopes,
_info.num_heads, _info.num_seqs,
_info.num_kv_heads, _info.scale, _info.max_num_blocks_per_seq, _info.block_size,
_info.q_stride, _info.kv_block_stride, _info.kv_head_stride, _info.o_stride,
#define LAUNCH_HEADSIZE_BLOCKSIZE(__H_SIZE, __B_SIZE) \
launchKernel<__H_SIZE, __B_SIZE>( \
out, q, k_cache, v_cache, _info.dtype, block_tables, seq_lens, alibi_slopes, \
_info.num_heads, _info.num_seqs, \
_info.num_kv_heads, _info.scale, _info.max_num_blocks_per_seq, _info.block_size, \
_info.q_stride, _info.kv_block_stride, _info.kv_head_stride, _info.o_stride, \
stream);
#define SWITCH_HEAD_SIZE(__B_SIZE) \
switch (_info.head_size) { \
case 16: \
LAUNCH_HEADSIZE_BLOCKSIZE(16, __B_SIZE) \
break; \
case 32: \
LAUNCH_HEADSIZE_BLOCKSIZE(32, __B_SIZE) \
break; \
case 64: \
LAUNCH_HEADSIZE_BLOCKSIZE(64, __B_SIZE) \
break; \
case 128: \
LAUNCH_HEADSIZE_BLOCKSIZE(128, __B_SIZE) \
break; \
case 256: \
LAUNCH_HEADSIZE_BLOCKSIZE(256, __B_SIZE) \
break; \
default: \
return INFINI_STATUS_BAD_TENSOR_SHAPE; \
}
if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) {
SWITCH_HEAD_SIZE(CUDA_BLOCK_SIZE_1024)
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_512) {
if (_info.head_size == 128) {
launchKernel<128, CUDA_BLOCK_SIZE_512>(
out, q, k_cache, v_cache, _info.dtype, block_tables, seq_lens, alibi_slopes,
_info.num_heads, _info.num_seqs,
_info.num_kv_heads, _info.scale, _info.max_num_blocks_per_seq, _info.block_size,
_info.q_stride, _info.kv_block_stride, _info.kv_head_stride, _info.o_stride,
stream);
}
SWITCH_HEAD_SIZE(CUDA_BLOCK_SIZE_512)
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) {
if (_info.head_size == 128) {
launchKernel<128, CUDA_BLOCK_SIZE_4096>(
out, q, k_cache, v_cache, _info.dtype, block_tables, seq_lens, alibi_slopes,
_info.num_heads, _info.num_seqs,
_info.num_kv_heads, _info.scale, _info.max_num_blocks_per_seq, _info.block_size,
_info.q_stride, _info.kv_block_stride, _info.kv_head_stride, _info.o_stride,
stream);
}
SWITCH_HEAD_SIZE(CUDA_BLOCK_SIZE_4096)
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
#undef LAUNCH_HEADSIZE_BLOCKSIZE
#undef SWITCH_HEAD_SIZE
return INFINI_STATUS_SUCCESS;
}
......
......@@ -5,9 +5,9 @@
#ifdef ENABLE_NVIDIA_API
#include "nvidia/paged_attention_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
#include "metax/paged_attention_metax.h"
#endif
// #ifdef ENABLE_METAX_API
// #include "metax/paged_attention_metax.h"
// #endif
__C infiniStatus_t infiniopCreatePagedAttentionDescriptor(
infiniopHandle_t handle,
......@@ -34,11 +34,12 @@ __C infiniStatus_t infiniopCreatePagedAttentionDescriptor(
#ifdef ENABLE_NVIDIA_API
CREATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax)
#endif
}
// #ifdef ENABLE_METAX_API
// CREATE(INFINI_DEVICE_METAX, metax)
// #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
}
__C infiniStatus_t infiniopGetPagedAttentionWorkspaceSize(
......@@ -54,11 +55,12 @@ __C infiniStatus_t infiniopGetPagedAttentionWorkspaceSize(
#ifdef ENABLE_NVIDIA_API
GET(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax)
#endif
}
// #ifdef ENABLE_METAX_API
// GET(INFINI_DEVICE_METAX, metax)
// #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
}
__C infiniStatus_t infiniopPagedAttention(
......@@ -78,11 +80,12 @@ __C infiniStatus_t infiniopPagedAttention(
#ifdef ENABLE_NVIDIA_API
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax)
#endif
}
// #ifdef ENABLE_METAX_API
// CALCULATE(INFINI_DEVICE_METAX, metax)
// #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
}
__C infiniStatus_t infiniopDestroyPagedAttentionDescriptor(
......@@ -97,9 +100,10 @@ __C infiniStatus_t infiniopDestroyPagedAttentionDescriptor(
#ifdef ENABLE_NVIDIA_API
DESTROY(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_METAX_API
DESTROY(INFINI_DEVICE_METAX, metax)
#endif
}
// #ifdef ENABLE_METAX_API
// DESTROY(INFINI_DEVICE_METAX, metax)
// #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
}
......@@ -28,10 +28,10 @@ public:
ptrdiff_t v_cache_block_stride;
static utils::Result<PagedCachingInfo> create(
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t slot_mapping_desc) {
auto dtype = k_desc->dtype();
......
......@@ -31,13 +31,13 @@ Descriptor::~Descriptor() {
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t slot_mapping_desc) {
auto info = PagedCachingInfo::create(k_desc, v_desc, k_cache_desc, v_cache_desc, slot_mapping_desc);
auto info = PagedCachingInfo::create(k_cache_desc, v_cache_desc, k_desc, v_desc, slot_mapping_desc);
CHECK_RESULT(info);
// Create and return the Descriptor instance.
......@@ -121,8 +121,8 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
// Execution method implementation
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
const void *k, const void *v,
void *k_cache, void *v_cache,
const void *k, const void *v,
const void *slot_mapping,
void *stream_) const {
......
......@@ -5,17 +5,17 @@
#ifdef ENABLE_NVIDIA_API
#include "nvidia/paged_caching_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
#include "metax/paged_caching_metax.h"
#endif
// #ifdef ENABLE_METAX_API
// #include "metax/paged_caching_metax.h"
// #endif
__C infiniStatus_t infiniopCreatePagedCachingDescriptor(
infiniopHandle_t handle,
infiniopPagedCachingDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t slot_mapping_desc) {
#define CREATE(CASE, NAMESPACE) \
......@@ -23,17 +23,18 @@ __C infiniStatus_t infiniopCreatePagedCachingDescriptor(
return op::paged_caching::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::paged_caching::NAMESPACE::Descriptor **>(desc_ptr), \
k_desc, v_desc, k_cache_desc, v_cache_desc, slot_mapping_desc);
k_cache_desc, v_cache_desc, k_desc, v_desc, slot_mapping_desc);
switch (handle->device) {
#ifdef ENABLE_NVIDIA_API
CREATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax)
#endif
}
// #ifdef ENABLE_METAX_API
// CREATE(INFINI_DEVICE_METAX, metax)
// #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
}
__C infiniStatus_t infiniopGetPagedCachingWorkspaceSize(
......@@ -49,35 +50,37 @@ __C infiniStatus_t infiniopGetPagedCachingWorkspaceSize(
#ifdef ENABLE_NVIDIA_API
GET(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax)
#endif
}
// #ifdef ENABLE_METAX_API
// GET(INFINI_DEVICE_METAX, metax)
// #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
}
__C infiniStatus_t infiniopPagedCaching(
infiniopPagedCachingDescriptor_t desc,
void *workspace, size_t workspace_size,
const void *k, const void *v,
void *k_cache, void *v_cache,
const void *k, const void *v,
const void *slot_mapping,
void *stream) {
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<op::paged_caching::NAMESPACE::Descriptor *>(desc)->calculate( \
workspace, workspace_size, k, v, k_cache, v_cache, slot_mapping, stream);
workspace, workspace_size, k_cache, v_cache, k, v, slot_mapping, stream);
switch (desc->device_type) {
#ifdef ENABLE_NVIDIA_API
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax)
#endif
}
// #ifdef ENABLE_METAX_API
// CALCULATE(INFINI_DEVICE_METAX, metax)
// #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
}
__C infiniStatus_t infiniopDestroyPagedCachingDescriptor(
......@@ -92,9 +95,10 @@ __C infiniStatus_t infiniopDestroyPagedCachingDescriptor(
#ifdef ENABLE_NVIDIA_API
DESTROY(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_METAX_API
DESTROY(INFINI_DEVICE_METAX, metax)
#endif
}
// #ifdef ENABLE_METAX_API
// DESTROY(INFINI_DEVICE_METAX, metax)
// #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
}
......@@ -32,16 +32,16 @@
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t k_desc, \
infiniopTensorDescriptor_t v_desc, \
infiniopTensorDescriptor_t k_cache_desc, \
infiniopTensorDescriptor_t v_cache_desc, \
infiniopTensorDescriptor_t k_desc, \
infiniopTensorDescriptor_t v_desc, \
infiniopTensorDescriptor_t slot_mapping_desc); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
const void *k, const void *v, \
void *k_cache, void *v_cache, \
const void *k, const void *v, \
const void *slot_mapping, \
void *stream) const; \
}; \
......
......@@ -25,6 +25,7 @@ _TEST_CASES_DATA = [
(4, 40, 40, 128, 16, 1024, False),
(6, 40, 40, 128, 16, 1024, False),
(3, 8, 8, 128, 16, 1024, False),
(3, 8, 8, 64, 16, 1024, False),
(8, 64, 8, 128, 16, 2048, False),
]
......@@ -68,8 +69,6 @@ def parse_test_cases():
0, num_seqs * max_blocks_per_seq, dtype=torch.int64
).view(num_seqs, max_blocks_per_seq)
print("block_tables.shape", block_tables.shape, block_tables)
q_shape = (num_seqs, num_heads, head_size)
k_cache_shape = (num_blocks, num_kv_heads, block_size, head_size)
v_cache_shape = (num_blocks, num_kv_heads, block_size, head_size)
......
......@@ -28,9 +28,9 @@ _TEST_CASES_DATA = [
# Tolerance configuration
_TOLERANCE_MAP = {
infinicore.float16: {"atol": 0, "rtol": 1e-2},
infinicore.float32: {"atol": 1e-4, "rtol": 1e-3},
infinicore.bfloat16: {"atol": 0, "rtol": 5e-2},
infinicore.float16: {"atol": 0, "rtol": 1e-5},
infinicore.float32: {"atol": 0, "rtol": 1e-5},
infinicore.bfloat16: {"atol": 0, "rtol": 1e-5},
}
# Data types to test
......@@ -40,15 +40,15 @@ _TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
# ==============================================================================
# Reference Implementation
# ==============================================================================
def ref_paged_caching(key, value, key_cache_pool, value_cache_pool, slot_mapping):
def ref_paged_caching(key_cache_pool, value_cache_pool, key, value, slot_mapping):
"""
Reference implementation for paged_caching operator.
Args:
key (torch.Tensor): Keys, shape [ntok, nkvh, dh]
value (torch.Tensor): Values, shape [ntok, nkvh, dh]
key_cache_pool (torch.Tensor): K cache pool, shape [num_blocks, nkvh, block_size, dh]
value_cache_pool (torch.Tensor): V cache pool, shape [num_blocks, nkvh, block_size, dh]
key (torch.Tensor): Keys, shape [ntok, nkvh, dh]
value (torch.Tensor): Values, shape [ntok, nkvh, dh]
slot_mapping (torch.Tensor): Slot mapping, shape [ntok]
"""
ntok = key.shape[0]
......@@ -56,8 +56,8 @@ def ref_paged_caching(key, value, key_cache_pool, value_cache_pool, slot_mapping
# This reference implementation operates on a cloned cache to avoid modifying the original input tensor,
# mimicking the behavior where the custom operator writes to its output tensor.
k_cache_ref = key_cache_pool.clone()
v_cache_ref = value_cache_pool.clone()
k_cache_ref = key_cache_pool
v_cache_ref = value_cache_pool
for i in range(ntok):
slot = slot_mapping[i].item()
......@@ -98,9 +98,9 @@ def parse_test_cases():
current_slot += length.item()
# Ensure we don't exceed the total number of slots in the cache
assert (
current_slot <= num_blocks * block_size
), "Not enough blocks in the cache pool for this test case"
assert current_slot <= num_blocks * block_size, (
"Not enough blocks in the cache pool for this test case"
)
slot_mapping = torch.tensor(slot_mapping_list, dtype=torch.int64)
......@@ -119,8 +119,12 @@ def parse_test_cases():
# Create typed tensor specs
k_spec = TensorSpec.from_tensor(k_shape, None, dtype)
v_spec = TensorSpec.from_tensor(v_shape, None, dtype)
k_cache_spec = TensorSpec.from_tensor(k_cache_shape, None, dtype)
v_cache_spec = TensorSpec.from_tensor(v_cache_shape, None, dtype)
k_cache_spec = TensorSpec.from_tensor(
k_cache_shape, None, dtype, init_mode=TensorInitializer.ZEROS
)
v_cache_spec = TensorSpec.from_tensor(
v_cache_shape, None, dtype, init_mode=TensorInitializer.ZEROS
)
slot_mapping_spec = TensorSpec.from_tensor(
slot_mapping_shape,
init_mode=TensorInitializer.MANUAL,
......@@ -132,10 +136,10 @@ def parse_test_cases():
test_cases.append(
TestCase(
inputs=[
k_spec,
v_spec,
k_cache_spec,
v_cache_spec,
k_spec,
v_spec,
slot_mapping_spec,
],
kwargs=None,
......
......@@ -1066,10 +1066,10 @@ def paged_caching_(lib):
lib.infiniopCreatePagedCachingDescriptor.argtypes = [
infiniopHandle_t,
POINTER(infiniopOperatorDescriptor_t),
infiniopTensorDescriptor_t, # k_desc
infiniopTensorDescriptor_t, # v_desc
infiniopTensorDescriptor_t, # k_cache_desc
infiniopTensorDescriptor_t, # v_cache_desc
infiniopTensorDescriptor_t, # k_desc
infiniopTensorDescriptor_t, # v_desc
infiniopTensorDescriptor_t, # slot_mapping_desc
]
......@@ -1086,10 +1086,10 @@ def paged_caching_(lib):
infiniopOperatorDescriptor_t,
c_void_p, # workspace
c_size_t, # workspace_size
c_void_p, # k
c_void_p, # v
c_void_p, # k_cache
c_void_p, # v_cache
c_void_p, # k
c_void_p, # v
c_void_p, # slot_mapping
c_void_p, # stream
]
......
......@@ -95,6 +95,7 @@ _TEST_CASES_ = [
(4, 40, 40, 128, 16, 1024, False),
(6, 40, 40, 128, 16, 1024, False),
(3, 8, 8, 128, 16, 1024, False),
(3, 8, 8, 64, 16, 1024, False),
(8, 64, 8, 128, 16, 2048, False),
]
......
......@@ -22,15 +22,15 @@ from libinfiniop import (
# ==============================================================================
# Reference Implementation
# ==============================================================================
def ref_paged_caching(key, value, key_cache_pool, value_cache_pool, slot_mapping):
def ref_paged_caching(key_cache_pool, value_cache_pool, key, value, slot_mapping):
"""
Reference implementation for paged_caching operator.
Args:
key (torch.Tensor): Keys, shape [ntok, nkvh, dh]
value (torch.Tensor): Values, shape [ntok, nkvh, dh]
key_cache_pool (torch.Tensor): K cache pool, shape [num_blocks, nkvh, block_size, dh]
value_cache_pool (torch.Tensor): V cache pool, shape [num_blocks, nkvh, block_size, dh]
key (torch.Tensor): Keys, shape [ntok, nkvh, dh]
value (torch.Tensor): Values, shape [ntok, nkvh, dh]
slot_mapping (torch.Tensor): Slot mapping, shape [ntok]
"""
ntok = key.shape[0]
......@@ -71,9 +71,9 @@ _TENSOR_DTYPES = [InfiniDtype.BF16, InfiniDtype.F16, InfiniDtype.F32]
# Tolerance map for different data types
_TOLERANCE_MAP = {
InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-2},
InfiniDtype.BF16: {"atol": 5e-3, "rtol": 5e-2},
InfiniDtype.F32: {"atol": 1e-5, "rtol": 1e-5},
InfiniDtype.F16: {"atol": 0, "rtol": 1e-5},
InfiniDtype.BF16: {"atol": 0, "rtol": 1e-5},
InfiniDtype.F32: {"atol": 0, "rtol": 1e-5},
}
# Global flags for controlling test behavior
......@@ -123,9 +123,9 @@ def test(
current_slot += length.item()
# Ensure we don't exceed the total number of slots in the cache
assert (
current_slot <= num_blocks * block_size
), "Not enough blocks in the cache pool for this test case"
assert current_slot <= num_blocks * block_size, (
"Not enough blocks in the cache pool for this test case"
)
slot_mapping_torch = torch.tensor(slot_mapping_list, dtype=torch.int64)
......@@ -144,10 +144,10 @@ def test(
# Run reference implementation
k_cache_ref, v_cache_ref = ref_paged_caching(
k.torch_tensor(),
v.torch_tensor(),
k_cache_pool.torch_tensor(),
v_cache_pool.torch_tensor(),
k.torch_tensor(),
v.torch_tensor(),
slot_mapping.torch_tensor(),
)
......@@ -160,10 +160,10 @@ def test(
LIBINFINIOP.infiniopCreatePagedCachingDescriptor(
handle,
ctypes.byref(descriptor),
k.descriptor,
v.descriptor,
k_cache_pool.descriptor,
v_cache_pool.descriptor,
k.descriptor,
v.descriptor,
slot_mapping.descriptor,
)
)
......@@ -191,10 +191,10 @@ def test(
descriptor,
workspace.data(),
workspace_size.value,
k.data(),
v.data(),
k_cache_pool.data(),
v_cache_pool.data(),
k.data(),
v.data(),
slot_mapping.data(),
None,
)
......
......@@ -80,7 +80,7 @@ class SimpleCacheManager:
return torch.tensor(slots, dtype=torch.int32)
def ref_paged_caching(k_new, v_new, k_pool, v_pool, slots, block_size):
def ref_paged_caching(k_pool, v_pool, k_new, v_new, slots, block_size):
"""Reference implementation for incremental caching."""
for i in range(k_new.shape[0]):
slot = slots[i].item()
......@@ -152,10 +152,10 @@ def test(
def torch_caching():
nonlocal k_pool_ref, v_pool_ref
return ref_paged_caching(
k_in.torch_tensor(),
v_in.torch_tensor(),
k_pool_ref,
v_pool_ref,
k_in.torch_tensor(),
v_in.torch_tensor(),
slots_torch,
block_size,
)
......@@ -168,10 +168,10 @@ def test(
LIBINFINIOP.infiniopCreatePagedCachingDescriptor(
handle,
ctypes.byref(descriptor),
k_in.descriptor,
v_in.descriptor,
k_cache_pool.descriptor,
v_cache_pool.descriptor,
k_in.descriptor,
v_in.descriptor,
slot_mapping.descriptor,
)
)
......@@ -190,10 +190,10 @@ def test(
descriptor,
workspace.data(),
workspace_size.value,
k_in.data(),
v_in.data(),
k_cache_pool.data(),
v_cache_pool.data(),
k_in.data(),
v_in.data(),
slot_mapping.data(),
None,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment