Commit 79cf758c authored by PanZezhong's avatar PanZezhong
Browse files

issue/847 fix tests

parent 38078981
...@@ -22,10 +22,10 @@ from framework import ( ...@@ -22,10 +22,10 @@ from framework import (
_TEST_CASES_DATA = [ _TEST_CASES_DATA = [
# (num_seqs, num_heads, num_kv_heads, head_size, block_size, max_seq_len, use_alibi) # (num_seqs, num_heads, num_kv_heads, head_size, block_size, max_seq_len, use_alibi)
(1, 1, 1, 128, 16, 15, False), (1, 1, 1, 128, 16, 15, False),
# (4, 40, 40, 128, 16, 1024, False), (4, 40, 40, 128, 16, 1024, False),
# (6, 40, 40, 128, 16, 1024, False), (6, 40, 40, 128, 16, 1024, False),
# (3, 8, 8, 128, 16, 1024, False), (3, 8, 8, 128, 16, 1024, False),
# (8, 64, 8, 128, 16, 2048, False), (8, 64, 8, 128, 16, 2048, False),
] ]
# Tolerance configuration # Tolerance configuration
...@@ -62,14 +62,10 @@ def parse_test_cases(): ...@@ -62,14 +62,10 @@ def parse_test_cases():
max_blocks_per_seq = (max_seq_len + block_size - 1) // block_size max_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
num_blocks = num_seqs * max_blocks_per_seq # A reasonable number for testing num_blocks = num_seqs * max_blocks_per_seq # A reasonable number for testing
seq_lens_torch = torch.randint(1, 1024, (num_seqs,), dtype=torch.int32) seq_lens_torch = torch.randint(1, max_seq_len, (num_seqs,), dtype=torch.int64)
# seq_lens_torch = torch.ones(
# (num_seqs,), dtype=torch.int32
# )
block_tables = torch.arange( block_tables = torch.arange(
0, num_seqs * max_blocks_per_seq, dtype=torch.int32 0, num_seqs * max_blocks_per_seq, dtype=torch.int64
).view(num_seqs, max_blocks_per_seq) ).view(num_seqs, max_blocks_per_seq)
print("block_tables.shape", block_tables.shape, block_tables) print("block_tables.shape", block_tables.shape, block_tables)
...@@ -93,13 +89,13 @@ def parse_test_cases(): ...@@ -93,13 +89,13 @@ def parse_test_cases():
block_tables_shape, block_tables_shape,
init_mode=TensorInitializer.MANUAL, init_mode=TensorInitializer.MANUAL,
set_tensor=block_tables, set_tensor=block_tables,
dtype=infinicore.int32, dtype=infinicore.int64,
) )
seq_lens_spec = TensorSpec.from_tensor( seq_lens_spec = TensorSpec.from_tensor(
seq_lens_shape, seq_lens_shape,
init_mode=TensorInitializer.MANUAL, init_mode=TensorInitializer.MANUAL,
set_tensor=seq_lens_torch, set_tensor=seq_lens_torch,
dtype=infinicore.int32, dtype=infinicore.int64,
) )
# Paged attention operation: returns output tensor # Paged attention operation: returns output tensor
......
...@@ -84,7 +84,7 @@ def parse_test_cases(): ...@@ -84,7 +84,7 @@ def parse_test_cases():
# Create metadata: variable context lengths for each sequence in the batch # Create metadata: variable context lengths for each sequence in the batch
context_lens_torch = torch.randint( context_lens_torch = torch.randint(
1, max_seq_len + 1, (num_seqs,), dtype=torch.int32 1, max_seq_len + 1, (num_seqs,), dtype=torch.int64
) )
ntok = torch.sum(context_lens_torch).item() ntok = torch.sum(context_lens_torch).item()
...@@ -98,11 +98,11 @@ def parse_test_cases(): ...@@ -98,11 +98,11 @@ def parse_test_cases():
current_slot += length.item() current_slot += length.item()
# Ensure we don't exceed the total number of slots in the cache # Ensure we don't exceed the total number of slots in the cache
assert current_slot <= num_blocks * block_size, ( assert (
"Not enough blocks in the cache pool for this test case" current_slot <= num_blocks * block_size
) ), "Not enough blocks in the cache pool for this test case"
slot_mapping = torch.tensor(slot_mapping_list, dtype=torch.int32) slot_mapping = torch.tensor(slot_mapping_list, dtype=torch.int64)
# print("slot_mapping", slot_mapping) # print("slot_mapping", slot_mapping)
slot_mapping_shape = slot_mapping.shape slot_mapping_shape = slot_mapping.shape
...@@ -125,7 +125,7 @@ def parse_test_cases(): ...@@ -125,7 +125,7 @@ def parse_test_cases():
slot_mapping_shape, slot_mapping_shape,
init_mode=TensorInitializer.MANUAL, init_mode=TensorInitializer.MANUAL,
set_tensor=slot_mapping, set_tensor=slot_mapping,
dtype=infinicore.int32, dtype=infinicore.int64,
) )
# In-place operation: modifies k_cache (index 2) and v_cache (index 3) # In-place operation: modifies k_cache (index 2) and v_cache (index 3)
......
...@@ -148,7 +148,7 @@ def test( ...@@ -148,7 +148,7 @@ def test(
(num_blocks, num_kv_heads, block_size, head_size), None, dtype, device (num_blocks, num_kv_heads, block_size, head_size), None, dtype, device
) )
seq_lens_torch = torch.randint(1, 1024, (num_seqs,), dtype=torch.int64) seq_lens_torch = torch.randint(1, max_seq_len, (num_seqs,), dtype=torch.int64)
seq_lens = TestTensor.from_torch(seq_lens_torch, InfiniDtype.I64, device) seq_lens = TestTensor.from_torch(seq_lens_torch, InfiniDtype.I64, device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment