Unverified Commit 296f927f authored by Chih-Chieh Yang's avatar Chih-Chieh Yang Committed by GitHub
Browse files

[Model] RE: Mamba2 Prefill Performance Tweaks: Fixing Flurry of Unnecessary Memory Copies (#14857)


Signed-off-by: default avatarChih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
parent 0032903a
...@@ -470,10 +470,11 @@ class MambaMixer2(CustomOp): ...@@ -470,10 +470,11 @@ class MambaMixer2(CustomOp):
if has_prefill: if has_prefill:
initial_states = None initial_states = None
if has_initial_states is not None and any(has_initial_states): if has_initial_states is not None and torch.any(
for idx in mamba_cache_params.state_indices_tensor[ has_initial_states):
~has_initial_states]: zero_init_indices = mamba_cache_params.state_indices_tensor[
mamba_cache_params.ssm_state[idx].zero_() ~has_initial_states]
mamba_cache_params.ssm_state[zero_init_indices] = 0
initial_states = mamba_cache_params.ssm_state[ initial_states = mamba_cache_params.ssm_state[
mamba_cache_params.state_indices_tensor] mamba_cache_params.state_indices_tensor]
...@@ -499,8 +500,8 @@ class MambaMixer2(CustomOp): ...@@ -499,8 +500,8 @@ class MambaMixer2(CustomOp):
# update ssm states # update ssm states
# - varlen state is a (batch, nheads, headdim, dstate) tensor # - varlen state is a (batch, nheads, headdim, dstate) tensor
for i, idx in enumerate(mamba_cache_params.state_indices_tensor): mamba_cache_params.ssm_state[
mamba_cache_params.ssm_state[idx].copy_(varlen_state[i]) mamba_cache_params.state_indices_tensor] = varlen_state
# - reshape # - reshape
hidden_states = scan_output.view(seq_len, -1) hidden_states = scan_output.view(seq_len, -1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment