Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
f06d6465
Commit
f06d6465
authored
Mar 03, 2026
by
zhushuang
Browse files
issue/1041 - feat: use template to replace int64_t in paged_attention_prefill kernel with test pass
parent
abd45713
Changes
5
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
185 additions
and
163 deletions
+185
-163
src/infiniop/ops/paged_attention_prefill/cuda/kernel_v2.cuh
src/infiniop/ops/paged_attention_prefill/cuda/kernel_v2.cuh
+38
-37
src/infiniop/ops/paged_attention_prefill/info.h
src/infiniop/ops/paged_attention_prefill/info.h
+7
-3
src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu
...ttention_prefill/nvidia/paged_attention_prefill_nvidia.cu
+62
-59
test/infinicore/ops/paged_attention_prefill.py
test/infinicore/ops/paged_attention_prefill.py
+51
-49
test/infiniop/paged_attention_prefill.py
test/infiniop/paged_attention_prefill.py
+27
-15
No files found.
src/infiniop/ops/paged_attention_prefill/cuda/kernel_v2.cuh
View file @
f06d6465
...
...
@@ -16,7 +16,8 @@
namespace
op
::
paged_attention_prefill
::
cuda
{
__device__
__forceinline__
size_t
find_seq_id
(
size_t
token_idx
,
const
int64_t
*
cu_seqlens_q
,
size_t
num_seqs
)
{
template
<
typename
Tindex
>
__device__
__forceinline__
size_t
find_seq_id
(
size_t
token_idx
,
const
Tindex
*
cu_seqlens_q
,
size_t
num_seqs
)
{
size_t
low
=
0
,
high
=
(
num_seqs
==
0
)
?
0
:
(
num_seqs
-
1
);
while
(
low
<=
high
)
{
size_t
mid
=
(
low
+
high
)
>>
1
;
...
...
@@ -43,8 +44,8 @@ __device__ void PagedAttentionPrefillWarpKernel(
const
Tdata
*
k_cache_
,
const
Tdata
*
v_cache_
,
const
Tindex
*
block_tables_
,
const
in
t64_t
*
total_kv_lens_
,
const
in
t64_t
*
cu_seqlens_q_
,
const
T
in
dex
*
total_kv_lens_
,
const
T
in
dex
*
cu_seqlens_q_
,
const
float
*
alibi_slopes_
,
size_t
num_kv_heads
,
float
scale
,
...
...
@@ -73,8 +74,8 @@ __device__ void PagedAttentionPrefillWarpKernel(
const
int
seq_idx
=
static_cast
<
int
>
(
blockIdx
.
y
);
const
int
q_token_local
=
static_cast
<
int
>
(
blockIdx
.
z
);
const
in
t64_t
q_start
=
cu_seqlens_q_
[
seq_idx
];
const
in
t64_t
q_end
=
cu_seqlens_q_
[
seq_idx
+
1
];
const
T
in
dex
q_start
=
cu_seqlens_q_
[
seq_idx
];
const
T
in
dex
q_end
=
cu_seqlens_q_
[
seq_idx
+
1
];
const
int
q_len
=
static_cast
<
int
>
(
q_end
-
q_start
);
if
(
q_token_local
>=
q_len
)
{
return
;
...
...
@@ -256,8 +257,8 @@ __global__ void PagedAttentionPrefillWarpGlobalKernel(
const
Tdata
*
k_cache_
,
const
Tdata
*
v_cache_
,
const
Tindex
*
block_tables_
,
const
in
t64_t
*
total_kv_lens_
,
const
in
t64_t
*
cu_seqlens_q_
,
const
T
in
dex
*
total_kv_lens_
,
const
T
in
dex
*
cu_seqlens_q_
,
const
float
*
alibi_slopes_
,
size_t
num_heads
,
size_t
num_seqs
,
...
...
@@ -291,9 +292,9 @@ __global__ void PagedAttentionPrefillWarpGlobalKernel(
return
;
}
const
size_t
seq_idx
=
find_seq_id
(
global_token_idx
,
cu_seqlens_q_
,
num_seqs
);
const
in
t64_t
q_start
=
cu_seqlens_q_
[
seq_idx
];
const
in
t64_t
q_end
=
cu_seqlens_q_
[
seq_idx
+
1
];
const
size_t
seq_idx
=
find_seq_id
<
Tindex
>
(
global_token_idx
,
cu_seqlens_q_
,
num_seqs
);
const
T
in
dex
q_start
=
cu_seqlens_q_
[
seq_idx
];
const
T
in
dex
q_end
=
cu_seqlens_q_
[
seq_idx
+
1
];
const
int
q_len
=
static_cast
<
int
>
(
q_end
-
q_start
);
const
int
q_token_local
=
static_cast
<
int
>
(
global_token_idx
-
static_cast
<
size_t
>
(
q_start
));
...
...
@@ -477,8 +478,8 @@ __global__ void PagedAttentionPrefillReferenceKernel(
const
Tdata
*
k_cache_
,
const
Tdata
*
v_cache_
,
const
Tindex
*
block_tables_
,
const
in
t64_t
*
total_kv_lens_
,
const
in
t64_t
*
cu_seqlens_q_
,
const
T
in
dex
*
total_kv_lens_
,
const
T
in
dex
*
cu_seqlens_q_
,
const
float
*
alibi_slopes_
,
size_t
num_heads
,
size_t
num_kv_heads
,
...
...
@@ -506,7 +507,7 @@ __global__ void PagedAttentionPrefillReferenceKernel(
return
;
}
const
size_t
seq_idx
=
find_seq_id
(
global_token_idx
,
cu_seqlens_q_
,
num_seqs
);
const
size_t
seq_idx
=
find_seq_id
<
Tindex
>
(
global_token_idx
,
cu_seqlens_q_
,
num_seqs
);
const
size_t
q_token_idx
=
global_token_idx
-
static_cast
<
size_t
>
(
cu_seqlens_q_
[
seq_idx
]);
const
size_t
q_len
=
static_cast
<
size_t
>
(
cu_seqlens_q_
[
seq_idx
+
1
]
-
cu_seqlens_q_
[
seq_idx
]);
...
...
@@ -595,8 +596,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernel(
const
Tdata
*
k_cache_
,
const
Tdata
*
v_cache_
,
const
Tindex
*
block_tables_
,
const
in
t64_t
*
total_kv_lens_
,
const
in
t64_t
*
cu_seqlens_q_
,
const
T
in
dex
*
total_kv_lens_
,
const
T
in
dex
*
cu_seqlens_q_
,
const
float
*
alibi_slopes_
,
size_t
num_kv_heads
,
float
scale
,
...
...
@@ -632,8 +633,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernel(
const
int
seq_idx
=
static_cast
<
int
>
(
blockIdx
.
y
);
const
int
m_block
=
static_cast
<
int
>
(
blockIdx
.
z
);
const
in
t64_t
q_start
=
cu_seqlens_q_
[
seq_idx
];
const
in
t64_t
q_end
=
cu_seqlens_q_
[
seq_idx
+
1
];
const
T
in
dex
q_start
=
cu_seqlens_q_
[
seq_idx
];
const
T
in
dex
q_end
=
cu_seqlens_q_
[
seq_idx
+
1
];
const
int
q_len
=
static_cast
<
int
>
(
q_end
-
q_start
);
if
(
q_len
<=
0
)
{
return
;
...
...
@@ -865,8 +866,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernelPipelined(
const
Tdata
*
k_cache_
,
const
Tdata
*
v_cache_
,
const
Tindex
*
block_tables_
,
const
in
t64_t
*
total_kv_lens_
,
const
in
t64_t
*
cu_seqlens_q_
,
const
T
in
dex
*
total_kv_lens_
,
const
T
in
dex
*
cu_seqlens_q_
,
const
float
*
alibi_slopes_
,
size_t
num_kv_heads
,
float
scale
,
...
...
@@ -904,8 +905,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernelPipelined(
const
int
seq_idx
=
static_cast
<
int
>
(
blockIdx
.
y
);
const
int
m_block
=
static_cast
<
int
>
(
blockIdx
.
z
);
const
in
t64_t
q_start
=
cu_seqlens_q_
[
seq_idx
];
const
in
t64_t
q_end
=
cu_seqlens_q_
[
seq_idx
+
1
];
const
T
in
dex
q_start
=
cu_seqlens_q_
[
seq_idx
];
const
T
in
dex
q_end
=
cu_seqlens_q_
[
seq_idx
+
1
];
const
int
q_len
=
static_cast
<
int
>
(
q_end
-
q_start
);
if
(
q_len
<=
0
)
{
return
;
...
...
@@ -1312,8 +1313,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernelPipelinedSplitKv(
const
Tdata
*
k_cache_
,
const
Tdata
*
v_cache_
,
const
Tindex
*
block_tables_
,
const
in
t64_t
*
total_kv_lens_
,
const
in
t64_t
*
cu_seqlens_q_
,
const
T
in
dex
*
total_kv_lens_
,
const
T
in
dex
*
cu_seqlens_q_
,
const
float
*
alibi_slopes_
,
size_t
num_kv_heads
,
float
scale
,
...
...
@@ -1350,8 +1351,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernelPipelinedSplitKv(
const
int
head_idx
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
int
seq_idx
=
static_cast
<
int
>
(
blockIdx
.
y
);
const
in
t64_t
q_start
=
cu_seqlens_q_
[
seq_idx
];
const
in
t64_t
q_end
=
cu_seqlens_q_
[
seq_idx
+
1
];
const
T
in
dex
q_start
=
cu_seqlens_q_
[
seq_idx
];
const
T
in
dex
q_end
=
cu_seqlens_q_
[
seq_idx
+
1
];
const
int
q_len
=
static_cast
<
int
>
(
q_end
-
q_start
);
if
(
q_len
<=
0
)
{
return
;
...
...
@@ -1778,8 +1779,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernelKOnly(
const
Tdata
*
k_cache_
,
const
Tdata
*
v_cache_
,
const
Tindex
*
block_tables_
,
const
in
t64_t
*
total_kv_lens_
,
const
in
t64_t
*
cu_seqlens_q_
,
const
T
in
dex
*
total_kv_lens_
,
const
T
in
dex
*
cu_seqlens_q_
,
const
float
*
alibi_slopes_
,
size_t
num_kv_heads
,
float
scale
,
...
...
@@ -1815,8 +1816,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernelKOnly(
const
int
seq_idx
=
static_cast
<
int
>
(
blockIdx
.
y
);
const
int
m_block
=
static_cast
<
int
>
(
blockIdx
.
z
);
const
in
t64_t
q_start
=
cu_seqlens_q_
[
seq_idx
];
const
in
t64_t
q_end
=
cu_seqlens_q_
[
seq_idx
+
1
];
const
T
in
dex
q_start
=
cu_seqlens_q_
[
seq_idx
];
const
T
in
dex
q_end
=
cu_seqlens_q_
[
seq_idx
+
1
];
const
int
q_len
=
static_cast
<
int
>
(
q_end
-
q_start
);
if
(
q_len
<=
0
)
{
return
;
...
...
@@ -2115,12 +2116,12 @@ __device__ __forceinline__ void PagedAttentionPrefillMmaScoreUpdateRow(
}
}
template
<
int
kWarpSize
,
int
kHeadDim
,
int
kDimsPerThread
>
template
<
typename
Tindex
,
int
kWarpSize
,
int
kHeadDim
,
int
kDimsPerThread
>
__device__
__forceinline__
void
PagedAttentionPrefillMmaScoreWriteRow
(
int
lane
,
bool
active
,
int
q_token_local
,
in
t64_t
q_start
,
T
in
dex
q_start
,
int
head_idx
,
half
*
out_
,
ptrdiff_t
o_stride
,
...
...
@@ -2153,8 +2154,8 @@ __device__ void PagedAttentionPrefillWarpCta8MmaHd128Kernel(
const
half
*
k_cache_
,
const
half
*
v_cache_
,
const
Tindex
*
block_tables_
,
const
in
t64_t
*
total_kv_lens_
,
const
in
t64_t
*
cu_seqlens_q_
,
const
T
in
dex
*
total_kv_lens_
,
const
T
in
dex
*
cu_seqlens_q_
,
const
float
*
alibi_slopes_
,
size_t
num_kv_heads
,
float
scale
,
...
...
@@ -2198,8 +2199,8 @@ __device__ void PagedAttentionPrefillWarpCta8MmaHd128Kernel(
const
int
seq_idx
=
static_cast
<
int
>
(
blockIdx
.
y
);
const
int
m_block
=
static_cast
<
int
>
(
blockIdx
.
z
);
const
in
t64_t
q_start
=
cu_seqlens_q_
[
seq_idx
];
const
in
t64_t
q_end
=
cu_seqlens_q_
[
seq_idx
+
1
];
const
T
in
dex
q_start
=
cu_seqlens_q_
[
seq_idx
];
const
T
in
dex
q_end
=
cu_seqlens_q_
[
seq_idx
+
1
];
const
int
q_len
=
static_cast
<
int
>
(
q_end
-
q_start
);
if
(
q_len
<=
0
)
{
return
;
...
...
@@ -2353,11 +2354,11 @@ __device__ void PagedAttentionPrefillWarpCta8MmaHd128Kernel(
// Write outputs.
if
(
row0
<
kBlockM
)
{
PagedAttentionPrefillMmaScoreWriteRow
<
kWarpSize
,
kHeadDim
,
kDimsPerThread
>
(
PagedAttentionPrefillMmaScoreWriteRow
<
Tindex
,
kWarpSize
,
kHeadDim
,
kDimsPerThread
>
(
lane
,
active0
,
m_start
+
row0
,
q_start
,
head_idx
,
out_
,
o_stride
,
o_head_stride
,
l0
,
acc0
);
}
if
(
row1
<
kBlockM
)
{
PagedAttentionPrefillMmaScoreWriteRow
<
kWarpSize
,
kHeadDim
,
kDimsPerThread
>
(
PagedAttentionPrefillMmaScoreWriteRow
<
Tindex
,
kWarpSize
,
kHeadDim
,
kDimsPerThread
>
(
lane
,
active1
,
m_start
+
row1
,
q_start
,
head_idx
,
out_
,
o_stride
,
o_head_stride
,
l1
,
acc1
);
}
}
...
...
src/infiniop/ops/paged_attention_prefill/info.h
View file @
f06d6465
...
...
@@ -80,9 +80,13 @@ public:
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
// Keep it simple: require total_kv_lens + cum_seqlens_q to be int64 for now
// (matches current paged_attention_prefill signature). We will convert to int32 internally later.
if
(
total_kv_lens_desc
->
dtype
()
!=
INFINI_DTYPE_I64
||
cum_seqlens_q_desc
->
dtype
()
!=
INFINI_DTYPE_I64
)
{
// Index tensors use int32_t to match mainstream paged-attention implementations
// (e.g., vLLM / FlashAttention2). 32-bit indices needed, but now we also support int64_t.
if
(
!
((
total_kv_lens_desc
->
dtype
()
==
INFINI_DTYPE_I64
)
||
(
total_kv_lens_desc
->
dtype
()
==
INFINI_DTYPE_I32
)
||
(
total_kv_lens_desc
->
dtype
()
==
INFINI_DTYPE_U32
)))
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
if
(
!
((
cum_seqlens_q_desc
->
dtype
()
==
INFINI_DTYPE_I64
)
||
(
cum_seqlens_q_desc
->
dtype
()
==
INFINI_DTYPE_I32
)
||
(
cum_seqlens_q_desc
->
dtype
()
==
INFINI_DTYPE_U32
)))
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
...
...
src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu
View file @
f06d6465
This diff is collapsed.
Click to expand it.
test/infinicore/ops/paged_attention_prefill.py
View file @
f06d6465
...
...
@@ -31,6 +31,8 @@ _TOLERANCE_MAP = {
_TENSOR_DTYPES
=
[
infinicore
.
float16
,
infinicore
.
bfloat16
]
_INDEX_DTYPES
=
[
infinicore
.
int32
,
infinicore
.
int64
]
class
SimpleCacheManager
:
def
__init__
(
self
,
num_blocks
,
block_size
):
...
...
@@ -72,16 +74,16 @@ def parse_test_cases():
scale
=
head_size
**-
0.5
num_blocks
=
8192
manager
=
SimpleCacheManager
(
num_blocks
,
block_size
)
kv_lens
=
torch
.
zeros
(
num_seqs
,
dtype
=
torch
.
int
64
)
kv_lens
=
torch
.
zeros
(
num_seqs
,
dtype
=
torch
.
int
32
)
persistent_k
=
torch
.
zeros
((
num_blocks
,
num_kv_heads
,
block_size
,
head_size
))
persistent_v
=
torch
.
zeros
((
num_blocks
,
num_kv_heads
,
block_size
,
head_size
))
for
r
in
range
(
num_rounds
):
q_lens
=
torch
.
randint
(
1
,
max_step_len
+
1
,
(
num_seqs
,),
dtype
=
torch
.
int
64
)
q_lens
=
torch
.
randint
(
1
,
max_step_len
+
1
,
(
num_seqs
,),
dtype
=
torch
.
int
32
)
kv_lens
=
kv_lens
+
q_lens
total_q_tokens
=
q_lens
.
sum
().
item
()
cum_seqlens_q
=
torch
.
zeros
(
num_seqs
+
1
,
dtype
=
torch
.
int
64
)
cum_seqlens_q
=
torch
.
zeros
(
num_seqs
+
1
,
dtype
=
torch
.
int
32
)
cum_seqlens_q
[
1
:]
=
torch
.
cumsum
(
q_lens
,
dim
=
0
)
query_base
=
torch
.
randn
((
total_q_tokens
,
num_heads
,
head_size
))
...
...
@@ -106,53 +108,53 @@ def parse_test_cases():
)
for
dtype
in
_TENSOR_DTYPES
:
tolerance
=
_TOLERANCE_MAP
.
get
(
dtype
)
test_cases
.
append
(
TestCase
(
inputs
=
[
TensorSpec
.
from_tensor
(
query_base
.
shape
,
init_mode
=
TensorInitializer
.
MANUAL
,
set_tensor
=
query_base
.
clone
(),
dtype
=
dtype
,
),
TensorSpec
.
from_tensor
(
persistent_k
.
shape
,
init_mode
=
TensorInitializer
.
MANUAL
,
set_tensor
=
persistent_k
.
clone
(),
dtype
=
dtype
,
),
TensorSpec
.
from_tensor
(
persistent_v
.
shape
,
init_mode
=
TensorInitializer
.
MANUAL
,
set_tensor
=
persistent_v
.
clone
(),
dtype
=
dtype
,
),
TensorSpec
.
from_tensor
(
padded_tables
.
shape
,
init_mode
=
TensorInitializer
.
MANUAL
,
set_tensor
=
padded_tables
.
clone
(),
dtype
=
infinicore
.
int64
,
),
TensorSpec
.
from_tensor
(
kv_lens
.
shape
,
init_mode
=
TensorInitializer
.
MANUAL
,
set_tensor
=
kv_lens
.
clone
(),
dtype
=
infinicore
.
int64
,
),
TensorSpec
.
from_tensor
(
cum_seqlens_q
.
shape
,
init_mode
=
TensorInitializer
.
MANUAL
,
set_tensor
=
cum_seqlens_q
.
clone
(),
dtype
=
infinicore
.
int64
,
),
],
kwargs
=
{
"scale"
:
scale
},
tolerance
=
tolerance
,
description
=
f
"PagedAttentionPrefill_Round_
{
r
}
_
{
str
(
dtype
).
split
(
'.'
)[
-
1
]
}
"
,
for
idx_dtype
in
_INDEX_DTYPES
:
# Loop through both I32 and I64
tolerance
=
_TOLERANCE_MAP
.
get
(
dtype
)
test_cases
.
append
(
TestCase
(
inputs
=
[
TensorSpec
.
from_tensor
(
query_base
.
shape
,
init_mode
=
TensorInitializer
.
MANUAL
,
set_tensor
=
query_base
.
clone
(),
dtype
=
dtype
,
),
TensorSpec
.
from_tensor
(
persistent_k
.
shape
,
init_mode
=
TensorInitializer
.
MANUAL
,
set_tensor
=
persistent_k
.
clone
(),
dtype
=
dtype
,
),
TensorSpec
.
from_tensor
(
persistent_v
.
shape
,
init_mode
=
TensorInitializer
.
MANUAL
,
set_tensor
=
persistent_v
.
clone
(),
dtype
=
dtype
,
),
TensorSpec
.
from_tensor
(
padded_tables
.
shape
,
init_mode
=
TensorInitializer
.
MANUAL
,
set_tensor
=
padded_tables
.
clone
(),
dtype
=
idx_dtype
,
),
TensorSpec
.
from_tensor
(
kv_lens
.
shape
,
init_mode
=
TensorInitializer
.
MANUAL
,
set_tensor
=
kv_lens
.
clone
(),
dtype
=
idx_dtype
,
),
TensorSpec
.
from_tensor
(
cum_seqlens_q
.
shape
,
init_mode
=
TensorInitializer
.
MANUAL
,
set_tensor
=
cum_seqlens_q
.
clone
(),
dtype
=
idx_dtype
,
),
],
kwargs
=
{
"scale"
:
scale
},
tolerance
=
tolerance
,
description
=
f
"PagedAttentionPrefill_Round_
{
r
}
_
{
str
(
dtype
).
split
(
'.'
)[
-
1
]
}
"
,
)
)
)
return
test_cases
...
...
test/infiniop/paged_attention_prefill.py
View file @
f06d6465
...
...
@@ -23,13 +23,20 @@ from libinfiniop import (
# Configuration (Internal Use Only)
# ==============================================================================
_TEST_CASES
=
[
# num_seqs, num_heads, num_kv_heads, head_size, block_size, max_step_len, num_rounds
(
1
,
1
,
1
,
128
,
8
,
16
,
1
),
(
1
,
4
,
4
,
128
,
8
,
16
,
4
),
(
2
,
8
,
8
,
128
,
16
,
32
,
2
),
(
4
,
16
,
16
,
128
,
8
,
64
,
3
),
(
8
,
64
,
64
,
128
,
8
,
16
,
5
),
(
16
,
128
,
128
,
128
,
8
,
16
,
4
),
# num_seqs, num_heads, num_kv_heads, head_size, block_size, max_step_len, num_rounds, index_dtypes
# index_dtype: The data type used for memory indexing of block_tables, cum_seq_lens and seq_lens
(
1
,
1
,
1
,
128
,
8
,
16
,
1
,
InfiniDtype
.
I32
),
(
1
,
1
,
1
,
128
,
8
,
16
,
1
,
InfiniDtype
.
I64
),
(
1
,
4
,
4
,
128
,
8
,
16
,
4
,
InfiniDtype
.
I32
),
(
1
,
4
,
4
,
128
,
8
,
16
,
4
,
InfiniDtype
.
I64
),
(
2
,
8
,
8
,
128
,
16
,
32
,
2
,
InfiniDtype
.
I32
),
(
2
,
8
,
8
,
128
,
16
,
32
,
2
,
InfiniDtype
.
I64
),
(
4
,
16
,
16
,
128
,
8
,
64
,
3
,
InfiniDtype
.
I32
),
(
4
,
16
,
16
,
128
,
8
,
64
,
3
,
InfiniDtype
.
I64
),
(
8
,
64
,
64
,
128
,
8
,
16
,
5
,
InfiniDtype
.
I32
),
(
8
,
64
,
64
,
128
,
8
,
16
,
5
,
InfiniDtype
.
I64
),
(
16
,
128
,
128
,
128
,
8
,
16
,
4
,
InfiniDtype
.
I32
),
(
16
,
128
,
128
,
128
,
8
,
16
,
4
,
InfiniDtype
.
I64
),
]
_TENSOR_DTYPES
=
[
InfiniDtype
.
BF16
,
InfiniDtype
.
F16
]
...
...
@@ -124,13 +131,15 @@ def test(
block_size
,
max_step_len
,
num_rounds
,
index_dtype
=
InfiniDtype
.
I64
,
dtype
=
InfiniDtype
.
F16
,
sync
=
None
,
):
print
(
f
"Testing PagedAttentionPrefill on
{
InfiniDeviceNames
[
device
]
}
with "
f
"seqs:
{
num_seqs
}
, heads:
{
num_heads
}
, head_size:
{
head_size
}
, "
f
"block:
{
block_size
}
, max_step_len:
{
max_step_len
}
, num_rounds:
{
num_rounds
}
, dtype:
{
InfiniDtypeNames
[
dtype
]
}
"
f
"block:
{
block_size
}
, max_step_len:
{
max_step_len
}
, num_rounds:
{
num_rounds
}
, dtype:
{
InfiniDtypeNames
[
dtype
]
}
, "
f
"index_dtype:
{
InfiniDtypeNames
[
index_dtype
]
}
"
)
# 1. Initialize persistent resources
...
...
@@ -194,23 +203,26 @@ def test(
out
=
TestTensor
.
from_torch
(
q_packed_tensors
,
dtype
,
device
)
out
.
actual_tensor
().
zero_
()
# 3. Referencing index_dtype to set torch dtype
torch_idx_type
=
torch
.
int32
if
index_dtype
==
InfiniDtype
.
I32
else
torch
.
int64
seq_lens
=
TestTensor
.
from_torch
(
torch
.
tensor
(
seq_lens_list
,
dtype
=
torch
.
int64
),
InfiniDtype
.
I64
,
device
torch
.
tensor
(
seq_lens_list
,
dtype
=
torch
_idx_type
),
index_dtype
,
device
)
cum_seq_lens_q
=
TestTensor
.
from_torch
(
torch
.
tensor
(
cum_seq_lens_q_list
,
dtype
=
torch
.
int64
),
InfiniDtype
.
I64
,
torch
.
tensor
(
cum_seq_lens_q_list
,
dtype
=
torch
_idx_type
),
index_dtype
,
device
,
)
max_blocks
=
max
(
len
(
t
)
for
t
in
all_block_tables
)
padded_tables
=
[
t
+
[
0
]
*
(
max_blocks
-
len
(
t
))
for
t
in
all_block_tables
]
block_tables
=
TestTensor
.
from_torch
(
torch
.
tensor
(
padded_tables
,
dtype
=
torch
.
int64
),
InfiniDtype
.
I64
,
device
torch
.
tensor
(
padded_tables
,
dtype
=
torch
_idx_type
),
index_dtype
,
device
)
#
3
. Reference Calculation
#
4
. Reference Calculation
def
torch_paged_attention_multi_turn
():
return
ref_paged_attention_multi_turn
(
q_new
.
torch_tensor
(),
...
...
@@ -224,7 +236,7 @@ def test(
ans
=
torch_paged_attention_multi_turn
()
#
4
. Infiniop Operator Execution
#
5
. Infiniop Operator Execution
descriptor
=
infiniopOperatorDescriptor_t
()
check_error
(
LIBINFINIOP
.
infiniopCreatePagedAttentionPrefillDescriptor
(
...
...
@@ -272,7 +284,7 @@ def test(
if
sync
:
sync
()
#
5
. Validation
#
6
. Validation
atol
,
rtol
=
get_tolerance
(
_TOLERANCE_MAP
,
dtype
)
if
DEBUG
:
debug
(
out
.
actual_tensor
(),
ans
,
atol
=
atol
,
rtol
=
rtol
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment