"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "2f4e6548efec402b913ffddc8726230d9311948d"
Unverified Commit 1fe46216 authored by Huamin Li's avatar Huamin Li Committed by GitHub
Browse files

[perf] Avoid dtype promotion sync in mamba_get_block_table_tensor (#34870)


Signed-off-by: default avatarHuamin Li <3ericli@gmail.com>
Co-authored-by: default avatarCyrus Leung <tlleungac@connect.ust.hk>
parent ed31a020
......@@ -855,8 +855,12 @@ def mamba_get_block_table_tensor(
(seq_lens - 1) // kv_cache_spec.block_size,
min=0,
)
# Use int32 for arithmetic to avoid dtype promotion overhead,
# then convert to int64 for gather (which requires Long indices)
offsets = torch.arange(
1 + kv_cache_spec.num_speculative_blocks, device=block_table.device
1 + kv_cache_spec.num_speculative_blocks,
device=block_table.device,
dtype=torch.int32,
)
indices_to_gather = start_indices.unsqueeze(1) + offsets
indices_to_gather = (start_indices.unsqueeze(1) + offsets).to(torch.int64)
return torch.gather(block_table, 1, indices_to_gather)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment