Unverified Commit 68d72319 authored by Ryan Rock's avatar Ryan Rock Committed by GitHub
Browse files

[CI/Build] Fix test_prefix_prefill for AMD (#28905)


Signed-off-by: default avatarRyan Rock <ryan.rock@amd.com>
parent 2fd893b4
...@@ -174,11 +174,11 @@ def test_contexted_kv_attention( ...@@ -174,11 +174,11 @@ def test_contexted_kv_attention(
block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request) block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request)
b_seq_len = torch.tensor(seq_lens, dtype=torch.int32) b_seq_len = torch.tensor(seq_lens, dtype=torch.int32)
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.int32) b_ctx_len = torch.tensor(ctx_lens, dtype=torch.int32)
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.int32), dim=0) b_start_loc = torch.cumsum(torch.tensor([0] + query_lens), dim=0).to(torch.int32)
max_input_len = MAX_SEQ_LEN max_input_len = MAX_SEQ_LEN
# copy kv to cache # copy kv to cache
b_seq_start_loc = torch.cumsum( b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1]), dim=0).to(
torch.tensor([0] + seq_lens[:-1], dtype=torch.int32), dim=0 torch.int32
) )
for i in range(BS): for i in range(BS):
for j in range(query_lens[i]): for j in range(query_lens[i]):
...@@ -417,11 +417,11 @@ def test_contexted_kv_attention_alibi( ...@@ -417,11 +417,11 @@ def test_contexted_kv_attention_alibi(
block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request) block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request)
b_seq_len = torch.tensor(seq_lens, dtype=torch.int32) b_seq_len = torch.tensor(seq_lens, dtype=torch.int32)
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.int32) b_ctx_len = torch.tensor(ctx_lens, dtype=torch.int32)
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.int32), dim=0) b_start_loc = torch.cumsum(torch.tensor([0] + query_lens), dim=0).to(torch.int32)
max_input_len = MAX_SEQ_LEN max_input_len = MAX_SEQ_LEN
# copy kv to cache # copy kv to cache
b_seq_start_loc = torch.cumsum( b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1]), dim=0).to(
torch.tensor([0] + seq_lens[:-1], dtype=torch.int32), dim=0 torch.int32
) )
for i in range(BS): for i in range(BS):
for j in range(query_lens[i]): for j in range(query_lens[i]):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment