Unverified Commit 296f927f authored by Chih-Chieh Yang's avatar Chih-Chieh Yang Committed by GitHub
Browse files

[Model] RE: Mamba2 Prefill Performance Tweaks: Fixing Flurry of Unnecessary Memory Copies (#14857)


Signed-off-by: default avatarChih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
parent 0032903a
......@@ -470,10 +470,11 @@ class MambaMixer2(CustomOp):
if has_prefill:
initial_states = None
if has_initial_states is not None and any(has_initial_states):
for idx in mamba_cache_params.state_indices_tensor[
~has_initial_states]:
mamba_cache_params.ssm_state[idx].zero_()
if has_initial_states is not None and torch.any(
has_initial_states):
zero_init_indices = mamba_cache_params.state_indices_tensor[
~has_initial_states]
mamba_cache_params.ssm_state[zero_init_indices] = 0
initial_states = mamba_cache_params.ssm_state[
mamba_cache_params.state_indices_tensor]
......@@ -499,8 +500,8 @@ class MambaMixer2(CustomOp):
# update ssm states
# - varlen state is a (batch, nheads, headdim, dstate) tensor
for i, idx in enumerate(mamba_cache_params.state_indices_tensor):
mamba_cache_params.ssm_state[idx].copy_(varlen_state[i])
mamba_cache_params.ssm_state[
mamba_cache_params.state_indices_tensor] = varlen_state
# - reshape
hidden_states = scan_output.view(seq_len, -1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment