Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ce9b3cd3
Unverified
Commit
ce9b3cd3
authored
Feb 07, 2026
by
whx
Committed by
GitHub
Feb 07, 2026
Browse files
[PluggableLayer][3/N] Apply PluggableLayer to mamba layers. (#33660)
Signed-off-by:
whx-sjtu
<
2952154980@qq.com
>
parent
db4ede97
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
13 additions
and
31 deletions
+13
-31
vllm/model_executor/layers/mamba/mamba_mixer.py
vllm/model_executor/layers/mamba/mamba_mixer.py
+5
-8
vllm/model_executor/layers/mamba/mamba_mixer2.py
vllm/model_executor/layers/mamba/mamba_mixer2.py
+3
-10
vllm/model_executor/models/plamo2.py
vllm/model_executor/models/plamo2.py
+5
-13
No files found.
vllm/model_executor/layers/mamba/mamba_mixer.py
View file @
ce9b3cd3
...
@@ -13,7 +13,7 @@ from vllm.distributed.parallel_state import (
...
@@ -13,7 +13,7 @@ from vllm.distributed.parallel_state import (
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
)
)
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.custom_op
import
PluggableLayer
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
ColumnParallelLinear
,
...
@@ -41,8 +41,8 @@ from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata
...
@@ -41,8 +41,8 @@ from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
# --8<-- [start:mamba_mixer]
# --8<-- [start:mamba_mixer]
@
CustomOp
.
register
(
"mamba_mixer"
)
@
PluggableLayer
.
register
(
"mamba_mixer"
)
class
MambaMixer
(
MambaBase
,
CustomOp
):
class
MambaMixer
(
MambaBase
,
PluggableLayer
):
"""
"""
Compute ∆, A, B, C, and D the state space parameters and compute
Compute ∆, A, B, C, and D the state space parameters and compute
the `contextualized_states`. A, D are input independent
the `contextualized_states`. A, D are input independent
...
@@ -230,10 +230,7 @@ class MambaMixer(MambaBase, CustomOp):
...
@@ -230,10 +230,7 @@ class MambaMixer(MambaBase, CustomOp):
self
.
prefix
,
self
.
prefix
,
)
)
def
forward_native
(
self
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
):
def
forward_impl
(
self
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
):
pass
def
forward_cuda
(
self
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
):
"""
"""
Run the Mamba-1 SSM pipeline.
Run the Mamba-1 SSM pipeline.
...
@@ -528,7 +525,7 @@ def mamba_mixer(
...
@@ -528,7 +525,7 @@ def mamba_mixer(
)
->
None
:
)
->
None
:
forward_context
:
ForwardContext
=
get_forward_context
()
forward_context
:
ForwardContext
=
get_forward_context
()
self
=
forward_context
.
no_compile_layers
[
layer_name
]
self
=
forward_context
.
no_compile_layers
[
layer_name
]
self
.
forward_
cuda
(
hidden_states
=
hidden_states
,
output
=
output
)
self
.
forward_
impl
(
hidden_states
=
hidden_states
,
output
=
output
)
def
mamba_mixer_fake
(
def
mamba_mixer_fake
(
...
...
vllm/model_executor/layers/mamba/mamba_mixer2.py
View file @
ce9b3cd3
...
@@ -14,7 +14,7 @@ from vllm.distributed import (
...
@@ -14,7 +14,7 @@ from vllm.distributed import (
tensor_model_parallel_all_reduce
,
tensor_model_parallel_all_reduce
,
)
)
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.custom_op
import
CustomOp
,
PluggableLayer
from
vllm.model_executor.layers.linear
import
(
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
ColumnParallelLinear
,
RowParallelLinear
,
RowParallelLinear
,
...
@@ -219,8 +219,8 @@ def mamba_v2_sharded_weight_loader(
...
@@ -219,8 +219,8 @@ def mamba_v2_sharded_weight_loader(
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
# --8<-- [start:mamba_mixer2]
# --8<-- [start:mamba_mixer2]
@
CustomOp
.
register
(
"mamba_mixer2"
)
@
PluggableLayer
.
register
(
"mamba_mixer2"
)
class
MambaMixer2
(
MambaBase
,
CustomOp
):
class
MambaMixer2
(
MambaBase
,
PluggableLayer
):
"""
"""
Compute ∆, A, B, C, and D the state space parameters and compute
Compute ∆, A, B, C, and D the state space parameters and compute
the `contextualized_states`. A, D are input independent
the `contextualized_states`. A, D are input independent
...
@@ -472,13 +472,6 @@ class MambaMixer2(MambaBase, CustomOp):
...
@@ -472,13 +472,6 @@ class MambaMixer2(MambaBase, CustomOp):
# Check if running on Blackwell (SM100+) for kernel tuning
# Check if running on Blackwell (SM100+) for kernel tuning
self
.
is_blackwell
=
current_platform
.
is_device_capability_family
(
100
)
self
.
is_blackwell
=
current_platform
.
is_device_capability_family
(
100
)
def
forward_native
(
self
,
hidden_states
:
torch
.
Tensor
,
mup_vector
:
torch
.
Tensor
|
None
=
None
,
):
pass
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/plamo2.py
View file @
ce9b3cd3
...
@@ -14,7 +14,7 @@ from vllm.config import VllmConfig, get_current_vllm_config
...
@@ -14,7 +14,7 @@ from vllm.config import VllmConfig, get_current_vllm_config
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.custom_op
import
PluggableLayer
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
...
@@ -107,8 +107,8 @@ def is_mamba(config: Plamo2Config, i: int) -> bool:
...
@@ -107,8 +107,8 @@ def is_mamba(config: Plamo2Config, i: int) -> bool:
# vllm.model_executor.layers.mamba.mamba_mixer2.MambaMixer2
# vllm.model_executor.layers.mamba.mamba_mixer2.MambaMixer2
# transformers.models.mamba.modeling_mamba.MambaMixer
# transformers.models.mamba.modeling_mamba.MambaMixer
# --8<-- [start:plamo2_mamba_mixer]
# --8<-- [start:plamo2_mamba_mixer]
@
CustomOp
.
register
(
"plamo2_mamba_mixer"
)
@
PluggableLayer
.
register
(
"plamo2_mamba_mixer"
)
class
Plamo2MambaMixer
(
MambaBase
,
CustomOp
):
class
Plamo2MambaMixer
(
MambaBase
,
PluggableLayer
):
# --8<-- [end:plamo2_mamba_mixer]
# --8<-- [end:plamo2_mamba_mixer]
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
*
,
prefix
:
str
=
""
,
**
kwargs
)
->
None
:
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
*
,
prefix
:
str
=
""
,
**
kwargs
)
->
None
:
...
@@ -233,14 +233,6 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
...
@@ -233,14 +233,6 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
dt
=
self
.
dt_proj
(
time_step
)
dt
=
self
.
dt_proj
(
time_step
)
return
B
,
C
,
dt
return
B
,
C
,
dt
def
forward_native
(
self
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
**
kwargs
,
):
pass
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
...
@@ -253,7 +245,7 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
...
@@ -253,7 +245,7 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
self
.
prefix
,
self
.
prefix
,
)
)
def
forward_
cuda
(
def
forward_
impl
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
...
@@ -494,7 +486,7 @@ def plamo2_mamba_mixer(
...
@@ -494,7 +486,7 @@ def plamo2_mamba_mixer(
)
->
None
:
)
->
None
:
forward_context
:
ForwardContext
=
get_forward_context
()
forward_context
:
ForwardContext
=
get_forward_context
()
self
=
forward_context
.
no_compile_layers
[
layer_name
]
self
=
forward_context
.
no_compile_layers
[
layer_name
]
self
.
forward_
cuda
(
hidden_states
=
hidden_states
,
output
=
output
)
self
.
forward_
impl
(
hidden_states
=
hidden_states
,
output
=
output
)
def
plamo2_mamba_mixer_fake
(
def
plamo2_mamba_mixer_fake
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment