Unverified Commit ce9b3cd3 authored by whx's avatar whx Committed by GitHub
Browse files

[PluggableLayer][3/N] Apply PluggableLayer to mamba layers. (#33660)


Signed-off-by: default avatarwhx-sjtu <2952154980@qq.com>
parent db4ede97
...@@ -13,7 +13,7 @@ from vllm.distributed.parallel_state import ( ...@@ -13,7 +13,7 @@ from vllm.distributed.parallel_state import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
from vllm.forward_context import ForwardContext, get_forward_context from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
...@@ -41,8 +41,8 @@ from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata ...@@ -41,8 +41,8 @@ from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer # Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
# --8<-- [start:mamba_mixer] # --8<-- [start:mamba_mixer]
@CustomOp.register("mamba_mixer") @PluggableLayer.register("mamba_mixer")
class MambaMixer(MambaBase, CustomOp): class MambaMixer(MambaBase, PluggableLayer):
""" """
Compute ∆, A, B, C, and D the state space parameters and compute Compute ∆, A, B, C, and D the state space parameters and compute
the `contextualized_states`. A, D are input independent the `contextualized_states`. A, D are input independent
...@@ -230,10 +230,7 @@ class MambaMixer(MambaBase, CustomOp): ...@@ -230,10 +230,7 @@ class MambaMixer(MambaBase, CustomOp):
self.prefix, self.prefix,
) )
def forward_native(self, hidden_states: torch.Tensor, output: torch.Tensor): def forward_impl(self, hidden_states: torch.Tensor, output: torch.Tensor):
pass
def forward_cuda(self, hidden_states: torch.Tensor, output: torch.Tensor):
""" """
Run the Mamba-1 SSM pipeline. Run the Mamba-1 SSM pipeline.
...@@ -528,7 +525,7 @@ def mamba_mixer( ...@@ -528,7 +525,7 @@ def mamba_mixer(
) -> None: ) -> None:
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
self.forward_cuda(hidden_states=hidden_states, output=output) self.forward_impl(hidden_states=hidden_states, output=output)
def mamba_mixer_fake( def mamba_mixer_fake(
......
...@@ -14,7 +14,7 @@ from vllm.distributed import ( ...@@ -14,7 +14,7 @@ from vllm.distributed import (
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from vllm.forward_context import ForwardContext, get_forward_context from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp, PluggableLayer
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
RowParallelLinear, RowParallelLinear,
...@@ -219,8 +219,8 @@ def mamba_v2_sharded_weight_loader( ...@@ -219,8 +219,8 @@ def mamba_v2_sharded_weight_loader(
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer # Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
# --8<-- [start:mamba_mixer2] # --8<-- [start:mamba_mixer2]
@CustomOp.register("mamba_mixer2") @PluggableLayer.register("mamba_mixer2")
class MambaMixer2(MambaBase, CustomOp): class MambaMixer2(MambaBase, PluggableLayer):
""" """
Compute ∆, A, B, C, and D the state space parameters and compute Compute ∆, A, B, C, and D the state space parameters and compute
the `contextualized_states`. A, D are input independent the `contextualized_states`. A, D are input independent
...@@ -472,13 +472,6 @@ class MambaMixer2(MambaBase, CustomOp): ...@@ -472,13 +472,6 @@ class MambaMixer2(MambaBase, CustomOp):
# Check if running on Blackwell (SM100+) for kernel tuning # Check if running on Blackwell (SM100+) for kernel tuning
self.is_blackwell = current_platform.is_device_capability_family(100) self.is_blackwell = current_platform.is_device_capability_family(100)
def forward_native(
self,
hidden_states: torch.Tensor,
mup_vector: torch.Tensor | None = None,
):
pass
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
......
...@@ -14,7 +14,7 @@ from vllm.config import VllmConfig, get_current_vllm_config ...@@ -14,7 +14,7 @@ from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import ForwardContext, get_forward_context from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
...@@ -107,8 +107,8 @@ def is_mamba(config: Plamo2Config, i: int) -> bool: ...@@ -107,8 +107,8 @@ def is_mamba(config: Plamo2Config, i: int) -> bool:
# vllm.model_executor.layers.mamba.mamba_mixer2.MambaMixer2 # vllm.model_executor.layers.mamba.mamba_mixer2.MambaMixer2
# transformers.models.mamba.modeling_mamba.MambaMixer # transformers.models.mamba.modeling_mamba.MambaMixer
# --8<-- [start:plamo2_mamba_mixer] # --8<-- [start:plamo2_mamba_mixer]
@CustomOp.register("plamo2_mamba_mixer") @PluggableLayer.register("plamo2_mamba_mixer")
class Plamo2MambaMixer(MambaBase, CustomOp): class Plamo2MambaMixer(MambaBase, PluggableLayer):
# --8<-- [end:plamo2_mamba_mixer] # --8<-- [end:plamo2_mamba_mixer]
def __init__(self, vllm_config: VllmConfig, *, prefix: str = "", **kwargs) -> None: def __init__(self, vllm_config: VllmConfig, *, prefix: str = "", **kwargs) -> None:
...@@ -233,14 +233,6 @@ class Plamo2MambaMixer(MambaBase, CustomOp): ...@@ -233,14 +233,6 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
dt = self.dt_proj(time_step) dt = self.dt_proj(time_step)
return B, C, dt return B, C, dt
def forward_native(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
**kwargs,
):
pass
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -253,7 +245,7 @@ class Plamo2MambaMixer(MambaBase, CustomOp): ...@@ -253,7 +245,7 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
self.prefix, self.prefix,
) )
def forward_cuda( def forward_impl(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
...@@ -494,7 +486,7 @@ def plamo2_mamba_mixer( ...@@ -494,7 +486,7 @@ def plamo2_mamba_mixer(
) -> None: ) -> None:
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
self.forward_cuda(hidden_states=hidden_states, output=output) self.forward_impl(hidden_states=hidden_states, output=output)
def plamo2_mamba_mixer_fake( def plamo2_mamba_mixer_fake(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment