Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ccf02fcb
Unverified
Commit
ccf02fcb
authored
Mar 14, 2025
by
Tyler Michael Smith
Committed by
GitHub
Mar 14, 2025
Browse files
Revert "[Model] Mamba2 Prefill Performance Tweaks: Fixing Flurry of U… (#14848)
parent
acaea3bb
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
22 deletions
+8
-22
vllm/model_executor/layers/mamba/mamba_mixer2.py
vllm/model_executor/layers/mamba/mamba_mixer2.py
+8
-22
No files found.
vllm/model_executor/layers/mamba/mamba_mixer2.py
View file @
ccf02fcb
...
@@ -466,17 +466,10 @@ class MambaMixer2(CustomOp):
...
@@ -466,17 +466,10 @@ class MambaMixer2(CustomOp):
if
has_prefill
:
if
has_prefill
:
initial_states
=
None
initial_states
=
None
if
has_initial_states
is
not
None
and
any
(
has_initial_states
):
if
has_initial_states
is
not
None
and
torch
.
any
(
for
idx
in
mamba_cache_params
.
state_indices_tensor
[
has_initial_states
):
~
has_initial_states
]:
mamba_cache_params
.
ssm_state
[
idx
].
zero_
()
# vectorized ssm_state zero init
batched_zero_init_func
=
torch
.
vmap
(
lambda
idx
:
mamba_cache_params
.
ssm_state
[
idx
].
zero_
())
batched_zero_init_func
(
mamba_cache_params
.
state_indices_tensor
[
~
has_initial_states
].
unsqueeze
(
dim
=-
1
),
)
initial_states
=
mamba_cache_params
.
ssm_state
[
initial_states
=
mamba_cache_params
.
ssm_state
[
mamba_cache_params
.
state_indices_tensor
]
mamba_cache_params
.
state_indices_tensor
]
...
@@ -500,17 +493,10 @@ class MambaMixer2(CustomOp):
...
@@ -500,17 +493,10 @@ class MambaMixer2(CustomOp):
dt_limit
=
(
0.0
,
float
(
"inf"
)),
dt_limit
=
(
0.0
,
float
(
"inf"
)),
)
)
# vectorized ssm state update using vmap
# update ssm states
# the 1d state_indices_tensor needs to be unsqueezed to avoid vmap
# - varlen state is a (batch, nheads, headdim, dstate) tensor
# limitation which doesn't allow use of `item()`
for
i
,
idx
in
enumerate
(
mamba_cache_params
.
state_indices_tensor
):
# Note: the lambda capture can happen where ssm_state is initialized
mamba_cache_params
.
ssm_state
[
idx
].
copy_
(
varlen_state
[
i
])
# instead of here
batched_copy
=
torch
.
vmap
(
lambda
idx
,
source_state
:
mamba_cache_params
.
ssm_state
[
idx
].
copy_
(
source_state
))
batched_copy
(
mamba_cache_params
.
state_indices_tensor
.
unsqueeze
(
dim
=-
1
),
varlen_state
)
# - reshape
# - reshape
hidden_states
=
scan_output
.
view
(
seq_len
,
-
1
)
hidden_states
=
scan_output
.
view
(
seq_len
,
-
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment