Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
8ab073b4
Commit
8ab073b4
authored
Mar 05, 2026
by
PanZezhong
Committed by
wooway777
Mar 05, 2026
Browse files
issue/1033 fix mha_varlen test
parent
f6496d44
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
18 deletions
+26
-18
test/infinicore/ops/mha_varlen.py
test/infinicore/ops/mha_varlen.py
+26
-18
No files found.
test/infinicore/ops/mha_varlen.py
View file @
8ab073b4
import
os
import
os
import
sys
import
sys
import
infinicore
import
torch
import
torch
import
infinicore
sys
.
path
.
insert
(
0
,
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
".."
))
sys
.
path
.
insert
(
0
,
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
".."
))
from
framework
import
(
from
framework
import
(
...
@@ -14,13 +15,17 @@ from framework import (
...
@@ -14,13 +15,17 @@ from framework import (
TestCase
,
TestCase
,
)
)
# Test Cases: (
num_seqs,
num_heads, num_kv_heads, head_size, block_size,
max_step_len, num_rounds
)
# Test Cases: (num_heads, num_kv_heads, head_size, block_size,
[request_batch]
)
_TEST_CASES_DATA
=
[
_TEST_CASES_DATA
=
[
(
1
,
1
,
1
,
128
,
256
,
16
,
1
),
(
1
,
1
,
128
,
256
,
[(
250
,),
(
7
,)]),
(
1
,
4
,
4
,
128
,
256
,
16
,
4
),
(
4
,
4
,
128
,
256
,
[(
250
,),
(
7
,)]),
(
2
,
8
,
8
,
128
,
256
,
16
,
2
),
(
1
,
1
,
128
,
256
,
[(
260
,
73
),
(
1
,
1
)]),
(
8
,
2
,
128
,
256
,
[(
250
,),
(
7
,)]),
(
8
,
2
,
128
,
256
,
[(
260
,
73
),
(
1
,
1
)]),
]
]
_MAX_SEQUENCE_LENGTH
=
8192
_TOLERANCE_MAP
=
{
_TOLERANCE_MAP
=
{
infinicore
.
float16
:
{
"atol"
:
1e-2
,
"rtol"
:
1e-2
},
infinicore
.
float16
:
{
"atol"
:
1e-2
,
"rtol"
:
1e-2
},
infinicore
.
bfloat16
:
{
"atol"
:
2e-2
,
"rtol"
:
2e-2
},
infinicore
.
bfloat16
:
{
"atol"
:
2e-2
,
"rtol"
:
2e-2
},
...
@@ -58,24 +63,24 @@ def parse_test_cases():
...
@@ -58,24 +63,24 @@ def parse_test_cases():
test_cases
=
[]
test_cases
=
[]
for
(
for
(
num_seqs
,
num_heads
,
num_heads
,
num_kv_heads
,
num_kv_heads
,
head_size
,
head_size
,
block_size
,
block_size
,
max_step_len
,
request_batches
,
num_rounds
,
)
in
_TEST_CASES_DATA
:
)
in
_TEST_CASES_DATA
:
scale
=
head_size
**-
0.5
scale
=
head_size
**-
0.5
num_blocks
=
512
num_blocks
=
512
manager
=
SimpleCacheManager
(
num_blocks
,
block_size
)
manager
=
SimpleCacheManager
(
num_blocks
,
block_size
)
num_seqs
=
len
(
request_batches
[
0
])
kv_lens
=
torch
.
zeros
(
num_seqs
,
dtype
=
torch
.
int32
)
kv_lens
=
torch
.
zeros
(
num_seqs
,
dtype
=
torch
.
int32
)
persistent_k
=
torch
.
zeros
((
num_blocks
,
num_kv_heads
,
block_size
,
head_size
))
persistent_k
=
torch
.
zeros
((
num_blocks
,
num_kv_heads
,
block_size
,
head_size
))
persistent_v
=
torch
.
zeros
((
num_blocks
,
num_kv_heads
,
block_size
,
head_size
))
persistent_v
=
torch
.
zeros
((
num_blocks
,
num_kv_heads
,
block_size
,
head_size
))
for
r
in
range
(
num_rounds
):
for
r
,
req
in
enumerate
(
request_batches
):
q_lens
=
torch
.
randint
(
1
,
max_step_len
+
1
,
(
num_seqs
,),
dtype
=
torch
.
int32
)
assert
len
(
req
)
==
num_seqs
,
"All requests should have the same length"
q_lens
=
torch
.
tensor
(
req
,
dtype
=
torch
.
int32
)
kv_lens
=
kv_lens
+
q_lens
kv_lens
=
kv_lens
+
q_lens
total_q_tokens
=
q_lens
.
sum
().
item
()
total_q_tokens
=
q_lens
.
sum
().
item
()
cum_seqlens_q
=
torch
.
zeros
(
num_seqs
+
1
,
dtype
=
torch
.
int32
)
cum_seqlens_q
=
torch
.
zeros
(
num_seqs
+
1
,
dtype
=
torch
.
int32
)
...
@@ -134,12 +139,6 @@ def parse_test_cases():
...
@@ -134,12 +139,6 @@ def parse_test_cases():
set_tensor
=
padded_tables
.
clone
(),
set_tensor
=
padded_tables
.
clone
(),
dtype
=
infinicore
.
int32
,
dtype
=
infinicore
.
int32
,
),
),
# TensorSpec.from_tensor(
# kv_lens.shape,
# init_mode=TensorInitializer.MANUAL,
# set_tensor=kv_lens.clone(),
# dtype=infinicore.int64,
# ),
TensorSpec
.
from_tensor
(
TensorSpec
.
from_tensor
(
cum_seqlens_q
.
shape
,
cum_seqlens_q
.
shape
,
init_mode
=
TensorInitializer
.
MANUAL
,
init_mode
=
TensorInitializer
.
MANUAL
,
...
@@ -155,8 +154,8 @@ def parse_test_cases():
...
@@ -155,8 +154,8 @@ def parse_test_cases():
],
],
kwargs
=
{
kwargs
=
{
"scale"
:
scale
,
"scale"
:
scale
,
"max_seqlen_q"
:
max_step_len
+
num_rounds
,
"max_seqlen_q"
:
_MAX_SEQUENCE_LENGTH
,
"max_seqlen_k"
:
max_step_len
+
num_rounds
,
"max_seqlen_k"
:
_MAX_SEQUENCE_LENGTH
,
},
},
tolerance
=
tolerance
,
tolerance
=
tolerance
,
description
=
f
"MHA_Varlen_Round_
{
r
}
_
{
str
(
dtype
).
split
(
'.'
)[
-
1
]
}
"
,
description
=
f
"MHA_Varlen_Round_
{
r
}
_
{
str
(
dtype
).
split
(
'.'
)[
-
1
]
}
"
,
...
@@ -191,6 +190,15 @@ def ref_paged_attention_multi_turn(
...
@@ -191,6 +190,15 @@ def ref_paged_attention_multi_turn(
K
=
torch
.
stack
(
keys
,
dim
=
0
)
K
=
torch
.
stack
(
keys
,
dim
=
0
)
V
=
torch
.
stack
(
values
,
dim
=
0
)
V
=
torch
.
stack
(
values
,
dim
=
0
)
q_heads
=
cur_q
.
shape
[
1
]
kv_heads
=
K
.
shape
[
1
]
assert
q_heads
%
kv_heads
==
0
group_size
=
q_heads
//
kv_heads
if
group_size
>
1
:
K
=
K
.
repeat_interleave
(
group_size
,
dim
=
1
)
V
=
V
.
repeat_interleave
(
group_size
,
dim
=
1
)
scores
=
torch
.
einsum
(
"qhd,khd->hqk"
,
cur_q
.
float
(),
K
.
float
())
*
scale
scores
=
torch
.
einsum
(
"qhd,khd->hqk"
,
cur_q
.
float
(),
K
.
float
())
*
scale
mask
=
torch
.
full
((
q_len
,
total_len
),
float
(
"-inf"
),
device
=
query
.
device
)
mask
=
torch
.
full
((
q_len
,
total_len
),
float
(
"-inf"
),
device
=
query
.
device
)
for
t
in
range
(
q_len
):
for
t
in
range
(
q_len
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment