Unverified Commit ecf5ff7c authored by roikoren755's avatar roikoren755 Committed by GitHub
Browse files

[Mamba] Flashinfer selective_state_update (#36162)


Signed-off-by: default avatarRoi Koren <roik@nvidia.com>
parent 30679319
......@@ -8,4 +8,5 @@ server_args: >-
--max-model-len 4096
--tensor-parallel-size 8
--enable-expert-parallel
--mamba-backend flashinfer
--speculative-config '{"method":"mtp","num_speculative_tokens":5}'
......@@ -8,4 +8,5 @@ server_args: >-
--max-model-len 4096
--tensor-parallel-size 2
--enable-expert-parallel
--mamba-backend flashinfer
--speculative-config '{"method":"mtp","num_speculative_tokens":5}'
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from vllm.config.mamba import MambaBackendEnum, MambaConfig
from vllm.model_executor.layers.mamba.ops.ssu_dispatch import (
FlashInferSSUBackend,
TritonSSUBackend,
get_mamba_ssu_backend,
initialize_mamba_ssu_backend,
selective_state_update,
)
from vllm.utils.torch_utils import set_random_seed
try:
import flashinfer.mamba # noqa: F401
HAS_FLASHINFER = True
except ImportError:
HAS_FLASHINFER = False
def test_default_backend_is_triton():
initialize_mamba_ssu_backend(MambaConfig())
backend = get_mamba_ssu_backend()
assert isinstance(backend, TritonSSUBackend)
assert backend.name == "triton"
def test_explicit_triton_backend():
initialize_mamba_ssu_backend(MambaConfig(backend=MambaBackendEnum.TRITON))
backend = get_mamba_ssu_backend()
assert isinstance(backend, TritonSSUBackend)
@pytest.mark.skipif(not HAS_FLASHINFER, reason="flashinfer not installed")
def test_flashinfer_backend_init():
initialize_mamba_ssu_backend(MambaConfig(backend=MambaBackendEnum.FLASHINFER))
backend = get_mamba_ssu_backend()
assert isinstance(backend, FlashInferSSUBackend)
assert backend.name == "flashinfer"
def test_uninitialized_backend_raises():
import vllm.model_executor.layers.mamba.ops.ssu_dispatch as mod
old = mod._mamba_ssu_backend
mod._mamba_ssu_backend = None
with pytest.raises(RuntimeError, match="not been initialized"):
get_mamba_ssu_backend()
mod._mamba_ssu_backend = old
@pytest.mark.skipif(HAS_FLASHINFER, reason="flashinfer is installed")
def test_flashinfer_import_error():
with pytest.raises(ImportError, match="FlashInfer is required"):
FlashInferSSUBackend(MambaConfig())
def test_triton_basic_call():
set_random_seed(0)
initialize_mamba_ssu_backend(MambaConfig(backend=MambaBackendEnum.TRITON))
device = "cuda"
batch_size = 2
dim = 64
dstate = 16
state = torch.randn(batch_size, dim, dstate, device=device)
x = torch.randn(batch_size, dim, device=device)
out = torch.empty_like(x)
dt = torch.randn(batch_size, dim, device=device)
dt_bias = torch.rand(dim, device=device) - 4.0
A = -torch.rand(dim, dstate, device=device)
B = torch.randn(batch_size, dstate, device=device)
C = torch.randn(batch_size, dstate, device=device)
D = torch.randn(dim, device=device)
selective_state_update(
state,
x,
dt,
A,
B,
C,
D=D,
dt_bias=dt_bias,
dt_softplus=True,
out=out,
)
assert not torch.isnan(out).any()
......@@ -16,6 +16,7 @@ from vllm.config.kv_events import KVEventsConfig
from vllm.config.kv_transfer import KVTransferConfig
from vllm.config.load import LoadConfig
from vllm.config.lora import LoRAConfig
from vllm.config.mamba import MambaConfig
from vllm.config.model import (
ModelConfig,
iter_architecture_defaults,
......@@ -83,6 +84,8 @@ __all__ = [
"LoadConfig",
# From vllm.config.lora
"LoRAConfig",
# From vllm.config.mamba
"MambaConfig",
# From vllm.config.model
"ModelConfig",
"iter_architecture_defaults",
......
......@@ -123,14 +123,6 @@ class CacheConfig:
- "align": only cache the mamba state of the last token of each scheduler step and
when the token is at position i * block_size.
"""
enable_mamba_cache_stochastic_rounding: bool = False
"""Enable stochastic rounding when writing SSM state to fp16 cache.
Uses random bits to unbias the rounding error, which can improve
numerical stability for long sequences."""
mamba_cache_philox_rounds: int = 0
"""Number of Philox PRNG rounds for stochastic rounding random number
generation. 0 uses the Triton default. Higher values improve randomness
quality at the cost of compute."""
# Will be set after profiling.
num_gpu_blocks: int | None = field(default=None, init=False)
......@@ -258,29 +250,3 @@ class CacheConfig:
str(cache_dtype),
)
return cache_dtype
def __post_init__(self):
if self.enable_mamba_cache_stochastic_rounding:
from vllm.platforms import current_platform
if not current_platform.is_cuda():
raise ValueError(
"Stochastic rounding for Mamba cache is only supported "
"on NVIDIA CUDA platforms. Please do not specify "
"`--enable-mamba-cache-stochastic-rounding`."
)
if not current_platform.is_device_capability_family(100):
raise ValueError(
"Stochastic rounding for Mamba cache requires compute "
"capability 10.0 (data center Blackwell). The `cvt.rs` PTX "
"instruction is not supported on your GPU. Please do not specify "
"`--enable-mamba-cache-stochastic-rounding`."
)
if self.mamba_ssm_cache_dtype != "float16":
raise ValueError(
"Stochastic rounding for Mamba cache requires "
"the SSM cache to be float16. Please set it explicitly, "
"by specifying `--mamba-ssm-cache-dtype float16`, or disable "
"stochastic rounding by not specifying "
"`--enable-mamba-cache-stochastic-rounding`."
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import Enum, EnumMeta
from typing import Any
from pydantic import field_validator
from vllm.config.utils import config
class _MambaBackendEnumMeta(EnumMeta):
"""Metaclass for MambaBackendEnum to provide better error messages."""
def __getitem__(cls, name: str):
try:
return super().__getitem__(name)
except KeyError:
valid = ", ".join(cls.__members__.keys())
raise ValueError(
f"Unknown Mamba SSU backend: '{name}'. Valid options are: {valid}"
) from None
class MambaBackendEnum(Enum, metaclass=_MambaBackendEnumMeta):
"""Enumeration of supported Mamba SSU (selective state update) backends."""
TRITON = "triton"
FLASHINFER = "flashinfer"
@config
class MambaConfig:
"""Configuration for Mamba SSM backends."""
backend: MambaBackendEnum = MambaBackendEnum.TRITON
"""Mamba SSU backend to use."""
enable_stochastic_rounding: bool = False
"""Enable stochastic rounding when writing SSM state to fp16 cache.
Uses random bits to unbias the rounding error, which can improve
numerical stability for long sequences."""
stochastic_rounding_philox_rounds: int = 0
"""Number of Philox PRNG rounds for stochastic rounding random number
generation. 0 uses the Triton default. Higher values improve randomness
quality at the cost of compute."""
@field_validator("backend", mode="before")
@classmethod
def validate_backend_before(cls, value: Any) -> Any:
"""Enable parsing of the `backend` enum type from string."""
if isinstance(value, str):
return MambaBackendEnum[value.upper()]
return value
def __post_init__(self):
if self.enable_stochastic_rounding:
from vllm.platforms import current_platform
if not current_platform.is_cuda():
raise ValueError(
"Stochastic rounding for Mamba cache is only supported "
"on NVIDIA CUDA platforms. Please do not specify "
"`--enable-mamba-cache-stochastic-rounding`."
)
if (
self.backend == MambaBackendEnum.TRITON
and not current_platform.is_device_capability_family(100)
):
raise ValueError(
"Stochastic rounding for Mamba cache with triton backend requires "
"compute capability 10.0 (data center Blackwell). The `cvt.rs` "
"PTX instruction is not supported on your GPU. Please do not "
"specify `--enable-mamba-cache-stochastic-rounding`, "
"or set `--mamba-backend flashinfer`."
)
......@@ -37,6 +37,7 @@ from .kv_events import KVEventsConfig
from .kv_transfer import KVTransferConfig
from .load import LoadConfig
from .lora import LoRAConfig
from .mamba import MambaConfig
from .model import ModelConfig
from .observability import ObservabilityConfig
from .offload import OffloadConfig
......@@ -275,6 +276,8 @@ class VllmConfig:
"""Model weight offloading configuration."""
attention_config: AttentionConfig = Field(default_factory=AttentionConfig)
"""Attention configuration."""
mamba_config: MambaConfig = Field(default_factory=MambaConfig)
"""Mamba configuration."""
kernel_config: KernelConfig = Field(default_factory=KernelConfig)
"""Kernel configuration."""
lora_config: LoRAConfig | None = None
......@@ -717,6 +720,18 @@ class VllmConfig:
if self.lora_config is not None:
self.lora_config.verify_with_model_config(self.model_config)
if (
self.mamba_config.enable_stochastic_rounding
and self.cache_config.mamba_ssm_cache_dtype != "float16"
):
raise ValueError(
"Stochastic rounding for Mamba cache requires "
"the SSM cache to be float16. Please set it explicitly, "
"by specifying `--mamba-ssm-cache-dtype float16`, or disable "
"stochastic rounding by not specifying "
"`--enable-mamba-cache-stochastic-rounding`."
)
if self.quant_config is None and self.model_config is not None:
self.quant_config = VllmConfig._get_quantization_config(
self.model_config, self.load_config
......
......@@ -45,6 +45,7 @@ from vllm.config import (
KVTransferConfig,
LoadConfig,
LoRAConfig,
MambaConfig,
ModelConfig,
MultiModalConfig,
ObservabilityConfig,
......@@ -72,6 +73,7 @@ from vllm.config.cache import (
from vllm.config.device import Device
from vllm.config.kernel import IrOpPriorityConfig, MoEBackend
from vllm.config.lora import MaxLoRARanks
from vllm.config.mamba import MambaBackendEnum
from vllm.config.model import (
ConvertOption,
HfOverrides,
......@@ -578,6 +580,7 @@ class EngineArgs:
pooler_config: PoolerConfig | None = ModelConfig.pooler_config
compilation_config: CompilationConfig = get_field(VllmConfig, "compilation_config")
attention_config: AttentionConfig = get_field(VllmConfig, "attention_config")
mamba_config: MambaConfig = get_field(VllmConfig, "mamba_config")
kernel_config: KernelConfig = get_field(VllmConfig, "kernel_config")
enable_flashinfer_autotune: bool = get_field(
KernelConfig, "enable_flashinfer_autotune"
......@@ -610,10 +613,12 @@ class EngineArgs:
mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype
mamba_block_size: int | None = get_field(CacheConfig, "mamba_block_size")
mamba_cache_mode: MambaCacheMode = CacheConfig.mamba_cache_mode
mamba_backend: MambaBackendEnum = MambaBackendEnum.TRITON
enable_mamba_cache_stochastic_rounding: bool = (
CacheConfig.enable_mamba_cache_stochastic_rounding
MambaConfig.enable_stochastic_rounding
)
mamba_cache_philox_rounds: int = CacheConfig.mamba_cache_philox_rounds
mamba_cache_philox_rounds: int = MambaConfig.stochastic_rounding_philox_rounds
additional_config: dict[str, Any] = get_field(VllmConfig, "additional_config")
......@@ -655,6 +660,8 @@ class EngineArgs:
self.compilation_config = CompilationConfig(**self.compilation_config)
if isinstance(self.attention_config, dict):
self.attention_config = AttentionConfig(**self.attention_config)
if isinstance(self.mamba_config, dict):
self.mamba_config = MambaConfig(**self.mamba_config)
if isinstance(self.kernel_config, dict):
self.kernel_config = KernelConfig(**self.kernel_config)
if isinstance(self.eplb_config, dict):
......@@ -825,6 +832,22 @@ class EngineArgs:
"--attention-backend", **attention_kwargs["backend"]
)
# Mamba arguments
mamba_kwargs = get_kwargs(MambaConfig)
mamba_group = parser.add_argument_group(
title="MambaConfig",
description=MambaConfig.__doc__,
)
mamba_group.add_argument("--mamba-backend", **mamba_kwargs["backend"])
mamba_group.add_argument(
"--enable-mamba-cache-stochastic-rounding",
**mamba_kwargs["enable_stochastic_rounding"],
)
mamba_group.add_argument(
"--mamba-cache-philox-rounds",
**mamba_kwargs["stochastic_rounding_philox_rounds"],
)
# Structured outputs arguments
structured_outputs_kwargs = get_kwargs(StructuredOutputsConfig)
structured_outputs_group = parser.add_argument_group(
......@@ -1050,13 +1073,6 @@ class EngineArgs:
cache_group.add_argument(
"--mamba-cache-mode", **cache_kwargs["mamba_cache_mode"]
)
cache_group.add_argument(
"--enable-mamba-cache-stochastic-rounding",
**cache_kwargs["enable_mamba_cache_stochastic_rounding"],
)
cache_group.add_argument(
"--mamba-cache-philox-rounds", **cache_kwargs["mamba_cache_philox_rounds"]
)
cache_group.add_argument(
"--kv-offloading-size", **cache_kwargs["kv_offloading_size"]
)
......@@ -1622,8 +1638,6 @@ class EngineArgs:
mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
mamba_block_size=self.mamba_block_size,
mamba_cache_mode=self.mamba_cache_mode,
enable_mamba_cache_stochastic_rounding=self.enable_mamba_cache_stochastic_rounding,
mamba_cache_philox_rounds=self.mamba_cache_philox_rounds,
kv_offloading_size=self.kv_offloading_size,
kv_offloading_backend=self.kv_offloading_backend,
)
......@@ -1934,6 +1948,22 @@ class EngineArgs:
self.attention_backend
)
# Mamba config overrides
mamba_config = copy.deepcopy(self.mamba_config)
# Convert string to enum if needed (CLI parsing returns a string)
if isinstance(self.mamba_backend, str):
mamba_config.backend = MambaBackendEnum[self.mamba_backend.upper()]
else:
mamba_config.backend = self.mamba_backend
if self.enable_mamba_cache_stochastic_rounding:
mamba_config.enable_stochastic_rounding = (
self.enable_mamba_cache_stochastic_rounding
)
if self.mamba_cache_philox_rounds:
mamba_config.stochastic_rounding_philox_rounds = (
self.mamba_cache_philox_rounds
)
# Kernel config overrides
kernel_config = copy.deepcopy(self.kernel_config)
if self.enable_flashinfer_autotune is not None:
......@@ -2032,6 +2062,7 @@ class EngineArgs:
load_config=load_config,
offload_config=offload_config,
attention_config=attention_config,
mamba_config=mamba_config,
kernel_config=kernel_config,
lora_config=lora_config,
speculative_config=speculative_config,
......
......@@ -30,10 +30,8 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn,
causal_conv1d_update,
)
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_scan_fn,
selective_state_update,
)
from vllm.model_executor.layers.mamba.ops.mamba_ssm import selective_scan_fn
from vllm.model_executor.layers.mamba.ops.ssu_dispatch import selective_state_update
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils.torch_utils import (
......@@ -431,14 +429,12 @@ class MambaMixer(MambaBase, PluggableLayer):
B_d,
C_d,
self.D,
gate_d.transpose(0, 1),
time_proj_bias,
z=gate_d.transpose(0, 1),
dt_softplus=True,
state_batch_indices=state_indices_tensor_d_input,
dst_state_batch_indices=state_indices_tensor_d_output,
out=scan_outputs_d,
enable_stochastic_rounding=self.cache_config.enable_mamba_cache_stochastic_rounding,
cache_philox_rounds=self.cache_config.mamba_cache_philox_rounds,
)
scan_outputs_d = scan_outputs_d.transpose(0, 1)
......
......@@ -31,10 +31,10 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_update,
)
from vllm.model_executor.layers.mamba.ops.layernorm_gated import rms_norm_gated
from vllm.model_executor.layers.mamba.ops.mamba_ssm import selective_state_update
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
mamba_chunk_scan_combined_varlen,
)
from vllm.model_executor.layers.mamba.ops.ssu_dispatch import selective_state_update
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import (
LoaderFunction,
......@@ -890,8 +890,7 @@ class MambaMixer2(MambaBase, PluggableLayer):
B_d,
C_d,
D_d,
z=None,
dt_bias=dt_bias,
dt_bias,
dt_softplus=True,
state_batch_indices=state_indices_tensor_d_input,
dst_state_batch_indices=state_indices_tensor_d_output,
......@@ -899,8 +898,6 @@ class MambaMixer2(MambaBase, PluggableLayer):
num_accepted_tokens=num_accepted_tokens,
cu_seqlens=query_start_loc_d,
is_blackwell=self.is_blackwell,
enable_stochastic_rounding=self.cache_config.enable_mamba_cache_stochastic_rounding,
cache_philox_rounds=self.cache_config.mamba_cache_philox_rounds,
)
def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:
......
......@@ -323,9 +323,9 @@ def selective_state_update(
A,
B,
C,
D=None,
D,
dt_bias,
z=None,
dt_bias=None,
dt_softplus=False,
state_batch_indices=None,
dst_state_batch_indices=None,
......@@ -374,11 +374,11 @@ def selective_state_update(
B = B.unsqueeze(1)
if C.dim() == 2:
C = C.unsqueeze(1)
if D is not None and D.dim() == 1:
if D.dim() == 1:
D = D.unsqueeze(0)
if z is not None and z.dim() == 2:
z = z.unsqueeze(1)
if dt_bias is not None and dt_bias.dim() == 1:
if dt_bias.dim() == 1:
dt_bias = dt_bias.unsqueeze(0)
if out.dim() == 2:
out = out.unsqueeze(1)
......@@ -410,11 +410,9 @@ def selective_state_update(
assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
assert B.shape == (batch, ngroups, dstate)
assert C.shape == B.shape
if D is not None:
assert D.shape == (nheads, dim)
if z is not None:
assert z.shape == x.shape
if dt_bias is not None:
assert dt_bias.shape == (nheads, dim)
if state_batch_indices is not None:
assert state_batch_indices.shape[0] >= N
......@@ -506,7 +504,8 @@ def selective_state_update(
dt.stride(0),
dt.stride(1),
dt.stride(2),
*(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0,
dt_bias.stride(0),
dt_bias.stride(1),
A.stride(0),
A.stride(1),
A.stride(2),
......@@ -516,7 +515,8 @@ def selective_state_update(
C.stride(0),
C.stride(1),
C.stride(2),
*(D.stride(0), D.stride(1)) if D is not None else 0,
D.stride(0),
D.stride(1),
z_strides[0],
z_strides[1],
z_strides[2],
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Dispatch module for Mamba selective state update (SSU) backends.
Provides a unified `selective_state_update` function that dispatches to
either the Triton or FlashInfer backend based on the configured
`MambaBackendEnum`. Follows SGLang's dispatch pattern adapted for vLLM.
"""
from abc import ABC, abstractmethod
import torch
from vllm.config.mamba import MambaBackendEnum, MambaConfig
from vllm.logger import init_logger
from vllm.v1.attention.backends.utils import NULL_BLOCK_ID
logger = init_logger(__name__)
class MambaSSUBackend(ABC):
"""Abstract base class for Mamba SSU backends."""
def __init__(self, mamba_config: MambaConfig):
self._mamba_config = mamba_config
@property
@abstractmethod
def name(self) -> str: ...
@abstractmethod
def __call__(
self,
state: torch.Tensor,
x: torch.Tensor,
dt: torch.Tensor,
A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
D: torch.Tensor,
dt_bias: torch.Tensor,
z: torch.Tensor | None = None,
dt_softplus: bool = False,
state_batch_indices: torch.Tensor | None = None,
dst_state_batch_indices: torch.Tensor | None = None,
null_block_id: int = NULL_BLOCK_ID,
out: torch.Tensor | None = None,
num_accepted_tokens: torch.Tensor | None = None,
cu_seqlens: torch.Tensor | None = None,
is_blackwell: bool = False,
) -> None: ...
class TritonSSUBackend(MambaSSUBackend):
"""Triton-based SSU backend (vLLM's default)."""
def __init__(self, mamba_config: MambaConfig):
super().__init__(mamba_config)
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_state_update as _triton_selective_state_update,
)
self._kernel = _triton_selective_state_update
@property
def name(self) -> str:
return "triton"
def __call__(
self,
state: torch.Tensor,
x: torch.Tensor,
dt: torch.Tensor,
A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
D: torch.Tensor,
dt_bias: torch.Tensor,
z: torch.Tensor | None = None,
dt_softplus: bool = False,
state_batch_indices: torch.Tensor | None = None,
dst_state_batch_indices: torch.Tensor | None = None,
null_block_id: int = NULL_BLOCK_ID,
out: torch.Tensor | None = None,
num_accepted_tokens: torch.Tensor | None = None,
cu_seqlens: torch.Tensor | None = None,
is_blackwell: bool = False,
) -> None:
self._kernel(
state,
x,
dt,
A,
B,
C,
D=D,
z=z,
dt_bias=dt_bias,
dt_softplus=dt_softplus,
state_batch_indices=state_batch_indices,
dst_state_batch_indices=dst_state_batch_indices,
null_block_id=null_block_id,
out=out,
num_accepted_tokens=num_accepted_tokens,
cu_seqlens=cu_seqlens,
is_blackwell=is_blackwell,
enable_stochastic_rounding=self._mamba_config.enable_stochastic_rounding,
cache_philox_rounds=self._mamba_config.stochastic_rounding_philox_rounds,
)
class FlashInferSSUBackend(MambaSSUBackend):
"""FlashInfer-based SSU backend."""
def __init__(self, mamba_config: MambaConfig):
super().__init__(mamba_config)
try:
from flashinfer.mamba import selective_state_update as _fi_ssu
except ImportError as e:
raise ImportError(
"FlashInfer is required for the flashinfer Mamba SSU backend. "
"Please install flashinfer (>= 0.6.4): "
"pip install flashinfer-python"
) from e
self._kernel = _fi_ssu
@property
def name(self) -> str:
return "flashinfer"
def __call__(
self,
state: torch.Tensor,
x: torch.Tensor,
dt: torch.Tensor,
A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
D: torch.Tensor,
dt_bias: torch.Tensor,
z: torch.Tensor | None = None,
dt_softplus: bool = False,
state_batch_indices: torch.Tensor | None = None,
dst_state_batch_indices: torch.Tensor | None = None,
null_block_id: int = NULL_BLOCK_ID,
out: torch.Tensor | None = None,
num_accepted_tokens: torch.Tensor | None = None,
cu_seqlens: torch.Tensor | None = None,
is_blackwell: bool = False,
) -> None:
rand_seed = (
torch.randint(0, 2**32, (1,), device=state.device)
if self._mamba_config.enable_stochastic_rounding
else None
)
self._kernel(
state,
x,
dt,
A,
B,
C,
D=D,
z=z,
dt_bias=dt_bias,
dt_softplus=dt_softplus,
state_batch_indices=state_batch_indices,
dst_state_batch_indices=dst_state_batch_indices,
cu_seqlens=cu_seqlens,
num_accepted_tokens=num_accepted_tokens,
cache_steps=state_batch_indices.size(-1)
if cu_seqlens is not None and state_batch_indices is not None
else 0,
pad_slot_id=null_block_id,
out=out,
rand_seed=rand_seed,
philox_rounds=self._mamba_config.stochastic_rounding_philox_rounds or 10,
)
_BACKEND_REGISTRY: dict[MambaBackendEnum, type[MambaSSUBackend]] = {
MambaBackendEnum.TRITON: TritonSSUBackend,
MambaBackendEnum.FLASHINFER: FlashInferSSUBackend,
}
_mamba_ssu_backend: MambaSSUBackend | None = None
def initialize_mamba_ssu_backend(mamba_config: MambaConfig) -> None:
"""Initialize the global Mamba SSU backend.
Args:
mamba_config: Mamba configuration.
"""
global _mamba_ssu_backend
backend = mamba_config.backend
if backend not in _BACKEND_REGISTRY:
raise ValueError(
f"Unknown Mamba SSU backend: {backend}. "
f"Valid options: {list(_BACKEND_REGISTRY.keys())}"
)
_mamba_ssu_backend = _BACKEND_REGISTRY[backend](mamba_config)
logger.info("Using %s Mamba SSU backend.", _mamba_ssu_backend.name)
def get_mamba_ssu_backend() -> MambaSSUBackend:
"""Get the current Mamba SSU backend. Raises if not initialized."""
if _mamba_ssu_backend is None:
raise RuntimeError(
"Mamba SSU backend has not been initialized. "
"Call initialize_mamba_ssu_backend() first."
)
return _mamba_ssu_backend
def selective_state_update(
state: torch.Tensor,
x: torch.Tensor,
dt: torch.Tensor,
A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
D: torch.Tensor,
dt_bias: torch.Tensor,
z: torch.Tensor | None = None,
dt_softplus: bool = False,
state_batch_indices: torch.Tensor | None = None,
dst_state_batch_indices: torch.Tensor | None = None,
null_block_id: int = NULL_BLOCK_ID,
out: torch.Tensor | None = None,
num_accepted_tokens: torch.Tensor | None = None,
cu_seqlens: torch.Tensor | None = None,
is_blackwell: bool = False,
) -> None:
"""Unified dispatch for Mamba selective state update.
Delegates to the initialized backend (Triton or FlashInfer).
"""
get_mamba_ssu_backend()(
state,
x,
dt,
A,
B,
C,
D,
dt_bias,
z=z,
dt_softplus=dt_softplus,
state_batch_indices=state_batch_indices,
dst_state_batch_indices=dst_state_batch_indices,
null_block_id=null_block_id,
out=out,
num_accepted_tokens=num_accepted_tokens,
cu_seqlens=cu_seqlens,
is_blackwell=is_blackwell,
)
......@@ -38,10 +38,10 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn,
causal_conv1d_update,
)
from vllm.model_executor.layers.mamba.ops.mamba_ssm import selective_state_update
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
mamba_chunk_scan_combined_varlen,
)
from vllm.model_executor.layers.mamba.ops.ssu_dispatch import selective_state_update
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -447,13 +447,11 @@ class Plamo2MambaMixer(MambaBase, PluggableLayer):
B,
C,
D,
dt_bias,
z=gate_d.reshape(num_decodes, -1, self.head_dim),
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=state_indices_tensor_d,
out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim),
enable_stochastic_rounding=self.cache_config.enable_mamba_cache_stochastic_rounding,
cache_philox_rounds=self.cache_config.mamba_cache_philox_rounds,
)
# 4. Final linear projection
......
......@@ -36,6 +36,9 @@ from vllm.distributed.parallel_state import (
)
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.mamba.ops.ssu_dispatch import (
initialize_mamba_ssu_backend,
)
from vllm.model_executor.model_loader import get_model_loader
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors
......@@ -360,6 +363,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.attn_backends, self.attn_groups, attn_cg_support = init_attn_backend(
self.kv_cache_config, self.vllm_config, self.device
)
initialize_mamba_ssu_backend(self.vllm_config.mamba_config)
cudagraph_mode = self.compilation_config.resolve_cudagraph_mode_and_sizes(
attn_cg_support.min_cg_support,
attn_cg_support.min_cg_attn_backend,
......
......@@ -56,6 +56,9 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.fused_moe.routed_experts_capturer import (
RoutedExpertsCapturer,
)
from vllm.model_executor.layers.mamba.ops.ssu_dispatch import (
initialize_mamba_ssu_backend,
)
from vllm.model_executor.layers.rotary_embedding import (
MRotaryEmbedding,
XDRotaryEmbedding,
......@@ -6750,6 +6753,7 @@ class GPUModelRunner(
self.may_add_encoder_only_layers_to_kv_cache_config()
self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config)
self.initialize_attn_backend(kv_cache_config, is_profiling=is_profiling)
initialize_mamba_ssu_backend(self.vllm_config.mamba_config)
# The kernel block size for all KV cache groups. For example, if
# kv_cache_manager uses block_size 256 for a given group, but the attention
# backends for that group only supports block_size 64, we will return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment