Unverified Commit 811ffab3 authored by spike-zhu's avatar spike-zhu Committed by GitHub
Browse files

Merge pull request #1045 from InfiniTensor/issue/1041

issue/1041 - feat: use template to replace int64_t in paged_attention_prefill kernel with test pass
parents b2f915cb f06d6465
......@@ -16,7 +16,8 @@
namespace op::paged_attention_prefill::cuda {
__device__ __forceinline__ size_t find_seq_id(size_t token_idx, const int64_t *cu_seqlens_q, size_t num_seqs) {
template <typename Tindex>
__device__ __forceinline__ size_t find_seq_id(size_t token_idx, const Tindex *cu_seqlens_q, size_t num_seqs) {
size_t low = 0, high = (num_seqs == 0) ? 0 : (num_seqs - 1);
while (low <= high) {
size_t mid = (low + high) >> 1;
......@@ -43,8 +44,8 @@ __device__ void PagedAttentionPrefillWarpKernel(
const Tdata *k_cache_,
const Tdata *v_cache_,
const Tindex *block_tables_,
const int64_t *total_kv_lens_,
const int64_t *cu_seqlens_q_,
const Tindex *total_kv_lens_,
const Tindex *cu_seqlens_q_,
const float *alibi_slopes_,
size_t num_kv_heads,
float scale,
......@@ -73,8 +74,8 @@ __device__ void PagedAttentionPrefillWarpKernel(
const int seq_idx = static_cast<int>(blockIdx.y);
const int q_token_local = static_cast<int>(blockIdx.z);
const int64_t q_start = cu_seqlens_q_[seq_idx];
const int64_t q_end = cu_seqlens_q_[seq_idx + 1];
const Tindex q_start = cu_seqlens_q_[seq_idx];
const Tindex q_end = cu_seqlens_q_[seq_idx + 1];
const int q_len = static_cast<int>(q_end - q_start);
if (q_token_local >= q_len) {
return;
......@@ -256,8 +257,8 @@ __global__ void PagedAttentionPrefillWarpGlobalKernel(
const Tdata *k_cache_,
const Tdata *v_cache_,
const Tindex *block_tables_,
const int64_t *total_kv_lens_,
const int64_t *cu_seqlens_q_,
const Tindex *total_kv_lens_,
const Tindex *cu_seqlens_q_,
const float *alibi_slopes_,
size_t num_heads,
size_t num_seqs,
......@@ -291,9 +292,9 @@ __global__ void PagedAttentionPrefillWarpGlobalKernel(
return;
}
const size_t seq_idx = find_seq_id(global_token_idx, cu_seqlens_q_, num_seqs);
const int64_t q_start = cu_seqlens_q_[seq_idx];
const int64_t q_end = cu_seqlens_q_[seq_idx + 1];
const size_t seq_idx = find_seq_id<Tindex>(global_token_idx, cu_seqlens_q_, num_seqs);
const Tindex q_start = cu_seqlens_q_[seq_idx];
const Tindex q_end = cu_seqlens_q_[seq_idx + 1];
const int q_len = static_cast<int>(q_end - q_start);
const int q_token_local = static_cast<int>(global_token_idx - static_cast<size_t>(q_start));
......@@ -477,8 +478,8 @@ __global__ void PagedAttentionPrefillReferenceKernel(
const Tdata *k_cache_,
const Tdata *v_cache_,
const Tindex *block_tables_,
const int64_t *total_kv_lens_,
const int64_t *cu_seqlens_q_,
const Tindex *total_kv_lens_,
const Tindex *cu_seqlens_q_,
const float *alibi_slopes_,
size_t num_heads,
size_t num_kv_heads,
......@@ -506,7 +507,7 @@ __global__ void PagedAttentionPrefillReferenceKernel(
return;
}
const size_t seq_idx = find_seq_id(global_token_idx, cu_seqlens_q_, num_seqs);
const size_t seq_idx = find_seq_id<Tindex>(global_token_idx, cu_seqlens_q_, num_seqs);
const size_t q_token_idx = global_token_idx - static_cast<size_t>(cu_seqlens_q_[seq_idx]);
const size_t q_len = static_cast<size_t>(cu_seqlens_q_[seq_idx + 1] - cu_seqlens_q_[seq_idx]);
......@@ -595,8 +596,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernel(
const Tdata *k_cache_,
const Tdata *v_cache_,
const Tindex *block_tables_,
const int64_t *total_kv_lens_,
const int64_t *cu_seqlens_q_,
const Tindex *total_kv_lens_,
const Tindex *cu_seqlens_q_,
const float *alibi_slopes_,
size_t num_kv_heads,
float scale,
......@@ -632,8 +633,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernel(
const int seq_idx = static_cast<int>(blockIdx.y);
const int m_block = static_cast<int>(blockIdx.z);
const int64_t q_start = cu_seqlens_q_[seq_idx];
const int64_t q_end = cu_seqlens_q_[seq_idx + 1];
const Tindex q_start = cu_seqlens_q_[seq_idx];
const Tindex q_end = cu_seqlens_q_[seq_idx + 1];
const int q_len = static_cast<int>(q_end - q_start);
if (q_len <= 0) {
return;
......@@ -865,8 +866,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernelPipelined(
const Tdata *k_cache_,
const Tdata *v_cache_,
const Tindex *block_tables_,
const int64_t *total_kv_lens_,
const int64_t *cu_seqlens_q_,
const Tindex *total_kv_lens_,
const Tindex *cu_seqlens_q_,
const float *alibi_slopes_,
size_t num_kv_heads,
float scale,
......@@ -904,8 +905,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernelPipelined(
const int seq_idx = static_cast<int>(blockIdx.y);
const int m_block = static_cast<int>(blockIdx.z);
const int64_t q_start = cu_seqlens_q_[seq_idx];
const int64_t q_end = cu_seqlens_q_[seq_idx + 1];
const Tindex q_start = cu_seqlens_q_[seq_idx];
const Tindex q_end = cu_seqlens_q_[seq_idx + 1];
const int q_len = static_cast<int>(q_end - q_start);
if (q_len <= 0) {
return;
......@@ -1312,8 +1313,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernelPipelinedSplitKv(
const Tdata *k_cache_,
const Tdata *v_cache_,
const Tindex *block_tables_,
const int64_t *total_kv_lens_,
const int64_t *cu_seqlens_q_,
const Tindex *total_kv_lens_,
const Tindex *cu_seqlens_q_,
const float *alibi_slopes_,
size_t num_kv_heads,
float scale,
......@@ -1350,8 +1351,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernelPipelinedSplitKv(
const int head_idx = static_cast<int>(blockIdx.x);
const int seq_idx = static_cast<int>(blockIdx.y);
const int64_t q_start = cu_seqlens_q_[seq_idx];
const int64_t q_end = cu_seqlens_q_[seq_idx + 1];
const Tindex q_start = cu_seqlens_q_[seq_idx];
const Tindex q_end = cu_seqlens_q_[seq_idx + 1];
const int q_len = static_cast<int>(q_end - q_start);
if (q_len <= 0) {
return;
......@@ -1778,8 +1779,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernelKOnly(
const Tdata *k_cache_,
const Tdata *v_cache_,
const Tindex *block_tables_,
const int64_t *total_kv_lens_,
const int64_t *cu_seqlens_q_,
const Tindex *total_kv_lens_,
const Tindex *cu_seqlens_q_,
const float *alibi_slopes_,
size_t num_kv_heads,
float scale,
......@@ -1815,8 +1816,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernelKOnly(
const int seq_idx = static_cast<int>(blockIdx.y);
const int m_block = static_cast<int>(blockIdx.z);
const int64_t q_start = cu_seqlens_q_[seq_idx];
const int64_t q_end = cu_seqlens_q_[seq_idx + 1];
const Tindex q_start = cu_seqlens_q_[seq_idx];
const Tindex q_end = cu_seqlens_q_[seq_idx + 1];
const int q_len = static_cast<int>(q_end - q_start);
if (q_len <= 0) {
return;
......@@ -2115,12 +2116,12 @@ __device__ __forceinline__ void PagedAttentionPrefillMmaScoreUpdateRow(
}
}
template <int kWarpSize, int kHeadDim, int kDimsPerThread>
template <typename Tindex, int kWarpSize, int kHeadDim, int kDimsPerThread>
__device__ __forceinline__ void PagedAttentionPrefillMmaScoreWriteRow(
int lane,
bool active,
int q_token_local,
int64_t q_start,
Tindex q_start,
int head_idx,
half *out_,
ptrdiff_t o_stride,
......@@ -2153,8 +2154,8 @@ __device__ void PagedAttentionPrefillWarpCta8MmaHd128Kernel(
const half *k_cache_,
const half *v_cache_,
const Tindex *block_tables_,
const int64_t *total_kv_lens_,
const int64_t *cu_seqlens_q_,
const Tindex *total_kv_lens_,
const Tindex *cu_seqlens_q_,
const float *alibi_slopes_,
size_t num_kv_heads,
float scale,
......@@ -2198,8 +2199,8 @@ __device__ void PagedAttentionPrefillWarpCta8MmaHd128Kernel(
const int seq_idx = static_cast<int>(blockIdx.y);
const int m_block = static_cast<int>(blockIdx.z);
const int64_t q_start = cu_seqlens_q_[seq_idx];
const int64_t q_end = cu_seqlens_q_[seq_idx + 1];
const Tindex q_start = cu_seqlens_q_[seq_idx];
const Tindex q_end = cu_seqlens_q_[seq_idx + 1];
const int q_len = static_cast<int>(q_end - q_start);
if (q_len <= 0) {
return;
......@@ -2353,11 +2354,11 @@ __device__ void PagedAttentionPrefillWarpCta8MmaHd128Kernel(
// Write outputs.
if (row0 < kBlockM) {
PagedAttentionPrefillMmaScoreWriteRow<kWarpSize, kHeadDim, kDimsPerThread>(
PagedAttentionPrefillMmaScoreWriteRow<Tindex, kWarpSize, kHeadDim, kDimsPerThread>(
lane, active0, m_start + row0, q_start, head_idx, out_, o_stride, o_head_stride, l0, acc0);
}
if (row1 < kBlockM) {
PagedAttentionPrefillMmaScoreWriteRow<kWarpSize, kHeadDim, kDimsPerThread>(
PagedAttentionPrefillMmaScoreWriteRow<Tindex, kWarpSize, kHeadDim, kDimsPerThread>(
lane, active1, m_start + row1, q_start, head_idx, out_, o_stride, o_head_stride, l1, acc1);
}
}
......
......@@ -80,9 +80,13 @@ public:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
// Keep it simple: require total_kv_lens + cum_seqlens_q to be int64 for now
// (matches current paged_attention_prefill signature). We will convert to int32 internally later.
if (total_kv_lens_desc->dtype() != INFINI_DTYPE_I64 || cum_seqlens_q_desc->dtype() != INFINI_DTYPE_I64) {
// Index tensors use int32_t to match mainstream paged-attention implementations
// (e.g., vLLM / FlashAttention2). 32-bit indices needed, but now we also support int64_t.
if (!((total_kv_lens_desc->dtype() == INFINI_DTYPE_I64) || (total_kv_lens_desc->dtype() == INFINI_DTYPE_I32) || (total_kv_lens_desc->dtype() == INFINI_DTYPE_U32))) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (!((cum_seqlens_q_desc->dtype() == INFINI_DTYPE_I64) || (cum_seqlens_q_desc->dtype() == INFINI_DTYPE_I32) || (cum_seqlens_q_desc->dtype() == INFINI_DTYPE_U32))) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
......
......@@ -31,6 +31,8 @@ _TOLERANCE_MAP = {
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16]
_INDEX_DTYPES = [infinicore.int32, infinicore.int64]
class SimpleCacheManager:
def __init__(self, num_blocks, block_size):
......@@ -72,16 +74,16 @@ def parse_test_cases():
scale = head_size**-0.5
num_blocks = 8192
manager = SimpleCacheManager(num_blocks, block_size)
kv_lens = torch.zeros(num_seqs, dtype=torch.int64)
kv_lens = torch.zeros(num_seqs, dtype=torch.int32)
persistent_k = torch.zeros((num_blocks, num_kv_heads, block_size, head_size))
persistent_v = torch.zeros((num_blocks, num_kv_heads, block_size, head_size))
for r in range(num_rounds):
q_lens = torch.randint(1, max_step_len + 1, (num_seqs,), dtype=torch.int64)
q_lens = torch.randint(1, max_step_len + 1, (num_seqs,), dtype=torch.int32)
kv_lens = kv_lens + q_lens
total_q_tokens = q_lens.sum().item()
cum_seqlens_q = torch.zeros(num_seqs + 1, dtype=torch.int64)
cum_seqlens_q = torch.zeros(num_seqs + 1, dtype=torch.int32)
cum_seqlens_q[1:] = torch.cumsum(q_lens, dim=0)
query_base = torch.randn((total_q_tokens, num_heads, head_size))
......@@ -106,53 +108,53 @@ def parse_test_cases():
)
for dtype in _TENSOR_DTYPES:
tolerance = _TOLERANCE_MAP.get(dtype)
test_cases.append(
TestCase(
inputs=[
TensorSpec.from_tensor(
query_base.shape,
init_mode=TensorInitializer.MANUAL,
set_tensor=query_base.clone(),
dtype=dtype,
),
TensorSpec.from_tensor(
persistent_k.shape,
init_mode=TensorInitializer.MANUAL,
set_tensor=persistent_k.clone(),
dtype=dtype,
),
TensorSpec.from_tensor(
persistent_v.shape,
init_mode=TensorInitializer.MANUAL,
set_tensor=persistent_v.clone(),
dtype=dtype,
),
TensorSpec.from_tensor(
padded_tables.shape,
init_mode=TensorInitializer.MANUAL,
set_tensor=padded_tables.clone(),
dtype=infinicore.int64,
),
TensorSpec.from_tensor(
kv_lens.shape,
init_mode=TensorInitializer.MANUAL,
set_tensor=kv_lens.clone(),
dtype=infinicore.int64,
),
TensorSpec.from_tensor(
cum_seqlens_q.shape,
init_mode=TensorInitializer.MANUAL,
set_tensor=cum_seqlens_q.clone(),
dtype=infinicore.int64,
),
],
kwargs={"scale": scale},
tolerance=tolerance,
description=f"PagedAttentionPrefill_Round_{r}_{str(dtype).split('.')[-1]}",
for idx_dtype in _INDEX_DTYPES: # Loop through both I32 and I64
tolerance = _TOLERANCE_MAP.get(dtype)
test_cases.append(
TestCase(
inputs=[
TensorSpec.from_tensor(
query_base.shape,
init_mode=TensorInitializer.MANUAL,
set_tensor=query_base.clone(),
dtype=dtype,
),
TensorSpec.from_tensor(
persistent_k.shape,
init_mode=TensorInitializer.MANUAL,
set_tensor=persistent_k.clone(),
dtype=dtype,
),
TensorSpec.from_tensor(
persistent_v.shape,
init_mode=TensorInitializer.MANUAL,
set_tensor=persistent_v.clone(),
dtype=dtype,
),
TensorSpec.from_tensor(
padded_tables.shape,
init_mode=TensorInitializer.MANUAL,
set_tensor=padded_tables.clone(),
dtype=idx_dtype,
),
TensorSpec.from_tensor(
kv_lens.shape,
init_mode=TensorInitializer.MANUAL,
set_tensor=kv_lens.clone(),
dtype=idx_dtype,
),
TensorSpec.from_tensor(
cum_seqlens_q.shape,
init_mode=TensorInitializer.MANUAL,
set_tensor=cum_seqlens_q.clone(),
dtype=idx_dtype,
),
],
kwargs={"scale": scale},
tolerance=tolerance,
description=f"PagedAttentionPrefill_Round_{r}_{str(dtype).split('.')[-1]}",
)
)
)
return test_cases
......
......@@ -23,13 +23,20 @@ from libinfiniop import (
# Configuration (Internal Use Only)
# ==============================================================================
_TEST_CASES = [
# num_seqs, num_heads, num_kv_heads, head_size, block_size, max_step_len, num_rounds
(1, 1, 1, 128, 8, 16, 1),
(1, 4, 4, 128, 8, 16, 4),
(2, 8, 8, 128, 16, 32, 2),
(4, 16, 16, 128, 8, 64, 3),
(8, 64, 64, 128, 8, 16, 5),
(16, 128, 128, 128, 8, 16, 4),
# num_seqs, num_heads, num_kv_heads, head_size, block_size, max_step_len, num_rounds, index_dtypes
# index_dtype: The data type used for memory indexing of block_tables, cum_seq_lens and seq_lens
(1, 1, 1, 128, 8, 16, 1, InfiniDtype.I32),
(1, 1, 1, 128, 8, 16, 1, InfiniDtype.I64),
(1, 4, 4, 128, 8, 16, 4, InfiniDtype.I32),
(1, 4, 4, 128, 8, 16, 4, InfiniDtype.I64),
(2, 8, 8, 128, 16, 32, 2, InfiniDtype.I32),
(2, 8, 8, 128, 16, 32, 2, InfiniDtype.I64),
(4, 16, 16, 128, 8, 64, 3, InfiniDtype.I32),
(4, 16, 16, 128, 8, 64, 3, InfiniDtype.I64),
(8, 64, 64, 128, 8, 16, 5, InfiniDtype.I32),
(8, 64, 64, 128, 8, 16, 5, InfiniDtype.I64),
(16, 128, 128, 128, 8, 16, 4, InfiniDtype.I32),
(16, 128, 128, 128, 8, 16, 4, InfiniDtype.I64),
]
_TENSOR_DTYPES = [InfiniDtype.BF16, InfiniDtype.F16]
......@@ -124,13 +131,15 @@ def test(
block_size,
max_step_len,
num_rounds,
index_dtype=InfiniDtype.I64,
dtype=InfiniDtype.F16,
sync=None,
):
print(
f"Testing PagedAttentionPrefill on {InfiniDeviceNames[device]} with "
f"seqs:{num_seqs}, heads:{num_heads}, head_size:{head_size}, "
f"block:{block_size}, max_step_len:{max_step_len}, num_rounds:{num_rounds}, dtype:{InfiniDtypeNames[dtype]}"
f"block:{block_size}, max_step_len:{max_step_len}, num_rounds:{num_rounds}, dtype:{InfiniDtypeNames[dtype]}, "
f"index_dtype:{InfiniDtypeNames[index_dtype]}"
)
# 1. Initialize persistent resources
......@@ -194,23 +203,26 @@ def test(
out = TestTensor.from_torch(q_packed_tensors, dtype, device)
out.actual_tensor().zero_()
# 3. Referencing index_dtype to set torch dtype
torch_idx_type = torch.int32 if index_dtype == InfiniDtype.I32 else torch.int64
seq_lens = TestTensor.from_torch(
torch.tensor(seq_lens_list, dtype=torch.int64), InfiniDtype.I64, device
torch.tensor(seq_lens_list, dtype=torch_idx_type), index_dtype, device
)
cum_seq_lens_q = TestTensor.from_torch(
torch.tensor(cum_seq_lens_q_list, dtype=torch.int64),
InfiniDtype.I64,
torch.tensor(cum_seq_lens_q_list, dtype=torch_idx_type),
index_dtype,
device,
)
max_blocks = max(len(t) for t in all_block_tables)
padded_tables = [t + [0] * (max_blocks - len(t)) for t in all_block_tables]
block_tables = TestTensor.from_torch(
torch.tensor(padded_tables, dtype=torch.int64), InfiniDtype.I64, device
torch.tensor(padded_tables, dtype=torch_idx_type), index_dtype, device
)
# 3. Reference Calculation
# 4. Reference Calculation
def torch_paged_attention_multi_turn():
return ref_paged_attention_multi_turn(
q_new.torch_tensor(),
......@@ -224,7 +236,7 @@ def test(
ans = torch_paged_attention_multi_turn()
# 4. Infiniop Operator Execution
# 5. Infiniop Operator Execution
descriptor = infiniopOperatorDescriptor_t()
check_error(
LIBINFINIOP.infiniopCreatePagedAttentionPrefillDescriptor(
......@@ -272,7 +284,7 @@ def test(
if sync:
sync()
# 5. Validation
# 6. Validation
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG:
debug(out.actual_tensor(), ans, atol=atol, rtol=rtol)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment