Unverified Commit eb58f595 authored by Chengji Yao's avatar Chengji Yao Committed by GitHub
Browse files

[TPU][Bugfix] fix test_pallas (#20666)


Signed-off-by: default avatarChengji Yao <chengjiyao@google.com>
parent 4ac9c33f
...@@ -50,6 +50,7 @@ def test_ragged_paged_attention(): ...@@ -50,6 +50,7 @@ def test_ragged_paged_attention():
slot_mapping = torch.zeros((3, num_tokens), dtype=torch.int64) slot_mapping = torch.zeros((3, num_tokens), dtype=torch.int64)
max_num_reqs = 8 max_num_reqs = 8
max_num_blocks_per_req = 8 max_num_blocks_per_req = 8
num_kv_update_slices = torch.tensor([num_tokens], dtype=torch.int32)
block_tables = torch.zeros((max_num_reqs, max_num_blocks_per_req), block_tables = torch.zeros((max_num_reqs, max_num_blocks_per_req),
dtype=torch.int32) dtype=torch.int32)
context_lens = torch.ones((max_num_reqs, ), dtype=torch.int32) context_lens = torch.ones((max_num_reqs, ), dtype=torch.int32)
...@@ -65,6 +66,7 @@ def test_ragged_paged_attention(): ...@@ -65,6 +66,7 @@ def test_ragged_paged_attention():
context_lens=context_lens, context_lens=context_lens,
query_start_loc=query_start_loc, query_start_loc=query_start_loc,
num_seqs=num_seqs, num_seqs=num_seqs,
num_kv_update_slices=num_kv_update_slices,
num_slices_per_kv_cache_update_block=8, num_slices_per_kv_cache_update_block=8,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment