Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b897f00c
Unverified
Commit
b897f00c
authored
Apr 16, 2026
by
roikoren755
Committed by
GitHub
Apr 16, 2026
Browse files
Gate SSU dispatch setup (#40039)
Signed-off-by:
Roi Koren
<
roik@nvidia.com
>
parent
adf9bb3c
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
73 additions
and
10 deletions
+73
-10
tests/kernels/mamba/test_ssu_dispatch.py
tests/kernels/mamba/test_ssu_dispatch.py
+48
-4
vllm/model_executor/layers/mamba/ops/ssu_dispatch.py
vllm/model_executor/layers/mamba/ops/ssu_dispatch.py
+19
-4
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+3
-1
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+3
-1
No files found.
tests/kernels/mamba/test_ssu_dispatch.py
View file @
b897f00c
...
...
@@ -13,6 +13,11 @@ from vllm.model_executor.layers.mamba.ops.ssu_dispatch import (
selective_state_update
,
)
from
vllm.utils.torch_utils
import
set_random_seed
from
vllm.v1.kv_cache_interface
import
(
KVCacheConfig
,
KVCacheGroupSpec
,
MambaSpec
,
)
try
:
import
flashinfer.mamba
# noqa: F401
...
...
@@ -22,22 +27,40 @@ except ImportError:
HAS_FLASHINFER
=
False
def
_kv_cache_config_with_ssu
(
mamba_type
:
str
=
"mamba2"
)
->
KVCacheConfig
:
spec
=
MambaSpec
(
block_size
=
16
,
shapes
=
((
16
,
64
),),
dtypes
=
(
torch
.
float16
,),
mamba_type
=
mamba_type
,
)
return
KVCacheConfig
(
num_blocks
=
1
,
kv_cache_tensors
=
[],
kv_cache_groups
=
[
KVCacheGroupSpec
(
layer_names
=
[
"l0"
],
kv_cache_spec
=
spec
)],
)
def
test_default_backend_is_triton
():
initialize_mamba_ssu_backend
(
MambaConfig
())
initialize_mamba_ssu_backend
(
MambaConfig
()
,
_kv_cache_config_with_ssu
()
)
backend
=
get_mamba_ssu_backend
()
assert
isinstance
(
backend
,
TritonSSUBackend
)
assert
backend
.
name
==
"triton"
def
test_explicit_triton_backend
():
initialize_mamba_ssu_backend
(
MambaConfig
(
backend
=
MambaBackendEnum
.
TRITON
))
initialize_mamba_ssu_backend
(
MambaConfig
(
backend
=
MambaBackendEnum
.
TRITON
),
_kv_cache_config_with_ssu
()
)
backend
=
get_mamba_ssu_backend
()
assert
isinstance
(
backend
,
TritonSSUBackend
)
@
pytest
.
mark
.
skipif
(
not
HAS_FLASHINFER
,
reason
=
"flashinfer not installed"
)
def
test_flashinfer_backend_init
():
initialize_mamba_ssu_backend
(
MambaConfig
(
backend
=
MambaBackendEnum
.
FLASHINFER
))
initialize_mamba_ssu_backend
(
MambaConfig
(
backend
=
MambaBackendEnum
.
FLASHINFER
),
_kv_cache_config_with_ssu
()
)
backend
=
get_mamba_ssu_backend
()
assert
isinstance
(
backend
,
FlashInferSSUBackend
)
assert
backend
.
name
==
"flashinfer"
...
...
@@ -53,6 +76,25 @@ def test_uninitialized_backend_raises():
mod
.
_mamba_ssu_backend
=
old
@
pytest
.
mark
.
parametrize
(
"mamba_type"
,
[
"linear_attention"
,
"gdn_attention"
,
"short_conv"
]
)
def
test_init_is_noop_for_non_ssu_mamba_type
(
mamba_type
):
import
vllm.model_executor.layers.mamba.ops.ssu_dispatch
as
mod
old
=
mod
.
_mamba_ssu_backend
mod
.
_mamba_ssu_backend
=
None
try
:
initialize_mamba_ssu_backend
(
MambaConfig
(),
_kv_cache_config_with_ssu
(
mamba_type
)
)
assert
mod
.
_mamba_ssu_backend
is
None
with
pytest
.
raises
(
RuntimeError
,
match
=
"not been initialized"
):
get_mamba_ssu_backend
()
finally
:
mod
.
_mamba_ssu_backend
=
old
@
pytest
.
mark
.
skipif
(
HAS_FLASHINFER
,
reason
=
"flashinfer is installed"
)
def
test_flashinfer_import_error
():
with
pytest
.
raises
(
ImportError
,
match
=
"FlashInfer is required"
):
...
...
@@ -61,7 +103,9 @@ def test_flashinfer_import_error():
def
test_triton_basic_call
():
set_random_seed
(
0
)
initialize_mamba_ssu_backend
(
MambaConfig
(
backend
=
MambaBackendEnum
.
TRITON
))
initialize_mamba_ssu_backend
(
MambaConfig
(
backend
=
MambaBackendEnum
.
TRITON
),
_kv_cache_config_with_ssu
()
)
device
=
"cuda"
batch_size
=
2
dim
=
64
...
...
vllm/model_executor/layers/mamba/ops/ssu_dispatch.py
View file @
b897f00c
...
...
@@ -15,6 +15,7 @@ import torch
from
vllm.config.mamba
import
MambaBackendEnum
,
MambaConfig
from
vllm.logger
import
init_logger
from
vllm.v1.attention.backends.utils
import
NULL_BLOCK_ID
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
,
MambaSpec
logger
=
init_logger
(
__name__
)
...
...
@@ -188,12 +189,22 @@ _BACKEND_REGISTRY: dict[MambaBackendEnum, type[MambaSSUBackend]] = {
_mamba_ssu_backend
:
MambaSSUBackend
|
None
=
None
def
initialize_mamba_ssu_backend
(
mamba_config
:
MambaConfig
)
->
None
:
def
initialize_mamba_ssu_backend
(
mamba_config
:
MambaConfig
,
kv_cache_config
:
KVCacheConfig
,
)
->
None
:
"""Initialize the global Mamba SSU backend.
Args:
mamba_config: Mamba configuration
.
No-op if `kv_cache_config` contains no specs that call
selective_state_update
.
"""
if
not
any
(
isinstance
(
g
.
kv_cache_spec
,
MambaSpec
)
and
g
.
kv_cache_spec
.
mamba_type
in
(
"mamba1"
,
"mamba2"
)
for
g
in
kv_cache_config
.
kv_cache_groups
):
return
global
_mamba_ssu_backend
backend
=
mamba_config
.
backend
...
...
@@ -203,7 +214,11 @@ def initialize_mamba_ssu_backend(mamba_config: MambaConfig) -> None:
f
"Valid options:
{
list
(
_BACKEND_REGISTRY
.
keys
())
}
"
)
_mamba_ssu_backend
=
_BACKEND_REGISTRY
[
backend
](
mamba_config
)
backend_cls
=
_BACKEND_REGISTRY
[
backend
]
if
isinstance
(
_mamba_ssu_backend
,
backend_cls
):
return
_mamba_ssu_backend
=
backend_cls
(
mamba_config
)
logger
.
info
(
"Using %s Mamba SSU backend."
,
_mamba_ssu_backend
.
name
)
...
...
vllm/v1/worker/gpu/model_runner.py
View file @
b897f00c
...
...
@@ -363,7 +363,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
attn_backends
,
self
.
attn_groups
,
attn_cg_support
=
init_attn_backend
(
self
.
kv_cache_config
,
self
.
vllm_config
,
self
.
device
)
initialize_mamba_ssu_backend
(
self
.
vllm_config
.
mamba_config
)
initialize_mamba_ssu_backend
(
self
.
vllm_config
.
mamba_config
,
self
.
kv_cache_config
)
cudagraph_mode
=
self
.
compilation_config
.
resolve_cudagraph_mode_and_sizes
(
attn_cg_support
.
min_cg_support
,
attn_cg_support
.
min_cg_attn_backend
,
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
b897f00c
...
...
@@ -6738,7 +6738,9 @@ class GPUModelRunner(
self
.
may_add_encoder_only_layers_to_kv_cache_config
()
self
.
maybe_add_kv_sharing_layers_to_kv_cache_groups
(
kv_cache_config
)
self
.
initialize_attn_backend
(
kv_cache_config
,
is_profiling
=
is_profiling
)
initialize_mamba_ssu_backend
(
self
.
vllm_config
.
mamba_config
)
initialize_mamba_ssu_backend
(
self
.
vllm_config
.
mamba_config
,
self
.
kv_cache_config
)
# The kernel block size for all KV cache groups. For example, if
# kv_cache_manager uses block_size 256 for a given group, but the attention
# backends for that group only supports block_size 64, we will return
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment