Unverified Commit 4e6c4923 authored by Binyao Jiang's avatar Binyao Jiang Committed by GitHub
Browse files

[Performance] Qwen3-Next: speed up update_mamba_state_after_mtp_verify by 10x;...

[Performance] Qwen3-Next: speed up update_mamba_state_after_mtp_verify by 10x; e2e up to 3.54% faster (#10586)
parent b91cb67e
......@@ -583,36 +583,15 @@ class HybridLinearAttnBackend(AttentionBackend):
# Compute common indices once to avoid duplication
last_steps_all = (accepted_length - 1).to(torch.int64)
valid_state_indices = state_indices_tensor[valid_mask].to(torch.int64)
last_steps = last_steps_all[valid_mask].to(torch.int64)
if valid_state_indices.numel() > 0:
chunk = 256
num_valid = valid_state_indices.numel()
# SSM state updates
for i in range(0, num_valid, chunk):
idx = valid_state_indices[i : i + chunk]
steps = last_steps[i : i + chunk]
# per (cache line, step)
for j in range(idx.numel()):
ci = idx[j].item()
st = steps[j].item()
ssm_states[:, ci, :].copy_(
intermediate_state_cache[:, ci, st].to(
ssm_states.dtype, copy=False
)
)
# Conv window updates
for i in range(0, num_valid, chunk):
idx = valid_state_indices[i : i + chunk]
steps = last_steps[i : i + chunk]
for j in range(idx.numel()):
ci = idx[j].item()
st = steps[j].item()
conv_states[:, ci, :, :].copy_(
intermediate_conv_window_cache[:, ci, st].to(
conv_states.dtype, copy=False
)
)
valid_state_indices = state_indices_tensor[valid_mask].to(torch.int64) # [N]
last_steps = last_steps_all[valid_mask].to(torch.int64) # [N]
# scatter into ssm_states at the chosen cache lines
ssm_states[:, valid_state_indices, :] = intermediate_state_cache[
:, valid_state_indices, last_steps
].to(ssm_states.dtype, copy=False)
# Scatter into conv_states at the chosen cache lines
conv_states[:, valid_state_indices, :, :] = intermediate_conv_window_cache[
:, valid_state_indices, last_steps
].to(conv_states.dtype, copy=False)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment