Unverified Commit 90cb1b54 authored by thatPepe's avatar thatPepe Committed by GitHub
Browse files

Merge pull request #1037 from InfiniTensor/issue/1036

Issue/1036 paged caching support strides
parents d8176086 3e1ef507
...@@ -72,8 +72,10 @@ __C infiniStatus_t infiniopCreateCausalSoftmaxDescriptor( ...@@ -72,8 +72,10 @@ __C infiniStatus_t infiniopCreateCausalSoftmaxDescriptor(
#ifdef ENABLE_MOORE_API #ifdef ENABLE_MOORE_API
CREATE(INFINI_DEVICE_MOORE, moore) CREATE(INFINI_DEVICE_MOORE, moore)
#endif #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; #undef CREATE
} }
__C infiniStatus_t infiniopGetCausalSoftmaxWorkspaceSize(infiniopCausalSoftmaxDescriptor_t desc, size_t *size) { __C infiniStatus_t infiniopGetCausalSoftmaxWorkspaceSize(infiniopCausalSoftmaxDescriptor_t desc, size_t *size) {
...@@ -117,8 +119,10 @@ __C infiniStatus_t infiniopGetCausalSoftmaxWorkspaceSize(infiniopCausalSoftmaxDe ...@@ -117,8 +119,10 @@ __C infiniStatus_t infiniopGetCausalSoftmaxWorkspaceSize(infiniopCausalSoftmaxDe
#ifdef ENABLE_MOORE_API #ifdef ENABLE_MOORE_API
GET(INFINI_DEVICE_MOORE, moore) GET(INFINI_DEVICE_MOORE, moore)
#endif #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; #undef GET
} }
__C infiniStatus_t infiniopCausalSoftmax( __C infiniStatus_t infiniopCausalSoftmax(
...@@ -167,8 +171,10 @@ __C infiniStatus_t infiniopCausalSoftmax( ...@@ -167,8 +171,10 @@ __C infiniStatus_t infiniopCausalSoftmax(
#ifdef ENABLE_MOORE_API #ifdef ENABLE_MOORE_API
CALCULATE(INFINI_DEVICE_MOORE, moore) CALCULATE(INFINI_DEVICE_MOORE, moore)
#endif #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; #undef CALCULATE
} }
__C infiniStatus_t infiniopDestroyCausalSoftmaxDescriptor(infiniopCausalSoftmaxDescriptor_t desc) { __C infiniStatus_t infiniopDestroyCausalSoftmaxDescriptor(infiniopCausalSoftmaxDescriptor_t desc) {
...@@ -212,6 +218,8 @@ __C infiniStatus_t infiniopDestroyCausalSoftmaxDescriptor(infiniopCausalSoftmaxD ...@@ -212,6 +218,8 @@ __C infiniStatus_t infiniopDestroyCausalSoftmaxDescriptor(infiniopCausalSoftmaxD
#ifdef ENABLE_MOORE_API #ifdef ENABLE_MOORE_API
DESTROY(INFINI_DEVICE_MOORE, moore) DESTROY(INFINI_DEVICE_MOORE, moore)
#endif #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; #undef DESTROY
} }
...@@ -91,11 +91,11 @@ __C infiniStatus_t infiniopGetClipWorkspaceSize(infiniopClipDescriptor_t desc, s ...@@ -91,11 +91,11 @@ __C infiniStatus_t infiniopGetClipWorkspaceSize(infiniopClipDescriptor_t desc, s
#ifdef ENABLE_KUNLUN_API #ifdef ENABLE_KUNLUN_API
GET(INFINI_DEVICE_KUNLUN, kunlun) GET(INFINI_DEVICE_KUNLUN, kunlun)
#endif #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
#undef GET #undef GET
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
__C infiniStatus_t infiniopClip( __C infiniStatus_t infiniopClip(
......
...@@ -51,8 +51,9 @@ __C infiniStatus_t infiniopCreateLogSoftmaxDescriptor( ...@@ -51,8 +51,9 @@ __C infiniStatus_t infiniopCreateLogSoftmaxDescriptor(
#ifdef ENABLE_ASCEND_API #ifdef ENABLE_ASCEND_API
// CREATE(INFINI_DEVICE_ASCEND, ascend) // CREATE(INFINI_DEVICE_ASCEND, ascend)
#endif #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
__C infiniStatus_t infiniopGetLogSoftmaxWorkspaceSize(infiniopLogSoftmaxDescriptor_t desc, size_t *size) { __C infiniStatus_t infiniopGetLogSoftmaxWorkspaceSize(infiniopLogSoftmaxDescriptor_t desc, size_t *size) {
...@@ -84,8 +85,9 @@ __C infiniStatus_t infiniopGetLogSoftmaxWorkspaceSize(infiniopLogSoftmaxDescript ...@@ -84,8 +85,9 @@ __C infiniStatus_t infiniopGetLogSoftmaxWorkspaceSize(infiniopLogSoftmaxDescript
#ifdef ENABLE_ASCEND_API #ifdef ENABLE_ASCEND_API
// GET(INFINI_DEVICE_ASCEND, ascend) // GET(INFINI_DEVICE_ASCEND, ascend)
#endif #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
__C infiniStatus_t infiniopLogSoftmax( __C infiniStatus_t infiniopLogSoftmax(
...@@ -122,8 +124,9 @@ __C infiniStatus_t infiniopLogSoftmax( ...@@ -122,8 +124,9 @@ __C infiniStatus_t infiniopLogSoftmax(
#ifdef ENABLE_ASCEND_API #ifdef ENABLE_ASCEND_API
// CALCULATE(INFINI_DEVICE_ASCEND, ascend) // CALCULATE(INFINI_DEVICE_ASCEND, ascend)
#endif #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
__C infiniStatus_t infiniopDestroyLogSoftmaxDescriptor(infiniopLogSoftmaxDescriptor_t desc) { __C infiniStatus_t infiniopDestroyLogSoftmaxDescriptor(infiniopLogSoftmaxDescriptor_t desc) {
...@@ -155,6 +158,7 @@ __C infiniStatus_t infiniopDestroyLogSoftmaxDescriptor(infiniopLogSoftmaxDescrip ...@@ -155,6 +158,7 @@ __C infiniStatus_t infiniopDestroyLogSoftmaxDescriptor(infiniopLogSoftmaxDescrip
#ifdef ENABLE_ASCEND_API #ifdef ENABLE_ASCEND_API
// DESTROY(INFINI_DEVICE_ASCEND, ascend) // DESTROY(INFINI_DEVICE_ASCEND, ascend)
#endif #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
...@@ -38,7 +38,11 @@ __device__ void pagedCachingKernel( ...@@ -38,7 +38,11 @@ __device__ void pagedCachingKernel(
const ptrdiff_t k_src_stride, // Stride between tokens in the source K tensor const ptrdiff_t k_src_stride, // Stride between tokens in the source K tensor
const ptrdiff_t v_src_stride, // Stride between tokens in the source V tensor const ptrdiff_t v_src_stride, // Stride between tokens in the source V tensor
const ptrdiff_t k_cache_block_stride, // Stride between blocks in the K cache pool const ptrdiff_t k_cache_block_stride, // Stride between blocks in the K cache pool
const ptrdiff_t v_cache_block_stride // Stride between blocks in the V cache pool const ptrdiff_t v_cache_block_stride, // Stride between blocks in the V cache pool
const ptrdiff_t k_cache_head_stride, // Stride between heads in the K cache pool
const ptrdiff_t v_cache_head_stride, // Stride between heads in the V cache pool
const ptrdiff_t k_cache_slot_stride, // Stride between block slots in the K cache pool
const ptrdiff_t v_cache_slot_stride // Stride between block slots in the V cache pool
) { ) {
//================================================================================ //================================================================================
// 1. Identify Work Unit & Calculate Addresses // 1. Identify Work Unit & Calculate Addresses
...@@ -66,13 +70,11 @@ __device__ void pagedCachingKernel( ...@@ -66,13 +70,11 @@ __device__ void pagedCachingKernel(
// Destination pointer calculation assumes a [num_blocks, block_size, num_heads, head_size] layout. // Destination pointer calculation assumes a [num_blocks, block_size, num_heads, head_size] layout.
// We point to the beginning of the memory region for this token's slot. // We point to the beginning of the memory region for this token's slot.
const ptrdiff_t cache_head_stride = block_size * head_size;
Tdata *k_cache_block_base_ptr = k_cache_ptr + physical_block_idx * k_cache_block_stride; Tdata *k_cache_block_base_ptr = k_cache_ptr + physical_block_idx * k_cache_block_stride;
Tdata *k_dst_head_ptr = k_cache_block_base_ptr + head_idx * cache_head_stride + block_offset * head_size; Tdata *k_dst_head_ptr = k_cache_block_base_ptr + head_idx * k_cache_head_stride + block_offset * k_cache_slot_stride;
Tdata *v_cache_block_base_ptr = v_cache_ptr + physical_block_idx * v_cache_block_stride; Tdata *v_cache_block_base_ptr = v_cache_ptr + physical_block_idx * v_cache_block_stride;
Tdata *v_dst_head_ptr = v_cache_block_base_ptr + head_idx * cache_head_stride + block_offset * head_size; Tdata *v_dst_head_ptr = v_cache_block_base_ptr + head_idx * v_cache_head_stride + block_offset * v_cache_slot_stride;
//================================================================================ //================================================================================
// 2. Perform Element-wise Data Copy (Safe, Non-Vectorized) // 2. Perform Element-wise Data Copy (Safe, Non-Vectorized)
......
...@@ -26,6 +26,10 @@ public: ...@@ -26,6 +26,10 @@ public:
ptrdiff_t v_src_stride; ptrdiff_t v_src_stride;
ptrdiff_t k_cache_block_stride; ptrdiff_t k_cache_block_stride;
ptrdiff_t v_cache_block_stride; ptrdiff_t v_cache_block_stride;
ptrdiff_t k_cache_head_stride;
ptrdiff_t v_cache_head_stride;
ptrdiff_t k_cache_slot_stride;
ptrdiff_t v_cache_slot_stride;
static utils::Result<PagedCachingInfo> create( static utils::Result<PagedCachingInfo> create(
infiniopTensorDescriptor_t k_cache_desc, infiniopTensorDescriptor_t k_cache_desc,
...@@ -63,6 +67,10 @@ public: ...@@ -63,6 +67,10 @@ public:
ptrdiff_t v_src_stride = v_desc->stride(0); ptrdiff_t v_src_stride = v_desc->stride(0);
ptrdiff_t k_cache_block_stride = k_cache_desc->stride(0); ptrdiff_t k_cache_block_stride = k_cache_desc->stride(0);
ptrdiff_t v_cache_block_stride = v_cache_desc->stride(0); ptrdiff_t v_cache_block_stride = v_cache_desc->stride(0);
ptrdiff_t k_cache_head_stride = k_cache_desc->stride(1);
ptrdiff_t v_cache_head_stride = v_cache_desc->stride(1);
ptrdiff_t k_cache_slot_stride = k_cache_desc->stride(2);
ptrdiff_t v_cache_slot_stride = v_cache_desc->stride(2);
return utils::Result<PagedCachingInfo>(PagedCachingInfo{ return utils::Result<PagedCachingInfo>(PagedCachingInfo{
dtype, dtype,
...@@ -73,7 +81,11 @@ public: ...@@ -73,7 +81,11 @@ public:
k_src_stride, k_src_stride,
v_src_stride, v_src_stride,
k_cache_block_stride, k_cache_block_stride,
v_cache_block_stride}); v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride});
} }
}; };
......
...@@ -10,10 +10,13 @@ INFINIOP_METAX_KERNEL pagedCaching( ...@@ -10,10 +10,13 @@ INFINIOP_METAX_KERNEL pagedCaching(
const int64_t *slot_mapping, const int64_t *slot_mapping,
const size_t head_size, const size_t block_size, const size_t head_size, const size_t block_size,
const ptrdiff_t k_src_stride, const ptrdiff_t v_src_stride, const ptrdiff_t k_src_stride, const ptrdiff_t v_src_stride,
const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride) { const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride,
const ptrdiff_t k_cache_head_stride, const ptrdiff_t v_cache_head_stride,
const ptrdiff_t k_cache_slot_stride, const ptrdiff_t v_cache_slot_strid) {
op::paged_caching::cuda::pagedCachingKernel<Tdata, NUM_THREADS>( op::paged_caching::cuda::pagedCachingKernel<Tdata, NUM_THREADS>(
k_cache, v_cache, k, v, slot_mapping, head_size, k_cache, v_cache, k, v, slot_mapping, head_size,
block_size, k_src_stride, v_src_stride, k_cache_block_stride, v_cache_block_stride); block_size, k_src_stride, v_src_stride,
k_cache_block_stride, v_cache_block_stride, k_cache_head_stride, v_cache_head_stride, k_cache_slot_stride, v_cache_slot_stride);
} }
namespace op::paged_caching::metax { namespace op::paged_caching::metax {
...@@ -59,6 +62,8 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info, ...@@ -59,6 +62,8 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
size_t num_tokens, size_t num_kv_heads, size_t head_size, size_t block_size, size_t num_tokens, size_t num_kv_heads, size_t head_size, size_t block_size,
ptrdiff_t k_src_stride, ptrdiff_t v_src_stride, ptrdiff_t k_src_stride, ptrdiff_t v_src_stride,
ptrdiff_t k_cache_block_stride, ptrdiff_t v_cache_block_stride, ptrdiff_t k_cache_block_stride, ptrdiff_t v_cache_block_stride,
ptrdiff_t k_cache_head_stride, ptrdiff_t v_cache_head_stride,
ptrdiff_t k_cache_slot_stride, ptrdiff_t v_cache_slot_stride,
hcStream_t stream) { hcStream_t stream) {
// Grid dimension is 1D, with one block per token, as we decided. // Grid dimension is 1D, with one block per token, as we decided.
...@@ -83,7 +88,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info, ...@@ -83,7 +88,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride, k_src_stride,
v_src_stride, v_src_stride,
k_cache_block_stride, k_cache_block_stride,
v_cache_block_stride); v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride);
} else if (dtype == INFINI_DTYPE_BF16) { } else if (dtype == INFINI_DTYPE_BF16) {
pagedCaching<cuda_bfloat16, NUM_THREADS> pagedCaching<cuda_bfloat16, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>( <<<grid, block, shared_mem_size, stream>>>(
...@@ -97,7 +106,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info, ...@@ -97,7 +106,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride, k_src_stride,
v_src_stride, v_src_stride,
k_cache_block_stride, k_cache_block_stride,
v_cache_block_stride); v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride);
} else if (dtype == INFINI_DTYPE_F32) { } else if (dtype == INFINI_DTYPE_F32) {
pagedCaching<float, NUM_THREADS> pagedCaching<float, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>( <<<grid, block, shared_mem_size, stream>>>(
...@@ -111,7 +124,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info, ...@@ -111,7 +124,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride, k_src_stride,
v_src_stride, v_src_stride,
k_cache_block_stride, k_cache_block_stride,
v_cache_block_stride); v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride);
} else { } else {
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
...@@ -138,6 +155,8 @@ infiniStatus_t Descriptor::calculate( ...@@ -138,6 +155,8 @@ infiniStatus_t Descriptor::calculate(
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size, _info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride, _info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride, _info.k_cache_block_stride, _info.v_cache_block_stride,
_info.k_cache_head_stride, _info.v_cache_head_stride,
_info.k_cache_slot_stride, _info.v_cache_slot_stride,
stream); stream);
} else if (max_threads >= METAX_BLOCK_SIZE_512) { } else if (max_threads >= METAX_BLOCK_SIZE_512) {
launchKernel<METAX_BLOCK_SIZE_512>( launchKernel<METAX_BLOCK_SIZE_512>(
...@@ -145,6 +164,8 @@ infiniStatus_t Descriptor::calculate( ...@@ -145,6 +164,8 @@ infiniStatus_t Descriptor::calculate(
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size, _info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride, _info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride, _info.k_cache_block_stride, _info.v_cache_block_stride,
_info.k_cache_head_stride, _info.v_cache_head_stride,
_info.k_cache_slot_stride, _info.v_cache_slot_stride,
stream); stream);
} else { } else {
// If the device supports fewer threads, return an error. // If the device supports fewer threads, return an error.
......
...@@ -10,10 +10,13 @@ INFINIOP_MOORE_KERNEL pagedCaching( ...@@ -10,10 +10,13 @@ INFINIOP_MOORE_KERNEL pagedCaching(
const int64_t *slot_mapping, const int64_t *slot_mapping,
const size_t head_size, const size_t block_size, const size_t head_size, const size_t block_size,
const ptrdiff_t k_src_stride, const ptrdiff_t v_src_stride, const ptrdiff_t k_src_stride, const ptrdiff_t v_src_stride,
const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride) { const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride,
const ptrdiff_t k_cache_head_stride, const ptrdiff_t v_cache_head_stride,
const ptrdiff_t k_cache_slot_stride, const ptrdiff_t v_cache_slot_stride) {
op::paged_caching::cuda::pagedCachingKernel<Tdata, NUM_THREADS>( op::paged_caching::cuda::pagedCachingKernel<Tdata, NUM_THREADS>(
k_cache, v_cache, k, v, slot_mapping, head_size, k_cache, v_cache, k, v, slot_mapping, head_size,
block_size, k_src_stride, v_src_stride, k_cache_block_stride, v_cache_block_stride); block_size, k_src_stride, v_src_stride,
k_cache_block_stride, v_cache_block_stride, k_cache_head_stride, v_cache_head_stride, k_cache_slot_stride, v_cache_slot_stride);
} }
namespace op::paged_caching::moore { namespace op::paged_caching::moore {
...@@ -59,6 +62,8 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info, ...@@ -59,6 +62,8 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
size_t num_tokens, size_t num_kv_heads, size_t head_size, size_t block_size, size_t num_tokens, size_t num_kv_heads, size_t head_size, size_t block_size,
ptrdiff_t k_src_stride, ptrdiff_t v_src_stride, ptrdiff_t k_src_stride, ptrdiff_t v_src_stride,
ptrdiff_t k_cache_block_stride, ptrdiff_t v_cache_block_stride, ptrdiff_t k_cache_block_stride, ptrdiff_t v_cache_block_stride,
ptrdiff_t k_cache_head_stride, ptrdiff_t v_cache_head_stride,
ptrdiff_t k_cache_slot_stride, ptrdiff_t v_cache_slot_stride,
musaStream_t stream) { musaStream_t stream) {
// Grid dimension is 1D, with one block per token, as we decided. // Grid dimension is 1D, with one block per token, as we decided.
...@@ -83,7 +88,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info, ...@@ -83,7 +88,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride, k_src_stride,
v_src_stride, v_src_stride,
k_cache_block_stride, k_cache_block_stride,
v_cache_block_stride); v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride);
} else if (dtype == INFINI_DTYPE_BF16) { } else if (dtype == INFINI_DTYPE_BF16) {
pagedCaching<__mt_bfloat16, NUM_THREADS> pagedCaching<__mt_bfloat16, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>( <<<grid, block, shared_mem_size, stream>>>(
...@@ -97,7 +106,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info, ...@@ -97,7 +106,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride, k_src_stride,
v_src_stride, v_src_stride,
k_cache_block_stride, k_cache_block_stride,
v_cache_block_stride); v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride);
} else if (dtype == INFINI_DTYPE_F32) { } else if (dtype == INFINI_DTYPE_F32) {
pagedCaching<float, NUM_THREADS> pagedCaching<float, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>( <<<grid, block, shared_mem_size, stream>>>(
...@@ -111,7 +124,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info, ...@@ -111,7 +124,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride, k_src_stride,
v_src_stride, v_src_stride,
k_cache_block_stride, k_cache_block_stride,
v_cache_block_stride); v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride);
} else { } else {
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
...@@ -137,6 +154,8 @@ infiniStatus_t Descriptor::calculate( ...@@ -137,6 +154,8 @@ infiniStatus_t Descriptor::calculate(
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size, _info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride, _info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride, _info.k_cache_block_stride, _info.v_cache_block_stride,
_info.k_cache_head_stride, _info.v_cache_head_stride,
_info.k_cache_slot_stride, _info.v_cache_slot_stride,
stream); stream);
} else if (_opaque->internal->maxThreadsPerBlock() >= MOORE_BLOCK_SIZE_512) { } else if (_opaque->internal->maxThreadsPerBlock() >= MOORE_BLOCK_SIZE_512) {
launchKernel<MOORE_BLOCK_SIZE_512>( launchKernel<MOORE_BLOCK_SIZE_512>(
...@@ -144,6 +163,8 @@ infiniStatus_t Descriptor::calculate( ...@@ -144,6 +163,8 @@ infiniStatus_t Descriptor::calculate(
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size, _info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride, _info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride, _info.k_cache_block_stride, _info.v_cache_block_stride,
_info.k_cache_head_stride, _info.v_cache_head_stride,
_info.k_cache_slot_stride, _info.v_cache_slot_stride,
stream); stream);
} else { } else {
// If the GPU is older and supports fewer threads, return an error. // If the GPU is older and supports fewer threads, return an error.
......
...@@ -10,10 +10,13 @@ INFINIOP_CUDA_KERNEL pagedCaching( ...@@ -10,10 +10,13 @@ INFINIOP_CUDA_KERNEL pagedCaching(
const int64_t *slot_mapping, const int64_t *slot_mapping,
const size_t head_size, const size_t block_size, const size_t head_size, const size_t block_size,
const ptrdiff_t k_src_stride, const ptrdiff_t v_src_stride, const ptrdiff_t k_src_stride, const ptrdiff_t v_src_stride,
const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride) { const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride,
const ptrdiff_t k_cache_head_stride, const ptrdiff_t v_cache_head_stride,
const ptrdiff_t k_cache_slot_stride, const ptrdiff_t v_cache_slot_stride) {
op::paged_caching::cuda::pagedCachingKernel<Tdata, NUM_THREADS>( op::paged_caching::cuda::pagedCachingKernel<Tdata, NUM_THREADS>(
k_cache, v_cache, k, v, slot_mapping, head_size, k_cache, v_cache, k, v, slot_mapping, head_size,
block_size, k_src_stride, v_src_stride, k_cache_block_stride, v_cache_block_stride); block_size, k_src_stride, v_src_stride,
k_cache_block_stride, v_cache_block_stride, k_cache_head_stride, v_cache_head_stride, k_cache_slot_stride, v_cache_slot_stride);
} }
namespace op::paged_caching::nvidia { namespace op::paged_caching::nvidia {
...@@ -59,6 +62,8 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info, ...@@ -59,6 +62,8 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
size_t num_tokens, size_t num_kv_heads, size_t head_size, size_t block_size, size_t num_tokens, size_t num_kv_heads, size_t head_size, size_t block_size,
ptrdiff_t k_src_stride, ptrdiff_t v_src_stride, ptrdiff_t k_src_stride, ptrdiff_t v_src_stride,
ptrdiff_t k_cache_block_stride, ptrdiff_t v_cache_block_stride, ptrdiff_t k_cache_block_stride, ptrdiff_t v_cache_block_stride,
ptrdiff_t k_cache_head_stride, ptrdiff_t v_cache_head_stride,
ptrdiff_t k_cache_slot_stride, ptrdiff_t v_cache_slot_stride,
cudaStream_t stream) { cudaStream_t stream) {
// Grid dimension is 1D, with one block per token, as we decided. // Grid dimension is 1D, with one block per token, as we decided.
...@@ -83,7 +88,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info, ...@@ -83,7 +88,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride, k_src_stride,
v_src_stride, v_src_stride,
k_cache_block_stride, k_cache_block_stride,
v_cache_block_stride); v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride);
} else if (dtype == INFINI_DTYPE_BF16) { } else if (dtype == INFINI_DTYPE_BF16) {
pagedCaching<__nv_bfloat16, NUM_THREADS> pagedCaching<__nv_bfloat16, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>( <<<grid, block, shared_mem_size, stream>>>(
...@@ -97,7 +106,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info, ...@@ -97,7 +106,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride, k_src_stride,
v_src_stride, v_src_stride,
k_cache_block_stride, k_cache_block_stride,
v_cache_block_stride); v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride);
} else if (dtype == INFINI_DTYPE_F32) { } else if (dtype == INFINI_DTYPE_F32) {
pagedCaching<float, NUM_THREADS> pagedCaching<float, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>( <<<grid, block, shared_mem_size, stream>>>(
...@@ -111,7 +124,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info, ...@@ -111,7 +124,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride, k_src_stride,
v_src_stride, v_src_stride,
k_cache_block_stride, k_cache_block_stride,
v_cache_block_stride); v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride);
} else { } else {
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
...@@ -137,6 +154,8 @@ infiniStatus_t Descriptor::calculate( ...@@ -137,6 +154,8 @@ infiniStatus_t Descriptor::calculate(
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size, _info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride, _info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride, _info.k_cache_block_stride, _info.v_cache_block_stride,
_info.k_cache_head_stride, _info.v_cache_head_stride,
_info.k_cache_slot_stride, _info.v_cache_slot_stride,
stream); stream);
} else if (_opaque->internal->maxThreadsPerBlock() >= CUDA_BLOCK_SIZE_512) { } else if (_opaque->internal->maxThreadsPerBlock() >= CUDA_BLOCK_SIZE_512) {
launchKernel<CUDA_BLOCK_SIZE_512>( launchKernel<CUDA_BLOCK_SIZE_512>(
...@@ -144,6 +163,8 @@ infiniStatus_t Descriptor::calculate( ...@@ -144,6 +163,8 @@ infiniStatus_t Descriptor::calculate(
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size, _info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride, _info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride, _info.k_cache_block_stride, _info.v_cache_block_stride,
_info.k_cache_head_stride, _info.v_cache_head_stride,
_info.k_cache_slot_stride, _info.v_cache_slot_stride,
stream); stream);
} else if (_opaque->internal->maxThreadsPerBlock() >= CUDA_BLOCK_SIZE_4096) { } else if (_opaque->internal->maxThreadsPerBlock() >= CUDA_BLOCK_SIZE_4096) {
launchKernel<CUDA_BLOCK_SIZE_4096>( launchKernel<CUDA_BLOCK_SIZE_4096>(
...@@ -151,6 +172,8 @@ infiniStatus_t Descriptor::calculate( ...@@ -151,6 +172,8 @@ infiniStatus_t Descriptor::calculate(
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size, _info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride, _info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride, _info.k_cache_block_stride, _info.v_cache_block_stride,
_info.k_cache_head_stride, _info.v_cache_head_stride,
_info.k_cache_slot_stride, _info.v_cache_slot_stride,
stream); stream);
} else { } else {
// If the GPU is older and supports fewer threads, return an error. // If the GPU is older and supports fewer threads, return an error.
......
...@@ -76,11 +76,11 @@ __C infiniStatus_t infiniopCreateRMSNormDescriptor( ...@@ -76,11 +76,11 @@ __C infiniStatus_t infiniopCreateRMSNormDescriptor(
#ifdef ENABLE_MOORE_API #ifdef ENABLE_MOORE_API
CREATE(INFINI_DEVICE_MOORE, moore); CREATE(INFINI_DEVICE_MOORE, moore);
#endif #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
#undef CREATE #undef CREATE
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
__C infiniStatus_t infiniopGetRMSNormWorkspaceSize(infiniopRMSNormDescriptor_t desc, size_t *size) { __C infiniStatus_t infiniopGetRMSNormWorkspaceSize(infiniopRMSNormDescriptor_t desc, size_t *size) {
...@@ -124,11 +124,11 @@ __C infiniStatus_t infiniopGetRMSNormWorkspaceSize(infiniopRMSNormDescriptor_t d ...@@ -124,11 +124,11 @@ __C infiniStatus_t infiniopGetRMSNormWorkspaceSize(infiniopRMSNormDescriptor_t d
#ifdef ENABLE_MOORE_API #ifdef ENABLE_MOORE_API
GET(INFINI_DEVICE_MOORE, moore); GET(INFINI_DEVICE_MOORE, moore);
#endif #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
#undef GET #undef GET
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
__C infiniStatus_t infiniopRMSNorm(infiniopRMSNormDescriptor_t desc, void *workspace, size_t workspace_size, __C infiniStatus_t infiniopRMSNorm(infiniopRMSNormDescriptor_t desc, void *workspace, size_t workspace_size,
...@@ -173,11 +173,11 @@ __C infiniStatus_t infiniopRMSNorm(infiniopRMSNormDescriptor_t desc, void *works ...@@ -173,11 +173,11 @@ __C infiniStatus_t infiniopRMSNorm(infiniopRMSNormDescriptor_t desc, void *works
#ifdef ENABLE_MOORE_API #ifdef ENABLE_MOORE_API
CALCULATE(INFINI_DEVICE_MOORE, moore); CALCULATE(INFINI_DEVICE_MOORE, moore);
#endif #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
#undef CALCULATE #undef CALCULATE
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
__C infiniStatus_t infiniopDestroyRMSNormDescriptor(infiniopRMSNormDescriptor_t desc) { __C infiniStatus_t infiniopDestroyRMSNormDescriptor(infiniopRMSNormDescriptor_t desc) {
...@@ -221,9 +221,9 @@ __C infiniStatus_t infiniopDestroyRMSNormDescriptor(infiniopRMSNormDescriptor_t ...@@ -221,9 +221,9 @@ __C infiniStatus_t infiniopDestroyRMSNormDescriptor(infiniopRMSNormDescriptor_t
#ifdef ENABLE_MOORE_API #ifdef ENABLE_MOORE_API
DESTROY(INFINI_DEVICE_MOORE, moore); DESTROY(INFINI_DEVICE_MOORE, moore);
#endif #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
#undef DESTROY #undef DESTROY
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
...@@ -80,11 +80,11 @@ __C infiniStatus_t infiniopCreateRoPEDescriptor( ...@@ -80,11 +80,11 @@ __C infiniStatus_t infiniopCreateRoPEDescriptor(
#ifdef ENABLE_CAMBRICON_API #ifdef ENABLE_CAMBRICON_API
CREATE(INFINI_DEVICE_CAMBRICON, bang); CREATE(INFINI_DEVICE_CAMBRICON, bang);
#endif #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
#undef CREATE #undef CREATE
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
__C infiniStatus_t infiniopGetRoPEWorkspaceSize(infiniopRoPEDescriptor_t desc, __C infiniStatus_t infiniopGetRoPEWorkspaceSize(infiniopRoPEDescriptor_t desc,
...@@ -128,11 +128,11 @@ __C infiniStatus_t infiniopGetRoPEWorkspaceSize(infiniopRoPEDescriptor_t desc, ...@@ -128,11 +128,11 @@ __C infiniStatus_t infiniopGetRoPEWorkspaceSize(infiniopRoPEDescriptor_t desc,
#ifdef ENABLE_ASCEND_API #ifdef ENABLE_ASCEND_API
GET(INFINI_DEVICE_ASCEND, ascend); GET(INFINI_DEVICE_ASCEND, ascend);
#endif #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
#undef GET #undef GET
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
__C infiniStatus_t infiniopRoPE( __C infiniStatus_t infiniopRoPE(
...@@ -185,11 +185,11 @@ __C infiniStatus_t infiniopRoPE( ...@@ -185,11 +185,11 @@ __C infiniStatus_t infiniopRoPE(
#ifdef ENABLE_ASCEND_API #ifdef ENABLE_ASCEND_API
CALCULATE(INFINI_DEVICE_ASCEND, ascend); CALCULATE(INFINI_DEVICE_ASCEND, ascend);
#endif #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
#undef CALCULATE #undef CALCULATE
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
__C infiniStatus_t __C infiniStatus_t
...@@ -234,9 +234,9 @@ infiniopDestroyRoPEDescriptor(infiniopRoPEDescriptor_t desc) { ...@@ -234,9 +234,9 @@ infiniopDestroyRoPEDescriptor(infiniopRoPEDescriptor_t desc) {
#ifdef ENABLE_ASCEND_API #ifdef ENABLE_ASCEND_API
DELETE(INFINI_DEVICE_ASCEND, ascend); DELETE(INFINI_DEVICE_ASCEND, ascend);
#endif #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
#undef DELETE #undef DELETE
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
...@@ -24,8 +24,11 @@ __C infiniStatus_t infiniopCreateSiluAndMulDescriptor( ...@@ -24,8 +24,11 @@ __C infiniStatus_t infiniopCreateSiluAndMulDescriptor(
#ifdef ENABLE_MOORE_API #ifdef ENABLE_MOORE_API
CREATE(INFINI_DEVICE_MOORE, moore); CREATE(INFINI_DEVICE_MOORE, moore);
#endif #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
#undef CREATE
} }
__C infiniStatus_t infiniopGetSiluAndMulWorkspaceSize(infiniopSiluAndMulDescriptor_t desc, size_t *size) { __C infiniStatus_t infiniopGetSiluAndMulWorkspaceSize(infiniopSiluAndMulDescriptor_t desc, size_t *size) {
...@@ -39,8 +42,11 @@ __C infiniStatus_t infiniopGetSiluAndMulWorkspaceSize(infiniopSiluAndMulDescript ...@@ -39,8 +42,11 @@ __C infiniStatus_t infiniopGetSiluAndMulWorkspaceSize(infiniopSiluAndMulDescript
#ifdef ENABLE_MOORE_API #ifdef ENABLE_MOORE_API
GET(INFINI_DEVICE_MOORE, moore); GET(INFINI_DEVICE_MOORE, moore);
#endif #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
#undef GET
} }
__C infiniStatus_t infiniopSiluAndMul( __C infiniStatus_t infiniopSiluAndMul(
...@@ -59,8 +65,11 @@ __C infiniStatus_t infiniopSiluAndMul( ...@@ -59,8 +65,11 @@ __C infiniStatus_t infiniopSiluAndMul(
#ifdef ENABLE_MOORE_API #ifdef ENABLE_MOORE_API
CALCULATE(INFINI_DEVICE_MOORE, moore); CALCULATE(INFINI_DEVICE_MOORE, moore);
#endif #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
#undef CALCULATE
} }
__C infiniStatus_t infiniopDestroySiluAndMulDescriptor(infiniopSiluAndMulDescriptor_t desc) { __C infiniStatus_t infiniopDestroySiluAndMulDescriptor(infiniopSiluAndMulDescriptor_t desc) {
...@@ -74,6 +83,9 @@ __C infiniStatus_t infiniopDestroySiluAndMulDescriptor(infiniopSiluAndMulDescrip ...@@ -74,6 +83,9 @@ __C infiniStatus_t infiniopDestroySiluAndMulDescriptor(infiniopSiluAndMulDescrip
#ifdef ENABLE_MOORE_API #ifdef ENABLE_MOORE_API
DESTROY(INFINI_DEVICE_MOORE, moore); DESTROY(INFINI_DEVICE_MOORE, moore);
#endif #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
#undef DESTROY
} }
...@@ -37,8 +37,9 @@ __C infiniStatus_t infiniopCreateSoftmaxDescriptor( ...@@ -37,8 +37,9 @@ __C infiniStatus_t infiniopCreateSoftmaxDescriptor(
#ifdef ENABLE_ALI_API #ifdef ENABLE_ALI_API
CREATE(INFINI_DEVICE_ALI, nvidia); CREATE(INFINI_DEVICE_ALI, nvidia);
#endif #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
__C infiniStatus_t infiniopGetSoftmaxWorkspaceSize(infiniopSoftmaxDescriptor_t desc, size_t *size) { __C infiniStatus_t infiniopGetSoftmaxWorkspaceSize(infiniopSoftmaxDescriptor_t desc, size_t *size) {
...@@ -64,8 +65,9 @@ __C infiniStatus_t infiniopGetSoftmaxWorkspaceSize(infiniopSoftmaxDescriptor_t d ...@@ -64,8 +65,9 @@ __C infiniStatus_t infiniopGetSoftmaxWorkspaceSize(infiniopSoftmaxDescriptor_t d
#ifdef ENABLE_ALI_API #ifdef ENABLE_ALI_API
GET(INFINI_DEVICE_ALI, nvidia); GET(INFINI_DEVICE_ALI, nvidia);
#endif #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
__C infiniStatus_t infiniopSoftmax( __C infiniStatus_t infiniopSoftmax(
...@@ -96,8 +98,9 @@ __C infiniStatus_t infiniopSoftmax( ...@@ -96,8 +98,9 @@ __C infiniStatus_t infiniopSoftmax(
#ifdef ENABLE_ALI_API #ifdef ENABLE_ALI_API
CALCULATE(INFINI_DEVICE_ALI, nvidia); CALCULATE(INFINI_DEVICE_ALI, nvidia);
#endif #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
__C infiniStatus_t infiniopDestroySoftmaxDescriptor(infiniopSoftmaxDescriptor_t desc) { __C infiniStatus_t infiniopDestroySoftmaxDescriptor(infiniopSoftmaxDescriptor_t desc) {
...@@ -123,6 +126,7 @@ __C infiniStatus_t infiniopDestroySoftmaxDescriptor(infiniopSoftmaxDescriptor_t ...@@ -123,6 +126,7 @@ __C infiniStatus_t infiniopDestroySoftmaxDescriptor(infiniopSoftmaxDescriptor_t
#ifdef ENABLE_ALI_API #ifdef ENABLE_ALI_API
DESTROY(INFINI_DEVICE_ALI, nvidia); DESTROY(INFINI_DEVICE_ALI, nvidia);
#endif #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
...@@ -42,11 +42,11 @@ __C infiniStatus_t infiniopCreateTopkrouterDescriptor(infiniopHandle_t handle, i ...@@ -42,11 +42,11 @@ __C infiniStatus_t infiniopCreateTopkrouterDescriptor(infiniopHandle_t handle, i
#ifdef ENABLE_ALI_API #ifdef ENABLE_ALI_API
CREATE(INFINI_DEVICE_ALI, nvidia); CREATE(INFINI_DEVICE_ALI, nvidia);
#endif #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
#undef CREATE #undef CREATE
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
__C infiniStatus_t infiniopGetTopkrouterWorkspaceSize(infiniopTopkrouterDescriptor_t desc, size_t *size) { __C infiniStatus_t infiniopGetTopkrouterWorkspaceSize(infiniopTopkrouterDescriptor_t desc, size_t *size) {
...@@ -74,11 +74,11 @@ __C infiniStatus_t infiniopGetTopkrouterWorkspaceSize(infiniopTopkrouterDescript ...@@ -74,11 +74,11 @@ __C infiniStatus_t infiniopGetTopkrouterWorkspaceSize(infiniopTopkrouterDescript
#ifdef ENABLE_ALI_API #ifdef ENABLE_ALI_API
GET(INFINI_DEVICE_ALI, nvidia); GET(INFINI_DEVICE_ALI, nvidia);
#endif #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
#undef GET #undef GET
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
__C infiniStatus_t infiniopTopkrouter(infiniopTopkrouterDescriptor_t desc, void *workspace, size_t workspace_size, __C infiniStatus_t infiniopTopkrouter(infiniopTopkrouterDescriptor_t desc, void *workspace, size_t workspace_size,
...@@ -109,11 +109,11 @@ __C infiniStatus_t infiniopTopkrouter(infiniopTopkrouterDescriptor_t desc, void ...@@ -109,11 +109,11 @@ __C infiniStatus_t infiniopTopkrouter(infiniopTopkrouterDescriptor_t desc, void
#ifdef ENABLE_ALI_API #ifdef ENABLE_ALI_API
CALCULATE(INFINI_DEVICE_ALI, nvidia); CALCULATE(INFINI_DEVICE_ALI, nvidia);
#endif #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
#undef CALCULATE #undef CALCULATE
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
__C infiniStatus_t infiniopDestroyTopkrouterDescriptor(infiniopTopkrouterDescriptor_t desc) { __C infiniStatus_t infiniopDestroyTopkrouterDescriptor(infiniopTopkrouterDescriptor_t desc) {
...@@ -141,9 +141,9 @@ __C infiniStatus_t infiniopDestroyTopkrouterDescriptor(infiniopTopkrouterDescrip ...@@ -141,9 +141,9 @@ __C infiniStatus_t infiniopDestroyTopkrouterDescriptor(infiniopTopkrouterDescrip
#ifdef ENABLE_ALI_API #ifdef ENABLE_ALI_API
DESTROY(INFINI_DEVICE_ALI, nvidia); DESTROY(INFINI_DEVICE_ALI, nvidia);
#endif #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
#undef DESTROY #undef DESTROY
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
...@@ -40,11 +40,11 @@ __C infiniStatus_t infiniopCreateTopksoftmaxDescriptor(infiniopHandle_t handle, ...@@ -40,11 +40,11 @@ __C infiniStatus_t infiniopCreateTopksoftmaxDescriptor(infiniopHandle_t handle,
#ifdef ENABLE_ILUVATAR_API #ifdef ENABLE_ILUVATAR_API
CREATE(INFINI_DEVICE_ILUVATAR, nvidia); CREATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
#undef CREATE #undef CREATE
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
__C infiniStatus_t infiniopGetTopksoftmaxWorkspaceSize(infiniopTopksoftmaxDescriptor_t desc, size_t *size) { __C infiniStatus_t infiniopGetTopksoftmaxWorkspaceSize(infiniopTopksoftmaxDescriptor_t desc, size_t *size) {
...@@ -73,11 +73,11 @@ __C infiniStatus_t infiniopGetTopksoftmaxWorkspaceSize(infiniopTopksoftmaxDescri ...@@ -73,11 +73,11 @@ __C infiniStatus_t infiniopGetTopksoftmaxWorkspaceSize(infiniopTopksoftmaxDescri
#ifdef ENABLE_ILUVATAR_API #ifdef ENABLE_ILUVATAR_API
GET(INFINI_DEVICE_ILUVATAR, nvidia); GET(INFINI_DEVICE_ILUVATAR, nvidia);
#endif #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
#undef GET #undef GET
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
__C infiniStatus_t infiniopTopksoftmax(infiniopTopksoftmaxDescriptor_t desc, void *workspace, size_t workspace_size, __C infiniStatus_t infiniopTopksoftmax(infiniopTopksoftmaxDescriptor_t desc, void *workspace, size_t workspace_size,
...@@ -111,11 +111,11 @@ __C infiniStatus_t infiniopTopksoftmax(infiniopTopksoftmaxDescriptor_t desc, voi ...@@ -111,11 +111,11 @@ __C infiniStatus_t infiniopTopksoftmax(infiniopTopksoftmaxDescriptor_t desc, voi
#ifdef ENABLE_ILUVATAR_API #ifdef ENABLE_ILUVATAR_API
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia); CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
#undef CALCULATE #undef CALCULATE
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
__C infiniStatus_t infiniopDestroyTopksoftmaxDescriptor(infiniopTopksoftmaxDescriptor_t desc) { __C infiniStatus_t infiniopDestroyTopksoftmaxDescriptor(infiniopTopksoftmaxDescriptor_t desc) {
...@@ -144,9 +144,9 @@ __C infiniStatus_t infiniopDestroyTopksoftmaxDescriptor(infiniopTopksoftmaxDescr ...@@ -144,9 +144,9 @@ __C infiniStatus_t infiniopDestroyTopksoftmaxDescriptor(infiniopTopksoftmaxDescr
#ifdef ENABLE_ILUVATAR_API #ifdef ENABLE_ILUVATAR_API
DESTROY(INFINI_DEVICE_ILUVATAR, nvidia); DESTROY(INFINI_DEVICE_ILUVATAR, nvidia);
#endif #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
#undef DESTROY #undef DESTROY
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
...@@ -18,12 +18,16 @@ from framework import ( ...@@ -18,12 +18,16 @@ from framework import (
# Operator-specific configuration # Operator-specific configuration
# ============================================================================== # ==============================================================================
# Test cases format: (num_seqs, max_seq_len, num_kv_heads, head_size, block_size) # Test cases format: (num_seqs, max_seq_len, num_kv_heads, head_size, block_size, permute_dim_1_2)
_TEST_CASES_DATA = [ _TEST_CASES_DATA = [
(1, 128, 8, 128, 16), (1, 128, 8, 128, 16, False),
(5, 512, 40, 128, 16), (1, 128, 8, 128, 16, True),
(16, 1024, 8, 64, 32), (5, 512, 40, 256, 16, False),
(10, 1024, 40, 64, 32), (5, 512, 40, 256, 16, True),
(16, 1024, 8, 64, 32, False),
(16, 1024, 8, 64, 32, True),
(10, 1024, 40, 64, 32, False),
(10, 1024, 40, 64, 32, True),
] ]
# Tolerance configuration # Tolerance configuration
...@@ -40,7 +44,9 @@ _TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32] ...@@ -40,7 +44,9 @@ _TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
# ============================================================================== # ==============================================================================
# Reference Implementation # Reference Implementation
# ============================================================================== # ==============================================================================
def ref_paged_caching(key_cache_pool, value_cache_pool, key, value, slot_mapping): def ref_paged_caching(
key_cache_pool, value_cache_pool, key, value, slot_mapping, permute_dim_1_2
):
""" """
Reference implementation for paged_caching operator. Reference implementation for paged_caching operator.
...@@ -52,7 +58,7 @@ def ref_paged_caching(key_cache_pool, value_cache_pool, key, value, slot_mapping ...@@ -52,7 +58,7 @@ def ref_paged_caching(key_cache_pool, value_cache_pool, key, value, slot_mapping
slot_mapping (torch.Tensor): Slot mapping, shape [ntok] slot_mapping (torch.Tensor): Slot mapping, shape [ntok]
""" """
ntok = key.shape[0] ntok = key.shape[0]
block_size = key_cache_pool.shape[2] block_size = key_cache_pool.shape[1] if permute_dim_1_2 else key_cache_pool.shape[2]
# This reference implementation operates on a cloned cache to avoid modifying the original input tensor, # This reference implementation operates on a cloned cache to avoid modifying the original input tensor,
# mimicking the behavior where the custom operator writes to its output tensor. # mimicking the behavior where the custom operator writes to its output tensor.
...@@ -67,8 +73,12 @@ def ref_paged_caching(key_cache_pool, value_cache_pool, key, value, slot_mapping ...@@ -67,8 +73,12 @@ def ref_paged_caching(key_cache_pool, value_cache_pool, key, value, slot_mapping
key_token = key[i] key_token = key[i]
value_token = value[i] value_token = value[i]
k_cache_ref[block_idx, :, block_offset, :] = key_token if permute_dim_1_2:
v_cache_ref[block_idx, :, block_offset, :] = value_token k_cache_ref[block_idx, block_offset, :, :] = key_token
v_cache_ref[block_idx, block_offset, :, :] = value_token
else:
k_cache_ref[block_idx, :, block_offset, :] = key_token
v_cache_ref[block_idx, :, block_offset, :] = value_token
return k_cache_ref, v_cache_ref return k_cache_ref, v_cache_ref
...@@ -79,7 +89,14 @@ def parse_test_cases(): ...@@ -79,7 +89,14 @@ def parse_test_cases():
Each test case contains all necessary information for execution and validation. Each test case contains all necessary information for execution and validation.
""" """
test_cases = [] test_cases = []
for num_seqs, max_seq_len, num_kv_heads, head_size, block_size in _TEST_CASES_DATA: for (
num_seqs,
max_seq_len,
num_kv_heads,
head_size,
block_size,
permute_dim_1_2,
) in _TEST_CASES_DATA:
num_blocks = 4096 # A reasonably large cache pool for testing num_blocks = 4096 # A reasonably large cache pool for testing
# Create metadata: variable context lengths for each sequence in the batch # Create metadata: variable context lengths for each sequence in the batch
...@@ -111,6 +128,9 @@ def parse_test_cases(): ...@@ -111,6 +128,9 @@ def parse_test_cases():
v_shape = (ntok, num_kv_heads, head_size) v_shape = (ntok, num_kv_heads, head_size)
k_cache_shape = (num_blocks, num_kv_heads, block_size, head_size) k_cache_shape = (num_blocks, num_kv_heads, block_size, head_size)
v_cache_shape = (num_blocks, num_kv_heads, block_size, head_size) v_cache_shape = (num_blocks, num_kv_heads, block_size, head_size)
if permute_dim_1_2:
k_cache_shape = (num_blocks, block_size, num_kv_heads, head_size)
v_cache_shape = (num_blocks, block_size, num_kv_heads, head_size)
# Generate test cases for all data types # Generate test cases for all data types
for dtype in _TENSOR_DTYPES: for dtype in _TENSOR_DTYPES:
...@@ -142,7 +162,7 @@ def parse_test_cases(): ...@@ -142,7 +162,7 @@ def parse_test_cases():
v_spec, v_spec,
slot_mapping_spec, slot_mapping_spec,
], ],
kwargs=None, kwargs={"permute_dim_1_2": permute_dim_1_2},
output_spec=None, output_spec=None,
comparison_target=0, # Only compare k_cache comparison_target=0, # Only compare k_cache
tolerance=tolerance, tolerance=tolerance,
...@@ -162,13 +182,22 @@ class OpTest(BaseOperatorTest): ...@@ -162,13 +182,22 @@ class OpTest(BaseOperatorTest):
def get_test_cases(self): def get_test_cases(self):
return parse_test_cases() return parse_test_cases()
def torch_operator(self, *args, **kwargs): def torch_operator(
self, k_cache, v_cache, key, value, slot_mapping, permute_dim_1_2=False
):
"""PyTorch paged_caching implementation""" """PyTorch paged_caching implementation"""
return ref_paged_caching(*args, **kwargs) return ref_paged_caching(
k_cache, v_cache, key, value, slot_mapping, permute_dim_1_2
)
def infinicore_operator(self, *args, **kwargs): def infinicore_operator(
self, k_cache, v_cache, key, value, slot_mapping, permute_dim_1_2=False
):
"""InfiniCore paged_caching implementation""" """InfiniCore paged_caching implementation"""
return infinicore.paged_caching(*args, **kwargs) if permute_dim_1_2:
k_cache = k_cache.permute([0, 2, 1, 3])
v_cache = v_cache.permute([0, 2, 1, 3])
return infinicore.paged_caching(k_cache, v_cache, key, value, slot_mapping)
def main(): def main():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment