Unverified Commit 4c33d673 authored by Hui Liu's avatar Hui Liu Committed by GitHub
Browse files

[Bugfix] fix tmp_out and exp_sums dimensions (#17438)


Signed-off-by: default avatarHui Liu <96135754+hliuca@users.noreply.github.com>
parent cb234955
...@@ -289,7 +289,7 @@ def chunked_prefill_paged_decode( ...@@ -289,7 +289,7 @@ def chunked_prefill_paged_decode(
max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) // max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) //
_PARTITION_SIZE_ROCM) _PARTITION_SIZE_ROCM)
assert _PARTITION_SIZE_ROCM % block_size == 0 assert _PARTITION_SIZE_ROCM % block_size == 0
total_num_seq = query.shape[0] total_num_seq = block_table.shape[0]
tmp_output = torch.empty( tmp_output = torch.empty(
size=(total_num_seq, num_query_heads, max_num_partitions, size=(total_num_seq, num_query_heads, max_num_partitions,
head_size), head_size),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment