Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
66e86f1d
"deploy/snapshot/pkg/common/overlay_test.go" did not exist on "bb8fc8a4a969357000caf57c79af47df6b2e2113"
Unverified
Commit
66e86f1d
authored
Apr 03, 2026
by
Nicolò Lucchesi
Committed by
GitHub
Apr 03, 2026
Browse files
[Kernel] Mamba support different layout for Conv state (#37416)
parent
bb39382b
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
169 additions
and
39 deletions
+169
-39
tests/models/language/generation/test_hybrid.py
tests/models/language/generation/test_hybrid.py
+23
-0
vllm/envs.py
vllm/envs.py
+8
-0
vllm/model_executor/layers/kda.py
vllm/model_executor/layers/kda.py
+11
-5
vllm/model_executor/layers/mamba/gdn_linear_attn.py
vllm/model_executor/layers/mamba/gdn_linear_attn.py
+15
-2
vllm/model_executor/layers/mamba/mamba_mixer.py
vllm/model_executor/layers/mamba/mamba_mixer.py
+7
-3
vllm/model_executor/layers/mamba/mamba_mixer2.py
vllm/model_executor/layers/mamba/mamba_mixer2.py
+10
-4
vllm/model_executor/layers/mamba/mamba_utils.py
vllm/model_executor/layers/mamba/mamba_utils.py
+73
-17
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
+0
-4
vllm/model_executor/layers/mamba/short_conv.py
vllm/model_executor/layers/mamba/short_conv.py
+6
-2
vllm/model_executor/models/olmo_hybrid.py
vllm/model_executor/models/olmo_hybrid.py
+8
-1
vllm/model_executor/models/plamo2.py
vllm/model_executor/models/plamo2.py
+8
-1
No files found.
tests/models/language/generation/test_hybrid.py
View file @
66e86f1d
...
...
@@ -60,6 +60,14 @@ MAX_NUM_SEQS = 4
ATTN_BACKEND
=
"TRITON_ATTN"
if
current_platform
.
is_rocm
()
else
"auto"
def
_set_conv_state_layout
(
monkeypatch
,
layout
:
str
)
->
None
:
"""Set conv state layout env var and clear cache to pick up new value."""
from
vllm.model_executor.layers.mamba
import
mamba_utils
monkeypatch
.
setenv
(
"VLLM_SSM_CONV_STATE_LAYOUT"
,
layout
)
mamba_utils
.
get_conv_state_layout
.
cache_clear
()
@
pytest
.
mark
.
parametrize
(
"model"
,
SSM_MODELS
+
HYBRID_MODELS
)
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
...
...
@@ -102,12 +110,15 @@ def test_models(
@
pytest
.
mark
.
parametrize
(
"model"
,
[
SSM_MODELS
[
0
],
HYBRID_MODELS
[
0
]])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"conv_state_layout"
,
[
"SD"
,
"DS"
])
def
test_batching
(
vllm_runner
,
example_prompts
,
monkeypatch
,
model
:
str
,
max_tokens
:
int
,
num_logprobs
:
int
,
conv_state_layout
:
str
,
)
->
None
:
try
:
model_info
=
HF_EXAMPLE_MODELS
.
find_hf_info
(
model
)
...
...
@@ -116,6 +127,8 @@ def test_batching(
except
ValueError
:
pass
_set_conv_state_layout
(
monkeypatch
,
conv_state_layout
)
for_loop_outputs
=
[]
with
vllm_runner
(
model
,
max_num_seqs
=
MAX_NUM_SEQS
)
as
vllm_model
:
for
prompt
in
example_prompts
:
...
...
@@ -138,11 +151,14 @@ def test_batching(
@
pytest
.
mark
.
parametrize
(
"model"
,
[
SSM_MODELS
[
0
],
HYBRID_MODELS
[
0
]])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
10
])
@
pytest
.
mark
.
parametrize
(
"conv_state_layout"
,
[
"SD"
,
"DS"
])
def
test_chunked_prefill_with_parallel_sampling
(
vllm_runner
,
example_prompts
,
monkeypatch
,
model
:
str
,
max_tokens
:
int
,
conv_state_layout
:
str
,
)
->
None
:
"""
Tests chunked prefill in conjunction with n > 1.
...
...
@@ -154,6 +170,8 @@ def test_chunked_prefill_with_parallel_sampling(
decoding steps inside a chunked prefill forward pass
(where we have both prefill and decode together)
"""
_set_conv_state_layout
(
monkeypatch
,
conv_state_layout
)
sampling_params
=
SamplingParams
(
n
=
3
,
temperature
=
1
,
seed
=
0
,
max_tokens
=
max_tokens
)
with
vllm_runner
(
model
,
...
...
@@ -168,17 +186,22 @@ def test_chunked_prefill_with_parallel_sampling(
@
pytest
.
mark
.
parametrize
(
"model"
,
[
SSM_MODELS
[
0
],
HYBRID_MODELS
[
0
]])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
20
])
@
pytest
.
mark
.
parametrize
(
"conv_state_layout"
,
[
"SD"
,
"DS"
])
def
test_mamba_cache_cg_padding
(
vllm_runner
,
example_prompts
,
monkeypatch
,
model
:
str
,
max_tokens
:
int
,
conv_state_layout
:
str
,
)
->
None
:
"""
This test is for verifying that mamba cache is padded to CG captured
batch size. If it's not, a torch RuntimeError will be raised because
tensor dimensions aren't compatible.
"""
_set_conv_state_layout
(
monkeypatch
,
conv_state_layout
)
vllm_config
=
EngineArgs
(
model
=
model
,
trust_remote_code
=
True
).
create_engine_config
()
cudagraph_dispatcher
=
CudagraphDispatcher
(
vllm_config
)
cudagraph_dispatcher
.
initialize_cudagraph_keys
(
...
...
vllm/envs.py
View file @
66e86f1d
...
...
@@ -191,6 +191,7 @@ if TYPE_CHECKING:
VLLM_MQ_MAX_CHUNK_BYTES_MB
:
int
=
16
VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS
:
int
=
300
VLLM_KV_CACHE_LAYOUT
:
Literal
[
"NHD"
,
"HND"
]
|
None
=
None
VLLM_SSM_CONV_STATE_LAYOUT
:
Literal
[
"SD"
,
"DS"
]
|
None
=
None
VLLM_COMPUTE_NANS_IN_LOGITS
:
bool
=
False
VLLM_USE_NVFP4_CT_EMULATIONS
:
bool
=
False
VLLM_ROCM_QUICK_REDUCE_QUANTIZATION
:
Literal
[
...
...
@@ -1409,6 +1410,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_KV_CACHE_LAYOUT"
:
env_with_choices
(
"VLLM_KV_CACHE_LAYOUT"
,
None
,
[
"NHD"
,
"HND"
]
),
# SSM conv state layout used for Mamba models.
# - SD: (state_len, dim) — dim contiguous (default)
# - DS: (dim, state_len) — TP-sharded dim on dim1,
# consistent with SSM temporal state and HND KV cache layout.
"VLLM_SSM_CONV_STATE_LAYOUT"
:
env_with_choices
(
"VLLM_SSM_CONV_STATE_LAYOUT"
,
None
,
[
"SD"
,
"DS"
]
),
# Enable checking whether the generated logits contain NaNs,
# indicating corrupted output. Useful for debugging low level bugs
# or bad hardware but it may add compute overhead.
...
...
vllm/model_executor/layers/kda.py
View file @
66e86f1d
...
...
@@ -31,7 +31,11 @@ from .linear import (
RowParallelLinear
,
)
from
.mamba.abstract
import
MambaBase
from
.mamba.mamba_utils
import
MambaStateDtypeCalculator
,
MambaStateShapeCalculator
from
.mamba.mamba_utils
import
(
MambaStateDtypeCalculator
,
MambaStateShapeCalculator
,
is_conv_state_dim_first
,
)
from
.mamba.ops.causal_conv1d
import
causal_conv1d_fn
,
causal_conv1d_update
from
.quantization.base_config
import
QuantizationConfig
...
...
@@ -315,7 +319,9 @@ class KimiDeltaAttention(nn.Module, MambaBase):
beta
=
beta
[:
num_actual_tokens
]
(
conv_state_q
,
conv_state_k
,
conv_state_v
,
recurrent_state
)
=
constant_caches
# deal with strides
# conv_state must be (..., dim, width-1) for the conv kernels.
# DS layout stores it that way directly; SD layout needs a transpose.
if
not
is_conv_state_dim_first
():
conv_state_q
=
conv_state_q
.
transpose
(
-
1
,
-
2
)
conv_state_k
=
conv_state_k
.
transpose
(
-
1
,
-
2
)
conv_state_v
=
conv_state_v
.
transpose
(
-
1
,
-
2
)
...
...
vllm/model_executor/layers/mamba/gdn_linear_attn.py
View file @
66e86f1d
...
...
@@ -41,6 +41,7 @@ from vllm.model_executor.layers.mamba.mamba_mixer2 import mamba_v2_sharded_weigh
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
MambaStateDtypeCalculator
,
MambaStateShapeCalculator
,
is_conv_state_dim_first
,
)
from
vllm.model_executor.layers.mamba.ops.causal_conv1d
import
(
causal_conv1d_fn
,
...
...
@@ -699,7 +700,13 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
spec_state_indices_tensor
=
attn_metadata
.
spec_state_indices_tensor
# noqa: E501
non_spec_state_indices_tensor
=
attn_metadata
.
non_spec_state_indices_tensor
# noqa: E501
self_kv_cache
=
self
.
kv_cache
conv_state
=
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
# conv_state must be (..., dim, width-1) for the conv kernels.
# DS layout stores it that way directly; SD layout needs a transpose.
conv_state
=
(
self_kv_cache
[
0
]
if
is_conv_state_dim_first
()
else
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
)
ssm_state
=
self_kv_cache
[
1
]
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
num_accepted_tokens
=
attn_metadata
.
num_accepted_tokens
...
...
@@ -914,7 +921,13 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
"""
non_spec_state_indices_tensor
=
attn_metadata
.
non_spec_state_indices_tensor
# noqa: E501
self_kv_cache
=
self
.
kv_cache
conv_state
=
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
# conv_state must be (..., dim, width-1) for the conv kernels.
# DS layout stores it that way directly; SD layout needs a transpose.
conv_state
=
(
self_kv_cache
[
0
]
if
is_conv_state_dim_first
()
else
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
)
ssm_state
=
self_kv_cache
[
1
]
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
...
...
vllm/model_executor/layers/mamba/mamba_mixer.py
View file @
66e86f1d
...
...
@@ -24,6 +24,7 @@ from vllm.model_executor.layers.mamba.abstract import MambaBase
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
MambaStateDtypeCalculator
,
MambaStateShapeCalculator
,
is_conv_state_dim_first
,
)
from
vllm.model_executor.layers.mamba.ops.causal_conv1d
import
(
causal_conv1d_fn
,
...
...
@@ -267,9 +268,12 @@ class MambaMixer(MambaBase, PluggableLayer):
query_start_loc_p
=
attn_metadata
.
query_start_loc_p
state_indices_tensor_p
=
attn_metadata
.
state_indices_tensor_p
state_indices_tensor_d
=
attn_metadata
.
state_indices_tensor_d
self_kv_cache
=
self
.
kv_cache
conv_state
=
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
ssm_state
=
self_kv_cache
[
1
]
conv_state
=
(
self
.
kv_cache
[
0
]
if
is_conv_state_dim_first
()
else
self
.
kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
)
ssm_state
=
self
.
kv_cache
[
1
]
has_initial_states_p
=
attn_metadata
.
has_initial_states_p
cu_chunk_seqlen_p
=
attn_metadata
.
cu_chunk_seqlen_p
last_chunk_indices_p
=
attn_metadata
.
last_chunk_indices_p
...
...
vllm/model_executor/layers/mamba/mamba_mixer2.py
View file @
66e86f1d
...
...
@@ -24,6 +24,7 @@ from vllm.model_executor.layers.mamba.abstract import MambaBase
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
MambaStateDtypeCalculator
,
MambaStateShapeCalculator
,
is_conv_state_dim_first
,
)
from
vllm.model_executor.layers.mamba.ops.causal_conv1d
import
(
causal_conv1d_fn
,
...
...
@@ -575,10 +576,15 @@ class MambaMixer2(MambaBase, PluggableLayer):
assert
isinstance
(
attn_metadata
,
dict
)
attn_metadata
=
attn_metadata
[
self
.
prefix
]
assert
isinstance
(
attn_metadata
,
Mamba2AttentionMetadata
)
self_kv_cache
=
self
.
kv_cache
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state
=
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
ssm_state
=
self_kv_cache
[
1
]
# conv_state must be (..., dim, width-1) for the conv kernels.
# DS layout stores it that way directly; SD layout needs a
# transpose (which keeps dim contiguous via stride tricks).
conv_state
=
(
self
.
kv_cache
[
0
]
if
is_conv_state_dim_first
()
else
self
.
kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
)
ssm_state
=
self
.
kv_cache
[
1
]
has_initial_states_p
=
attn_metadata
.
has_initial_states_p
prep_initial_states
=
attn_metadata
.
prep_initial_states
chunk_size
=
attn_metadata
.
chunk_size
...
...
vllm/model_executor/layers/mamba/mamba_utils.py
View file @
66e86f1d
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
functools
from
collections.abc
import
Callable
from
dataclasses
import
dataclass
from
typing
import
TypeAlias
from
typing
import
Literal
,
TypeAlias
import
torch
import
vllm.envs
as
envs
from
vllm.config.cache
import
MambaDType
from
vllm.config.model
import
ModelDType
from
vllm.distributed
import
divide
from
vllm.logger
import
init_logger
from
vllm.utils.torch_utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
get_kv_cache_torch_dtype
,
)
logger
=
init_logger
(
__name__
)
ConvStateLayoutType
=
Literal
[
"SD"
,
"DS"
]
@
functools
.
lru_cache
def
get_conv_state_layout
()
->
ConvStateLayoutType
:
"""Return the SSM conv state layout.
SD = (state_len, dim) — dim is the innermost contiguous dimension.
DS = (dim, state_len) — TP-sharded dim is on dim-1 (like HND for KV
cache), consistent with SSM temporal state layout.
"""
layout
:
ConvStateLayoutType
|
None
=
envs
.
VLLM_SSM_CONV_STATE_LAYOUT
if
layout
is
not
None
:
logger
.
info_once
(
"VLLM_SSM_CONV_STATE_LAYOUT env detected. "
"Setting SSM conv state layout to %s."
,
layout
,
)
return
layout
return
"SD"
def
is_conv_state_dim_first
()
->
bool
:
"""True when the conv state is stored as (dim, state_len) per block."""
return
get_conv_state_layout
()
==
"DS"
class
MambaStateDtypeCalculator
:
@
classmethod
...
...
@@ -107,6 +139,13 @@ class MambaStateShapeCalculator:
state_shape
=
(
num_heads
//
tp_size
,
head_dim
,
head_dim
)
return
(
state_shape
,)
@
staticmethod
def
_orient_conv_shape
(
dim
:
int
,
state_len
:
int
)
->
tuple
[
int
,
int
]:
"""Return (dim, state_len) for DS layout, (state_len, dim) for SD."""
if
is_conv_state_dim_first
():
return
(
dim
,
state_len
)
return
(
state_len
,
dim
)
@
classmethod
def
mamba1_state_shape
(
cls
,
...
...
@@ -115,12 +154,11 @@ class MambaStateShapeCalculator:
state_size
:
int
,
conv_kernel
:
int
,
)
->
tuple
[
tuple
[
int
,
int
],
tuple
[
int
,
int
]]:
conv_state_shape
=
(
divide
(
intermediate_size
,
tp_world_size
),
conv_kernel
-
1
)
conv_dim
=
divide
(
intermediate_size
,
tp_world_size
)
conv_state_shape
=
cls
.
_orient_conv_shape
(
conv_dim
,
conv_kernel
-
1
)
temporal_state_shape
=
(
divide
(
intermediate_size
,
tp_world_size
),
state_size
)
conv_state_shape
=
conv_state_shape
[
1
],
conv_state_shape
[
0
]
return
conv_state_shape
,
temporal_state_shape
@
classmethod
...
...
@@ -141,8 +179,9 @@ class MambaStateShapeCalculator:
# heads and n_groups are TP-ed
conv_dim
=
intermediate_size
+
2
*
n_groups
*
state_size
# contiguous along 'dim' axis
conv_state_shape
=
(
conv_kernel
-
1
+
num_spec
,
divide
(
conv_dim
,
tp_world_size
))
conv_state_shape
=
cls
.
_orient_conv_shape
(
divide
(
conv_dim
,
tp_world_size
),
conv_kernel
-
1
+
num_spec
)
# These are not TP-ed as they depend on A, dt_bias, D
# - they are typically small
...
...
@@ -158,7 +197,7 @@ class MambaStateShapeCalculator:
conv_kernel
:
int
,
)
->
tuple
[
tuple
[
int
,
int
]]:
conv_dim
=
divide
(
intermediate_size
,
tp_world_size
)
conv_state_shape
=
(
conv_kernel
-
1
,
conv_dim
)
conv_state_shape
=
cls
.
_orient_conv_shape
(
conv_dim
,
conv_kernel
-
1
)
return
(
conv_state_shape
,)
@
classmethod
...
...
@@ -185,13 +224,11 @@ class MambaStateShapeCalculator:
num_spec
:
int
=
0
,
):
conv_dim
=
head_k_dim
*
num_k_heads
*
2
+
head_v_dim
*
num_v_heads
conv_state_shape
=
(
conv_state_shape
=
cls
.
_orient_conv_shape
(
divide
(
conv_dim
,
tp_world_size
),
conv_kernel_size
-
1
+
num_spec
,
)
conv_state_shape
=
conv_state_shape
[
1
],
conv_state_shape
[
0
]
temporal_state_shape
=
(
divide
(
num_v_heads
,
tp_world_size
),
head_v_dim
,
...
...
@@ -218,12 +255,13 @@ class MambaStateShapeCalculator:
proj_size
=
num_heads
*
head_dim
proj_k_size
=
num_k_heads
*
head_k_dim
conv_state_shape
=
(
divide
(
proj_size
,
tp_world_size
),
conv_kernel_size
-
1
)
conv_state_k_shape
=
(
divide
(
proj_k_size
,
tp_world_size
),
conv_kernel_size
-
1
)
conv_state_shape
=
cls
.
_orient_conv_shape
(
divide
(
proj_size
,
tp_world_size
),
conv_kernel_size
-
1
)
conv_state_k_shape
=
cls
.
_orient_conv_shape
(
divide
(
proj_k_size
,
tp_world_size
),
conv_kernel_size
-
1
)
recurrent_state_shape
=
(
divide
(
num_heads
,
tp_world_size
),
head_dim
,
head_dim
)
conv_state_shape
=
conv_state_shape
[
1
],
conv_state_shape
[
0
]
conv_state_k_shape
=
conv_state_k_shape
[
1
],
conv_state_k_shape
[
0
]
return
(
conv_state_shape
,
conv_state_k_shape
,
...
...
@@ -267,9 +305,27 @@ def get_conv_copy_spec(
cur_block_idx
:
int
,
num_accepted_tokens
:
int
,
)
->
MambaCopySpec
:
"""Return a MambaCopySpec for copying a convolutional state slice."""
"""Return a MambaCopySpec for copying a convolutional state slice.
Works for both SD layout ``(num_blocks, state_len, dim)`` and
DS layout ``(num_blocks, dim, state_len)``.
"""
src_block_id
=
block_ids
[
cur_block_idx
]
src_state
=
state
[
src_block_id
,
num_accepted_tokens
-
1
:]
offset
=
num_accepted_tokens
-
1
if
is_conv_state_dim_first
():
# DS layout: (num_blocks, dim, state_len) — state_len is last.
if
offset
>
0
:
# Slicing along the last dim yields a non-contiguous view
# because features (dim) are strided by state_len.
raise
NotImplementedError
(
"DS conv state layout does not yet support speculative "
"decoding with mamba_cache_mode='align' "
"(num_accepted_tokens > 1)."
)
src_state
=
state
[
src_block_id
]
else
:
# SD layout: (num_blocks, state_len, dim) — dim contiguous.
src_state
=
state
[
src_block_id
,
offset
:]
return
MambaCopySpec
(
start_addr
=
src_state
.
data_ptr
(),
num_elements
=
src_state
.
numel
()
)
...
...
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
View file @
66e86f1d
...
...
@@ -592,7 +592,6 @@ def causal_conv1d_fn(
stride_istate_seq
=
conv_states
.
stride
(
0
)
stride_istate_dim
=
conv_states
.
stride
(
1
)
stride_istate_token
=
conv_states
.
stride
(
2
)
assert
stride_istate_dim
==
1
if
out
.
dim
()
==
2
:
stride_o_dim
=
out
.
stride
(
0
)
stride_o_token
=
out
.
stride
(
1
)
...
...
@@ -1149,9 +1148,6 @@ def causal_conv1d_update(
if
validate_data
:
assert
dim
==
weight
.
size
(
0
)
assert
conv_state
.
stride
(
-
2
)
==
1
,
(
f
"ERROR: expect contiguous along feat-dim of conv_state (currently stride=
{
conv_state
.
stride
()
}
)"
)
assert
state_len
>=
width
-
1
# when above happens, we don't shift-left to keep any records in conv_state
assert
dim
==
conv_state
.
size
(
1
)
...
...
vllm/model_executor/layers/mamba/short_conv.py
View file @
66e86f1d
...
...
@@ -17,6 +17,7 @@ from vllm.model_executor.layers.mamba.abstract import MambaBase
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
MambaStateDtypeCalculator
,
MambaStateShapeCalculator
,
is_conv_state_dim_first
,
)
from
vllm.model_executor.layers.mamba.ops.causal_conv1d
import
(
causal_conv1d_fn
,
...
...
@@ -117,8 +118,11 @@ class ShortConv(MambaBase, CustomOp):
assert
isinstance
(
attn_metadata
,
dict
)
attn_metadata
=
attn_metadata
[
self
.
prefix
]
assert
isinstance
(
attn_metadata
,
ShortConvAttentionMetadata
)
self_kv_cache
=
self
.
kv_cache
conv_state
=
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
conv_state
=
(
self
.
kv_cache
[
0
]
if
is_conv_state_dim_first
()
else
self
.
kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
)
state_indices_tensor_p
=
attn_metadata
.
state_indices_tensor_p
state_indices_tensor_d
=
attn_metadata
.
state_indices_tensor_d
has_initial_states_p
=
attn_metadata
.
has_initial_states_p
...
...
vllm/model_executor/models/olmo_hybrid.py
View file @
66e86f1d
...
...
@@ -68,6 +68,7 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFuncCalculator
,
MambaStateDtypeCalculator
,
MambaStateShapeCalculator
,
is_conv_state_dim_first
,
)
from
vllm.model_executor.layers.mamba.ops.causal_conv1d
import
(
causal_conv1d_fn
,
...
...
@@ -429,7 +430,13 @@ class OlmoHybridGatedDeltaNet(nn.Module, MambaBase):
spec_state_indices_tensor
=
attn_metadata
.
spec_state_indices_tensor
non_spec_state_indices_tensor
=
attn_metadata
.
non_spec_state_indices_tensor
self_kv_cache
=
self
.
kv_cache
conv_state
=
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
# conv_state must be (..., dim, width-1) for the conv kernels.
# DS layout stores it that way directly; SD layout needs a transpose.
conv_state
=
(
self_kv_cache
[
0
]
if
is_conv_state_dim_first
()
else
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
)
ssm_state
=
self_kv_cache
[
1
]
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
num_accepted_tokens
=
attn_metadata
.
num_accepted_tokens
...
...
vllm/model_executor/models/plamo2.py
View file @
66e86f1d
...
...
@@ -32,6 +32,7 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFuncCalculator
,
MambaStateDtypeCalculator
,
MambaStateShapeCalculator
,
is_conv_state_dim_first
,
)
from
vllm.model_executor.layers.mamba.ops.causal_conv1d
import
(
causal_conv1d_fn
,
...
...
@@ -266,7 +267,13 @@ class Plamo2MambaMixer(MambaBase, PluggableLayer):
assert
isinstance
(
attn_metadata
,
Mamba2AttentionMetadata
)
self_kv_cache
=
self
.
kv_cache
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state
=
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
# conv_state must be (..., dim, width-1) for the conv kernels.
# DS layout stores it that way directly; SD layout needs a transpose.
conv_state
=
(
self_kv_cache
[
0
]
if
is_conv_state_dim_first
()
else
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
)
ssm_state
=
self_kv_cache
[
1
]
state_indices_tensor_p
=
attn_metadata
.
state_indices_tensor_p
state_indices_tensor_d
=
attn_metadata
.
state_indices_tensor_d
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment