# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from enum import Enum, EnumMeta from typing import Any from pydantic import field_validator from vllm.config.utils import config class _MambaBackendEnumMeta(EnumMeta): """Metaclass for MambaBackendEnum to provide better error messages.""" def __getitem__(cls, name: str): try: return super().__getitem__(name) except KeyError: valid = ", ".join(cls.__members__.keys()) raise ValueError( f"Unknown Mamba SSU backend: '{name}'. Valid options are: {valid}" ) from None class MambaBackendEnum(Enum, metaclass=_MambaBackendEnumMeta): """Enumeration of supported Mamba SSU (selective state update) backends.""" TRITON = "triton" FLASHINFER = "flashinfer" @config class MambaConfig: """Configuration for Mamba SSM backends.""" backend: MambaBackendEnum = MambaBackendEnum.TRITON """Mamba SSU backend to use.""" enable_stochastic_rounding: bool = False """Enable stochastic rounding when writing SSM state to fp16 cache. Uses random bits to unbias the rounding error, which can improve numerical stability for long sequences.""" stochastic_rounding_philox_rounds: int = 0 """Number of Philox PRNG rounds for stochastic rounding random number generation. 0 uses the Triton default. Higher values improve randomness quality at the cost of compute.""" @field_validator("backend", mode="before") @classmethod def validate_backend_before(cls, value: Any) -> Any: """Enable parsing of the `backend` enum type from string.""" if isinstance(value, str): return MambaBackendEnum[value.upper()] return value def __post_init__(self): if self.enable_stochastic_rounding: from vllm.platforms import current_platform if not current_platform.is_cuda(): raise ValueError( "Stochastic rounding for Mamba cache is only supported " "on NVIDIA CUDA platforms. Please do not specify " "`--enable-mamba-cache-stochastic-rounding`." ) if ( self.backend == MambaBackendEnum.TRITON and not current_platform.is_device_capability_family(100) ): raise ValueError( "Stochastic rounding for Mamba cache with triton backend requires " "compute capability 10.0 (data center Blackwell). The `cvt.rs` " "PTX instruction is not supported on your GPU. Please do not " "specify `--enable-mamba-cache-stochastic-rounding`, " "or set `--mamba-backend flashinfer`." )