# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest import torch from vllm.config.mamba import MambaBackendEnum, MambaConfig from vllm.model_executor.layers.mamba.ops.ssu_dispatch import ( FlashInferSSUBackend, TritonSSUBackend, get_mamba_ssu_backend, initialize_mamba_ssu_backend, selective_state_update, ) from vllm.utils.torch_utils import set_random_seed from vllm.v1.kv_cache_interface import ( KVCacheConfig, KVCacheGroupSpec, MambaSpec, ) try: import flashinfer.mamba # noqa: F401 HAS_FLASHINFER = True except ImportError: HAS_FLASHINFER = False def _kv_cache_config_with_ssu(mamba_type: str = "mamba2") -> KVCacheConfig: spec = MambaSpec( block_size=16, shapes=((16, 64),), dtypes=(torch.float16,), mamba_type=mamba_type, ) return KVCacheConfig( num_blocks=1, kv_cache_tensors=[], kv_cache_groups=[KVCacheGroupSpec(layer_names=["l0"], kv_cache_spec=spec)], ) def test_default_backend_is_triton(): initialize_mamba_ssu_backend(MambaConfig(), _kv_cache_config_with_ssu()) backend = get_mamba_ssu_backend() assert isinstance(backend, TritonSSUBackend) assert backend.name == "triton" def test_explicit_triton_backend(): initialize_mamba_ssu_backend( MambaConfig(backend=MambaBackendEnum.TRITON), _kv_cache_config_with_ssu() ) backend = get_mamba_ssu_backend() assert isinstance(backend, TritonSSUBackend) @pytest.mark.skipif(not HAS_FLASHINFER, reason="flashinfer not installed") def test_flashinfer_backend_init(): initialize_mamba_ssu_backend( MambaConfig(backend=MambaBackendEnum.FLASHINFER), _kv_cache_config_with_ssu() ) backend = get_mamba_ssu_backend() assert isinstance(backend, FlashInferSSUBackend) assert backend.name == "flashinfer" def test_uninitialized_backend_raises(): import vllm.model_executor.layers.mamba.ops.ssu_dispatch as mod old = mod._mamba_ssu_backend mod._mamba_ssu_backend = None with pytest.raises(RuntimeError, match="not been initialized"): get_mamba_ssu_backend() mod._mamba_ssu_backend = old @pytest.mark.parametrize( "mamba_type", ["linear_attention", "gdn_attention", "short_conv"] ) def test_init_is_noop_for_non_ssu_mamba_type(mamba_type): import vllm.model_executor.layers.mamba.ops.ssu_dispatch as mod old = mod._mamba_ssu_backend mod._mamba_ssu_backend = None try: initialize_mamba_ssu_backend( MambaConfig(), _kv_cache_config_with_ssu(mamba_type) ) assert mod._mamba_ssu_backend is None with pytest.raises(RuntimeError, match="not been initialized"): get_mamba_ssu_backend() finally: mod._mamba_ssu_backend = old @pytest.mark.skipif(HAS_FLASHINFER, reason="flashinfer is installed") def test_flashinfer_import_error(): with pytest.raises(ImportError, match="FlashInfer is required"): FlashInferSSUBackend(MambaConfig()) def test_triton_basic_call(): set_random_seed(0) initialize_mamba_ssu_backend( MambaConfig(backend=MambaBackendEnum.TRITON), _kv_cache_config_with_ssu() ) device = "cuda" batch_size = 2 dim = 64 dstate = 16 state = torch.randn(batch_size, dim, dstate, device=device) x = torch.randn(batch_size, dim, device=device) out = torch.empty_like(x) dt = torch.randn(batch_size, dim, device=device) dt_bias = torch.rand(dim, device=device) - 4.0 A = -torch.rand(dim, dstate, device=device) B = torch.randn(batch_size, dstate, device=device) C = torch.randn(batch_size, dstate, device=device) D = torch.randn(dim, device=device) selective_state_update( state, x, dt, A, B, C, D=D, dt_bias=dt_bias, dt_softplus=True, out=out, ) assert not torch.isnan(out).any()