Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
79cf758c
Commit
79cf758c
authored
Dec 29, 2025
by
PanZezhong
Browse files
issue/847 fix tests
parent
38078981
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
15 additions
and
19 deletions
+15
-19
test/infinicore/ops/paged_attention.py
test/infinicore/ops/paged_attention.py
+8
-12
test/infinicore/ops/paged_caching.py
test/infinicore/ops/paged_caching.py
+6
-6
test/infiniop/paged_attention.py
test/infiniop/paged_attention.py
+1
-1
No files found.
test/infinicore/ops/paged_attention.py
View file @
79cf758c
...
...
@@ -22,10 +22,10 @@ from framework import (
_TEST_CASES_DATA
=
[
# (num_seqs, num_heads, num_kv_heads, head_size, block_size, max_seq_len, use_alibi)
(
1
,
1
,
1
,
128
,
16
,
15
,
False
),
#
(4, 40, 40, 128, 16, 1024, False),
#
(6, 40, 40, 128, 16, 1024, False),
#
(3, 8, 8, 128, 16, 1024, False),
#
(8, 64, 8, 128, 16, 2048, False),
(
4
,
40
,
40
,
128
,
16
,
1024
,
False
),
(
6
,
40
,
40
,
128
,
16
,
1024
,
False
),
(
3
,
8
,
8
,
128
,
16
,
1024
,
False
),
(
8
,
64
,
8
,
128
,
16
,
2048
,
False
),
]
# Tolerance configuration
...
...
@@ -62,14 +62,10 @@ def parse_test_cases():
max_blocks_per_seq
=
(
max_seq_len
+
block_size
-
1
)
//
block_size
num_blocks
=
num_seqs
*
max_blocks_per_seq
# A reasonable number for testing
seq_lens_torch
=
torch
.
randint
(
1
,
1024
,
(
num_seqs
,),
dtype
=
torch
.
int32
)
# seq_lens_torch = torch.ones(
# (num_seqs,), dtype=torch.int32
# )
seq_lens_torch
=
torch
.
randint
(
1
,
max_seq_len
,
(
num_seqs
,),
dtype
=
torch
.
int64
)
block_tables
=
torch
.
arange
(
0
,
num_seqs
*
max_blocks_per_seq
,
dtype
=
torch
.
int
32
0
,
num_seqs
*
max_blocks_per_seq
,
dtype
=
torch
.
int
64
).
view
(
num_seqs
,
max_blocks_per_seq
)
print
(
"block_tables.shape"
,
block_tables
.
shape
,
block_tables
)
...
...
@@ -93,13 +89,13 @@ def parse_test_cases():
block_tables_shape
,
init_mode
=
TensorInitializer
.
MANUAL
,
set_tensor
=
block_tables
,
dtype
=
infinicore
.
int
32
,
dtype
=
infinicore
.
int
64
,
)
seq_lens_spec
=
TensorSpec
.
from_tensor
(
seq_lens_shape
,
init_mode
=
TensorInitializer
.
MANUAL
,
set_tensor
=
seq_lens_torch
,
dtype
=
infinicore
.
int
32
,
dtype
=
infinicore
.
int
64
,
)
# Paged attention operation: returns output tensor
...
...
test/infinicore/ops/paged_caching.py
View file @
79cf758c
...
...
@@ -84,7 +84,7 @@ def parse_test_cases():
# Create metadata: variable context lengths for each sequence in the batch
context_lens_torch
=
torch
.
randint
(
1
,
max_seq_len
+
1
,
(
num_seqs
,),
dtype
=
torch
.
int
32
1
,
max_seq_len
+
1
,
(
num_seqs
,),
dtype
=
torch
.
int
64
)
ntok
=
torch
.
sum
(
context_lens_torch
).
item
()
...
...
@@ -98,11 +98,11 @@ def parse_test_cases():
current_slot
+=
length
.
item
()
# Ensure we don't exceed the total number of slots in the cache
assert
current_slot
<=
num_blocks
*
block_size
,
(
"Not enough blocks in the cache pool for this test case"
)
assert
(
current_slot
<=
num_blocks
*
block_size
)
,
"Not enough blocks in the cache pool for this test case"
slot_mapping
=
torch
.
tensor
(
slot_mapping_list
,
dtype
=
torch
.
int
32
)
slot_mapping
=
torch
.
tensor
(
slot_mapping_list
,
dtype
=
torch
.
int
64
)
# print("slot_mapping", slot_mapping)
slot_mapping_shape
=
slot_mapping
.
shape
...
...
@@ -125,7 +125,7 @@ def parse_test_cases():
slot_mapping_shape
,
init_mode
=
TensorInitializer
.
MANUAL
,
set_tensor
=
slot_mapping
,
dtype
=
infinicore
.
int
32
,
dtype
=
infinicore
.
int
64
,
)
# In-place operation: modifies k_cache (index 2) and v_cache (index 3)
...
...
test/infiniop/paged_attention.py
View file @
79cf758c
...
...
@@ -148,7 +148,7 @@ def test(
(
num_blocks
,
num_kv_heads
,
block_size
,
head_size
),
None
,
dtype
,
device
)
seq_lens_torch
=
torch
.
randint
(
1
,
1024
,
(
num_seqs
,),
dtype
=
torch
.
int64
)
seq_lens_torch
=
torch
.
randint
(
1
,
max_seq_len
,
(
num_seqs
,),
dtype
=
torch
.
int64
)
seq_lens
=
TestTensor
.
from_torch
(
seq_lens_torch
,
InfiniDtype
.
I64
,
device
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment