Commit 67425576 authored by PanZezhong's avatar PanZezhong
Browse files

issue/1036 paged caching support strides

parent d8176086
......@@ -38,7 +38,11 @@ __device__ void pagedCachingKernel(
const ptrdiff_t k_src_stride, // Stride between tokens in the source K tensor
const ptrdiff_t v_src_stride, // Stride between tokens in the source V tensor
const ptrdiff_t k_cache_block_stride, // Stride between blocks in the K cache pool
const ptrdiff_t v_cache_block_stride // Stride between blocks in the V cache pool
const ptrdiff_t v_cache_block_stride, // Stride between blocks in the V cache pool
const ptrdiff_t k_cache_head_stride, // Stride between heads in the K cache pool
const ptrdiff_t v_cache_head_stride, // Stride between heads in the V cache pool
const ptrdiff_t k_cache_slot_stride, // Stride between block slots in the K cache pool
const ptrdiff_t v_cache_slot_stride // Stride between block slots in the V cache pool
) {
//================================================================================
// 1. Identify Work Unit & Calculate Addresses
......@@ -66,13 +70,11 @@ __device__ void pagedCachingKernel(
// Destination pointer calculation assumes a [num_blocks, block_size, num_heads, head_size] layout.
// We point to the beginning of the memory region for this token's slot.
const ptrdiff_t cache_head_stride = block_size * head_size;
Tdata *k_cache_block_base_ptr = k_cache_ptr + physical_block_idx * k_cache_block_stride;
Tdata *k_dst_head_ptr = k_cache_block_base_ptr + head_idx * cache_head_stride + block_offset * head_size;
Tdata *k_dst_head_ptr = k_cache_block_base_ptr + head_idx * k_cache_head_stride + block_offset * k_cache_slot_stride;
Tdata *v_cache_block_base_ptr = v_cache_ptr + physical_block_idx * v_cache_block_stride;
Tdata *v_dst_head_ptr = v_cache_block_base_ptr + head_idx * cache_head_stride + block_offset * head_size;
Tdata *v_dst_head_ptr = v_cache_block_base_ptr + head_idx * v_cache_head_stride + block_offset * v_cache_slot_stride;
//================================================================================
// 2. Perform Element-wise Data Copy (Safe, Non-Vectorized)
......
......@@ -26,6 +26,10 @@ public:
ptrdiff_t v_src_stride;
ptrdiff_t k_cache_block_stride;
ptrdiff_t v_cache_block_stride;
ptrdiff_t k_cache_head_stride;
ptrdiff_t v_cache_head_stride;
ptrdiff_t k_cache_slot_stride;
ptrdiff_t v_cache_slot_stride;
static utils::Result<PagedCachingInfo> create(
infiniopTensorDescriptor_t k_cache_desc,
......@@ -63,6 +67,10 @@ public:
ptrdiff_t v_src_stride = v_desc->stride(0);
ptrdiff_t k_cache_block_stride = k_cache_desc->stride(0);
ptrdiff_t v_cache_block_stride = v_cache_desc->stride(0);
ptrdiff_t k_cache_head_stride = k_cache_desc->stride(1);
ptrdiff_t v_cache_head_stride = v_cache_desc->stride(1);
ptrdiff_t k_cache_slot_stride = k_cache_desc->stride(2);
ptrdiff_t v_cache_slot_stride = v_cache_desc->stride(2);
return utils::Result<PagedCachingInfo>(PagedCachingInfo{
dtype,
......@@ -73,7 +81,11 @@ public:
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride});
v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride});
}
};
......
......@@ -10,10 +10,13 @@ INFINIOP_METAX_KERNEL pagedCaching(
const int64_t *slot_mapping,
const size_t head_size, const size_t block_size,
const ptrdiff_t k_src_stride, const ptrdiff_t v_src_stride,
const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride) {
const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride,
const ptrdiff_t k_cache_head_stride, const ptrdiff_t v_cache_head_stride,
const ptrdiff_t k_cache_slot_stride, const ptrdiff_t v_cache_slot_strid) {
op::paged_caching::cuda::pagedCachingKernel<Tdata, NUM_THREADS>(
k_cache, v_cache, k, v, slot_mapping, head_size,
block_size, k_src_stride, v_src_stride, k_cache_block_stride, v_cache_block_stride);
block_size, k_src_stride, v_src_stride,
k_cache_block_stride, v_cache_block_stride, k_cache_head_stride, v_cache_head_stride, k_cache_slot_stride, v_cache_slot_stride);
}
namespace op::paged_caching::metax {
......@@ -59,6 +62,8 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
size_t num_tokens, size_t num_kv_heads, size_t head_size, size_t block_size,
ptrdiff_t k_src_stride, ptrdiff_t v_src_stride,
ptrdiff_t k_cache_block_stride, ptrdiff_t v_cache_block_stride,
ptrdiff_t k_cache_head_stride, ptrdiff_t v_cache_head_stride,
ptrdiff_t k_cache_slot_stride, ptrdiff_t v_cache_slot_stride,
hcStream_t stream) {
// Grid dimension is 1D, with one block per token, as we decided.
......@@ -83,7 +88,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride);
} else if (dtype == INFINI_DTYPE_BF16) {
pagedCaching<cuda_bfloat16, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
......@@ -97,7 +106,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride);
} else if (dtype == INFINI_DTYPE_F32) {
pagedCaching<float, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
......@@ -111,7 +124,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
......@@ -138,6 +155,8 @@ infiniStatus_t Descriptor::calculate(
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride,
_info.k_cache_head_stride, _info.v_cache_head_stride,
_info.k_cache_slot_stride, _info.v_cache_slot_stride,
stream);
} else if (max_threads >= METAX_BLOCK_SIZE_512) {
launchKernel<METAX_BLOCK_SIZE_512>(
......@@ -145,6 +164,8 @@ infiniStatus_t Descriptor::calculate(
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride,
_info.k_cache_head_stride, _info.v_cache_head_stride,
_info.k_cache_slot_stride, _info.v_cache_slot_stride,
stream);
} else {
// If the device supports fewer threads, return an error.
......
......@@ -10,10 +10,13 @@ INFINIOP_MOORE_KERNEL pagedCaching(
const int64_t *slot_mapping,
const size_t head_size, const size_t block_size,
const ptrdiff_t k_src_stride, const ptrdiff_t v_src_stride,
const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride) {
const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride,
const ptrdiff_t k_cache_head_stride, const ptrdiff_t v_cache_head_stride,
const ptrdiff_t k_cache_slot_stride, const ptrdiff_t v_cache_slot_stride) {
op::paged_caching::cuda::pagedCachingKernel<Tdata, NUM_THREADS>(
k_cache, v_cache, k, v, slot_mapping, head_size,
block_size, k_src_stride, v_src_stride, k_cache_block_stride, v_cache_block_stride);
block_size, k_src_stride, v_src_stride,
k_cache_block_stride, v_cache_block_stride, k_cache_head_stride, v_cache_head_stride, k_cache_slot_stride, v_cache_slot_stride);
}
namespace op::paged_caching::moore {
......@@ -59,6 +62,8 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
size_t num_tokens, size_t num_kv_heads, size_t head_size, size_t block_size,
ptrdiff_t k_src_stride, ptrdiff_t v_src_stride,
ptrdiff_t k_cache_block_stride, ptrdiff_t v_cache_block_stride,
ptrdiff_t k_cache_head_stride, ptrdiff_t v_cache_head_stride,
ptrdiff_t k_cache_slot_stride, ptrdiff_t v_cache_slot_stride,
musaStream_t stream) {
// Grid dimension is 1D, with one block per token, as we decided.
......@@ -83,7 +88,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride);
} else if (dtype == INFINI_DTYPE_BF16) {
pagedCaching<__mt_bfloat16, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
......@@ -97,7 +106,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride);
} else if (dtype == INFINI_DTYPE_F32) {
pagedCaching<float, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
......@@ -111,7 +124,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
......@@ -137,6 +154,8 @@ infiniStatus_t Descriptor::calculate(
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride,
_info.k_cache_head_stride, _info.v_cache_head_stride,
_info.k_cache_slot_stride, _info.v_cache_slot_stride,
stream);
} else if (_opaque->internal->maxThreadsPerBlock() >= MOORE_BLOCK_SIZE_512) {
launchKernel<MOORE_BLOCK_SIZE_512>(
......@@ -144,6 +163,8 @@ infiniStatus_t Descriptor::calculate(
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride,
_info.k_cache_head_stride, _info.v_cache_head_stride,
_info.k_cache_slot_stride, _info.v_cache_slot_stride,
stream);
} else {
// If the GPU is older and supports fewer threads, return an error.
......
......@@ -10,10 +10,13 @@ INFINIOP_CUDA_KERNEL pagedCaching(
const int64_t *slot_mapping,
const size_t head_size, const size_t block_size,
const ptrdiff_t k_src_stride, const ptrdiff_t v_src_stride,
const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride) {
const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride,
const ptrdiff_t k_cache_head_stride, const ptrdiff_t v_cache_head_stride,
const ptrdiff_t k_cache_slot_stride, const ptrdiff_t v_cache_slot_stride) {
op::paged_caching::cuda::pagedCachingKernel<Tdata, NUM_THREADS>(
k_cache, v_cache, k, v, slot_mapping, head_size,
block_size, k_src_stride, v_src_stride, k_cache_block_stride, v_cache_block_stride);
block_size, k_src_stride, v_src_stride,
k_cache_block_stride, v_cache_block_stride, k_cache_head_stride, v_cache_head_stride, k_cache_slot_stride, v_cache_slot_stride);
}
namespace op::paged_caching::nvidia {
......@@ -59,6 +62,8 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
size_t num_tokens, size_t num_kv_heads, size_t head_size, size_t block_size,
ptrdiff_t k_src_stride, ptrdiff_t v_src_stride,
ptrdiff_t k_cache_block_stride, ptrdiff_t v_cache_block_stride,
ptrdiff_t k_cache_head_stride, ptrdiff_t v_cache_head_stride,
ptrdiff_t k_cache_slot_stride, ptrdiff_t v_cache_slot_stride,
cudaStream_t stream) {
// Grid dimension is 1D, with one block per token, as we decided.
......@@ -83,7 +88,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride);
} else if (dtype == INFINI_DTYPE_BF16) {
pagedCaching<__nv_bfloat16, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
......@@ -97,7 +106,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride);
} else if (dtype == INFINI_DTYPE_F32) {
pagedCaching<float, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
......@@ -111,7 +124,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_src_stride,
v_src_stride,
k_cache_block_stride,
v_cache_block_stride);
v_cache_block_stride,
k_cache_head_stride,
v_cache_head_stride,
k_cache_slot_stride,
v_cache_slot_stride);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
......@@ -137,6 +154,8 @@ infiniStatus_t Descriptor::calculate(
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride,
_info.k_cache_head_stride, _info.v_cache_head_stride,
_info.k_cache_slot_stride, _info.v_cache_slot_stride,
stream);
} else if (_opaque->internal->maxThreadsPerBlock() >= CUDA_BLOCK_SIZE_512) {
launchKernel<CUDA_BLOCK_SIZE_512>(
......@@ -144,6 +163,8 @@ infiniStatus_t Descriptor::calculate(
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride,
_info.k_cache_head_stride, _info.v_cache_head_stride,
_info.k_cache_slot_stride, _info.v_cache_slot_stride,
stream);
} else if (_opaque->internal->maxThreadsPerBlock() >= CUDA_BLOCK_SIZE_4096) {
launchKernel<CUDA_BLOCK_SIZE_4096>(
......@@ -151,6 +172,8 @@ infiniStatus_t Descriptor::calculate(
_info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size,
_info.k_src_stride, _info.v_src_stride,
_info.k_cache_block_stride, _info.v_cache_block_stride,
_info.k_cache_head_stride, _info.v_cache_head_stride,
_info.k_cache_slot_stride, _info.v_cache_slot_stride,
stream);
} else {
// If the GPU is older and supports fewer threads, return an error.
......
......@@ -18,12 +18,16 @@ from framework import (
# Operator-specific configuration
# ==============================================================================
# Test cases format: (num_seqs, max_seq_len, num_kv_heads, head_size, block_size)
# Test cases format: (num_seqs, max_seq_len, num_kv_heads, head_size, block_size, permute_dim_1_2)
_TEST_CASES_DATA = [
(1, 128, 8, 128, 16),
(5, 512, 40, 128, 16),
(16, 1024, 8, 64, 32),
(10, 1024, 40, 64, 32),
(1, 128, 8, 128, 16, False),
(1, 128, 8, 128, 16, True),
(5, 512, 40, 256, 16, False),
(5, 512, 40, 256, 16, True),
(16, 1024, 8, 64, 32, False),
(16, 1024, 8, 64, 32, True),
(10, 1024, 40, 64, 32, False),
(10, 1024, 40, 64, 32, True),
]
# Tolerance configuration
......@@ -40,7 +44,9 @@ _TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
# ==============================================================================
# Reference Implementation
# ==============================================================================
def ref_paged_caching(key_cache_pool, value_cache_pool, key, value, slot_mapping):
def ref_paged_caching(
key_cache_pool, value_cache_pool, key, value, slot_mapping, permute_dim_1_2
):
"""
Reference implementation for paged_caching operator.
......@@ -52,7 +58,7 @@ def ref_paged_caching(key_cache_pool, value_cache_pool, key, value, slot_mapping
slot_mapping (torch.Tensor): Slot mapping, shape [ntok]
"""
ntok = key.shape[0]
block_size = key_cache_pool.shape[2]
block_size = key_cache_pool.shape[1] if permute_dim_1_2 else key_cache_pool.shape[2]
# This reference implementation operates on a cloned cache to avoid modifying the original input tensor,
# mimicking the behavior where the custom operator writes to its output tensor.
......@@ -67,8 +73,12 @@ def ref_paged_caching(key_cache_pool, value_cache_pool, key, value, slot_mapping
key_token = key[i]
value_token = value[i]
k_cache_ref[block_idx, :, block_offset, :] = key_token
v_cache_ref[block_idx, :, block_offset, :] = value_token
if permute_dim_1_2:
k_cache_ref[block_idx, block_offset, :, :] = key_token
v_cache_ref[block_idx, block_offset, :, :] = value_token
else:
k_cache_ref[block_idx, :, block_offset, :] = key_token
v_cache_ref[block_idx, :, block_offset, :] = value_token
return k_cache_ref, v_cache_ref
......@@ -79,7 +89,14 @@ def parse_test_cases():
Each test case contains all necessary information for execution and validation.
"""
test_cases = []
for num_seqs, max_seq_len, num_kv_heads, head_size, block_size in _TEST_CASES_DATA:
for (
num_seqs,
max_seq_len,
num_kv_heads,
head_size,
block_size,
permute_dim_1_2,
) in _TEST_CASES_DATA:
num_blocks = 4096 # A reasonably large cache pool for testing
# Create metadata: variable context lengths for each sequence in the batch
......@@ -111,6 +128,9 @@ def parse_test_cases():
v_shape = (ntok, num_kv_heads, head_size)
k_cache_shape = (num_blocks, num_kv_heads, block_size, head_size)
v_cache_shape = (num_blocks, num_kv_heads, block_size, head_size)
if permute_dim_1_2:
k_cache_shape = (num_blocks, block_size, num_kv_heads, head_size)
v_cache_shape = (num_blocks, block_size, num_kv_heads, head_size)
# Generate test cases for all data types
for dtype in _TENSOR_DTYPES:
......@@ -142,7 +162,7 @@ def parse_test_cases():
v_spec,
slot_mapping_spec,
],
kwargs=None,
kwargs={"permute_dim_1_2": permute_dim_1_2},
output_spec=None,
comparison_target=0, # Only compare k_cache
tolerance=tolerance,
......@@ -162,13 +182,22 @@ class OpTest(BaseOperatorTest):
def get_test_cases(self):
return parse_test_cases()
def torch_operator(self, *args, **kwargs):
def torch_operator(
self, k_cache, v_cache, key, value, slot_mapping, permute_dim_1_2=False
):
"""PyTorch paged_caching implementation"""
return ref_paged_caching(*args, **kwargs)
return ref_paged_caching(
k_cache, v_cache, key, value, slot_mapping, permute_dim_1_2
)
def infinicore_operator(self, *args, **kwargs):
def infinicore_operator(
self, k_cache, v_cache, key, value, slot_mapping, permute_dim_1_2=False
):
"""InfiniCore paged_caching implementation"""
return infinicore.paged_caching(*args, **kwargs)
if permute_dim_1_2:
k_cache = k_cache.permute([0, 2, 1, 3])
v_cache = v_cache.permute([0, 2, 1, 3])
return infinicore.paged_caching(k_cache, v_cache, key, value, slot_mapping)
def main():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment