Unverified Commit ecf5ff7c authored by roikoren755's avatar roikoren755 Committed by GitHub
Browse files

[Mamba] Flashinfer selective_state_update (#36162)


Signed-off-by: default avatarRoi Koren <roik@nvidia.com>
parent 30679319
...@@ -8,4 +8,5 @@ server_args: >- ...@@ -8,4 +8,5 @@ server_args: >-
--max-model-len 4096 --max-model-len 4096
--tensor-parallel-size 8 --tensor-parallel-size 8
--enable-expert-parallel --enable-expert-parallel
--mamba-backend flashinfer
--speculative-config '{"method":"mtp","num_speculative_tokens":5}' --speculative-config '{"method":"mtp","num_speculative_tokens":5}'
...@@ -8,4 +8,5 @@ server_args: >- ...@@ -8,4 +8,5 @@ server_args: >-
--max-model-len 4096 --max-model-len 4096
--tensor-parallel-size 2 --tensor-parallel-size 2
--enable-expert-parallel --enable-expert-parallel
--mamba-backend flashinfer
--speculative-config '{"method":"mtp","num_speculative_tokens":5}' --speculative-config '{"method":"mtp","num_speculative_tokens":5}'
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from vllm.config.mamba import MambaBackendEnum, MambaConfig
from vllm.model_executor.layers.mamba.ops.ssu_dispatch import (
FlashInferSSUBackend,
TritonSSUBackend,
get_mamba_ssu_backend,
initialize_mamba_ssu_backend,
selective_state_update,
)
from vllm.utils.torch_utils import set_random_seed
try:
import flashinfer.mamba # noqa: F401
HAS_FLASHINFER = True
except ImportError:
HAS_FLASHINFER = False
def test_default_backend_is_triton():
initialize_mamba_ssu_backend(MambaConfig())
backend = get_mamba_ssu_backend()
assert isinstance(backend, TritonSSUBackend)
assert backend.name == "triton"
def test_explicit_triton_backend():
initialize_mamba_ssu_backend(MambaConfig(backend=MambaBackendEnum.TRITON))
backend = get_mamba_ssu_backend()
assert isinstance(backend, TritonSSUBackend)
@pytest.mark.skipif(not HAS_FLASHINFER, reason="flashinfer not installed")
def test_flashinfer_backend_init():
initialize_mamba_ssu_backend(MambaConfig(backend=MambaBackendEnum.FLASHINFER))
backend = get_mamba_ssu_backend()
assert isinstance(backend, FlashInferSSUBackend)
assert backend.name == "flashinfer"
def test_uninitialized_backend_raises():
import vllm.model_executor.layers.mamba.ops.ssu_dispatch as mod
old = mod._mamba_ssu_backend
mod._mamba_ssu_backend = None
with pytest.raises(RuntimeError, match="not been initialized"):
get_mamba_ssu_backend()
mod._mamba_ssu_backend = old
@pytest.mark.skipif(HAS_FLASHINFER, reason="flashinfer is installed")
def test_flashinfer_import_error():
with pytest.raises(ImportError, match="FlashInfer is required"):
FlashInferSSUBackend(MambaConfig())
def test_triton_basic_call():
set_random_seed(0)
initialize_mamba_ssu_backend(MambaConfig(backend=MambaBackendEnum.TRITON))
device = "cuda"
batch_size = 2
dim = 64
dstate = 16
state = torch.randn(batch_size, dim, dstate, device=device)
x = torch.randn(batch_size, dim, device=device)
out = torch.empty_like(x)
dt = torch.randn(batch_size, dim, device=device)
dt_bias = torch.rand(dim, device=device) - 4.0
A = -torch.rand(dim, dstate, device=device)
B = torch.randn(batch_size, dstate, device=device)
C = torch.randn(batch_size, dstate, device=device)
D = torch.randn(dim, device=device)
selective_state_update(
state,
x,
dt,
A,
B,
C,
D=D,
dt_bias=dt_bias,
dt_softplus=True,
out=out,
)
assert not torch.isnan(out).any()
...@@ -16,6 +16,7 @@ from vllm.config.kv_events import KVEventsConfig ...@@ -16,6 +16,7 @@ from vllm.config.kv_events import KVEventsConfig
from vllm.config.kv_transfer import KVTransferConfig from vllm.config.kv_transfer import KVTransferConfig
from vllm.config.load import LoadConfig from vllm.config.load import LoadConfig
from vllm.config.lora import LoRAConfig from vllm.config.lora import LoRAConfig
from vllm.config.mamba import MambaConfig
from vllm.config.model import ( from vllm.config.model import (
ModelConfig, ModelConfig,
iter_architecture_defaults, iter_architecture_defaults,
...@@ -83,6 +84,8 @@ __all__ = [ ...@@ -83,6 +84,8 @@ __all__ = [
"LoadConfig", "LoadConfig",
# From vllm.config.lora # From vllm.config.lora
"LoRAConfig", "LoRAConfig",
# From vllm.config.mamba
"MambaConfig",
# From vllm.config.model # From vllm.config.model
"ModelConfig", "ModelConfig",
"iter_architecture_defaults", "iter_architecture_defaults",
......
...@@ -123,14 +123,6 @@ class CacheConfig: ...@@ -123,14 +123,6 @@ class CacheConfig:
- "align": only cache the mamba state of the last token of each scheduler step and - "align": only cache the mamba state of the last token of each scheduler step and
when the token is at position i * block_size. when the token is at position i * block_size.
""" """
enable_mamba_cache_stochastic_rounding: bool = False
"""Enable stochastic rounding when writing SSM state to fp16 cache.
Uses random bits to unbias the rounding error, which can improve
numerical stability for long sequences."""
mamba_cache_philox_rounds: int = 0
"""Number of Philox PRNG rounds for stochastic rounding random number
generation. 0 uses the Triton default. Higher values improve randomness
quality at the cost of compute."""
# Will be set after profiling. # Will be set after profiling.
num_gpu_blocks: int | None = field(default=None, init=False) num_gpu_blocks: int | None = field(default=None, init=False)
...@@ -258,29 +250,3 @@ class CacheConfig: ...@@ -258,29 +250,3 @@ class CacheConfig:
str(cache_dtype), str(cache_dtype),
) )
return cache_dtype return cache_dtype
def __post_init__(self):
if self.enable_mamba_cache_stochastic_rounding:
from vllm.platforms import current_platform
if not current_platform.is_cuda():
raise ValueError(
"Stochastic rounding for Mamba cache is only supported "
"on NVIDIA CUDA platforms. Please do not specify "
"`--enable-mamba-cache-stochastic-rounding`."
)
if not current_platform.is_device_capability_family(100):
raise ValueError(
"Stochastic rounding for Mamba cache requires compute "
"capability 10.0 (data center Blackwell). The `cvt.rs` PTX "
"instruction is not supported on your GPU. Please do not specify "
"`--enable-mamba-cache-stochastic-rounding`."
)
if self.mamba_ssm_cache_dtype != "float16":
raise ValueError(
"Stochastic rounding for Mamba cache requires "
"the SSM cache to be float16. Please set it explicitly, "
"by specifying `--mamba-ssm-cache-dtype float16`, or disable "
"stochastic rounding by not specifying "
"`--enable-mamba-cache-stochastic-rounding`."
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import Enum, EnumMeta
from typing import Any
from pydantic import field_validator
from vllm.config.utils import config
class _MambaBackendEnumMeta(EnumMeta):
"""Metaclass for MambaBackendEnum to provide better error messages."""
def __getitem__(cls, name: str):
try:
return super().__getitem__(name)
except KeyError:
valid = ", ".join(cls.__members__.keys())
raise ValueError(
f"Unknown Mamba SSU backend: '{name}'. Valid options are: {valid}"
) from None
class MambaBackendEnum(Enum, metaclass=_MambaBackendEnumMeta):
"""Enumeration of supported Mamba SSU (selective state update) backends."""
TRITON = "triton"
FLASHINFER = "flashinfer"
@config
class MambaConfig:
"""Configuration for Mamba SSM backends."""
backend: MambaBackendEnum = MambaBackendEnum.TRITON
"""Mamba SSU backend to use."""
enable_stochastic_rounding: bool = False
"""Enable stochastic rounding when writing SSM state to fp16 cache.
Uses random bits to unbias the rounding error, which can improve
numerical stability for long sequences."""
stochastic_rounding_philox_rounds: int = 0
"""Number of Philox PRNG rounds for stochastic rounding random number
generation. 0 uses the Triton default. Higher values improve randomness
quality at the cost of compute."""
@field_validator("backend", mode="before")
@classmethod
def validate_backend_before(cls, value: Any) -> Any:
"""Enable parsing of the `backend` enum type from string."""
if isinstance(value, str):
return MambaBackendEnum[value.upper()]
return value
def __post_init__(self):
if self.enable_stochastic_rounding:
from vllm.platforms import current_platform
if not current_platform.is_cuda():
raise ValueError(
"Stochastic rounding for Mamba cache is only supported "
"on NVIDIA CUDA platforms. Please do not specify "
"`--enable-mamba-cache-stochastic-rounding`."
)
if (
self.backend == MambaBackendEnum.TRITON
and not current_platform.is_device_capability_family(100)
):
raise ValueError(
"Stochastic rounding for Mamba cache with triton backend requires "
"compute capability 10.0 (data center Blackwell). The `cvt.rs` "
"PTX instruction is not supported on your GPU. Please do not "
"specify `--enable-mamba-cache-stochastic-rounding`, "
"or set `--mamba-backend flashinfer`."
)
...@@ -37,6 +37,7 @@ from .kv_events import KVEventsConfig ...@@ -37,6 +37,7 @@ from .kv_events import KVEventsConfig
from .kv_transfer import KVTransferConfig from .kv_transfer import KVTransferConfig
from .load import LoadConfig from .load import LoadConfig
from .lora import LoRAConfig from .lora import LoRAConfig
from .mamba import MambaConfig
from .model import ModelConfig from .model import ModelConfig
from .observability import ObservabilityConfig from .observability import ObservabilityConfig
from .offload import OffloadConfig from .offload import OffloadConfig
...@@ -275,6 +276,8 @@ class VllmConfig: ...@@ -275,6 +276,8 @@ class VllmConfig:
"""Model weight offloading configuration.""" """Model weight offloading configuration."""
attention_config: AttentionConfig = Field(default_factory=AttentionConfig) attention_config: AttentionConfig = Field(default_factory=AttentionConfig)
"""Attention configuration.""" """Attention configuration."""
mamba_config: MambaConfig = Field(default_factory=MambaConfig)
"""Mamba configuration."""
kernel_config: KernelConfig = Field(default_factory=KernelConfig) kernel_config: KernelConfig = Field(default_factory=KernelConfig)
"""Kernel configuration.""" """Kernel configuration."""
lora_config: LoRAConfig | None = None lora_config: LoRAConfig | None = None
...@@ -717,6 +720,18 @@ class VllmConfig: ...@@ -717,6 +720,18 @@ class VllmConfig:
if self.lora_config is not None: if self.lora_config is not None:
self.lora_config.verify_with_model_config(self.model_config) self.lora_config.verify_with_model_config(self.model_config)
if (
self.mamba_config.enable_stochastic_rounding
and self.cache_config.mamba_ssm_cache_dtype != "float16"
):
raise ValueError(
"Stochastic rounding for Mamba cache requires "
"the SSM cache to be float16. Please set it explicitly, "
"by specifying `--mamba-ssm-cache-dtype float16`, or disable "
"stochastic rounding by not specifying "
"`--enable-mamba-cache-stochastic-rounding`."
)
if self.quant_config is None and self.model_config is not None: if self.quant_config is None and self.model_config is not None:
self.quant_config = VllmConfig._get_quantization_config( self.quant_config = VllmConfig._get_quantization_config(
self.model_config, self.load_config self.model_config, self.load_config
......
...@@ -45,6 +45,7 @@ from vllm.config import ( ...@@ -45,6 +45,7 @@ from vllm.config import (
KVTransferConfig, KVTransferConfig,
LoadConfig, LoadConfig,
LoRAConfig, LoRAConfig,
MambaConfig,
ModelConfig, ModelConfig,
MultiModalConfig, MultiModalConfig,
ObservabilityConfig, ObservabilityConfig,
...@@ -72,6 +73,7 @@ from vllm.config.cache import ( ...@@ -72,6 +73,7 @@ from vllm.config.cache import (
from vllm.config.device import Device from vllm.config.device import Device
from vllm.config.kernel import IrOpPriorityConfig, MoEBackend from vllm.config.kernel import IrOpPriorityConfig, MoEBackend
from vllm.config.lora import MaxLoRARanks from vllm.config.lora import MaxLoRARanks
from vllm.config.mamba import MambaBackendEnum
from vllm.config.model import ( from vllm.config.model import (
ConvertOption, ConvertOption,
HfOverrides, HfOverrides,
...@@ -578,6 +580,7 @@ class EngineArgs: ...@@ -578,6 +580,7 @@ class EngineArgs:
pooler_config: PoolerConfig | None = ModelConfig.pooler_config pooler_config: PoolerConfig | None = ModelConfig.pooler_config
compilation_config: CompilationConfig = get_field(VllmConfig, "compilation_config") compilation_config: CompilationConfig = get_field(VllmConfig, "compilation_config")
attention_config: AttentionConfig = get_field(VllmConfig, "attention_config") attention_config: AttentionConfig = get_field(VllmConfig, "attention_config")
mamba_config: MambaConfig = get_field(VllmConfig, "mamba_config")
kernel_config: KernelConfig = get_field(VllmConfig, "kernel_config") kernel_config: KernelConfig = get_field(VllmConfig, "kernel_config")
enable_flashinfer_autotune: bool = get_field( enable_flashinfer_autotune: bool = get_field(
KernelConfig, "enable_flashinfer_autotune" KernelConfig, "enable_flashinfer_autotune"
...@@ -610,10 +613,12 @@ class EngineArgs: ...@@ -610,10 +613,12 @@ class EngineArgs:
mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype
mamba_block_size: int | None = get_field(CacheConfig, "mamba_block_size") mamba_block_size: int | None = get_field(CacheConfig, "mamba_block_size")
mamba_cache_mode: MambaCacheMode = CacheConfig.mamba_cache_mode mamba_cache_mode: MambaCacheMode = CacheConfig.mamba_cache_mode
mamba_backend: MambaBackendEnum = MambaBackendEnum.TRITON
enable_mamba_cache_stochastic_rounding: bool = ( enable_mamba_cache_stochastic_rounding: bool = (
CacheConfig.enable_mamba_cache_stochastic_rounding MambaConfig.enable_stochastic_rounding
) )
mamba_cache_philox_rounds: int = CacheConfig.mamba_cache_philox_rounds mamba_cache_philox_rounds: int = MambaConfig.stochastic_rounding_philox_rounds
additional_config: dict[str, Any] = get_field(VllmConfig, "additional_config") additional_config: dict[str, Any] = get_field(VllmConfig, "additional_config")
...@@ -655,6 +660,8 @@ class EngineArgs: ...@@ -655,6 +660,8 @@ class EngineArgs:
self.compilation_config = CompilationConfig(**self.compilation_config) self.compilation_config = CompilationConfig(**self.compilation_config)
if isinstance(self.attention_config, dict): if isinstance(self.attention_config, dict):
self.attention_config = AttentionConfig(**self.attention_config) self.attention_config = AttentionConfig(**self.attention_config)
if isinstance(self.mamba_config, dict):
self.mamba_config = MambaConfig(**self.mamba_config)
if isinstance(self.kernel_config, dict): if isinstance(self.kernel_config, dict):
self.kernel_config = KernelConfig(**self.kernel_config) self.kernel_config = KernelConfig(**self.kernel_config)
if isinstance(self.eplb_config, dict): if isinstance(self.eplb_config, dict):
...@@ -825,6 +832,22 @@ class EngineArgs: ...@@ -825,6 +832,22 @@ class EngineArgs:
"--attention-backend", **attention_kwargs["backend"] "--attention-backend", **attention_kwargs["backend"]
) )
# Mamba arguments
mamba_kwargs = get_kwargs(MambaConfig)
mamba_group = parser.add_argument_group(
title="MambaConfig",
description=MambaConfig.__doc__,
)
mamba_group.add_argument("--mamba-backend", **mamba_kwargs["backend"])
mamba_group.add_argument(
"--enable-mamba-cache-stochastic-rounding",
**mamba_kwargs["enable_stochastic_rounding"],
)
mamba_group.add_argument(
"--mamba-cache-philox-rounds",
**mamba_kwargs["stochastic_rounding_philox_rounds"],
)
# Structured outputs arguments # Structured outputs arguments
structured_outputs_kwargs = get_kwargs(StructuredOutputsConfig) structured_outputs_kwargs = get_kwargs(StructuredOutputsConfig)
structured_outputs_group = parser.add_argument_group( structured_outputs_group = parser.add_argument_group(
...@@ -1050,13 +1073,6 @@ class EngineArgs: ...@@ -1050,13 +1073,6 @@ class EngineArgs:
cache_group.add_argument( cache_group.add_argument(
"--mamba-cache-mode", **cache_kwargs["mamba_cache_mode"] "--mamba-cache-mode", **cache_kwargs["mamba_cache_mode"]
) )
cache_group.add_argument(
"--enable-mamba-cache-stochastic-rounding",
**cache_kwargs["enable_mamba_cache_stochastic_rounding"],
)
cache_group.add_argument(
"--mamba-cache-philox-rounds", **cache_kwargs["mamba_cache_philox_rounds"]
)
cache_group.add_argument( cache_group.add_argument(
"--kv-offloading-size", **cache_kwargs["kv_offloading_size"] "--kv-offloading-size", **cache_kwargs["kv_offloading_size"]
) )
...@@ -1622,8 +1638,6 @@ class EngineArgs: ...@@ -1622,8 +1638,6 @@ class EngineArgs:
mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype, mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
mamba_block_size=self.mamba_block_size, mamba_block_size=self.mamba_block_size,
mamba_cache_mode=self.mamba_cache_mode, mamba_cache_mode=self.mamba_cache_mode,
enable_mamba_cache_stochastic_rounding=self.enable_mamba_cache_stochastic_rounding,
mamba_cache_philox_rounds=self.mamba_cache_philox_rounds,
kv_offloading_size=self.kv_offloading_size, kv_offloading_size=self.kv_offloading_size,
kv_offloading_backend=self.kv_offloading_backend, kv_offloading_backend=self.kv_offloading_backend,
) )
...@@ -1934,6 +1948,22 @@ class EngineArgs: ...@@ -1934,6 +1948,22 @@ class EngineArgs:
self.attention_backend self.attention_backend
) )
# Mamba config overrides
mamba_config = copy.deepcopy(self.mamba_config)
# Convert string to enum if needed (CLI parsing returns a string)
if isinstance(self.mamba_backend, str):
mamba_config.backend = MambaBackendEnum[self.mamba_backend.upper()]
else:
mamba_config.backend = self.mamba_backend
if self.enable_mamba_cache_stochastic_rounding:
mamba_config.enable_stochastic_rounding = (
self.enable_mamba_cache_stochastic_rounding
)
if self.mamba_cache_philox_rounds:
mamba_config.stochastic_rounding_philox_rounds = (
self.mamba_cache_philox_rounds
)
# Kernel config overrides # Kernel config overrides
kernel_config = copy.deepcopy(self.kernel_config) kernel_config = copy.deepcopy(self.kernel_config)
if self.enable_flashinfer_autotune is not None: if self.enable_flashinfer_autotune is not None:
...@@ -2032,6 +2062,7 @@ class EngineArgs: ...@@ -2032,6 +2062,7 @@ class EngineArgs:
load_config=load_config, load_config=load_config,
offload_config=offload_config, offload_config=offload_config,
attention_config=attention_config, attention_config=attention_config,
mamba_config=mamba_config,
kernel_config=kernel_config, kernel_config=kernel_config,
lora_config=lora_config, lora_config=lora_config,
speculative_config=speculative_config, speculative_config=speculative_config,
......
...@@ -30,10 +30,8 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( ...@@ -30,10 +30,8 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_fn,
causal_conv1d_update, causal_conv1d_update,
) )
from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( from vllm.model_executor.layers.mamba.ops.mamba_ssm import selective_scan_fn
selective_scan_fn, from vllm.model_executor.layers.mamba.ops.ssu_dispatch import selective_state_update
selective_state_update,
)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import ( from vllm.utils.torch_utils import (
...@@ -431,14 +429,12 @@ class MambaMixer(MambaBase, PluggableLayer): ...@@ -431,14 +429,12 @@ class MambaMixer(MambaBase, PluggableLayer):
B_d, B_d,
C_d, C_d,
self.D, self.D,
gate_d.transpose(0, 1),
time_proj_bias, time_proj_bias,
z=gate_d.transpose(0, 1),
dt_softplus=True, dt_softplus=True,
state_batch_indices=state_indices_tensor_d_input, state_batch_indices=state_indices_tensor_d_input,
dst_state_batch_indices=state_indices_tensor_d_output, dst_state_batch_indices=state_indices_tensor_d_output,
out=scan_outputs_d, out=scan_outputs_d,
enable_stochastic_rounding=self.cache_config.enable_mamba_cache_stochastic_rounding,
cache_philox_rounds=self.cache_config.mamba_cache_philox_rounds,
) )
scan_outputs_d = scan_outputs_d.transpose(0, 1) scan_outputs_d = scan_outputs_d.transpose(0, 1)
......
...@@ -31,10 +31,10 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( ...@@ -31,10 +31,10 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_update, causal_conv1d_update,
) )
from vllm.model_executor.layers.mamba.ops.layernorm_gated import rms_norm_gated from vllm.model_executor.layers.mamba.ops.layernorm_gated import rms_norm_gated
from vllm.model_executor.layers.mamba.ops.mamba_ssm import selective_state_update
from vllm.model_executor.layers.mamba.ops.ssd_combined import ( from vllm.model_executor.layers.mamba.ops.ssd_combined import (
mamba_chunk_scan_combined_varlen, mamba_chunk_scan_combined_varlen,
) )
from vllm.model_executor.layers.mamba.ops.ssu_dispatch import selective_state_update
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
LoaderFunction, LoaderFunction,
...@@ -890,8 +890,7 @@ class MambaMixer2(MambaBase, PluggableLayer): ...@@ -890,8 +890,7 @@ class MambaMixer2(MambaBase, PluggableLayer):
B_d, B_d,
C_d, C_d,
D_d, D_d,
z=None, dt_bias,
dt_bias=dt_bias,
dt_softplus=True, dt_softplus=True,
state_batch_indices=state_indices_tensor_d_input, state_batch_indices=state_indices_tensor_d_input,
dst_state_batch_indices=state_indices_tensor_d_output, dst_state_batch_indices=state_indices_tensor_d_output,
...@@ -899,8 +898,6 @@ class MambaMixer2(MambaBase, PluggableLayer): ...@@ -899,8 +898,6 @@ class MambaMixer2(MambaBase, PluggableLayer):
num_accepted_tokens=num_accepted_tokens, num_accepted_tokens=num_accepted_tokens,
cu_seqlens=query_start_loc_d, cu_seqlens=query_start_loc_d,
is_blackwell=self.is_blackwell, is_blackwell=self.is_blackwell,
enable_stochastic_rounding=self.cache_config.enable_mamba_cache_stochastic_rounding,
cache_philox_rounds=self.cache_config.mamba_cache_philox_rounds,
) )
def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]: def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:
......
...@@ -323,9 +323,9 @@ def selective_state_update( ...@@ -323,9 +323,9 @@ def selective_state_update(
A, A,
B, B,
C, C,
D=None, D,
dt_bias,
z=None, z=None,
dt_bias=None,
dt_softplus=False, dt_softplus=False,
state_batch_indices=None, state_batch_indices=None,
dst_state_batch_indices=None, dst_state_batch_indices=None,
...@@ -374,11 +374,11 @@ def selective_state_update( ...@@ -374,11 +374,11 @@ def selective_state_update(
B = B.unsqueeze(1) B = B.unsqueeze(1)
if C.dim() == 2: if C.dim() == 2:
C = C.unsqueeze(1) C = C.unsqueeze(1)
if D is not None and D.dim() == 1: if D.dim() == 1:
D = D.unsqueeze(0) D = D.unsqueeze(0)
if z is not None and z.dim() == 2: if z is not None and z.dim() == 2:
z = z.unsqueeze(1) z = z.unsqueeze(1)
if dt_bias is not None and dt_bias.dim() == 1: if dt_bias.dim() == 1:
dt_bias = dt_bias.unsqueeze(0) dt_bias = dt_bias.unsqueeze(0)
if out.dim() == 2: if out.dim() == 2:
out = out.unsqueeze(1) out = out.unsqueeze(1)
...@@ -410,11 +410,9 @@ def selective_state_update( ...@@ -410,11 +410,9 @@ def selective_state_update(
assert nheads % ngroups == 0, "nheads must be divisible by ngroups" assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
assert B.shape == (batch, ngroups, dstate) assert B.shape == (batch, ngroups, dstate)
assert C.shape == B.shape assert C.shape == B.shape
if D is not None:
assert D.shape == (nheads, dim) assert D.shape == (nheads, dim)
if z is not None: if z is not None:
assert z.shape == x.shape assert z.shape == x.shape
if dt_bias is not None:
assert dt_bias.shape == (nheads, dim) assert dt_bias.shape == (nheads, dim)
if state_batch_indices is not None: if state_batch_indices is not None:
assert state_batch_indices.shape[0] >= N assert state_batch_indices.shape[0] >= N
...@@ -506,7 +504,8 @@ def selective_state_update( ...@@ -506,7 +504,8 @@ def selective_state_update(
dt.stride(0), dt.stride(0),
dt.stride(1), dt.stride(1),
dt.stride(2), dt.stride(2),
*(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0, dt_bias.stride(0),
dt_bias.stride(1),
A.stride(0), A.stride(0),
A.stride(1), A.stride(1),
A.stride(2), A.stride(2),
...@@ -516,7 +515,8 @@ def selective_state_update( ...@@ -516,7 +515,8 @@ def selective_state_update(
C.stride(0), C.stride(0),
C.stride(1), C.stride(1),
C.stride(2), C.stride(2),
*(D.stride(0), D.stride(1)) if D is not None else 0, D.stride(0),
D.stride(1),
z_strides[0], z_strides[0],
z_strides[1], z_strides[1],
z_strides[2], z_strides[2],
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Dispatch module for Mamba selective state update (SSU) backends.
Provides a unified `selective_state_update` function that dispatches to
either the Triton or FlashInfer backend based on the configured
`MambaBackendEnum`. Follows SGLang's dispatch pattern adapted for vLLM.
"""
from abc import ABC, abstractmethod
import torch
from vllm.config.mamba import MambaBackendEnum, MambaConfig
from vllm.logger import init_logger
from vllm.v1.attention.backends.utils import NULL_BLOCK_ID
logger = init_logger(__name__)
class MambaSSUBackend(ABC):
"""Abstract base class for Mamba SSU backends."""
def __init__(self, mamba_config: MambaConfig):
self._mamba_config = mamba_config
@property
@abstractmethod
def name(self) -> str: ...
@abstractmethod
def __call__(
self,
state: torch.Tensor,
x: torch.Tensor,
dt: torch.Tensor,
A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
D: torch.Tensor,
dt_bias: torch.Tensor,
z: torch.Tensor | None = None,
dt_softplus: bool = False,
state_batch_indices: torch.Tensor | None = None,
dst_state_batch_indices: torch.Tensor | None = None,
null_block_id: int = NULL_BLOCK_ID,
out: torch.Tensor | None = None,
num_accepted_tokens: torch.Tensor | None = None,
cu_seqlens: torch.Tensor | None = None,
is_blackwell: bool = False,
) -> None: ...
class TritonSSUBackend(MambaSSUBackend):
"""Triton-based SSU backend (vLLM's default)."""
def __init__(self, mamba_config: MambaConfig):
super().__init__(mamba_config)
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_state_update as _triton_selective_state_update,
)
self._kernel = _triton_selective_state_update
@property
def name(self) -> str:
return "triton"
def __call__(
self,
state: torch.Tensor,
x: torch.Tensor,
dt: torch.Tensor,
A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
D: torch.Tensor,
dt_bias: torch.Tensor,
z: torch.Tensor | None = None,
dt_softplus: bool = False,
state_batch_indices: torch.Tensor | None = None,
dst_state_batch_indices: torch.Tensor | None = None,
null_block_id: int = NULL_BLOCK_ID,
out: torch.Tensor | None = None,
num_accepted_tokens: torch.Tensor | None = None,
cu_seqlens: torch.Tensor | None = None,
is_blackwell: bool = False,
) -> None:
self._kernel(
state,
x,
dt,
A,
B,
C,
D=D,
z=z,
dt_bias=dt_bias,
dt_softplus=dt_softplus,
state_batch_indices=state_batch_indices,
dst_state_batch_indices=dst_state_batch_indices,
null_block_id=null_block_id,
out=out,
num_accepted_tokens=num_accepted_tokens,
cu_seqlens=cu_seqlens,
is_blackwell=is_blackwell,
enable_stochastic_rounding=self._mamba_config.enable_stochastic_rounding,
cache_philox_rounds=self._mamba_config.stochastic_rounding_philox_rounds,
)
class FlashInferSSUBackend(MambaSSUBackend):
"""FlashInfer-based SSU backend."""
def __init__(self, mamba_config: MambaConfig):
super().__init__(mamba_config)
try:
from flashinfer.mamba import selective_state_update as _fi_ssu
except ImportError as e:
raise ImportError(
"FlashInfer is required for the flashinfer Mamba SSU backend. "
"Please install flashinfer (>= 0.6.4): "
"pip install flashinfer-python"
) from e
self._kernel = _fi_ssu
@property
def name(self) -> str:
return "flashinfer"
def __call__(
self,
state: torch.Tensor,
x: torch.Tensor,
dt: torch.Tensor,
A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
D: torch.Tensor,
dt_bias: torch.Tensor,
z: torch.Tensor | None = None,
dt_softplus: bool = False,
state_batch_indices: torch.Tensor | None = None,
dst_state_batch_indices: torch.Tensor | None = None,
null_block_id: int = NULL_BLOCK_ID,
out: torch.Tensor | None = None,
num_accepted_tokens: torch.Tensor | None = None,
cu_seqlens: torch.Tensor | None = None,
is_blackwell: bool = False,
) -> None:
rand_seed = (
torch.randint(0, 2**32, (1,), device=state.device)
if self._mamba_config.enable_stochastic_rounding
else None
)
self._kernel(
state,
x,
dt,
A,
B,
C,
D=D,
z=z,
dt_bias=dt_bias,
dt_softplus=dt_softplus,
state_batch_indices=state_batch_indices,
dst_state_batch_indices=dst_state_batch_indices,
cu_seqlens=cu_seqlens,
num_accepted_tokens=num_accepted_tokens,
cache_steps=state_batch_indices.size(-1)
if cu_seqlens is not None and state_batch_indices is not None
else 0,
pad_slot_id=null_block_id,
out=out,
rand_seed=rand_seed,
philox_rounds=self._mamba_config.stochastic_rounding_philox_rounds or 10,
)
_BACKEND_REGISTRY: dict[MambaBackendEnum, type[MambaSSUBackend]] = {
MambaBackendEnum.TRITON: TritonSSUBackend,
MambaBackendEnum.FLASHINFER: FlashInferSSUBackend,
}
_mamba_ssu_backend: MambaSSUBackend | None = None
def initialize_mamba_ssu_backend(mamba_config: MambaConfig) -> None:
"""Initialize the global Mamba SSU backend.
Args:
mamba_config: Mamba configuration.
"""
global _mamba_ssu_backend
backend = mamba_config.backend
if backend not in _BACKEND_REGISTRY:
raise ValueError(
f"Unknown Mamba SSU backend: {backend}. "
f"Valid options: {list(_BACKEND_REGISTRY.keys())}"
)
_mamba_ssu_backend = _BACKEND_REGISTRY[backend](mamba_config)
logger.info("Using %s Mamba SSU backend.", _mamba_ssu_backend.name)
def get_mamba_ssu_backend() -> MambaSSUBackend:
"""Get the current Mamba SSU backend. Raises if not initialized."""
if _mamba_ssu_backend is None:
raise RuntimeError(
"Mamba SSU backend has not been initialized. "
"Call initialize_mamba_ssu_backend() first."
)
return _mamba_ssu_backend
def selective_state_update(
state: torch.Tensor,
x: torch.Tensor,
dt: torch.Tensor,
A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
D: torch.Tensor,
dt_bias: torch.Tensor,
z: torch.Tensor | None = None,
dt_softplus: bool = False,
state_batch_indices: torch.Tensor | None = None,
dst_state_batch_indices: torch.Tensor | None = None,
null_block_id: int = NULL_BLOCK_ID,
out: torch.Tensor | None = None,
num_accepted_tokens: torch.Tensor | None = None,
cu_seqlens: torch.Tensor | None = None,
is_blackwell: bool = False,
) -> None:
"""Unified dispatch for Mamba selective state update.
Delegates to the initialized backend (Triton or FlashInfer).
"""
get_mamba_ssu_backend()(
state,
x,
dt,
A,
B,
C,
D,
dt_bias,
z=z,
dt_softplus=dt_softplus,
state_batch_indices=state_batch_indices,
dst_state_batch_indices=dst_state_batch_indices,
null_block_id=null_block_id,
out=out,
num_accepted_tokens=num_accepted_tokens,
cu_seqlens=cu_seqlens,
is_blackwell=is_blackwell,
)
...@@ -38,10 +38,10 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( ...@@ -38,10 +38,10 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_fn,
causal_conv1d_update, causal_conv1d_update,
) )
from vllm.model_executor.layers.mamba.ops.mamba_ssm import selective_state_update
from vllm.model_executor.layers.mamba.ops.ssd_combined import ( from vllm.model_executor.layers.mamba.ops.ssd_combined import (
mamba_chunk_scan_combined_varlen, mamba_chunk_scan_combined_varlen,
) )
from vllm.model_executor.layers.mamba.ops.ssu_dispatch import selective_state_update
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -447,13 +447,11 @@ class Plamo2MambaMixer(MambaBase, PluggableLayer): ...@@ -447,13 +447,11 @@ class Plamo2MambaMixer(MambaBase, PluggableLayer):
B, B,
C, C,
D, D,
dt_bias,
z=gate_d.reshape(num_decodes, -1, self.head_dim), z=gate_d.reshape(num_decodes, -1, self.head_dim),
dt_bias=dt_bias,
dt_softplus=True, dt_softplus=True,
state_batch_indices=state_indices_tensor_d, state_batch_indices=state_indices_tensor_d,
out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim), out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim),
enable_stochastic_rounding=self.cache_config.enable_mamba_cache_stochastic_rounding,
cache_philox_rounds=self.cache_config.mamba_cache_philox_rounds,
) )
# 4. Final linear projection # 4. Final linear projection
......
...@@ -36,6 +36,9 @@ from vllm.distributed.parallel_state import ( ...@@ -36,6 +36,9 @@ from vllm.distributed.parallel_state import (
) )
from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.mamba.ops.ssu_dispatch import (
initialize_mamba_ssu_backend,
)
from vllm.model_executor.model_loader import get_model_loader from vllm.model_executor.model_loader import get_model_loader
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -360,6 +363,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -360,6 +363,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.attn_backends, self.attn_groups, attn_cg_support = init_attn_backend( self.attn_backends, self.attn_groups, attn_cg_support = init_attn_backend(
self.kv_cache_config, self.vllm_config, self.device self.kv_cache_config, self.vllm_config, self.device
) )
initialize_mamba_ssu_backend(self.vllm_config.mamba_config)
cudagraph_mode = self.compilation_config.resolve_cudagraph_mode_and_sizes( cudagraph_mode = self.compilation_config.resolve_cudagraph_mode_and_sizes(
attn_cg_support.min_cg_support, attn_cg_support.min_cg_support,
attn_cg_support.min_cg_attn_backend, attn_cg_support.min_cg_attn_backend,
......
...@@ -56,6 +56,9 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase ...@@ -56,6 +56,9 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.fused_moe.routed_experts_capturer import ( from vllm.model_executor.layers.fused_moe.routed_experts_capturer import (
RoutedExpertsCapturer, RoutedExpertsCapturer,
) )
from vllm.model_executor.layers.mamba.ops.ssu_dispatch import (
initialize_mamba_ssu_backend,
)
from vllm.model_executor.layers.rotary_embedding import ( from vllm.model_executor.layers.rotary_embedding import (
MRotaryEmbedding, MRotaryEmbedding,
XDRotaryEmbedding, XDRotaryEmbedding,
...@@ -6750,6 +6753,7 @@ class GPUModelRunner( ...@@ -6750,6 +6753,7 @@ class GPUModelRunner(
self.may_add_encoder_only_layers_to_kv_cache_config() self.may_add_encoder_only_layers_to_kv_cache_config()
self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config) self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config)
self.initialize_attn_backend(kv_cache_config, is_profiling=is_profiling) self.initialize_attn_backend(kv_cache_config, is_profiling=is_profiling)
initialize_mamba_ssu_backend(self.vllm_config.mamba_config)
# The kernel block size for all KV cache groups. For example, if # The kernel block size for all KV cache groups. For example, if
# kv_cache_manager uses block_size 256 for a given group, but the attention # kv_cache_manager uses block_size 256 for a given group, but the attention
# backends for that group only supports block_size 64, we will return # backends for that group only supports block_size 64, we will return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment