Unverified Commit 90cb1b54 authored by thatPepe's avatar thatPepe Committed by GitHub
Browse files

Merge pull request #1037 from InfiniTensor/issue/1036

Issue/1036 paged caching support strides
parents d8176086 3e1ef507
......@@ -72,8 +72,10 @@ __C infiniStatus_t infiniopCreateCausalSoftmaxDescriptor(
#ifdef ENABLE_MOORE_API
CREATE(INFINI_DEVICE_MOORE, moore)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
#undef CREATE
}
__C infiniStatus_t infiniopGetCausalSoftmaxWorkspaceSize(infiniopCausalSoftmaxDescriptor_t desc, size_t *size) {
......@@ -117,8 +119,10 @@ __C infiniStatus_t infiniopGetCausalSoftmaxWorkspaceSize(infiniopCausalSoftmaxDe
#ifdef ENABLE_MOORE_API
GET(INFINI_DEVICE_MOORE, moore)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
#undef GET
}
__C infiniStatus_t infiniopCausalSoftmax(
......@@ -167,8 +171,10 @@ __C infiniStatus_t infiniopCausalSoftmax(
#ifdef ENABLE_MOORE_API
CALCULATE(INFINI_DEVICE_MOORE, moore)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
#undef CALCULATE
}
__C infiniStatus_t infiniopDestroyCausalSoftmaxDescriptor(infiniopCausalSoftmaxDescriptor_t desc) {
......@@ -212,6 +218,8 @@ __C infiniStatus_t infiniopDestroyCausalSoftmaxDescriptor(infiniopCausalSoftmaxD
#ifdef ENABLE_MOORE_API
DESTROY(INFINI_DEVICE_MOORE, moore)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
#undef DESTROY
}
......@@ -91,11 +91,11 @@ __C infiniStatus_t infiniopGetClipWorkspaceSize(infiniopClipDescriptor_t desc, s
#ifdef ENABLE_KUNLUN_API
GET(INFINI_DEVICE_KUNLUN, kunlun)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef GET
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopClip(
......
......@@ -51,8 +51,9 @@ __C infiniStatus_t infiniopCreateLogSoftmaxDescriptor(
#ifdef ENABLE_ASCEND_API
// CREATE(INFINI_DEVICE_ASCEND, ascend)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopGetLogSoftmaxWorkspaceSize(infiniopLogSoftmaxDescriptor_t desc, size_t *size) {
......@@ -84,8 +85,9 @@ __C infiniStatus_t infiniopGetLogSoftmaxWorkspaceSize(infiniopLogSoftmaxDescript
#ifdef ENABLE_ASCEND_API
// GET(INFINI_DEVICE_ASCEND, ascend)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopLogSoftmax(
......@@ -122,8 +124,9 @@ __C infiniStatus_t infiniopLogSoftmax(
#ifdef ENABLE_ASCEND_API
// CALCULATE(INFINI_DEVICE_ASCEND, ascend)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopDestroyLogSoftmaxDescriptor(infiniopLogSoftmaxDescriptor_t desc) {
......@@ -155,6 +158,7 @@ __C infiniStatus_t infiniopDestroyLogSoftmaxDescriptor(infiniopLogSoftmaxDescrip
#ifdef ENABLE_ASCEND_API
// DESTROY(INFINI_DEVICE_ASCEND, ascend)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
......@@ -38,7 +38,11 @@ __device__ void pagedCachingKernel(
const ptrdiff_t k_src_stride, // Stride between tokens in the source K tensor
const ptrdiff_t v_src_stride, // Stride between tokens in the source V tensor
const ptrdiff_t k_cache_block_stride, // Stride between blocks in the K cache pool
const ptrdiff_t v_cache_block_stride // Stride between blocks in the V cache pool
const ptrdiff_t v_cache_block_stride, // Stride between blocks in the V cache pool
const ptrdiff_t k_cache_head_stride, // Stride between heads in the K cache pool
const ptrdiff_t v_cache_head_stride, // Stride between heads in the V cache pool
const ptrdiff_t k_cache_slot_stride, // Stride between block slots in the K cache pool
const ptrdiff_t v_cache_slot_stride // Stride between block slots in the V cache pool
) {
//================================================================================
// 1. Identify Work Unit & Calculate Addresses
......@@ -66,13 +70,11 @@ __device__ void pagedCachingKernel(
// Destination pointer calculation assumes a [num_blocks, block_size, num_heads, head_size] layout.
// We point to the beginning of the memory region for this token's slot.
const ptrdiff_t cache_head_stride = block_size * head_size;
Tdata *k_cache_block_base_ptr = k_cache_ptr + physical_block_idx * k_cache_block_stride;
Tdata *k_dst_head_ptr = k_cache_block_base_ptr + head_idx * cache_head_stride + block_offset * head_size;
Tdata *k_dst_head_ptr = k_cache_block_base_ptr + head_idx * k_cache_head_stride + block_offset * k_cache_slot_stride;
Tdata *v_cache_block_base_ptr = v_cache_ptr + physical_block_idx * v_cache_block_stride;
Tdata *v_dst_head_ptr = v_cache_block_base_ptr + head_idx * cache_head_stride + block_offset * head_size;
Tdata *v_dst_head_ptr = v_cache_block_base_ptr + head_idx * v_cache_head_stride + block_offset * v_cache_slot_stride;
//================================================================================
// 2. Perform Element-wise Data Copy (Safe, Non-Vectorized)
......
......@@ -26,6 +26,10 @@ public:
ptrdiff_t v_src_stride;
ptrdiff_t k_cache_block_stride;
ptrdiff_t v_cache_block_stride;
ptrdiff_t k_cache_head_stride;
ptrdiff_t v_cache_head_stride;
ptrdiff_t k_cache_slot_stride;
ptrdiff_t v_cache_slot_stride;
static utils::Result<PagedCachingInfo> create(
infiniopTensorDescriptor_t k_cache_desc,
......@@ -63,6 +67,10 @@ public:
ptrdiff_t v_src_stride = v_desc->stride(0);
ptrdiff_t k_cache_block_stride = k_cache_desc->stride(0);
ptrdiff_t v_cache_block_stride = v_cache_desc->stride(0);
ptrdiff_t k_cache_head_stride = k_cache_desc->stride(1);
ptrdiff_t v_cache_head_stride = v_cache_desc->stride(1);
ptrdiff_t k_cache_slot_stride = k_cache_desc->stride(2);
ptrdiff_t v_cache_slot_stride = v_cache_desc->stride(2);
return utils::Result<PagedCachingInfo>(PagedCachingInfo{
dtype,
......@@ -73,7 +81,11 @@ public:
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride});
v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride});
}
};
......
......@@ -10,10 +10,13 @@ INFINIOP_METAX_KERNEL pagedCaching(
const int64_t *slot_mapping,
const size_t head_size, const size_t block_size,
const ptrdiff_t k_src_stride, const ptrdiff_t v_src_stride,
const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride) {
const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride,
const ptrdiff_t k_cache_head_stride, const ptrdiff_t v_cache_head_stride,
const ptrdiff_t k_cache_slot_stride, const ptrdiff_t v_cache_slot_strid) {
op::paged_caching::cuda::pagedCachingKernel<Tdata, NUM_THREADS>(
k_cache, v_cache, k, v, slot_mapping, head_size,
block_size, k_src_stride, v_src_stride, k_cache_block_stride, v_cache_block_stride);
block_size, k_src_stride, v_src_stride,
k_cache_block_stride, v_cache_block_stride, k_cache_head_stride, v_cache_head_stride, k_cache_slot_stride, v_cache_slot_stride);
}
namespace op::paged_caching::metax {
......@@ -59,6 +62,8 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
size_t num_tokens, size_t num_kv_heads, size_t head_size, size_t block_size,
ptrdiff_t k_src_stride, ptrdiff_t v_src_stride,
ptrdiff_t k_cache_block_stride, ptrdiff_t v_cache_block_stride,
ptrdiff_t k_cache_head_stride, ptrdiff_t v_cache_head_stride,
ptrdiff_t k_cache_slot_stride, ptrdiff_t v_cache_slot_stride,
hcStream_t stream) {
// Grid dimension is 1D, with one block per token, as we decided.
......@@ -83,7 +88,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride);
} else if (dtype == INFINI_DTYPE_BF16) {
pagedCaching<cuda_bfloat16, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
......@@ -97,7 +106,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride);
} else if (dtype == INFINI_DTYPE_F32) {
pagedCaching<float, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
......@@ -111,7 +124,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
......@@ -138,6 +155,8 @@ infiniStatus_t Descriptor::calculate(
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride,
_info.k_cache_head_stride, _info.v_cache_head_stride,
_info.k_cache_slot_stride, _info.v_cache_slot_stride,
stream);
} else if (max_threads >= METAX_BLOCK_SIZE_512) {
launchKernel<METAX_BLOCK_SIZE_512>(
......@@ -145,6 +164,8 @@ infiniStatus_t Descriptor::calculate(
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride,
_info.k_cache_head_stride, _info.v_cache_head_stride,
_info.k_cache_slot_stride, _info.v_cache_slot_stride,
stream);
} else {
// If the device supports fewer threads, return an error.
......
......@@ -10,10 +10,13 @@ INFINIOP_MOORE_KERNEL pagedCaching(
const int64_t *slot_mapping,
const size_t head_size, const size_t block_size,
const ptrdiff_t k_src_stride, const ptrdiff_t v_src_stride,
const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride) {
const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride,
const ptrdiff_t k_cache_head_stride, const ptrdiff_t v_cache_head_stride,
const ptrdiff_t k_cache_slot_stride, const ptrdiff_t v_cache_slot_stride) {
op::paged_caching::cuda::pagedCachingKernel<Tdata, NUM_THREADS>(
k_cache, v_cache, k, v, slot_mapping, head_size,
block_size, k_src_stride, v_src_stride, k_cache_block_stride, v_cache_block_stride);
block_size, k_src_stride, v_src_stride,
k_cache_block_stride, v_cache_block_stride, k_cache_head_stride, v_cache_head_stride, k_cache_slot_stride, v_cache_slot_stride);
}
namespace op::paged_caching::moore {
......@@ -59,6 +62,8 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
size_t num_tokens, size_t num_kv_heads, size_t head_size, size_t block_size,
ptrdiff_t k_src_stride, ptrdiff_t v_src_stride,
ptrdiff_t k_cache_block_stride, ptrdiff_t v_cache_block_stride,
ptrdiff_t k_cache_head_stride, ptrdiff_t v_cache_head_stride,
ptrdiff_t k_cache_slot_stride, ptrdiff_t v_cache_slot_stride,
musaStream_t stream) {
// Grid dimension is 1D, with one block per token, as we decided.
......@@ -83,7 +88,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride);
} else if (dtype == INFINI_DTYPE_BF16) {
pagedCaching<__mt_bfloat16, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
......@@ -97,7 +106,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride);
} else if (dtype == INFINI_DTYPE_F32) {
pagedCaching<float, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
......@@ -111,7 +124,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
......@@ -137,6 +154,8 @@ infiniStatus_t Descriptor::calculate(
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride,
_info.k_cache_head_stride, _info.v_cache_head_stride,
_info.k_cache_slot_stride, _info.v_cache_slot_stride,
stream);
} else if (_opaque->internal->maxThreadsPerBlock() >= MOORE_BLOCK_SIZE_512) {
launchKernel<MOORE_BLOCK_SIZE_512>(
......@@ -144,6 +163,8 @@ infiniStatus_t Descriptor::calculate(
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride,
_info.k_cache_head_stride, _info.v_cache_head_stride,
_info.k_cache_slot_stride, _info.v_cache_slot_stride,
stream);
} else {
// If the GPU is older and supports fewer threads, return an error.
......
......@@ -10,10 +10,13 @@ INFINIOP_CUDA_KERNEL pagedCaching(
const int64_t *slot_mapping,
const size_t head_size, const size_t block_size,
const ptrdiff_t k_src_stride, const ptrdiff_t v_src_stride,
const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride) {
const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride,
const ptrdiff_t k_cache_head_stride, const ptrdiff_t v_cache_head_stride,
const ptrdiff_t k_cache_slot_stride, const ptrdiff_t v_cache_slot_stride) {
op::paged_caching::cuda::pagedCachingKernel<Tdata, NUM_THREADS>(
k_cache, v_cache, k, v, slot_mapping, head_size,
block_size, k_src_stride, v_src_stride, k_cache_block_stride, v_cache_block_stride);
block_size, k_src_stride, v_src_stride,
k_cache_block_stride, v_cache_block_stride, k_cache_head_stride, v_cache_head_stride, k_cache_slot_stride, v_cache_slot_stride);
}
namespace op::paged_caching::nvidia {
......@@ -59,6 +62,8 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
size_t num_tokens, size_t num_kv_heads, size_t head_size, size_t block_size,
ptrdiff_t k_src_stride, ptrdiff_t v_src_stride,
ptrdiff_t k_cache_block_stride, ptrdiff_t v_cache_block_stride,
ptrdiff_t k_cache_head_stride, ptrdiff_t v_cache_head_stride,
ptrdiff_t k_cache_slot_stride, ptrdiff_t v_cache_slot_stride,
cudaStream_t stream) {
// Grid dimension is 1D, with one block per token, as we decided.
......@@ -83,7 +88,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride);
} else if (dtype == INFINI_DTYPE_BF16) {
pagedCaching<__nv_bfloat16, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
......@@ -97,7 +106,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride);
} else if (dtype == INFINI_DTYPE_F32) {
pagedCaching<float, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
......@@ -111,7 +124,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
......@@ -137,6 +154,8 @@ infiniStatus_t Descriptor::calculate(
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride,
_info.k_cache_head_stride, _info.v_cache_head_stride,
_info.k_cache_slot_stride, _info.v_cache_slot_stride,
stream);
} else if (_opaque->internal->maxThreadsPerBlock() >= CUDA_BLOCK_SIZE_512) {
launchKernel<CUDA_BLOCK_SIZE_512>(
......@@ -144,6 +163,8 @@ infiniStatus_t Descriptor::calculate(
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride,
_info.k_cache_head_stride, _info.v_cache_head_stride,
_info.k_cache_slot_stride, _info.v_cache_slot_stride,
stream);
} else if (_opaque->internal->maxThreadsPerBlock() >= CUDA_BLOCK_SIZE_4096) {
launchKernel<CUDA_BLOCK_SIZE_4096>(
......@@ -151,6 +172,8 @@ infiniStatus_t Descriptor::calculate(
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride,
_info.k_cache_head_stride, _info.v_cache_head_stride,
_info.k_cache_slot_stride, _info.v_cache_slot_stride,
stream);
} else {
// If the GPU is older and supports fewer threads, return an error.
......
......@@ -76,11 +76,11 @@ __C infiniStatus_t infiniopCreateRMSNormDescriptor(
#ifdef ENABLE_MOORE_API
CREATE(INFINI_DEVICE_MOORE, moore);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CREATE
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopGetRMSNormWorkspaceSize(infiniopRMSNormDescriptor_t desc, size_t *size) {
......@@ -124,11 +124,11 @@ __C infiniStatus_t infiniopGetRMSNormWorkspaceSize(infiniopRMSNormDescriptor_t d
#ifdef ENABLE_MOORE_API
GET(INFINI_DEVICE_MOORE, moore);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef GET
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopRMSNorm(infiniopRMSNormDescriptor_t desc, void *workspace, size_t workspace_size,
......@@ -173,11 +173,11 @@ __C infiniStatus_t infiniopRMSNorm(infiniopRMSNormDescriptor_t desc, void *works
#ifdef ENABLE_MOORE_API
CALCULATE(INFINI_DEVICE_MOORE, moore);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CALCULATE
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopDestroyRMSNormDescriptor(infiniopRMSNormDescriptor_t desc) {
......@@ -221,9 +221,9 @@ __C infiniStatus_t infiniopDestroyRMSNormDescriptor(infiniopRMSNormDescriptor_t
#ifdef ENABLE_MOORE_API
DESTROY(INFINI_DEVICE_MOORE, moore);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef DESTROY
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
......@@ -80,11 +80,11 @@ __C infiniStatus_t infiniopCreateRoPEDescriptor(
#ifdef ENABLE_CAMBRICON_API
CREATE(INFINI_DEVICE_CAMBRICON, bang);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CREATE
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopGetRoPEWorkspaceSize(infiniopRoPEDescriptor_t desc,
......@@ -128,11 +128,11 @@ __C infiniStatus_t infiniopGetRoPEWorkspaceSize(infiniopRoPEDescriptor_t desc,
#ifdef ENABLE_ASCEND_API
GET(INFINI_DEVICE_ASCEND, ascend);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef GET
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopRoPE(
......@@ -185,11 +185,11 @@ __C infiniStatus_t infiniopRoPE(
#ifdef ENABLE_ASCEND_API
CALCULATE(INFINI_DEVICE_ASCEND, ascend);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CALCULATE
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t
......@@ -234,9 +234,9 @@ infiniopDestroyRoPEDescriptor(infiniopRoPEDescriptor_t desc) {
#ifdef ENABLE_ASCEND_API
DELETE(INFINI_DEVICE_ASCEND, ascend);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef DELETE
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
......@@ -24,8 +24,11 @@ __C infiniStatus_t infiniopCreateSiluAndMulDescriptor(
#ifdef ENABLE_MOORE_API
CREATE(INFINI_DEVICE_MOORE, moore);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
#undef CREATE
}
__C infiniStatus_t infiniopGetSiluAndMulWorkspaceSize(infiniopSiluAndMulDescriptor_t desc, size_t *size) {
......@@ -39,8 +42,11 @@ __C infiniStatus_t infiniopGetSiluAndMulWorkspaceSize(infiniopSiluAndMulDescript
#ifdef ENABLE_MOORE_API
GET(INFINI_DEVICE_MOORE, moore);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
#undef GET
}
__C infiniStatus_t infiniopSiluAndMul(
......@@ -59,8 +65,11 @@ __C infiniStatus_t infiniopSiluAndMul(
#ifdef ENABLE_MOORE_API
CALCULATE(INFINI_DEVICE_MOORE, moore);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
#undef CALCULATE
}
__C infiniStatus_t infiniopDestroySiluAndMulDescriptor(infiniopSiluAndMulDescriptor_t desc) {
......@@ -74,6 +83,9 @@ __C infiniStatus_t infiniopDestroySiluAndMulDescriptor(infiniopSiluAndMulDescrip
#ifdef ENABLE_MOORE_API
DESTROY(INFINI_DEVICE_MOORE, moore);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
#undef DESTROY
}
......@@ -37,8 +37,9 @@ __C infiniStatus_t infiniopCreateSoftmaxDescriptor(
#ifdef ENABLE_ALI_API
CREATE(INFINI_DEVICE_ALI, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopGetSoftmaxWorkspaceSize(infiniopSoftmaxDescriptor_t desc, size_t *size) {
......@@ -64,8 +65,9 @@ __C infiniStatus_t infiniopGetSoftmaxWorkspaceSize(infiniopSoftmaxDescriptor_t d
#ifdef ENABLE_ALI_API
GET(INFINI_DEVICE_ALI, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopSoftmax(
......@@ -96,8 +98,9 @@ __C infiniStatus_t infiniopSoftmax(
#ifdef ENABLE_ALI_API
CALCULATE(INFINI_DEVICE_ALI, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopDestroySoftmaxDescriptor(infiniopSoftmaxDescriptor_t desc) {
......@@ -123,6 +126,7 @@ __C infiniStatus_t infiniopDestroySoftmaxDescriptor(infiniopSoftmaxDescriptor_t
#ifdef ENABLE_ALI_API
DESTROY(INFINI_DEVICE_ALI, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
......@@ -42,11 +42,11 @@ __C infiniStatus_t infiniopCreateTopkrouterDescriptor(infiniopHandle_t handle, i
#ifdef ENABLE_ALI_API
CREATE(INFINI_DEVICE_ALI, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CREATE
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopGetTopkrouterWorkspaceSize(infiniopTopkrouterDescriptor_t desc, size_t *size) {
......@@ -74,11 +74,11 @@ __C infiniStatus_t infiniopGetTopkrouterWorkspaceSize(infiniopTopkrouterDescript
#ifdef ENABLE_ALI_API
GET(INFINI_DEVICE_ALI, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef GET
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopTopkrouter(infiniopTopkrouterDescriptor_t desc, void *workspace, size_t workspace_size,
......@@ -109,11 +109,11 @@ __C infiniStatus_t infiniopTopkrouter(infiniopTopkrouterDescriptor_t desc, void
#ifdef ENABLE_ALI_API
CALCULATE(INFINI_DEVICE_ALI, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CALCULATE
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopDestroyTopkrouterDescriptor(infiniopTopkrouterDescriptor_t desc) {
......@@ -141,9 +141,9 @@ __C infiniStatus_t infiniopDestroyTopkrouterDescriptor(infiniopTopkrouterDescrip
#ifdef ENABLE_ALI_API
DESTROY(INFINI_DEVICE_ALI, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef DESTROY
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
......@@ -40,11 +40,11 @@ __C infiniStatus_t infiniopCreateTopksoftmaxDescriptor(infiniopHandle_t handle,
#ifdef ENABLE_ILUVATAR_API
CREATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CREATE
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopGetTopksoftmaxWorkspaceSize(infiniopTopksoftmaxDescriptor_t desc, size_t *size) {
......@@ -73,11 +73,11 @@ __C infiniStatus_t infiniopGetTopksoftmaxWorkspaceSize(infiniopTopksoftmaxDescri
#ifdef ENABLE_ILUVATAR_API
GET(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef GET
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopTopksoftmax(infiniopTopksoftmaxDescriptor_t desc, void *workspace, size_t workspace_size,
......@@ -111,11 +111,11 @@ __C infiniStatus_t infiniopTopksoftmax(infiniopTopksoftmaxDescriptor_t desc, voi
#ifdef ENABLE_ILUVATAR_API
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CALCULATE
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopDestroyTopksoftmaxDescriptor(infiniopTopksoftmaxDescriptor_t desc) {
......@@ -144,9 +144,9 @@ __C infiniStatus_t infiniopDestroyTopksoftmaxDescriptor(infiniopTopksoftmaxDescr
#ifdef ENABLE_ILUVATAR_API
DESTROY(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef DESTROY
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
......@@ -18,12 +18,16 @@ from framework import (
# Operator-specific configuration
# ==============================================================================
# Test cases format: (num_seqs, max_seq_len, num_kv_heads, head_size, block_size)
# Test cases format: (num_seqs, max_seq_len, num_kv_heads, head_size, block_size, permute_dim_1_2)
_TEST_CASES_DATA = [
(1, 128, 8, 128, 16),
(5, 512, 40, 128, 16),
(16, 1024, 8, 64, 32),
(10, 1024, 40, 64, 32),
(1, 128, 8, 128, 16, False),
(1, 128, 8, 128, 16, True),
(5, 512, 40, 256, 16, False),
(5, 512, 40, 256, 16, True),
(16, 1024, 8, 64, 32, False),
(16, 1024, 8, 64, 32, True),
(10, 1024, 40, 64, 32, False),
(10, 1024, 40, 64, 32, True),
]
# Tolerance configuration
......@@ -40,7 +44,9 @@ _TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
# ==============================================================================
# Reference Implementation
# ==============================================================================
def ref_paged_caching(key_cache_pool, value_cache_pool, key, value, slot_mapping):
def ref_paged_caching(
key_cache_pool, value_cache_pool, key, value, slot_mapping, permute_dim_1_2
):
"""
Reference implementation for paged_caching operator.
......@@ -52,7 +58,7 @@ def ref_paged_caching(key_cache_pool, value_cache_pool, key, value, slot_mapping
slot_mapping (torch.Tensor): Slot mapping, shape [ntok]
"""
ntok = key.shape[0]
block_size = key_cache_pool.shape[2]
block_size = key_cache_pool.shape[1] if permute_dim_1_2 else key_cache_pool.shape[2]
# This reference implementation operates on a cloned cache to avoid modifying the original input tensor,
# mimicking the behavior where the custom operator writes to its output tensor.
......@@ -67,8 +73,12 @@ def ref_paged_caching(key_cache_pool, value_cache_pool, key, value, slot_mapping
key_token = key[i]
value_token = value[i]
k_cache_ref[block_idx, :, block_offset, :] = key_token
v_cache_ref[block_idx, :, block_offset, :] = value_token
if permute_dim_1_2:
k_cache_ref[block_idx, block_offset, :, :] = key_token
v_cache_ref[block_idx, block_offset, :, :] = value_token
else:
k_cache_ref[block_idx, :, block_offset, :] = key_token
v_cache_ref[block_idx, :, block_offset, :] = value_token
return k_cache_ref, v_cache_ref
......@@ -79,7 +89,14 @@ def parse_test_cases():
Each test case contains all necessary information for execution and validation.
"""
test_cases = []
for num_seqs, max_seq_len, num_kv_heads, head_size, block_size in _TEST_CASES_DATA:
for (
num_seqs,
max_seq_len,
num_kv_heads,
head_size,
block_size,
permute_dim_1_2,
) in _TEST_CASES_DATA:
num_blocks = 4096 # A reasonably large cache pool for testing
# Create metadata: variable context lengths for each sequence in the batch
......@@ -111,6 +128,9 @@ def parse_test_cases():
v_shape = (ntok, num_kv_heads, head_size)
k_cache_shape = (num_blocks, num_kv_heads, block_size, head_size)
v_cache_shape = (num_blocks, num_kv_heads, block_size, head_size)
if permute_dim_1_2:
k_cache_shape = (num_blocks, block_size, num_kv_heads, head_size)
v_cache_shape = (num_blocks, block_size, num_kv_heads, head_size)
# Generate test cases for all data types
for dtype in _TENSOR_DTYPES:
......@@ -142,7 +162,7 @@ def parse_test_cases():
v_spec,
slot_mapping_spec,
],
kwargs=None,
kwargs={"permute_dim_1_2": permute_dim_1_2},
output_spec=None,
comparison_target=0, # Only compare k_cache
tolerance=tolerance,
......@@ -162,13 +182,22 @@ class OpTest(BaseOperatorTest):
def get_test_cases(self):
return parse_test_cases()
def torch_operator(self, *args, **kwargs):
def torch_operator(
self, k_cache, v_cache, key, value, slot_mapping, permute_dim_1_2=False
):
"""PyTorch paged_caching implementation"""
return ref_paged_caching(*args, **kwargs)
return ref_paged_caching(
k_cache, v_cache, key, value, slot_mapping, permute_dim_1_2
)
def infinicore_operator(self, *args, **kwargs):
def infinicore_operator(
self, k_cache, v_cache, key, value, slot_mapping, permute_dim_1_2=False
):
"""InfiniCore paged_caching implementation"""
return infinicore.paged_caching(*args, **kwargs)
if permute_dim_1_2:
k_cache = k_cache.permute([0, 2, 1, 3])
v_cache = v_cache.permute([0, 2, 1, 3])
return infinicore.paged_caching(k_cache, v_cache, key, value, slot_mapping)
def main():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment