Unverified Commit b897f00c authored by roikoren755's avatar roikoren755 Committed by GitHub
Browse files

Gate SSU dispatch setup (#40039)


Signed-off-by: default avatarRoi Koren <roik@nvidia.com>
parent adf9bb3c
......@@ -13,6 +13,11 @@ from vllm.model_executor.layers.mamba.ops.ssu_dispatch import (
selective_state_update,
)
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.kv_cache_interface import (
KVCacheConfig,
KVCacheGroupSpec,
MambaSpec,
)
try:
import flashinfer.mamba # noqa: F401
......@@ -22,22 +27,40 @@ except ImportError:
HAS_FLASHINFER = False
def _kv_cache_config_with_ssu(mamba_type: str = "mamba2") -> KVCacheConfig:
spec = MambaSpec(
block_size=16,
shapes=((16, 64),),
dtypes=(torch.float16,),
mamba_type=mamba_type,
)
return KVCacheConfig(
num_blocks=1,
kv_cache_tensors=[],
kv_cache_groups=[KVCacheGroupSpec(layer_names=["l0"], kv_cache_spec=spec)],
)
def test_default_backend_is_triton():
initialize_mamba_ssu_backend(MambaConfig())
initialize_mamba_ssu_backend(MambaConfig(), _kv_cache_config_with_ssu())
backend = get_mamba_ssu_backend()
assert isinstance(backend, TritonSSUBackend)
assert backend.name == "triton"
def test_explicit_triton_backend():
initialize_mamba_ssu_backend(MambaConfig(backend=MambaBackendEnum.TRITON))
initialize_mamba_ssu_backend(
MambaConfig(backend=MambaBackendEnum.TRITON), _kv_cache_config_with_ssu()
)
backend = get_mamba_ssu_backend()
assert isinstance(backend, TritonSSUBackend)
@pytest.mark.skipif(not HAS_FLASHINFER, reason="flashinfer not installed")
def test_flashinfer_backend_init():
initialize_mamba_ssu_backend(MambaConfig(backend=MambaBackendEnum.FLASHINFER))
initialize_mamba_ssu_backend(
MambaConfig(backend=MambaBackendEnum.FLASHINFER), _kv_cache_config_with_ssu()
)
backend = get_mamba_ssu_backend()
assert isinstance(backend, FlashInferSSUBackend)
assert backend.name == "flashinfer"
......@@ -53,6 +76,25 @@ def test_uninitialized_backend_raises():
mod._mamba_ssu_backend = old
@pytest.mark.parametrize(
"mamba_type", ["linear_attention", "gdn_attention", "short_conv"]
)
def test_init_is_noop_for_non_ssu_mamba_type(mamba_type):
import vllm.model_executor.layers.mamba.ops.ssu_dispatch as mod
old = mod._mamba_ssu_backend
mod._mamba_ssu_backend = None
try:
initialize_mamba_ssu_backend(
MambaConfig(), _kv_cache_config_with_ssu(mamba_type)
)
assert mod._mamba_ssu_backend is None
with pytest.raises(RuntimeError, match="not been initialized"):
get_mamba_ssu_backend()
finally:
mod._mamba_ssu_backend = old
@pytest.mark.skipif(HAS_FLASHINFER, reason="flashinfer is installed")
def test_flashinfer_import_error():
with pytest.raises(ImportError, match="FlashInfer is required"):
......@@ -61,7 +103,9 @@ def test_flashinfer_import_error():
def test_triton_basic_call():
set_random_seed(0)
initialize_mamba_ssu_backend(MambaConfig(backend=MambaBackendEnum.TRITON))
initialize_mamba_ssu_backend(
MambaConfig(backend=MambaBackendEnum.TRITON), _kv_cache_config_with_ssu()
)
device = "cuda"
batch_size = 2
dim = 64
......
......@@ -15,6 +15,7 @@ import torch
from vllm.config.mamba import MambaBackendEnum, MambaConfig
from vllm.logger import init_logger
from vllm.v1.attention.backends.utils import NULL_BLOCK_ID
from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec
logger = init_logger(__name__)
......@@ -188,12 +189,22 @@ _BACKEND_REGISTRY: dict[MambaBackendEnum, type[MambaSSUBackend]] = {
_mamba_ssu_backend: MambaSSUBackend | None = None
def initialize_mamba_ssu_backend(mamba_config: MambaConfig) -> None:
def initialize_mamba_ssu_backend(
mamba_config: MambaConfig,
kv_cache_config: KVCacheConfig,
) -> None:
"""Initialize the global Mamba SSU backend.
Args:
mamba_config: Mamba configuration.
No-op if `kv_cache_config` contains no specs that call
selective_state_update.
"""
if not any(
isinstance(g.kv_cache_spec, MambaSpec)
and g.kv_cache_spec.mamba_type in ("mamba1", "mamba2")
for g in kv_cache_config.kv_cache_groups
):
return
global _mamba_ssu_backend
backend = mamba_config.backend
......@@ -203,7 +214,11 @@ def initialize_mamba_ssu_backend(mamba_config: MambaConfig) -> None:
f"Valid options: {list(_BACKEND_REGISTRY.keys())}"
)
_mamba_ssu_backend = _BACKEND_REGISTRY[backend](mamba_config)
backend_cls = _BACKEND_REGISTRY[backend]
if isinstance(_mamba_ssu_backend, backend_cls):
return
_mamba_ssu_backend = backend_cls(mamba_config)
logger.info("Using %s Mamba SSU backend.", _mamba_ssu_backend.name)
......
......@@ -363,7 +363,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.attn_backends, self.attn_groups, attn_cg_support = init_attn_backend(
self.kv_cache_config, self.vllm_config, self.device
)
initialize_mamba_ssu_backend(self.vllm_config.mamba_config)
initialize_mamba_ssu_backend(
self.vllm_config.mamba_config, self.kv_cache_config
)
cudagraph_mode = self.compilation_config.resolve_cudagraph_mode_and_sizes(
attn_cg_support.min_cg_support,
attn_cg_support.min_cg_attn_backend,
......
......@@ -6738,7 +6738,9 @@ class GPUModelRunner(
self.may_add_encoder_only_layers_to_kv_cache_config()
self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config)
self.initialize_attn_backend(kv_cache_config, is_profiling=is_profiling)
initialize_mamba_ssu_backend(self.vllm_config.mamba_config)
initialize_mamba_ssu_backend(
self.vllm_config.mamba_config, self.kv_cache_config
)
# The kernel block size for all KV cache groups. For example, if
# kv_cache_manager uses block_size 256 for a given group, but the attention
# backends for that group only supports block_size 64, we will return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment