"git@developer.sourcefind.cn:yangql/googletest.git" did not exist on "aa43220fe5d34f7283f3a55e11dda5e3f25e5632"
Commit 96551cb7 authored by PanZezhong's avatar PanZezhong
Browse files

issue/867 fix page caching api, paged attn support more head dims

parent 01a4a0c8
...@@ -8,10 +8,10 @@ namespace infinicore::op { ...@@ -8,10 +8,10 @@ namespace infinicore::op {
class PagedCaching { class PagedCaching {
public: public:
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor); using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor);
static void execute(Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, Tensor slot_mapping); static void execute(Tensor k_cache, Tensor v_cache, Tensor k, Tensor v, Tensor slot_mapping);
static common::OpDispatcher<schema> &dispatcher(); static common::OpDispatcher<schema> &dispatcher();
}; };
void paged_caching_(Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, Tensor slot_mapping); void paged_caching_(Tensor k_cache, Tensor v_cache, Tensor k, Tensor v, Tensor slot_mapping);
} // namespace infinicore::op } // namespace infinicore::op
...@@ -14,20 +14,20 @@ typedef struct InfiniopDescriptor *infiniopPagedCachingDescriptor_t; ...@@ -14,20 +14,20 @@ typedef struct InfiniopDescriptor *infiniopPagedCachingDescriptor_t;
* *
* @param handle The handle to the InfiniOP library context. * @param handle The handle to the InfiniOP library context.
* @param desc_ptr A pointer to store the created descriptor. * @param desc_ptr A pointer to store the created descriptor.
* @param k_desc Descriptor for the source key tensor.
* @param v_desc Descriptor for the source value tensor.
* @param k_cache_desc Descriptor for the key cache pool tensor. * @param k_cache_desc Descriptor for the key cache pool tensor.
* @param v_cache_desc Descriptor for the value cache pool tensor. * @param v_cache_desc Descriptor for the value cache pool tensor.
* @param k_desc Descriptor for the source key tensor.
* @param v_desc Descriptor for the source value tensor.
* @param slot_mapping_desc Descriptor for the slot mapping tensor. * @param slot_mapping_desc Descriptor for the slot mapping tensor.
* @return infiniStatus_t Status code of the operation. * @return infiniStatus_t Status code of the operation.
*/ */
__C __export infiniStatus_t infiniopCreatePagedCachingDescriptor( __C __export infiniStatus_t infiniopCreatePagedCachingDescriptor(
infiniopHandle_t handle, infiniopHandle_t handle,
infiniopPagedCachingDescriptor_t *desc_ptr, infiniopPagedCachingDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t k_cache_desc, infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc, infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t slot_mapping_desc); infiniopTensorDescriptor_t slot_mapping_desc);
/** /**
...@@ -46,10 +46,10 @@ __C __export infiniStatus_t infiniopGetPagedCachingWorkspaceSize( ...@@ -46,10 +46,10 @@ __C __export infiniStatus_t infiniopGetPagedCachingWorkspaceSize(
* @param desc The Paged Caching descriptor. * @param desc The Paged Caching descriptor.
* @param workspace Pointer to the workspace memory. * @param workspace Pointer to the workspace memory.
* @param workspace_size The size of the workspace. * @param workspace_size The size of the workspace.
* @param k Pointer to the source key tensor data.
* @param v Pointer to the source value tensor data.
* @param k_cache Pointer to the key cache pool data. * @param k_cache Pointer to the key cache pool data.
* @param v_cache Pointer to the value cache pool data. * @param v_cache Pointer to the value cache pool data.
* @param k Pointer to the source key tensor data.
* @param v Pointer to the source value tensor data.
* @param slot_mapping Pointer to the slot mapping data. * @param slot_mapping Pointer to the slot mapping data.
* @param stream The CUDA stream for the operation. Can be NULL. * @param stream The CUDA stream for the operation. Can be NULL.
* @return infiniStatus_t Status code of the operation. * @return infiniStatus_t Status code of the operation.
...@@ -58,10 +58,10 @@ __C __export infiniStatus_t infiniopPagedCaching( ...@@ -58,10 +58,10 @@ __C __export infiniStatus_t infiniopPagedCaching(
infiniopPagedCachingDescriptor_t desc, infiniopPagedCachingDescriptor_t desc,
void *workspace, void *workspace,
size_t workspace_size, size_t workspace_size,
const void *k,
const void *v,
void *k_cache, void *k_cache,
void *v_cache, void *v_cache,
const void *k,
const void *v,
const void *slot_mapping, const void *slot_mapping,
void *stream); void *stream);
......
...@@ -3,18 +3,18 @@ from infinicore.tensor import Tensor ...@@ -3,18 +3,18 @@ from infinicore.tensor import Tensor
def paged_caching( def paged_caching(
k: Tensor,
v: Tensor,
k_cache: Tensor, k_cache: Tensor,
v_cache: Tensor, v_cache: Tensor,
k: Tensor,
v: Tensor,
slot_mapping: Tensor, slot_mapping: Tensor,
): ):
Tensor( Tensor(
_infinicore.paged_caching_( _infinicore.paged_caching_(
k._underlying,
v._underlying,
k_cache._underlying, k_cache._underlying,
v_cache._underlying, v_cache._underlying,
k._underlying,
v._underlying,
slot_mapping._underlying, slot_mapping._underlying,
) )
) )
......
...@@ -9,14 +9,14 @@ common::OpDispatcher<PagedCaching::schema> &PagedCaching::dispatcher() { ...@@ -9,14 +9,14 @@ common::OpDispatcher<PagedCaching::schema> &PagedCaching::dispatcher() {
return dispatcher_; return dispatcher_;
}; };
void PagedCaching::execute(Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, Tensor slot_mapping) { void PagedCaching::execute(Tensor k_cache, Tensor v_cache, Tensor k, Tensor v, Tensor slot_mapping) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(k, v, k_cache, v_cache, slot_mapping); INFINICORE_ASSERT_TENSORS_SAME_DEVICE(k_cache, v_cache, k, v, slot_mapping);
infinicore::context::setDevice(k->device()); infinicore::context::setDevice(k_cache->device());
dispatcher().lookup(k->device().getType())(k, v, k_cache, v_cache, slot_mapping); dispatcher().lookup(k_cache->device().getType())(k_cache, v_cache, k, v, slot_mapping);
} }
void paged_caching_(Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, Tensor slot_mapping) { void paged_caching_(Tensor k_cache, Tensor v_cache, Tensor k, Tensor v, Tensor slot_mapping) {
PagedCaching::execute(k, v, k_cache, v_cache, slot_mapping); PagedCaching::execute(k_cache, v_cache, k, v, slot_mapping);
} }
} // namespace infinicore::op } // namespace infinicore::op
...@@ -15,8 +15,8 @@ thread_local common::OpCache<size_t, infiniopPagedCachingDescriptor_t> caches( ...@@ -15,8 +15,8 @@ thread_local common::OpCache<size_t, infiniopPagedCachingDescriptor_t> caches(
} }
}); });
void calculate(Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, Tensor slot_mapping) { void calculate(Tensor k_cache, Tensor v_cache, Tensor k, Tensor v, Tensor slot_mapping) {
size_t seed = hash_combine(k, v, k_cache, v_cache, slot_mapping); size_t seed = hash_combine(k_cache, v_cache, k, v, slot_mapping);
auto device = context::getDevice(); auto device = context::getDevice();
auto &cache = caches.getCache(device); auto &cache = caches.getCache(device);
...@@ -27,7 +27,7 @@ void calculate(Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, Tensor slot_m ...@@ -27,7 +27,7 @@ void calculate(Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, Tensor slot_m
if (!desc_opt) { if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreatePagedCachingDescriptor( INFINICORE_CHECK_ERROR(infiniopCreatePagedCachingDescriptor(
context::getInfiniopHandle(device), &desc, context::getInfiniopHandle(device), &desc,
k->desc(), v->desc(), k_cache->desc(), v_cache->desc(), slot_mapping->desc())); k_cache->desc(), v_cache->desc(), k->desc(), v->desc(), slot_mapping->desc()));
cache.put(seed, desc); cache.put(seed, desc);
} else { } else {
desc = *desc_opt; desc = *desc_opt;
...@@ -39,7 +39,7 @@ void calculate(Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, Tensor slot_m ...@@ -39,7 +39,7 @@ void calculate(Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, Tensor slot_m
INFINICORE_CHECK_ERROR(infiniopPagedCaching( INFINICORE_CHECK_ERROR(infiniopPagedCaching(
desc, workspace->data(), workspace_size, desc, workspace->data(), workspace_size,
k->data(), v->data(), k_cache->data(), v_cache->data(), slot_mapping->data(), context::getStream())); k_cache->data(), v_cache->data(), k->data(), v->data(), slot_mapping->data(), context::getStream()));
} }
static bool registered = []() { static bool registered = []() {
......
...@@ -11,10 +11,10 @@ namespace infinicore::ops { ...@@ -11,10 +11,10 @@ namespace infinicore::ops {
inline void bind_paged_caching(py::module &m) { inline void bind_paged_caching(py::module &m) {
m.def("paged_caching_", m.def("paged_caching_",
&op::paged_caching_, &op::paged_caching_,
py::arg("k"),
py::arg("v"),
py::arg("k_cache"), py::arg("k_cache"),
py::arg("v_cache"), py::arg("v_cache"),
py::arg("k"),
py::arg("v"),
py::arg("slot_mapping"), py::arg("slot_mapping"),
R"doc(Paged caching of key and value tensors.)doc"); R"doc(Paged caching of key and value tensors.)doc");
} }
......
...@@ -67,11 +67,9 @@ public: ...@@ -67,11 +67,9 @@ public:
size_t num_heads = q_shape[1]; size_t num_heads = q_shape[1];
size_t head_size = q_shape[2]; size_t head_size = q_shape[2];
if (head_size != 128) { if (head_size != 16 && head_size != 32 && head_size != 64 && head_size != 128 && head_size != 256) {
// 输出具体的错误原因和当前的参数值 std::cerr << "[Error] Now only supports head_size = 16/32/64/128/256, but got "
std::cerr << "[Error] Now only supports head_size = 128, but got "
<< head_size << "." << std::endl; << head_size << "." << std::endl;
// 建议返回 SHAPE 相关的错误码
return INFINI_STATUS_BAD_TENSOR_SHAPE; return INFINI_STATUS_BAD_TENSOR_SHAPE;
} }
......
...@@ -98,37 +98,49 @@ infiniStatus_t Descriptor::calculate( ...@@ -98,37 +98,49 @@ infiniStatus_t Descriptor::calculate(
const void *block_tables, const void *seq_lens, const void *alibi_slopes, const void *block_tables, const void *seq_lens, const void *alibi_slopes,
void *stream_) const { void *stream_) const {
cudaStream_t stream = (cudaStream_t)stream_; cudaStream_t stream = (cudaStream_t)stream_;
#define LAUNCH_HEADSIZE_BLOCKSIZE(__H_SIZE, __B_SIZE) \
launchKernel<__H_SIZE, __B_SIZE>( \
out, q, k_cache, v_cache, _info.dtype, block_tables, seq_lens, alibi_slopes, \
_info.num_heads, _info.num_seqs, \
_info.num_kv_heads, _info.scale, _info.max_num_blocks_per_seq, _info.block_size, \
_info.q_stride, _info.kv_block_stride, _info.kv_head_stride, _info.o_stride, \
stream);
#define SWITCH_HEAD_SIZE(__B_SIZE) \
switch (_info.head_size) { \
case 16: \
LAUNCH_HEADSIZE_BLOCKSIZE(16, __B_SIZE) \
break; \
case 32: \
LAUNCH_HEADSIZE_BLOCKSIZE(32, __B_SIZE) \
break; \
case 64: \
LAUNCH_HEADSIZE_BLOCKSIZE(64, __B_SIZE) \
break; \
case 128: \
LAUNCH_HEADSIZE_BLOCKSIZE(128, __B_SIZE) \
break; \
case 256: \
LAUNCH_HEADSIZE_BLOCKSIZE(256, __B_SIZE) \
break; \
default: \
return INFINI_STATUS_BAD_TENSOR_SHAPE; \
}
if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) { if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) {
if (_info.head_size == 128) { SWITCH_HEAD_SIZE(CUDA_BLOCK_SIZE_1024)
launchKernel<128, CUDA_BLOCK_SIZE_1024>(
out, q, k_cache, v_cache, _info.dtype, block_tables, seq_lens, alibi_slopes,
_info.num_heads, _info.num_seqs,
_info.num_kv_heads, _info.scale, _info.max_num_blocks_per_seq, _info.block_size,
_info.q_stride, _info.kv_block_stride, _info.kv_head_stride, _info.o_stride,
stream);
}
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_512) { } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_512) {
if (_info.head_size == 128) { SWITCH_HEAD_SIZE(CUDA_BLOCK_SIZE_512)
launchKernel<128, CUDA_BLOCK_SIZE_512>(
out, q, k_cache, v_cache, _info.dtype, block_tables, seq_lens, alibi_slopes,
_info.num_heads, _info.num_seqs,
_info.num_kv_heads, _info.scale, _info.max_num_blocks_per_seq, _info.block_size,
_info.q_stride, _info.kv_block_stride, _info.kv_head_stride, _info.o_stride,
stream);
}
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) { } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) {
if (_info.head_size == 128) { SWITCH_HEAD_SIZE(CUDA_BLOCK_SIZE_4096)
launchKernel<128, CUDA_BLOCK_SIZE_4096>(
out, q, k_cache, v_cache, _info.dtype, block_tables, seq_lens, alibi_slopes,
_info.num_heads, _info.num_seqs,
_info.num_kv_heads, _info.scale, _info.max_num_blocks_per_seq, _info.block_size,
_info.q_stride, _info.kv_block_stride, _info.kv_head_stride, _info.o_stride,
stream);
}
} else { } else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
} }
#undef LAUNCH_HEADSIZE_BLOCKSIZE
#undef SWITCH_HEAD_SIZE
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
......
...@@ -5,9 +5,9 @@ ...@@ -5,9 +5,9 @@
#ifdef ENABLE_NVIDIA_API #ifdef ENABLE_NVIDIA_API
#include "nvidia/paged_attention_nvidia.cuh" #include "nvidia/paged_attention_nvidia.cuh"
#endif #endif
#ifdef ENABLE_METAX_API // #ifdef ENABLE_METAX_API
#include "metax/paged_attention_metax.h" // #include "metax/paged_attention_metax.h"
#endif // #endif
__C infiniStatus_t infiniopCreatePagedAttentionDescriptor( __C infiniStatus_t infiniopCreatePagedAttentionDescriptor(
infiniopHandle_t handle, infiniopHandle_t handle,
...@@ -34,11 +34,12 @@ __C infiniStatus_t infiniopCreatePagedAttentionDescriptor( ...@@ -34,11 +34,12 @@ __C infiniStatus_t infiniopCreatePagedAttentionDescriptor(
#ifdef ENABLE_NVIDIA_API #ifdef ENABLE_NVIDIA_API
CREATE(INFINI_DEVICE_NVIDIA, nvidia) CREATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif #endif
#ifdef ENABLE_METAX_API // #ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax) // CREATE(INFINI_DEVICE_METAX, metax)
#endif // #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
__C infiniStatus_t infiniopGetPagedAttentionWorkspaceSize( __C infiniStatus_t infiniopGetPagedAttentionWorkspaceSize(
...@@ -54,11 +55,12 @@ __C infiniStatus_t infiniopGetPagedAttentionWorkspaceSize( ...@@ -54,11 +55,12 @@ __C infiniStatus_t infiniopGetPagedAttentionWorkspaceSize(
#ifdef ENABLE_NVIDIA_API #ifdef ENABLE_NVIDIA_API
GET(INFINI_DEVICE_NVIDIA, nvidia) GET(INFINI_DEVICE_NVIDIA, nvidia)
#endif #endif
#ifdef ENABLE_METAX_API // #ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax) // GET(INFINI_DEVICE_METAX, metax)
#endif // #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
__C infiniStatus_t infiniopPagedAttention( __C infiniStatus_t infiniopPagedAttention(
...@@ -78,11 +80,12 @@ __C infiniStatus_t infiniopPagedAttention( ...@@ -78,11 +80,12 @@ __C infiniStatus_t infiniopPagedAttention(
#ifdef ENABLE_NVIDIA_API #ifdef ENABLE_NVIDIA_API
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia) CALCULATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif #endif
#ifdef ENABLE_METAX_API // #ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax) // CALCULATE(INFINI_DEVICE_METAX, metax)
#endif // #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
__C infiniStatus_t infiniopDestroyPagedAttentionDescriptor( __C infiniStatus_t infiniopDestroyPagedAttentionDescriptor(
...@@ -97,9 +100,10 @@ __C infiniStatus_t infiniopDestroyPagedAttentionDescriptor( ...@@ -97,9 +100,10 @@ __C infiniStatus_t infiniopDestroyPagedAttentionDescriptor(
#ifdef ENABLE_NVIDIA_API #ifdef ENABLE_NVIDIA_API
DESTROY(INFINI_DEVICE_NVIDIA, nvidia) DESTROY(INFINI_DEVICE_NVIDIA, nvidia)
#endif #endif
#ifdef ENABLE_METAX_API // #ifdef ENABLE_METAX_API
DESTROY(INFINI_DEVICE_METAX, metax) // DESTROY(INFINI_DEVICE_METAX, metax)
#endif // #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
...@@ -28,10 +28,10 @@ public: ...@@ -28,10 +28,10 @@ public:
ptrdiff_t v_cache_block_stride; ptrdiff_t v_cache_block_stride;
static utils::Result<PagedCachingInfo> create( static utils::Result<PagedCachingInfo> create(
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t k_cache_desc, infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc, infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t slot_mapping_desc) { infiniopTensorDescriptor_t slot_mapping_desc) {
auto dtype = k_desc->dtype(); auto dtype = k_desc->dtype();
......
...@@ -31,13 +31,13 @@ Descriptor::~Descriptor() { ...@@ -31,13 +31,13 @@ Descriptor::~Descriptor() {
infiniStatus_t Descriptor::create( infiniStatus_t Descriptor::create(
infiniopHandle_t handle, infiniopHandle_t handle,
Descriptor **desc_ptr, Descriptor **desc_ptr,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t k_cache_desc, infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc, infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t slot_mapping_desc) { infiniopTensorDescriptor_t slot_mapping_desc) {
auto info = PagedCachingInfo::create(k_desc, v_desc, k_cache_desc, v_cache_desc, slot_mapping_desc); auto info = PagedCachingInfo::create(k_cache_desc, v_cache_desc, k_desc, v_desc, slot_mapping_desc);
CHECK_RESULT(info); CHECK_RESULT(info);
// Create and return the Descriptor instance. // Create and return the Descriptor instance.
...@@ -121,8 +121,8 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info, ...@@ -121,8 +121,8 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
// Execution method implementation // Execution method implementation
infiniStatus_t Descriptor::calculate( infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size, void *workspace, size_t workspace_size,
const void *k, const void *v,
void *k_cache, void *v_cache, void *k_cache, void *v_cache,
const void *k, const void *v,
const void *slot_mapping, const void *slot_mapping,
void *stream_) const { void *stream_) const {
......
...@@ -5,17 +5,17 @@ ...@@ -5,17 +5,17 @@
#ifdef ENABLE_NVIDIA_API #ifdef ENABLE_NVIDIA_API
#include "nvidia/paged_caching_nvidia.cuh" #include "nvidia/paged_caching_nvidia.cuh"
#endif #endif
#ifdef ENABLE_METAX_API // #ifdef ENABLE_METAX_API
#include "metax/paged_caching_metax.h" // #include "metax/paged_caching_metax.h"
#endif // #endif
__C infiniStatus_t infiniopCreatePagedCachingDescriptor( __C infiniStatus_t infiniopCreatePagedCachingDescriptor(
infiniopHandle_t handle, infiniopHandle_t handle,
infiniopPagedCachingDescriptor_t *desc_ptr, infiniopPagedCachingDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t k_cache_desc, infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc, infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t slot_mapping_desc) { infiniopTensorDescriptor_t slot_mapping_desc) {
#define CREATE(CASE, NAMESPACE) \ #define CREATE(CASE, NAMESPACE) \
...@@ -23,17 +23,18 @@ __C infiniStatus_t infiniopCreatePagedCachingDescriptor( ...@@ -23,17 +23,18 @@ __C infiniStatus_t infiniopCreatePagedCachingDescriptor(
return op::paged_caching::NAMESPACE::Descriptor::create( \ return op::paged_caching::NAMESPACE::Descriptor::create( \
handle, \ handle, \
reinterpret_cast<op::paged_caching::NAMESPACE::Descriptor **>(desc_ptr), \ reinterpret_cast<op::paged_caching::NAMESPACE::Descriptor **>(desc_ptr), \
k_desc, v_desc, k_cache_desc, v_cache_desc, slot_mapping_desc); k_cache_desc, v_cache_desc, k_desc, v_desc, slot_mapping_desc);
switch (handle->device) { switch (handle->device) {
#ifdef ENABLE_NVIDIA_API #ifdef ENABLE_NVIDIA_API
CREATE(INFINI_DEVICE_NVIDIA, nvidia) CREATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif #endif
#ifdef ENABLE_METAX_API // #ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax) // CREATE(INFINI_DEVICE_METAX, metax)
#endif // #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
__C infiniStatus_t infiniopGetPagedCachingWorkspaceSize( __C infiniStatus_t infiniopGetPagedCachingWorkspaceSize(
...@@ -49,35 +50,37 @@ __C infiniStatus_t infiniopGetPagedCachingWorkspaceSize( ...@@ -49,35 +50,37 @@ __C infiniStatus_t infiniopGetPagedCachingWorkspaceSize(
#ifdef ENABLE_NVIDIA_API #ifdef ENABLE_NVIDIA_API
GET(INFINI_DEVICE_NVIDIA, nvidia) GET(INFINI_DEVICE_NVIDIA, nvidia)
#endif #endif
#ifdef ENABLE_METAX_API // #ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax) // GET(INFINI_DEVICE_METAX, metax)
#endif // #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
__C infiniStatus_t infiniopPagedCaching( __C infiniStatus_t infiniopPagedCaching(
infiniopPagedCachingDescriptor_t desc, infiniopPagedCachingDescriptor_t desc,
void *workspace, size_t workspace_size, void *workspace, size_t workspace_size,
const void *k, const void *v,
void *k_cache, void *v_cache, void *k_cache, void *v_cache,
const void *k, const void *v,
const void *slot_mapping, const void *slot_mapping,
void *stream) { void *stream) {
#define CALCULATE(CASE, NAMESPACE) \ #define CALCULATE(CASE, NAMESPACE) \
case CASE: \ case CASE: \
return reinterpret_cast<op::paged_caching::NAMESPACE::Descriptor *>(desc)->calculate( \ return reinterpret_cast<op::paged_caching::NAMESPACE::Descriptor *>(desc)->calculate( \
workspace, workspace_size, k, v, k_cache, v_cache, slot_mapping, stream); workspace, workspace_size, k_cache, v_cache, k, v, slot_mapping, stream);
switch (desc->device_type) { switch (desc->device_type) {
#ifdef ENABLE_NVIDIA_API #ifdef ENABLE_NVIDIA_API
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia) CALCULATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif #endif
#ifdef ENABLE_METAX_API // #ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax) // CALCULATE(INFINI_DEVICE_METAX, metax)
#endif // #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
__C infiniStatus_t infiniopDestroyPagedCachingDescriptor( __C infiniStatus_t infiniopDestroyPagedCachingDescriptor(
...@@ -92,9 +95,10 @@ __C infiniStatus_t infiniopDestroyPagedCachingDescriptor( ...@@ -92,9 +95,10 @@ __C infiniStatus_t infiniopDestroyPagedCachingDescriptor(
#ifdef ENABLE_NVIDIA_API #ifdef ENABLE_NVIDIA_API
DESTROY(INFINI_DEVICE_NVIDIA, nvidia) DESTROY(INFINI_DEVICE_NVIDIA, nvidia)
#endif #endif
#ifdef ENABLE_METAX_API // #ifdef ENABLE_METAX_API
DESTROY(INFINI_DEVICE_METAX, metax) // DESTROY(INFINI_DEVICE_METAX, metax)
#endif // #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
...@@ -32,16 +32,16 @@ ...@@ -32,16 +32,16 @@
static infiniStatus_t create( \ static infiniStatus_t create( \
infiniopHandle_t handle, \ infiniopHandle_t handle, \
Descriptor **desc_ptr, \ Descriptor **desc_ptr, \
infiniopTensorDescriptor_t k_desc, \
infiniopTensorDescriptor_t v_desc, \
infiniopTensorDescriptor_t k_cache_desc, \ infiniopTensorDescriptor_t k_cache_desc, \
infiniopTensorDescriptor_t v_cache_desc, \ infiniopTensorDescriptor_t v_cache_desc, \
infiniopTensorDescriptor_t k_desc, \
infiniopTensorDescriptor_t v_desc, \
infiniopTensorDescriptor_t slot_mapping_desc); \ infiniopTensorDescriptor_t slot_mapping_desc); \
\ \
infiniStatus_t calculate( \ infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \ void *workspace, size_t workspace_size, \
const void *k, const void *v, \
void *k_cache, void *v_cache, \ void *k_cache, void *v_cache, \
const void *k, const void *v, \
const void *slot_mapping, \ const void *slot_mapping, \
void *stream) const; \ void *stream) const; \
}; \ }; \
......
...@@ -25,6 +25,7 @@ _TEST_CASES_DATA = [ ...@@ -25,6 +25,7 @@ _TEST_CASES_DATA = [
(4, 40, 40, 128, 16, 1024, False), (4, 40, 40, 128, 16, 1024, False),
(6, 40, 40, 128, 16, 1024, False), (6, 40, 40, 128, 16, 1024, False),
(3, 8, 8, 128, 16, 1024, False), (3, 8, 8, 128, 16, 1024, False),
(3, 8, 8, 64, 16, 1024, False),
(8, 64, 8, 128, 16, 2048, False), (8, 64, 8, 128, 16, 2048, False),
] ]
...@@ -68,8 +69,6 @@ def parse_test_cases(): ...@@ -68,8 +69,6 @@ def parse_test_cases():
0, num_seqs * max_blocks_per_seq, dtype=torch.int64 0, num_seqs * max_blocks_per_seq, dtype=torch.int64
).view(num_seqs, max_blocks_per_seq) ).view(num_seqs, max_blocks_per_seq)
print("block_tables.shape", block_tables.shape, block_tables)
q_shape = (num_seqs, num_heads, head_size) q_shape = (num_seqs, num_heads, head_size)
k_cache_shape = (num_blocks, num_kv_heads, block_size, head_size) k_cache_shape = (num_blocks, num_kv_heads, block_size, head_size)
v_cache_shape = (num_blocks, num_kv_heads, block_size, head_size) v_cache_shape = (num_blocks, num_kv_heads, block_size, head_size)
......
...@@ -28,9 +28,9 @@ _TEST_CASES_DATA = [ ...@@ -28,9 +28,9 @@ _TEST_CASES_DATA = [
# Tolerance configuration # Tolerance configuration
_TOLERANCE_MAP = { _TOLERANCE_MAP = {
infinicore.float16: {"atol": 0, "rtol": 1e-2}, infinicore.float16: {"atol": 0, "rtol": 1e-5},
infinicore.float32: {"atol": 1e-4, "rtol": 1e-3}, infinicore.float32: {"atol": 0, "rtol": 1e-5},
infinicore.bfloat16: {"atol": 0, "rtol": 5e-2}, infinicore.bfloat16: {"atol": 0, "rtol": 1e-5},
} }
# Data types to test # Data types to test
...@@ -40,15 +40,15 @@ _TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32] ...@@ -40,15 +40,15 @@ _TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
# ============================================================================== # ==============================================================================
# Reference Implementation # Reference Implementation
# ============================================================================== # ==============================================================================
def ref_paged_caching(key, value, key_cache_pool, value_cache_pool, slot_mapping): def ref_paged_caching(key_cache_pool, value_cache_pool, key, value, slot_mapping):
""" """
Reference implementation for paged_caching operator. Reference implementation for paged_caching operator.
Args: Args:
key (torch.Tensor): Keys, shape [ntok, nkvh, dh]
value (torch.Tensor): Values, shape [ntok, nkvh, dh]
key_cache_pool (torch.Tensor): K cache pool, shape [num_blocks, nkvh, block_size, dh] key_cache_pool (torch.Tensor): K cache pool, shape [num_blocks, nkvh, block_size, dh]
value_cache_pool (torch.Tensor): V cache pool, shape [num_blocks, nkvh, block_size, dh] value_cache_pool (torch.Tensor): V cache pool, shape [num_blocks, nkvh, block_size, dh]
key (torch.Tensor): Keys, shape [ntok, nkvh, dh]
value (torch.Tensor): Values, shape [ntok, nkvh, dh]
slot_mapping (torch.Tensor): Slot mapping, shape [ntok] slot_mapping (torch.Tensor): Slot mapping, shape [ntok]
""" """
ntok = key.shape[0] ntok = key.shape[0]
...@@ -56,8 +56,8 @@ def ref_paged_caching(key, value, key_cache_pool, value_cache_pool, slot_mapping ...@@ -56,8 +56,8 @@ def ref_paged_caching(key, value, key_cache_pool, value_cache_pool, slot_mapping
# This reference implementation operates on a cloned cache to avoid modifying the original input tensor, # This reference implementation operates on a cloned cache to avoid modifying the original input tensor,
# mimicking the behavior where the custom operator writes to its output tensor. # mimicking the behavior where the custom operator writes to its output tensor.
k_cache_ref = key_cache_pool.clone() k_cache_ref = key_cache_pool
v_cache_ref = value_cache_pool.clone() v_cache_ref = value_cache_pool
for i in range(ntok): for i in range(ntok):
slot = slot_mapping[i].item() slot = slot_mapping[i].item()
...@@ -98,9 +98,9 @@ def parse_test_cases(): ...@@ -98,9 +98,9 @@ def parse_test_cases():
current_slot += length.item() current_slot += length.item()
# Ensure we don't exceed the total number of slots in the cache # Ensure we don't exceed the total number of slots in the cache
assert ( assert current_slot <= num_blocks * block_size, (
current_slot <= num_blocks * block_size "Not enough blocks in the cache pool for this test case"
), "Not enough blocks in the cache pool for this test case" )
slot_mapping = torch.tensor(slot_mapping_list, dtype=torch.int64) slot_mapping = torch.tensor(slot_mapping_list, dtype=torch.int64)
...@@ -119,8 +119,12 @@ def parse_test_cases(): ...@@ -119,8 +119,12 @@ def parse_test_cases():
# Create typed tensor specs # Create typed tensor specs
k_spec = TensorSpec.from_tensor(k_shape, None, dtype) k_spec = TensorSpec.from_tensor(k_shape, None, dtype)
v_spec = TensorSpec.from_tensor(v_shape, None, dtype) v_spec = TensorSpec.from_tensor(v_shape, None, dtype)
k_cache_spec = TensorSpec.from_tensor(k_cache_shape, None, dtype) k_cache_spec = TensorSpec.from_tensor(
v_cache_spec = TensorSpec.from_tensor(v_cache_shape, None, dtype) k_cache_shape, None, dtype, init_mode=TensorInitializer.ZEROS
)
v_cache_spec = TensorSpec.from_tensor(
v_cache_shape, None, dtype, init_mode=TensorInitializer.ZEROS
)
slot_mapping_spec = TensorSpec.from_tensor( slot_mapping_spec = TensorSpec.from_tensor(
slot_mapping_shape, slot_mapping_shape,
init_mode=TensorInitializer.MANUAL, init_mode=TensorInitializer.MANUAL,
...@@ -132,10 +136,10 @@ def parse_test_cases(): ...@@ -132,10 +136,10 @@ def parse_test_cases():
test_cases.append( test_cases.append(
TestCase( TestCase(
inputs=[ inputs=[
k_spec,
v_spec,
k_cache_spec, k_cache_spec,
v_cache_spec, v_cache_spec,
k_spec,
v_spec,
slot_mapping_spec, slot_mapping_spec,
], ],
kwargs=None, kwargs=None,
......
...@@ -1066,10 +1066,10 @@ def paged_caching_(lib): ...@@ -1066,10 +1066,10 @@ def paged_caching_(lib):
lib.infiniopCreatePagedCachingDescriptor.argtypes = [ lib.infiniopCreatePagedCachingDescriptor.argtypes = [
infiniopHandle_t, infiniopHandle_t,
POINTER(infiniopOperatorDescriptor_t), POINTER(infiniopOperatorDescriptor_t),
infiniopTensorDescriptor_t, # k_desc
infiniopTensorDescriptor_t, # v_desc
infiniopTensorDescriptor_t, # k_cache_desc infiniopTensorDescriptor_t, # k_cache_desc
infiniopTensorDescriptor_t, # v_cache_desc infiniopTensorDescriptor_t, # v_cache_desc
infiniopTensorDescriptor_t, # k_desc
infiniopTensorDescriptor_t, # v_desc
infiniopTensorDescriptor_t, # slot_mapping_desc infiniopTensorDescriptor_t, # slot_mapping_desc
] ]
...@@ -1086,10 +1086,10 @@ def paged_caching_(lib): ...@@ -1086,10 +1086,10 @@ def paged_caching_(lib):
infiniopOperatorDescriptor_t, infiniopOperatorDescriptor_t,
c_void_p, # workspace c_void_p, # workspace
c_size_t, # workspace_size c_size_t, # workspace_size
c_void_p, # k
c_void_p, # v
c_void_p, # k_cache c_void_p, # k_cache
c_void_p, # v_cache c_void_p, # v_cache
c_void_p, # k
c_void_p, # v
c_void_p, # slot_mapping c_void_p, # slot_mapping
c_void_p, # stream c_void_p, # stream
] ]
......
...@@ -95,6 +95,7 @@ _TEST_CASES_ = [ ...@@ -95,6 +95,7 @@ _TEST_CASES_ = [
(4, 40, 40, 128, 16, 1024, False), (4, 40, 40, 128, 16, 1024, False),
(6, 40, 40, 128, 16, 1024, False), (6, 40, 40, 128, 16, 1024, False),
(3, 8, 8, 128, 16, 1024, False), (3, 8, 8, 128, 16, 1024, False),
(3, 8, 8, 64, 16, 1024, False),
(8, 64, 8, 128, 16, 2048, False), (8, 64, 8, 128, 16, 2048, False),
] ]
......
...@@ -22,15 +22,15 @@ from libinfiniop import ( ...@@ -22,15 +22,15 @@ from libinfiniop import (
# ============================================================================== # ==============================================================================
# Reference Implementation # Reference Implementation
# ============================================================================== # ==============================================================================
def ref_paged_caching(key, value, key_cache_pool, value_cache_pool, slot_mapping): def ref_paged_caching(key_cache_pool, value_cache_pool, key, value, slot_mapping):
""" """
Reference implementation for paged_caching operator. Reference implementation for paged_caching operator.
Args: Args:
key (torch.Tensor): Keys, shape [ntok, nkvh, dh]
value (torch.Tensor): Values, shape [ntok, nkvh, dh]
key_cache_pool (torch.Tensor): K cache pool, shape [num_blocks, nkvh, block_size, dh] key_cache_pool (torch.Tensor): K cache pool, shape [num_blocks, nkvh, block_size, dh]
value_cache_pool (torch.Tensor): V cache pool, shape [num_blocks, nkvh, block_size, dh] value_cache_pool (torch.Tensor): V cache pool, shape [num_blocks, nkvh, block_size, dh]
key (torch.Tensor): Keys, shape [ntok, nkvh, dh]
value (torch.Tensor): Values, shape [ntok, nkvh, dh]
slot_mapping (torch.Tensor): Slot mapping, shape [ntok] slot_mapping (torch.Tensor): Slot mapping, shape [ntok]
""" """
ntok = key.shape[0] ntok = key.shape[0]
...@@ -71,9 +71,9 @@ _TENSOR_DTYPES = [InfiniDtype.BF16, InfiniDtype.F16, InfiniDtype.F32] ...@@ -71,9 +71,9 @@ _TENSOR_DTYPES = [InfiniDtype.BF16, InfiniDtype.F16, InfiniDtype.F32]
# Tolerance map for different data types # Tolerance map for different data types
_TOLERANCE_MAP = { _TOLERANCE_MAP = {
InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-2}, InfiniDtype.F16: {"atol": 0, "rtol": 1e-5},
InfiniDtype.BF16: {"atol": 5e-3, "rtol": 5e-2}, InfiniDtype.BF16: {"atol": 0, "rtol": 1e-5},
InfiniDtype.F32: {"atol": 1e-5, "rtol": 1e-5}, InfiniDtype.F32: {"atol": 0, "rtol": 1e-5},
} }
# Global flags for controlling test behavior # Global flags for controlling test behavior
...@@ -123,9 +123,9 @@ def test( ...@@ -123,9 +123,9 @@ def test(
current_slot += length.item() current_slot += length.item()
# Ensure we don't exceed the total number of slots in the cache # Ensure we don't exceed the total number of slots in the cache
assert ( assert current_slot <= num_blocks * block_size, (
current_slot <= num_blocks * block_size "Not enough blocks in the cache pool for this test case"
), "Not enough blocks in the cache pool for this test case" )
slot_mapping_torch = torch.tensor(slot_mapping_list, dtype=torch.int64) slot_mapping_torch = torch.tensor(slot_mapping_list, dtype=torch.int64)
...@@ -144,10 +144,10 @@ def test( ...@@ -144,10 +144,10 @@ def test(
# Run reference implementation # Run reference implementation
k_cache_ref, v_cache_ref = ref_paged_caching( k_cache_ref, v_cache_ref = ref_paged_caching(
k.torch_tensor(),
v.torch_tensor(),
k_cache_pool.torch_tensor(), k_cache_pool.torch_tensor(),
v_cache_pool.torch_tensor(), v_cache_pool.torch_tensor(),
k.torch_tensor(),
v.torch_tensor(),
slot_mapping.torch_tensor(), slot_mapping.torch_tensor(),
) )
...@@ -160,10 +160,10 @@ def test( ...@@ -160,10 +160,10 @@ def test(
LIBINFINIOP.infiniopCreatePagedCachingDescriptor( LIBINFINIOP.infiniopCreatePagedCachingDescriptor(
handle, handle,
ctypes.byref(descriptor), ctypes.byref(descriptor),
k.descriptor,
v.descriptor,
k_cache_pool.descriptor, k_cache_pool.descriptor,
v_cache_pool.descriptor, v_cache_pool.descriptor,
k.descriptor,
v.descriptor,
slot_mapping.descriptor, slot_mapping.descriptor,
) )
) )
...@@ -191,10 +191,10 @@ def test( ...@@ -191,10 +191,10 @@ def test(
descriptor, descriptor,
workspace.data(), workspace.data(),
workspace_size.value, workspace_size.value,
k.data(),
v.data(),
k_cache_pool.data(), k_cache_pool.data(),
v_cache_pool.data(), v_cache_pool.data(),
k.data(),
v.data(),
slot_mapping.data(), slot_mapping.data(),
None, None,
) )
......
...@@ -80,7 +80,7 @@ class SimpleCacheManager: ...@@ -80,7 +80,7 @@ class SimpleCacheManager:
return torch.tensor(slots, dtype=torch.int32) return torch.tensor(slots, dtype=torch.int32)
def ref_paged_caching(k_new, v_new, k_pool, v_pool, slots, block_size): def ref_paged_caching(k_pool, v_pool, k_new, v_new, slots, block_size):
"""Reference implementation for incremental caching.""" """Reference implementation for incremental caching."""
for i in range(k_new.shape[0]): for i in range(k_new.shape[0]):
slot = slots[i].item() slot = slots[i].item()
...@@ -152,10 +152,10 @@ def test( ...@@ -152,10 +152,10 @@ def test(
def torch_caching(): def torch_caching():
nonlocal k_pool_ref, v_pool_ref nonlocal k_pool_ref, v_pool_ref
return ref_paged_caching( return ref_paged_caching(
k_in.torch_tensor(),
v_in.torch_tensor(),
k_pool_ref, k_pool_ref,
v_pool_ref, v_pool_ref,
k_in.torch_tensor(),
v_in.torch_tensor(),
slots_torch, slots_torch,
block_size, block_size,
) )
...@@ -168,10 +168,10 @@ def test( ...@@ -168,10 +168,10 @@ def test(
LIBINFINIOP.infiniopCreatePagedCachingDescriptor( LIBINFINIOP.infiniopCreatePagedCachingDescriptor(
handle, handle,
ctypes.byref(descriptor), ctypes.byref(descriptor),
k_in.descriptor,
v_in.descriptor,
k_cache_pool.descriptor, k_cache_pool.descriptor,
v_cache_pool.descriptor, v_cache_pool.descriptor,
k_in.descriptor,
v_in.descriptor,
slot_mapping.descriptor, slot_mapping.descriptor,
) )
) )
...@@ -190,10 +190,10 @@ def test( ...@@ -190,10 +190,10 @@ def test(
descriptor, descriptor,
workspace.data(), workspace.data(),
workspace_size.value, workspace_size.value,
k_in.data(),
v_in.data(),
k_cache_pool.data(), k_cache_pool.data(),
v_cache_pool.data(), v_cache_pool.data(),
k_in.data(),
v_in.data(),
slot_mapping.data(), slot_mapping.data(),
None, None,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment