Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
66e86f1d
Unverified
Commit
66e86f1d
authored
Apr 03, 2026
by
Nicolò Lucchesi
Committed by
GitHub
Apr 03, 2026
Browse files
[Kernel] Mamba support different layout for Conv state (#37416)
parent
bb39382b
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
169 additions
and
39 deletions
+169
-39
tests/models/language/generation/test_hybrid.py
tests/models/language/generation/test_hybrid.py
+23
-0
vllm/envs.py
vllm/envs.py
+8
-0
vllm/model_executor/layers/kda.py
vllm/model_executor/layers/kda.py
+11
-5
vllm/model_executor/layers/mamba/gdn_linear_attn.py
vllm/model_executor/layers/mamba/gdn_linear_attn.py
+15
-2
vllm/model_executor/layers/mamba/mamba_mixer.py
vllm/model_executor/layers/mamba/mamba_mixer.py
+7
-3
vllm/model_executor/layers/mamba/mamba_mixer2.py
vllm/model_executor/layers/mamba/mamba_mixer2.py
+10
-4
vllm/model_executor/layers/mamba/mamba_utils.py
vllm/model_executor/layers/mamba/mamba_utils.py
+73
-17
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
+0
-4
vllm/model_executor/layers/mamba/short_conv.py
vllm/model_executor/layers/mamba/short_conv.py
+6
-2
vllm/model_executor/models/olmo_hybrid.py
vllm/model_executor/models/olmo_hybrid.py
+8
-1
vllm/model_executor/models/plamo2.py
vllm/model_executor/models/plamo2.py
+8
-1
No files found.
tests/models/language/generation/test_hybrid.py
View file @
66e86f1d
...
@@ -60,6 +60,14 @@ MAX_NUM_SEQS = 4
...
@@ -60,6 +60,14 @@ MAX_NUM_SEQS = 4
ATTN_BACKEND
=
"TRITON_ATTN"
if
current_platform
.
is_rocm
()
else
"auto"
ATTN_BACKEND
=
"TRITON_ATTN"
if
current_platform
.
is_rocm
()
else
"auto"
def
_set_conv_state_layout
(
monkeypatch
,
layout
:
str
)
->
None
:
"""Set conv state layout env var and clear cache to pick up new value."""
from
vllm.model_executor.layers.mamba
import
mamba_utils
monkeypatch
.
setenv
(
"VLLM_SSM_CONV_STATE_LAYOUT"
,
layout
)
mamba_utils
.
get_conv_state_layout
.
cache_clear
()
@
pytest
.
mark
.
parametrize
(
"model"
,
SSM_MODELS
+
HYBRID_MODELS
)
@
pytest
.
mark
.
parametrize
(
"model"
,
SSM_MODELS
+
HYBRID_MODELS
)
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
...
@@ -102,12 +110,15 @@ def test_models(
...
@@ -102,12 +110,15 @@ def test_models(
@
pytest
.
mark
.
parametrize
(
"model"
,
[
SSM_MODELS
[
0
],
HYBRID_MODELS
[
0
]])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
SSM_MODELS
[
0
],
HYBRID_MODELS
[
0
]])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"conv_state_layout"
,
[
"SD"
,
"DS"
])
def
test_batching
(
def
test_batching
(
vllm_runner
,
vllm_runner
,
example_prompts
,
example_prompts
,
monkeypatch
,
model
:
str
,
model
:
str
,
max_tokens
:
int
,
max_tokens
:
int
,
num_logprobs
:
int
,
num_logprobs
:
int
,
conv_state_layout
:
str
,
)
->
None
:
)
->
None
:
try
:
try
:
model_info
=
HF_EXAMPLE_MODELS
.
find_hf_info
(
model
)
model_info
=
HF_EXAMPLE_MODELS
.
find_hf_info
(
model
)
...
@@ -116,6 +127,8 @@ def test_batching(
...
@@ -116,6 +127,8 @@ def test_batching(
except
ValueError
:
except
ValueError
:
pass
pass
_set_conv_state_layout
(
monkeypatch
,
conv_state_layout
)
for_loop_outputs
=
[]
for_loop_outputs
=
[]
with
vllm_runner
(
model
,
max_num_seqs
=
MAX_NUM_SEQS
)
as
vllm_model
:
with
vllm_runner
(
model
,
max_num_seqs
=
MAX_NUM_SEQS
)
as
vllm_model
:
for
prompt
in
example_prompts
:
for
prompt
in
example_prompts
:
...
@@ -138,11 +151,14 @@ def test_batching(
...
@@ -138,11 +151,14 @@ def test_batching(
@
pytest
.
mark
.
parametrize
(
"model"
,
[
SSM_MODELS
[
0
],
HYBRID_MODELS
[
0
]])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
SSM_MODELS
[
0
],
HYBRID_MODELS
[
0
]])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
10
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
10
])
@
pytest
.
mark
.
parametrize
(
"conv_state_layout"
,
[
"SD"
,
"DS"
])
def
test_chunked_prefill_with_parallel_sampling
(
def
test_chunked_prefill_with_parallel_sampling
(
vllm_runner
,
vllm_runner
,
example_prompts
,
example_prompts
,
monkeypatch
,
model
:
str
,
model
:
str
,
max_tokens
:
int
,
max_tokens
:
int
,
conv_state_layout
:
str
,
)
->
None
:
)
->
None
:
"""
"""
Tests chunked prefill in conjunction with n > 1.
Tests chunked prefill in conjunction with n > 1.
...
@@ -154,6 +170,8 @@ def test_chunked_prefill_with_parallel_sampling(
...
@@ -154,6 +170,8 @@ def test_chunked_prefill_with_parallel_sampling(
decoding steps inside a chunked prefill forward pass
decoding steps inside a chunked prefill forward pass
(where we have both prefill and decode together)
(where we have both prefill and decode together)
"""
"""
_set_conv_state_layout
(
monkeypatch
,
conv_state_layout
)
sampling_params
=
SamplingParams
(
n
=
3
,
temperature
=
1
,
seed
=
0
,
max_tokens
=
max_tokens
)
sampling_params
=
SamplingParams
(
n
=
3
,
temperature
=
1
,
seed
=
0
,
max_tokens
=
max_tokens
)
with
vllm_runner
(
with
vllm_runner
(
model
,
model
,
...
@@ -168,17 +186,22 @@ def test_chunked_prefill_with_parallel_sampling(
...
@@ -168,17 +186,22 @@ def test_chunked_prefill_with_parallel_sampling(
@
pytest
.
mark
.
parametrize
(
"model"
,
[
SSM_MODELS
[
0
],
HYBRID_MODELS
[
0
]])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
SSM_MODELS
[
0
],
HYBRID_MODELS
[
0
]])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
20
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
20
])
@
pytest
.
mark
.
parametrize
(
"conv_state_layout"
,
[
"SD"
,
"DS"
])
def
test_mamba_cache_cg_padding
(
def
test_mamba_cache_cg_padding
(
vllm_runner
,
vllm_runner
,
example_prompts
,
example_prompts
,
monkeypatch
,
model
:
str
,
model
:
str
,
max_tokens
:
int
,
max_tokens
:
int
,
conv_state_layout
:
str
,
)
->
None
:
)
->
None
:
"""
"""
This test is for verifying that mamba cache is padded to CG captured
This test is for verifying that mamba cache is padded to CG captured
batch size. If it's not, a torch RuntimeError will be raised because
batch size. If it's not, a torch RuntimeError will be raised because
tensor dimensions aren't compatible.
tensor dimensions aren't compatible.
"""
"""
_set_conv_state_layout
(
monkeypatch
,
conv_state_layout
)
vllm_config
=
EngineArgs
(
model
=
model
,
trust_remote_code
=
True
).
create_engine_config
()
vllm_config
=
EngineArgs
(
model
=
model
,
trust_remote_code
=
True
).
create_engine_config
()
cudagraph_dispatcher
=
CudagraphDispatcher
(
vllm_config
)
cudagraph_dispatcher
=
CudagraphDispatcher
(
vllm_config
)
cudagraph_dispatcher
.
initialize_cudagraph_keys
(
cudagraph_dispatcher
.
initialize_cudagraph_keys
(
...
...
vllm/envs.py
View file @
66e86f1d
...
@@ -191,6 +191,7 @@ if TYPE_CHECKING:
...
@@ -191,6 +191,7 @@ if TYPE_CHECKING:
VLLM_MQ_MAX_CHUNK_BYTES_MB
:
int
=
16
VLLM_MQ_MAX_CHUNK_BYTES_MB
:
int
=
16
VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS
:
int
=
300
VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS
:
int
=
300
VLLM_KV_CACHE_LAYOUT
:
Literal
[
"NHD"
,
"HND"
]
|
None
=
None
VLLM_KV_CACHE_LAYOUT
:
Literal
[
"NHD"
,
"HND"
]
|
None
=
None
VLLM_SSM_CONV_STATE_LAYOUT
:
Literal
[
"SD"
,
"DS"
]
|
None
=
None
VLLM_COMPUTE_NANS_IN_LOGITS
:
bool
=
False
VLLM_COMPUTE_NANS_IN_LOGITS
:
bool
=
False
VLLM_USE_NVFP4_CT_EMULATIONS
:
bool
=
False
VLLM_USE_NVFP4_CT_EMULATIONS
:
bool
=
False
VLLM_ROCM_QUICK_REDUCE_QUANTIZATION
:
Literal
[
VLLM_ROCM_QUICK_REDUCE_QUANTIZATION
:
Literal
[
...
@@ -1409,6 +1410,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -1409,6 +1410,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_KV_CACHE_LAYOUT"
:
env_with_choices
(
"VLLM_KV_CACHE_LAYOUT"
:
env_with_choices
(
"VLLM_KV_CACHE_LAYOUT"
,
None
,
[
"NHD"
,
"HND"
]
"VLLM_KV_CACHE_LAYOUT"
,
None
,
[
"NHD"
,
"HND"
]
),
),
# SSM conv state layout used for Mamba models.
# - SD: (state_len, dim) — dim contiguous (default)
# - DS: (dim, state_len) — TP-sharded dim on dim1,
# consistent with SSM temporal state and HND KV cache layout.
"VLLM_SSM_CONV_STATE_LAYOUT"
:
env_with_choices
(
"VLLM_SSM_CONV_STATE_LAYOUT"
,
None
,
[
"SD"
,
"DS"
]
),
# Enable checking whether the generated logits contain NaNs,
# Enable checking whether the generated logits contain NaNs,
# indicating corrupted output. Useful for debugging low level bugs
# indicating corrupted output. Useful for debugging low level bugs
# or bad hardware but it may add compute overhead.
# or bad hardware but it may add compute overhead.
...
...
vllm/model_executor/layers/kda.py
View file @
66e86f1d
...
@@ -31,7 +31,11 @@ from .linear import (
...
@@ -31,7 +31,11 @@ from .linear import (
RowParallelLinear
,
RowParallelLinear
,
)
)
from
.mamba.abstract
import
MambaBase
from
.mamba.abstract
import
MambaBase
from
.mamba.mamba_utils
import
MambaStateDtypeCalculator
,
MambaStateShapeCalculator
from
.mamba.mamba_utils
import
(
MambaStateDtypeCalculator
,
MambaStateShapeCalculator
,
is_conv_state_dim_first
,
)
from
.mamba.ops.causal_conv1d
import
causal_conv1d_fn
,
causal_conv1d_update
from
.mamba.ops.causal_conv1d
import
causal_conv1d_fn
,
causal_conv1d_update
from
.quantization.base_config
import
QuantizationConfig
from
.quantization.base_config
import
QuantizationConfig
...
@@ -315,10 +319,12 @@ class KimiDeltaAttention(nn.Module, MambaBase):
...
@@ -315,10 +319,12 @@ class KimiDeltaAttention(nn.Module, MambaBase):
beta
=
beta
[:
num_actual_tokens
]
beta
=
beta
[:
num_actual_tokens
]
(
conv_state_q
,
conv_state_k
,
conv_state_v
,
recurrent_state
)
=
constant_caches
(
conv_state_q
,
conv_state_k
,
conv_state_v
,
recurrent_state
)
=
constant_caches
# deal with strides
# conv_state must be (..., dim, width-1) for the conv kernels.
conv_state_q
=
conv_state_q
.
transpose
(
-
1
,
-
2
)
# DS layout stores it that way directly; SD layout needs a transpose.
conv_state_k
=
conv_state_k
.
transpose
(
-
1
,
-
2
)
if
not
is_conv_state_dim_first
():
conv_state_v
=
conv_state_v
.
transpose
(
-
1
,
-
2
)
conv_state_q
=
conv_state_q
.
transpose
(
-
1
,
-
2
)
conv_state_k
=
conv_state_k
.
transpose
(
-
1
,
-
2
)
conv_state_v
=
conv_state_v
.
transpose
(
-
1
,
-
2
)
q_conv_weights
=
self
.
q_conv1d
.
weight
.
view
(
q_conv_weights
=
self
.
q_conv1d
.
weight
.
view
(
self
.
q_conv1d
.
weight
.
size
(
0
),
self
.
q_conv1d
.
weight
.
size
(
2
)
self
.
q_conv1d
.
weight
.
size
(
0
),
self
.
q_conv1d
.
weight
.
size
(
2
)
...
...
vllm/model_executor/layers/mamba/gdn_linear_attn.py
View file @
66e86f1d
...
@@ -41,6 +41,7 @@ from vllm.model_executor.layers.mamba.mamba_mixer2 import mamba_v2_sharded_weigh
...
@@ -41,6 +41,7 @@ from vllm.model_executor.layers.mamba.mamba_mixer2 import mamba_v2_sharded_weigh
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
MambaStateDtypeCalculator
,
MambaStateDtypeCalculator
,
MambaStateShapeCalculator
,
MambaStateShapeCalculator
,
is_conv_state_dim_first
,
)
)
from
vllm.model_executor.layers.mamba.ops.causal_conv1d
import
(
from
vllm.model_executor.layers.mamba.ops.causal_conv1d
import
(
causal_conv1d_fn
,
causal_conv1d_fn
,
...
@@ -699,7 +700,13 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
...
@@ -699,7 +700,13 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
spec_state_indices_tensor
=
attn_metadata
.
spec_state_indices_tensor
# noqa: E501
spec_state_indices_tensor
=
attn_metadata
.
spec_state_indices_tensor
# noqa: E501
non_spec_state_indices_tensor
=
attn_metadata
.
non_spec_state_indices_tensor
# noqa: E501
non_spec_state_indices_tensor
=
attn_metadata
.
non_spec_state_indices_tensor
# noqa: E501
self_kv_cache
=
self
.
kv_cache
self_kv_cache
=
self
.
kv_cache
conv_state
=
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
# conv_state must be (..., dim, width-1) for the conv kernels.
# DS layout stores it that way directly; SD layout needs a transpose.
conv_state
=
(
self_kv_cache
[
0
]
if
is_conv_state_dim_first
()
else
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
)
ssm_state
=
self_kv_cache
[
1
]
ssm_state
=
self_kv_cache
[
1
]
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
num_accepted_tokens
=
attn_metadata
.
num_accepted_tokens
num_accepted_tokens
=
attn_metadata
.
num_accepted_tokens
...
@@ -914,7 +921,13 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
...
@@ -914,7 +921,13 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
"""
"""
non_spec_state_indices_tensor
=
attn_metadata
.
non_spec_state_indices_tensor
# noqa: E501
non_spec_state_indices_tensor
=
attn_metadata
.
non_spec_state_indices_tensor
# noqa: E501
self_kv_cache
=
self
.
kv_cache
self_kv_cache
=
self
.
kv_cache
conv_state
=
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
# conv_state must be (..., dim, width-1) for the conv kernels.
# DS layout stores it that way directly; SD layout needs a transpose.
conv_state
=
(
self_kv_cache
[
0
]
if
is_conv_state_dim_first
()
else
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
)
ssm_state
=
self_kv_cache
[
1
]
ssm_state
=
self_kv_cache
[
1
]
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
...
...
vllm/model_executor/layers/mamba/mamba_mixer.py
View file @
66e86f1d
...
@@ -24,6 +24,7 @@ from vllm.model_executor.layers.mamba.abstract import MambaBase
...
@@ -24,6 +24,7 @@ from vllm.model_executor.layers.mamba.abstract import MambaBase
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
MambaStateDtypeCalculator
,
MambaStateDtypeCalculator
,
MambaStateShapeCalculator
,
MambaStateShapeCalculator
,
is_conv_state_dim_first
,
)
)
from
vllm.model_executor.layers.mamba.ops.causal_conv1d
import
(
from
vllm.model_executor.layers.mamba.ops.causal_conv1d
import
(
causal_conv1d_fn
,
causal_conv1d_fn
,
...
@@ -267,9 +268,12 @@ class MambaMixer(MambaBase, PluggableLayer):
...
@@ -267,9 +268,12 @@ class MambaMixer(MambaBase, PluggableLayer):
query_start_loc_p
=
attn_metadata
.
query_start_loc_p
query_start_loc_p
=
attn_metadata
.
query_start_loc_p
state_indices_tensor_p
=
attn_metadata
.
state_indices_tensor_p
state_indices_tensor_p
=
attn_metadata
.
state_indices_tensor_p
state_indices_tensor_d
=
attn_metadata
.
state_indices_tensor_d
state_indices_tensor_d
=
attn_metadata
.
state_indices_tensor_d
self_kv_cache
=
self
.
kv_cache
conv_state
=
(
conv_state
=
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
self
.
kv_cache
[
0
]
ssm_state
=
self_kv_cache
[
1
]
if
is_conv_state_dim_first
()
else
self
.
kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
)
ssm_state
=
self
.
kv_cache
[
1
]
has_initial_states_p
=
attn_metadata
.
has_initial_states_p
has_initial_states_p
=
attn_metadata
.
has_initial_states_p
cu_chunk_seqlen_p
=
attn_metadata
.
cu_chunk_seqlen_p
cu_chunk_seqlen_p
=
attn_metadata
.
cu_chunk_seqlen_p
last_chunk_indices_p
=
attn_metadata
.
last_chunk_indices_p
last_chunk_indices_p
=
attn_metadata
.
last_chunk_indices_p
...
...
vllm/model_executor/layers/mamba/mamba_mixer2.py
View file @
66e86f1d
...
@@ -24,6 +24,7 @@ from vllm.model_executor.layers.mamba.abstract import MambaBase
...
@@ -24,6 +24,7 @@ from vllm.model_executor.layers.mamba.abstract import MambaBase
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
MambaStateDtypeCalculator
,
MambaStateDtypeCalculator
,
MambaStateShapeCalculator
,
MambaStateShapeCalculator
,
is_conv_state_dim_first
,
)
)
from
vllm.model_executor.layers.mamba.ops.causal_conv1d
import
(
from
vllm.model_executor.layers.mamba.ops.causal_conv1d
import
(
causal_conv1d_fn
,
causal_conv1d_fn
,
...
@@ -575,10 +576,15 @@ class MambaMixer2(MambaBase, PluggableLayer):
...
@@ -575,10 +576,15 @@ class MambaMixer2(MambaBase, PluggableLayer):
assert
isinstance
(
attn_metadata
,
dict
)
assert
isinstance
(
attn_metadata
,
dict
)
attn_metadata
=
attn_metadata
[
self
.
prefix
]
attn_metadata
=
attn_metadata
[
self
.
prefix
]
assert
isinstance
(
attn_metadata
,
Mamba2AttentionMetadata
)
assert
isinstance
(
attn_metadata
,
Mamba2AttentionMetadata
)
self_kv_cache
=
self
.
kv_cache
# conv_state must be (..., dim, width-1) for the conv kernels.
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
# DS layout stores it that way directly; SD layout needs a
conv_state
=
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
# transpose (which keeps dim contiguous via stride tricks).
ssm_state
=
self_kv_cache
[
1
]
conv_state
=
(
self
.
kv_cache
[
0
]
if
is_conv_state_dim_first
()
else
self
.
kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
)
ssm_state
=
self
.
kv_cache
[
1
]
has_initial_states_p
=
attn_metadata
.
has_initial_states_p
has_initial_states_p
=
attn_metadata
.
has_initial_states_p
prep_initial_states
=
attn_metadata
.
prep_initial_states
prep_initial_states
=
attn_metadata
.
prep_initial_states
chunk_size
=
attn_metadata
.
chunk_size
chunk_size
=
attn_metadata
.
chunk_size
...
...
vllm/model_executor/layers/mamba/mamba_utils.py
View file @
66e86f1d
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
functools
from
collections.abc
import
Callable
from
collections.abc
import
Callable
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TypeAlias
from
typing
import
Literal
,
TypeAlias
import
torch
import
torch
import
vllm.envs
as
envs
from
vllm.config.cache
import
MambaDType
from
vllm.config.cache
import
MambaDType
from
vllm.config.model
import
ModelDType
from
vllm.config.model
import
ModelDType
from
vllm.distributed
import
divide
from
vllm.distributed
import
divide
from
vllm.logger
import
init_logger
from
vllm.utils.torch_utils
import
(
from
vllm.utils.torch_utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
STR_DTYPE_TO_TORCH_DTYPE
,
get_kv_cache_torch_dtype
,
get_kv_cache_torch_dtype
,
)
)
logger
=
init_logger
(
__name__
)
ConvStateLayoutType
=
Literal
[
"SD"
,
"DS"
]
@
functools
.
lru_cache
def
get_conv_state_layout
()
->
ConvStateLayoutType
:
"""Return the SSM conv state layout.
SD = (state_len, dim) — dim is the innermost contiguous dimension.
DS = (dim, state_len) — TP-sharded dim is on dim-1 (like HND for KV
cache), consistent with SSM temporal state layout.
"""
layout
:
ConvStateLayoutType
|
None
=
envs
.
VLLM_SSM_CONV_STATE_LAYOUT
if
layout
is
not
None
:
logger
.
info_once
(
"VLLM_SSM_CONV_STATE_LAYOUT env detected. "
"Setting SSM conv state layout to %s."
,
layout
,
)
return
layout
return
"SD"
def
is_conv_state_dim_first
()
->
bool
:
"""True when the conv state is stored as (dim, state_len) per block."""
return
get_conv_state_layout
()
==
"DS"
class
MambaStateDtypeCalculator
:
class
MambaStateDtypeCalculator
:
@
classmethod
@
classmethod
...
@@ -107,6 +139,13 @@ class MambaStateShapeCalculator:
...
@@ -107,6 +139,13 @@ class MambaStateShapeCalculator:
state_shape
=
(
num_heads
//
tp_size
,
head_dim
,
head_dim
)
state_shape
=
(
num_heads
//
tp_size
,
head_dim
,
head_dim
)
return
(
state_shape
,)
return
(
state_shape
,)
@
staticmethod
def
_orient_conv_shape
(
dim
:
int
,
state_len
:
int
)
->
tuple
[
int
,
int
]:
"""Return (dim, state_len) for DS layout, (state_len, dim) for SD."""
if
is_conv_state_dim_first
():
return
(
dim
,
state_len
)
return
(
state_len
,
dim
)
@
classmethod
@
classmethod
def
mamba1_state_shape
(
def
mamba1_state_shape
(
cls
,
cls
,
...
@@ -115,12 +154,11 @@ class MambaStateShapeCalculator:
...
@@ -115,12 +154,11 @@ class MambaStateShapeCalculator:
state_size
:
int
,
state_size
:
int
,
conv_kernel
:
int
,
conv_kernel
:
int
,
)
->
tuple
[
tuple
[
int
,
int
],
tuple
[
int
,
int
]]:
)
->
tuple
[
tuple
[
int
,
int
],
tuple
[
int
,
int
]]:
conv_state_shape
=
(
divide
(
intermediate_size
,
tp_world_size
),
conv_kernel
-
1
)
conv_dim
=
divide
(
intermediate_size
,
tp_world_size
)
conv_state_shape
=
cls
.
_orient_conv_shape
(
conv_dim
,
conv_kernel
-
1
)
temporal_state_shape
=
(
divide
(
intermediate_size
,
tp_world_size
),
state_size
)
temporal_state_shape
=
(
divide
(
intermediate_size
,
tp_world_size
),
state_size
)
conv_state_shape
=
conv_state_shape
[
1
],
conv_state_shape
[
0
]
return
conv_state_shape
,
temporal_state_shape
return
conv_state_shape
,
temporal_state_shape
@
classmethod
@
classmethod
...
@@ -141,8 +179,9 @@ class MambaStateShapeCalculator:
...
@@ -141,8 +179,9 @@ class MambaStateShapeCalculator:
# heads and n_groups are TP-ed
# heads and n_groups are TP-ed
conv_dim
=
intermediate_size
+
2
*
n_groups
*
state_size
conv_dim
=
intermediate_size
+
2
*
n_groups
*
state_size
# contiguous along 'dim' axis
conv_state_shape
=
cls
.
_orient_conv_shape
(
conv_state_shape
=
(
conv_kernel
-
1
+
num_spec
,
divide
(
conv_dim
,
tp_world_size
))
divide
(
conv_dim
,
tp_world_size
),
conv_kernel
-
1
+
num_spec
)
# These are not TP-ed as they depend on A, dt_bias, D
# These are not TP-ed as they depend on A, dt_bias, D
# - they are typically small
# - they are typically small
...
@@ -158,7 +197,7 @@ class MambaStateShapeCalculator:
...
@@ -158,7 +197,7 @@ class MambaStateShapeCalculator:
conv_kernel
:
int
,
conv_kernel
:
int
,
)
->
tuple
[
tuple
[
int
,
int
]]:
)
->
tuple
[
tuple
[
int
,
int
]]:
conv_dim
=
divide
(
intermediate_size
,
tp_world_size
)
conv_dim
=
divide
(
intermediate_size
,
tp_world_size
)
conv_state_shape
=
(
conv_kernel
-
1
,
conv_dim
)
conv_state_shape
=
cls
.
_orient_conv_shape
(
conv_dim
,
conv_kernel
-
1
)
return
(
conv_state_shape
,)
return
(
conv_state_shape
,)
@
classmethod
@
classmethod
...
@@ -185,13 +224,11 @@ class MambaStateShapeCalculator:
...
@@ -185,13 +224,11 @@ class MambaStateShapeCalculator:
num_spec
:
int
=
0
,
num_spec
:
int
=
0
,
):
):
conv_dim
=
head_k_dim
*
num_k_heads
*
2
+
head_v_dim
*
num_v_heads
conv_dim
=
head_k_dim
*
num_k_heads
*
2
+
head_v_dim
*
num_v_heads
conv_state_shape
=
(
conv_state_shape
=
cls
.
_orient_conv_shape
(
divide
(
conv_dim
,
tp_world_size
),
divide
(
conv_dim
,
tp_world_size
),
conv_kernel_size
-
1
+
num_spec
,
conv_kernel_size
-
1
+
num_spec
,
)
)
conv_state_shape
=
conv_state_shape
[
1
],
conv_state_shape
[
0
]
temporal_state_shape
=
(
temporal_state_shape
=
(
divide
(
num_v_heads
,
tp_world_size
),
divide
(
num_v_heads
,
tp_world_size
),
head_v_dim
,
head_v_dim
,
...
@@ -218,12 +255,13 @@ class MambaStateShapeCalculator:
...
@@ -218,12 +255,13 @@ class MambaStateShapeCalculator:
proj_size
=
num_heads
*
head_dim
proj_size
=
num_heads
*
head_dim
proj_k_size
=
num_k_heads
*
head_k_dim
proj_k_size
=
num_k_heads
*
head_k_dim
conv_state_shape
=
(
divide
(
proj_size
,
tp_world_size
),
conv_kernel_size
-
1
)
conv_state_shape
=
cls
.
_orient_conv_shape
(
conv_state_k_shape
=
(
divide
(
proj_k_size
,
tp_world_size
),
conv_kernel_size
-
1
)
divide
(
proj_size
,
tp_world_size
),
conv_kernel_size
-
1
)
conv_state_k_shape
=
cls
.
_orient_conv_shape
(
divide
(
proj_k_size
,
tp_world_size
),
conv_kernel_size
-
1
)
recurrent_state_shape
=
(
divide
(
num_heads
,
tp_world_size
),
head_dim
,
head_dim
)
recurrent_state_shape
=
(
divide
(
num_heads
,
tp_world_size
),
head_dim
,
head_dim
)
conv_state_shape
=
conv_state_shape
[
1
],
conv_state_shape
[
0
]
conv_state_k_shape
=
conv_state_k_shape
[
1
],
conv_state_k_shape
[
0
]
return
(
return
(
conv_state_shape
,
conv_state_shape
,
conv_state_k_shape
,
conv_state_k_shape
,
...
@@ -267,9 +305,27 @@ def get_conv_copy_spec(
...
@@ -267,9 +305,27 @@ def get_conv_copy_spec(
cur_block_idx
:
int
,
cur_block_idx
:
int
,
num_accepted_tokens
:
int
,
num_accepted_tokens
:
int
,
)
->
MambaCopySpec
:
)
->
MambaCopySpec
:
"""Return a MambaCopySpec for copying a convolutional state slice."""
"""Return a MambaCopySpec for copying a convolutional state slice.
Works for both SD layout ``(num_blocks, state_len, dim)`` and
DS layout ``(num_blocks, dim, state_len)``.
"""
src_block_id
=
block_ids
[
cur_block_idx
]
src_block_id
=
block_ids
[
cur_block_idx
]
src_state
=
state
[
src_block_id
,
num_accepted_tokens
-
1
:]
offset
=
num_accepted_tokens
-
1
if
is_conv_state_dim_first
():
# DS layout: (num_blocks, dim, state_len) — state_len is last.
if
offset
>
0
:
# Slicing along the last dim yields a non-contiguous view
# because features (dim) are strided by state_len.
raise
NotImplementedError
(
"DS conv state layout does not yet support speculative "
"decoding with mamba_cache_mode='align' "
"(num_accepted_tokens > 1)."
)
src_state
=
state
[
src_block_id
]
else
:
# SD layout: (num_blocks, state_len, dim) — dim contiguous.
src_state
=
state
[
src_block_id
,
offset
:]
return
MambaCopySpec
(
return
MambaCopySpec
(
start_addr
=
src_state
.
data_ptr
(),
num_elements
=
src_state
.
numel
()
start_addr
=
src_state
.
data_ptr
(),
num_elements
=
src_state
.
numel
()
)
)
...
...
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
View file @
66e86f1d
...
@@ -592,7 +592,6 @@ def causal_conv1d_fn(
...
@@ -592,7 +592,6 @@ def causal_conv1d_fn(
stride_istate_seq
=
conv_states
.
stride
(
0
)
stride_istate_seq
=
conv_states
.
stride
(
0
)
stride_istate_dim
=
conv_states
.
stride
(
1
)
stride_istate_dim
=
conv_states
.
stride
(
1
)
stride_istate_token
=
conv_states
.
stride
(
2
)
stride_istate_token
=
conv_states
.
stride
(
2
)
assert
stride_istate_dim
==
1
if
out
.
dim
()
==
2
:
if
out
.
dim
()
==
2
:
stride_o_dim
=
out
.
stride
(
0
)
stride_o_dim
=
out
.
stride
(
0
)
stride_o_token
=
out
.
stride
(
1
)
stride_o_token
=
out
.
stride
(
1
)
...
@@ -1149,9 +1148,6 @@ def causal_conv1d_update(
...
@@ -1149,9 +1148,6 @@ def causal_conv1d_update(
if
validate_data
:
if
validate_data
:
assert
dim
==
weight
.
size
(
0
)
assert
dim
==
weight
.
size
(
0
)
assert
conv_state
.
stride
(
-
2
)
==
1
,
(
f
"ERROR: expect contiguous along feat-dim of conv_state (currently stride=
{
conv_state
.
stride
()
}
)"
)
assert
state_len
>=
width
-
1
assert
state_len
>=
width
-
1
# when above happens, we don't shift-left to keep any records in conv_state
# when above happens, we don't shift-left to keep any records in conv_state
assert
dim
==
conv_state
.
size
(
1
)
assert
dim
==
conv_state
.
size
(
1
)
...
...
vllm/model_executor/layers/mamba/short_conv.py
View file @
66e86f1d
...
@@ -17,6 +17,7 @@ from vllm.model_executor.layers.mamba.abstract import MambaBase
...
@@ -17,6 +17,7 @@ from vllm.model_executor.layers.mamba.abstract import MambaBase
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
MambaStateDtypeCalculator
,
MambaStateDtypeCalculator
,
MambaStateShapeCalculator
,
MambaStateShapeCalculator
,
is_conv_state_dim_first
,
)
)
from
vllm.model_executor.layers.mamba.ops.causal_conv1d
import
(
from
vllm.model_executor.layers.mamba.ops.causal_conv1d
import
(
causal_conv1d_fn
,
causal_conv1d_fn
,
...
@@ -117,8 +118,11 @@ class ShortConv(MambaBase, CustomOp):
...
@@ -117,8 +118,11 @@ class ShortConv(MambaBase, CustomOp):
assert
isinstance
(
attn_metadata
,
dict
)
assert
isinstance
(
attn_metadata
,
dict
)
attn_metadata
=
attn_metadata
[
self
.
prefix
]
attn_metadata
=
attn_metadata
[
self
.
prefix
]
assert
isinstance
(
attn_metadata
,
ShortConvAttentionMetadata
)
assert
isinstance
(
attn_metadata
,
ShortConvAttentionMetadata
)
self_kv_cache
=
self
.
kv_cache
conv_state
=
(
conv_state
=
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
self
.
kv_cache
[
0
]
if
is_conv_state_dim_first
()
else
self
.
kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
)
state_indices_tensor_p
=
attn_metadata
.
state_indices_tensor_p
state_indices_tensor_p
=
attn_metadata
.
state_indices_tensor_p
state_indices_tensor_d
=
attn_metadata
.
state_indices_tensor_d
state_indices_tensor_d
=
attn_metadata
.
state_indices_tensor_d
has_initial_states_p
=
attn_metadata
.
has_initial_states_p
has_initial_states_p
=
attn_metadata
.
has_initial_states_p
...
...
vllm/model_executor/models/olmo_hybrid.py
View file @
66e86f1d
...
@@ -68,6 +68,7 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
...
@@ -68,6 +68,7 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFuncCalculator
,
MambaStateCopyFuncCalculator
,
MambaStateDtypeCalculator
,
MambaStateDtypeCalculator
,
MambaStateShapeCalculator
,
MambaStateShapeCalculator
,
is_conv_state_dim_first
,
)
)
from
vllm.model_executor.layers.mamba.ops.causal_conv1d
import
(
from
vllm.model_executor.layers.mamba.ops.causal_conv1d
import
(
causal_conv1d_fn
,
causal_conv1d_fn
,
...
@@ -429,7 +430,13 @@ class OlmoHybridGatedDeltaNet(nn.Module, MambaBase):
...
@@ -429,7 +430,13 @@ class OlmoHybridGatedDeltaNet(nn.Module, MambaBase):
spec_state_indices_tensor
=
attn_metadata
.
spec_state_indices_tensor
spec_state_indices_tensor
=
attn_metadata
.
spec_state_indices_tensor
non_spec_state_indices_tensor
=
attn_metadata
.
non_spec_state_indices_tensor
non_spec_state_indices_tensor
=
attn_metadata
.
non_spec_state_indices_tensor
self_kv_cache
=
self
.
kv_cache
self_kv_cache
=
self
.
kv_cache
conv_state
=
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
# conv_state must be (..., dim, width-1) for the conv kernels.
# DS layout stores it that way directly; SD layout needs a transpose.
conv_state
=
(
self_kv_cache
[
0
]
if
is_conv_state_dim_first
()
else
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
)
ssm_state
=
self_kv_cache
[
1
]
ssm_state
=
self_kv_cache
[
1
]
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
num_accepted_tokens
=
attn_metadata
.
num_accepted_tokens
num_accepted_tokens
=
attn_metadata
.
num_accepted_tokens
...
...
vllm/model_executor/models/plamo2.py
View file @
66e86f1d
...
@@ -32,6 +32,7 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
...
@@ -32,6 +32,7 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFuncCalculator
,
MambaStateCopyFuncCalculator
,
MambaStateDtypeCalculator
,
MambaStateDtypeCalculator
,
MambaStateShapeCalculator
,
MambaStateShapeCalculator
,
is_conv_state_dim_first
,
)
)
from
vllm.model_executor.layers.mamba.ops.causal_conv1d
import
(
from
vllm.model_executor.layers.mamba.ops.causal_conv1d
import
(
causal_conv1d_fn
,
causal_conv1d_fn
,
...
@@ -266,7 +267,13 @@ class Plamo2MambaMixer(MambaBase, PluggableLayer):
...
@@ -266,7 +267,13 @@ class Plamo2MambaMixer(MambaBase, PluggableLayer):
assert
isinstance
(
attn_metadata
,
Mamba2AttentionMetadata
)
assert
isinstance
(
attn_metadata
,
Mamba2AttentionMetadata
)
self_kv_cache
=
self
.
kv_cache
self_kv_cache
=
self
.
kv_cache
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state
=
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
# conv_state must be (..., dim, width-1) for the conv kernels.
# DS layout stores it that way directly; SD layout needs a transpose.
conv_state
=
(
self_kv_cache
[
0
]
if
is_conv_state_dim_first
()
else
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
)
ssm_state
=
self_kv_cache
[
1
]
ssm_state
=
self_kv_cache
[
1
]
state_indices_tensor_p
=
attn_metadata
.
state_indices_tensor_p
state_indices_tensor_p
=
attn_metadata
.
state_indices_tensor_p
state_indices_tensor_d
=
attn_metadata
.
state_indices_tensor_d
state_indices_tensor_d
=
attn_metadata
.
state_indices_tensor_d
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment