"ssh:/git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "269bf46d99f1df74e4d779f9c52c74002e057a17"
Unverified Commit 1fe46216 authored by Huamin Li's avatar Huamin Li Committed by GitHub
Browse files

[perf] Avoid dtype promotion sync in mamba_get_block_table_tensor (#34870)


Signed-off-by: default avatarHuamin Li <3ericli@gmail.com>
Co-authored-by: default avatarCyrus Leung <tlleungac@connect.ust.hk>
parent ed31a020
...@@ -855,8 +855,12 @@ def mamba_get_block_table_tensor( ...@@ -855,8 +855,12 @@ def mamba_get_block_table_tensor(
(seq_lens - 1) // kv_cache_spec.block_size, (seq_lens - 1) // kv_cache_spec.block_size,
min=0, min=0,
) )
# Use int32 for arithmetic to avoid dtype promotion overhead,
# then convert to int64 for gather (which requires Long indices)
offsets = torch.arange( offsets = torch.arange(
1 + kv_cache_spec.num_speculative_blocks, device=block_table.device 1 + kv_cache_spec.num_speculative_blocks,
device=block_table.device,
dtype=torch.int32,
) )
indices_to_gather = start_indices.unsqueeze(1) + offsets indices_to_gather = (start_indices.unsqueeze(1) + offsets).to(torch.int64)
return torch.gather(block_table, 1, indices_to_gather) return torch.gather(block_table, 1, indices_to_gather)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment