Unverified Commit 66e86f1d authored by Nicolò Lucchesi's avatar Nicolò Lucchesi Committed by GitHub
Browse files

[Kernel] Mamba support different layout for Conv state (#37416)

parent bb39382b
...@@ -60,6 +60,14 @@ MAX_NUM_SEQS = 4 ...@@ -60,6 +60,14 @@ MAX_NUM_SEQS = 4
ATTN_BACKEND = "TRITON_ATTN" if current_platform.is_rocm() else "auto" ATTN_BACKEND = "TRITON_ATTN" if current_platform.is_rocm() else "auto"
def _set_conv_state_layout(monkeypatch, layout: str) -> None:
"""Set conv state layout env var and clear cache to pick up new value."""
from vllm.model_executor.layers.mamba import mamba_utils
monkeypatch.setenv("VLLM_SSM_CONV_STATE_LAYOUT", layout)
mamba_utils.get_conv_state_layout.cache_clear()
@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS) @pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS)
@pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("num_logprobs", [5])
...@@ -102,12 +110,15 @@ def test_models( ...@@ -102,12 +110,15 @@ def test_models(
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) @pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
@pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("conv_state_layout", ["SD", "DS"])
def test_batching( def test_batching(
vllm_runner, vllm_runner,
example_prompts, example_prompts,
monkeypatch,
model: str, model: str,
max_tokens: int, max_tokens: int,
num_logprobs: int, num_logprobs: int,
conv_state_layout: str,
) -> None: ) -> None:
try: try:
model_info = HF_EXAMPLE_MODELS.find_hf_info(model) model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
...@@ -116,6 +127,8 @@ def test_batching( ...@@ -116,6 +127,8 @@ def test_batching(
except ValueError: except ValueError:
pass pass
_set_conv_state_layout(monkeypatch, conv_state_layout)
for_loop_outputs = [] for_loop_outputs = []
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
for prompt in example_prompts: for prompt in example_prompts:
...@@ -138,11 +151,14 @@ def test_batching( ...@@ -138,11 +151,14 @@ def test_batching(
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) @pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
@pytest.mark.parametrize("max_tokens", [10]) @pytest.mark.parametrize("max_tokens", [10])
@pytest.mark.parametrize("conv_state_layout", ["SD", "DS"])
def test_chunked_prefill_with_parallel_sampling( def test_chunked_prefill_with_parallel_sampling(
vllm_runner, vllm_runner,
example_prompts, example_prompts,
monkeypatch,
model: str, model: str,
max_tokens: int, max_tokens: int,
conv_state_layout: str,
) -> None: ) -> None:
""" """
Tests chunked prefill in conjunction with n > 1. Tests chunked prefill in conjunction with n > 1.
...@@ -154,6 +170,8 @@ def test_chunked_prefill_with_parallel_sampling( ...@@ -154,6 +170,8 @@ def test_chunked_prefill_with_parallel_sampling(
decoding steps inside a chunked prefill forward pass decoding steps inside a chunked prefill forward pass
(where we have both prefill and decode together) (where we have both prefill and decode together)
""" """
_set_conv_state_layout(monkeypatch, conv_state_layout)
sampling_params = SamplingParams(n=3, temperature=1, seed=0, max_tokens=max_tokens) sampling_params = SamplingParams(n=3, temperature=1, seed=0, max_tokens=max_tokens)
with vllm_runner( with vllm_runner(
model, model,
...@@ -168,17 +186,22 @@ def test_chunked_prefill_with_parallel_sampling( ...@@ -168,17 +186,22 @@ def test_chunked_prefill_with_parallel_sampling(
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) @pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
@pytest.mark.parametrize("max_tokens", [20]) @pytest.mark.parametrize("max_tokens", [20])
@pytest.mark.parametrize("conv_state_layout", ["SD", "DS"])
def test_mamba_cache_cg_padding( def test_mamba_cache_cg_padding(
vllm_runner, vllm_runner,
example_prompts, example_prompts,
monkeypatch,
model: str, model: str,
max_tokens: int, max_tokens: int,
conv_state_layout: str,
) -> None: ) -> None:
""" """
This test is for verifying that mamba cache is padded to CG captured This test is for verifying that mamba cache is padded to CG captured
batch size. If it's not, a torch RuntimeError will be raised because batch size. If it's not, a torch RuntimeError will be raised because
tensor dimensions aren't compatible. tensor dimensions aren't compatible.
""" """
_set_conv_state_layout(monkeypatch, conv_state_layout)
vllm_config = EngineArgs(model=model, trust_remote_code=True).create_engine_config() vllm_config = EngineArgs(model=model, trust_remote_code=True).create_engine_config()
cudagraph_dispatcher = CudagraphDispatcher(vllm_config) cudagraph_dispatcher = CudagraphDispatcher(vllm_config)
cudagraph_dispatcher.initialize_cudagraph_keys( cudagraph_dispatcher.initialize_cudagraph_keys(
......
...@@ -191,6 +191,7 @@ if TYPE_CHECKING: ...@@ -191,6 +191,7 @@ if TYPE_CHECKING:
VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16 VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16
VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS: int = 300 VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS: int = 300
VLLM_KV_CACHE_LAYOUT: Literal["NHD", "HND"] | None = None VLLM_KV_CACHE_LAYOUT: Literal["NHD", "HND"] | None = None
VLLM_SSM_CONV_STATE_LAYOUT: Literal["SD", "DS"] | None = None
VLLM_COMPUTE_NANS_IN_LOGITS: bool = False VLLM_COMPUTE_NANS_IN_LOGITS: bool = False
VLLM_USE_NVFP4_CT_EMULATIONS: bool = False VLLM_USE_NVFP4_CT_EMULATIONS: bool = False
VLLM_ROCM_QUICK_REDUCE_QUANTIZATION: Literal[ VLLM_ROCM_QUICK_REDUCE_QUANTIZATION: Literal[
...@@ -1409,6 +1410,13 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1409,6 +1410,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_KV_CACHE_LAYOUT": env_with_choices( "VLLM_KV_CACHE_LAYOUT": env_with_choices(
"VLLM_KV_CACHE_LAYOUT", None, ["NHD", "HND"] "VLLM_KV_CACHE_LAYOUT", None, ["NHD", "HND"]
), ),
# SSM conv state layout used for Mamba models.
# - SD: (state_len, dim) — dim contiguous (default)
# - DS: (dim, state_len) — TP-sharded dim on dim1,
# consistent with SSM temporal state and HND KV cache layout.
"VLLM_SSM_CONV_STATE_LAYOUT": env_with_choices(
"VLLM_SSM_CONV_STATE_LAYOUT", None, ["SD", "DS"]
),
# Enable checking whether the generated logits contain NaNs, # Enable checking whether the generated logits contain NaNs,
# indicating corrupted output. Useful for debugging low level bugs # indicating corrupted output. Useful for debugging low level bugs
# or bad hardware but it may add compute overhead. # or bad hardware but it may add compute overhead.
......
...@@ -31,7 +31,11 @@ from .linear import ( ...@@ -31,7 +31,11 @@ from .linear import (
RowParallelLinear, RowParallelLinear,
) )
from .mamba.abstract import MambaBase from .mamba.abstract import MambaBase
from .mamba.mamba_utils import MambaStateDtypeCalculator, MambaStateShapeCalculator from .mamba.mamba_utils import (
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
is_conv_state_dim_first,
)
from .mamba.ops.causal_conv1d import causal_conv1d_fn, causal_conv1d_update from .mamba.ops.causal_conv1d import causal_conv1d_fn, causal_conv1d_update
from .quantization.base_config import QuantizationConfig from .quantization.base_config import QuantizationConfig
...@@ -315,10 +319,12 @@ class KimiDeltaAttention(nn.Module, MambaBase): ...@@ -315,10 +319,12 @@ class KimiDeltaAttention(nn.Module, MambaBase):
beta = beta[:num_actual_tokens] beta = beta[:num_actual_tokens]
(conv_state_q, conv_state_k, conv_state_v, recurrent_state) = constant_caches (conv_state_q, conv_state_k, conv_state_v, recurrent_state) = constant_caches
# deal with strides # conv_state must be (..., dim, width-1) for the conv kernels.
conv_state_q = conv_state_q.transpose(-1, -2) # DS layout stores it that way directly; SD layout needs a transpose.
conv_state_k = conv_state_k.transpose(-1, -2) if not is_conv_state_dim_first():
conv_state_v = conv_state_v.transpose(-1, -2) conv_state_q = conv_state_q.transpose(-1, -2)
conv_state_k = conv_state_k.transpose(-1, -2)
conv_state_v = conv_state_v.transpose(-1, -2)
q_conv_weights = self.q_conv1d.weight.view( q_conv_weights = self.q_conv1d.weight.view(
self.q_conv1d.weight.size(0), self.q_conv1d.weight.size(2) self.q_conv1d.weight.size(0), self.q_conv1d.weight.size(2)
......
...@@ -41,6 +41,7 @@ from vllm.model_executor.layers.mamba.mamba_mixer2 import mamba_v2_sharded_weigh ...@@ -41,6 +41,7 @@ from vllm.model_executor.layers.mamba.mamba_mixer2 import mamba_v2_sharded_weigh
from vllm.model_executor.layers.mamba.mamba_utils import ( from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateDtypeCalculator,
MambaStateShapeCalculator, MambaStateShapeCalculator,
is_conv_state_dim_first,
) )
from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_fn,
...@@ -699,7 +700,13 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase): ...@@ -699,7 +700,13 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501 spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
self_kv_cache = self.kv_cache self_kv_cache = self.kv_cache
conv_state = self_kv_cache[0].transpose(-1, -2) # conv_state must be (..., dim, width-1) for the conv kernels.
# DS layout stores it that way directly; SD layout needs a transpose.
conv_state = (
self_kv_cache[0]
if is_conv_state_dim_first()
else self_kv_cache[0].transpose(-1, -2)
)
ssm_state = self_kv_cache[1] ssm_state = self_kv_cache[1]
num_actual_tokens = attn_metadata.num_actual_tokens num_actual_tokens = attn_metadata.num_actual_tokens
num_accepted_tokens = attn_metadata.num_accepted_tokens num_accepted_tokens = attn_metadata.num_accepted_tokens
...@@ -914,7 +921,13 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase): ...@@ -914,7 +921,13 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
""" """
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
self_kv_cache = self.kv_cache self_kv_cache = self.kv_cache
conv_state = self_kv_cache[0].transpose(-1, -2) # conv_state must be (..., dim, width-1) for the conv kernels.
# DS layout stores it that way directly; SD layout needs a transpose.
conv_state = (
self_kv_cache[0]
if is_conv_state_dim_first()
else self_kv_cache[0].transpose(-1, -2)
)
ssm_state = self_kv_cache[1] ssm_state = self_kv_cache[1]
num_actual_tokens = attn_metadata.num_actual_tokens num_actual_tokens = attn_metadata.num_actual_tokens
......
...@@ -24,6 +24,7 @@ from vllm.model_executor.layers.mamba.abstract import MambaBase ...@@ -24,6 +24,7 @@ from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba_utils import ( from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateDtypeCalculator,
MambaStateShapeCalculator, MambaStateShapeCalculator,
is_conv_state_dim_first,
) )
from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_fn,
...@@ -267,9 +268,12 @@ class MambaMixer(MambaBase, PluggableLayer): ...@@ -267,9 +268,12 @@ class MambaMixer(MambaBase, PluggableLayer):
query_start_loc_p = attn_metadata.query_start_loc_p query_start_loc_p = attn_metadata.query_start_loc_p
state_indices_tensor_p = attn_metadata.state_indices_tensor_p state_indices_tensor_p = attn_metadata.state_indices_tensor_p
state_indices_tensor_d = attn_metadata.state_indices_tensor_d state_indices_tensor_d = attn_metadata.state_indices_tensor_d
self_kv_cache = self.kv_cache conv_state = (
conv_state = self_kv_cache[0].transpose(-1, -2) self.kv_cache[0]
ssm_state = self_kv_cache[1] if is_conv_state_dim_first()
else self.kv_cache[0].transpose(-1, -2)
)
ssm_state = self.kv_cache[1]
has_initial_states_p = attn_metadata.has_initial_states_p has_initial_states_p = attn_metadata.has_initial_states_p
cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p
last_chunk_indices_p = attn_metadata.last_chunk_indices_p last_chunk_indices_p = attn_metadata.last_chunk_indices_p
......
...@@ -24,6 +24,7 @@ from vllm.model_executor.layers.mamba.abstract import MambaBase ...@@ -24,6 +24,7 @@ from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba_utils import ( from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateDtypeCalculator,
MambaStateShapeCalculator, MambaStateShapeCalculator,
is_conv_state_dim_first,
) )
from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_fn,
...@@ -575,10 +576,15 @@ class MambaMixer2(MambaBase, PluggableLayer): ...@@ -575,10 +576,15 @@ class MambaMixer2(MambaBase, PluggableLayer):
assert isinstance(attn_metadata, dict) assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix] attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, Mamba2AttentionMetadata) assert isinstance(attn_metadata, Mamba2AttentionMetadata)
self_kv_cache = self.kv_cache # conv_state must be (..., dim, width-1) for the conv kernels.
# conv_state = (..., dim, width-1) yet contiguous along 'dim' # DS layout stores it that way directly; SD layout needs a
conv_state = self_kv_cache[0].transpose(-1, -2) # transpose (which keeps dim contiguous via stride tricks).
ssm_state = self_kv_cache[1] conv_state = (
self.kv_cache[0]
if is_conv_state_dim_first()
else self.kv_cache[0].transpose(-1, -2)
)
ssm_state = self.kv_cache[1]
has_initial_states_p = attn_metadata.has_initial_states_p has_initial_states_p = attn_metadata.has_initial_states_p
prep_initial_states = attn_metadata.prep_initial_states prep_initial_states = attn_metadata.prep_initial_states
chunk_size = attn_metadata.chunk_size chunk_size = attn_metadata.chunk_size
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from typing import TypeAlias from typing import Literal, TypeAlias
import torch import torch
import vllm.envs as envs
from vllm.config.cache import MambaDType from vllm.config.cache import MambaDType
from vllm.config.model import ModelDType from vllm.config.model import ModelDType
from vllm.distributed import divide from vllm.distributed import divide
from vllm.logger import init_logger
from vllm.utils.torch_utils import ( from vllm.utils.torch_utils import (
STR_DTYPE_TO_TORCH_DTYPE, STR_DTYPE_TO_TORCH_DTYPE,
get_kv_cache_torch_dtype, get_kv_cache_torch_dtype,
) )
logger = init_logger(__name__)
ConvStateLayoutType = Literal["SD", "DS"]
@functools.lru_cache
def get_conv_state_layout() -> ConvStateLayoutType:
"""Return the SSM conv state layout.
SD = (state_len, dim) — dim is the innermost contiguous dimension.
DS = (dim, state_len) — TP-sharded dim is on dim-1 (like HND for KV
cache), consistent with SSM temporal state layout.
"""
layout: ConvStateLayoutType | None = envs.VLLM_SSM_CONV_STATE_LAYOUT
if layout is not None:
logger.info_once(
"VLLM_SSM_CONV_STATE_LAYOUT env detected. "
"Setting SSM conv state layout to %s.",
layout,
)
return layout
return "SD"
def is_conv_state_dim_first() -> bool:
"""True when the conv state is stored as (dim, state_len) per block."""
return get_conv_state_layout() == "DS"
class MambaStateDtypeCalculator: class MambaStateDtypeCalculator:
@classmethod @classmethod
...@@ -107,6 +139,13 @@ class MambaStateShapeCalculator: ...@@ -107,6 +139,13 @@ class MambaStateShapeCalculator:
state_shape = (num_heads // tp_size, head_dim, head_dim) state_shape = (num_heads // tp_size, head_dim, head_dim)
return (state_shape,) return (state_shape,)
@staticmethod
def _orient_conv_shape(dim: int, state_len: int) -> tuple[int, int]:
"""Return (dim, state_len) for DS layout, (state_len, dim) for SD."""
if is_conv_state_dim_first():
return (dim, state_len)
return (state_len, dim)
@classmethod @classmethod
def mamba1_state_shape( def mamba1_state_shape(
cls, cls,
...@@ -115,12 +154,11 @@ class MambaStateShapeCalculator: ...@@ -115,12 +154,11 @@ class MambaStateShapeCalculator:
state_size: int, state_size: int,
conv_kernel: int, conv_kernel: int,
) -> tuple[tuple[int, int], tuple[int, int]]: ) -> tuple[tuple[int, int], tuple[int, int]]:
conv_state_shape = (divide(intermediate_size, tp_world_size), conv_kernel - 1) conv_dim = divide(intermediate_size, tp_world_size)
conv_state_shape = cls._orient_conv_shape(conv_dim, conv_kernel - 1)
temporal_state_shape = (divide(intermediate_size, tp_world_size), state_size) temporal_state_shape = (divide(intermediate_size, tp_world_size), state_size)
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
return conv_state_shape, temporal_state_shape return conv_state_shape, temporal_state_shape
@classmethod @classmethod
...@@ -141,8 +179,9 @@ class MambaStateShapeCalculator: ...@@ -141,8 +179,9 @@ class MambaStateShapeCalculator:
# heads and n_groups are TP-ed # heads and n_groups are TP-ed
conv_dim = intermediate_size + 2 * n_groups * state_size conv_dim = intermediate_size + 2 * n_groups * state_size
# contiguous along 'dim' axis conv_state_shape = cls._orient_conv_shape(
conv_state_shape = (conv_kernel - 1 + num_spec, divide(conv_dim, tp_world_size)) divide(conv_dim, tp_world_size), conv_kernel - 1 + num_spec
)
# These are not TP-ed as they depend on A, dt_bias, D # These are not TP-ed as they depend on A, dt_bias, D
# - they are typically small # - they are typically small
...@@ -158,7 +197,7 @@ class MambaStateShapeCalculator: ...@@ -158,7 +197,7 @@ class MambaStateShapeCalculator:
conv_kernel: int, conv_kernel: int,
) -> tuple[tuple[int, int]]: ) -> tuple[tuple[int, int]]:
conv_dim = divide(intermediate_size, tp_world_size) conv_dim = divide(intermediate_size, tp_world_size)
conv_state_shape = (conv_kernel - 1, conv_dim) conv_state_shape = cls._orient_conv_shape(conv_dim, conv_kernel - 1)
return (conv_state_shape,) return (conv_state_shape,)
@classmethod @classmethod
...@@ -185,13 +224,11 @@ class MambaStateShapeCalculator: ...@@ -185,13 +224,11 @@ class MambaStateShapeCalculator:
num_spec: int = 0, num_spec: int = 0,
): ):
conv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads conv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads
conv_state_shape = ( conv_state_shape = cls._orient_conv_shape(
divide(conv_dim, tp_world_size), divide(conv_dim, tp_world_size),
conv_kernel_size - 1 + num_spec, conv_kernel_size - 1 + num_spec,
) )
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
temporal_state_shape = ( temporal_state_shape = (
divide(num_v_heads, tp_world_size), divide(num_v_heads, tp_world_size),
head_v_dim, head_v_dim,
...@@ -218,12 +255,13 @@ class MambaStateShapeCalculator: ...@@ -218,12 +255,13 @@ class MambaStateShapeCalculator:
proj_size = num_heads * head_dim proj_size = num_heads * head_dim
proj_k_size = num_k_heads * head_k_dim proj_k_size = num_k_heads * head_k_dim
conv_state_shape = (divide(proj_size, tp_world_size), conv_kernel_size - 1) conv_state_shape = cls._orient_conv_shape(
conv_state_k_shape = (divide(proj_k_size, tp_world_size), conv_kernel_size - 1) divide(proj_size, tp_world_size), conv_kernel_size - 1
)
conv_state_k_shape = cls._orient_conv_shape(
divide(proj_k_size, tp_world_size), conv_kernel_size - 1
)
recurrent_state_shape = (divide(num_heads, tp_world_size), head_dim, head_dim) recurrent_state_shape = (divide(num_heads, tp_world_size), head_dim, head_dim)
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
conv_state_k_shape = conv_state_k_shape[1], conv_state_k_shape[0]
return ( return (
conv_state_shape, conv_state_shape,
conv_state_k_shape, conv_state_k_shape,
...@@ -267,9 +305,27 @@ def get_conv_copy_spec( ...@@ -267,9 +305,27 @@ def get_conv_copy_spec(
cur_block_idx: int, cur_block_idx: int,
num_accepted_tokens: int, num_accepted_tokens: int,
) -> MambaCopySpec: ) -> MambaCopySpec:
"""Return a MambaCopySpec for copying a convolutional state slice.""" """Return a MambaCopySpec for copying a convolutional state slice.
Works for both SD layout ``(num_blocks, state_len, dim)`` and
DS layout ``(num_blocks, dim, state_len)``.
"""
src_block_id = block_ids[cur_block_idx] src_block_id = block_ids[cur_block_idx]
src_state = state[src_block_id, num_accepted_tokens - 1 :] offset = num_accepted_tokens - 1
if is_conv_state_dim_first():
# DS layout: (num_blocks, dim, state_len) — state_len is last.
if offset > 0:
# Slicing along the last dim yields a non-contiguous view
# because features (dim) are strided by state_len.
raise NotImplementedError(
"DS conv state layout does not yet support speculative "
"decoding with mamba_cache_mode='align' "
"(num_accepted_tokens > 1)."
)
src_state = state[src_block_id]
else:
# SD layout: (num_blocks, state_len, dim) — dim contiguous.
src_state = state[src_block_id, offset:]
return MambaCopySpec( return MambaCopySpec(
start_addr=src_state.data_ptr(), num_elements=src_state.numel() start_addr=src_state.data_ptr(), num_elements=src_state.numel()
) )
......
...@@ -592,7 +592,6 @@ def causal_conv1d_fn( ...@@ -592,7 +592,6 @@ def causal_conv1d_fn(
stride_istate_seq = conv_states.stride(0) stride_istate_seq = conv_states.stride(0)
stride_istate_dim = conv_states.stride(1) stride_istate_dim = conv_states.stride(1)
stride_istate_token = conv_states.stride(2) stride_istate_token = conv_states.stride(2)
assert stride_istate_dim == 1
if out.dim() == 2: if out.dim() == 2:
stride_o_dim = out.stride(0) stride_o_dim = out.stride(0)
stride_o_token = out.stride(1) stride_o_token = out.stride(1)
...@@ -1149,9 +1148,6 @@ def causal_conv1d_update( ...@@ -1149,9 +1148,6 @@ def causal_conv1d_update(
if validate_data: if validate_data:
assert dim == weight.size(0) assert dim == weight.size(0)
assert conv_state.stride(-2) == 1, (
f"ERROR: expect contiguous along feat-dim of conv_state (currently stride={conv_state.stride()})"
)
assert state_len >= width - 1 assert state_len >= width - 1
# when above happens, we don't shift-left to keep any records in conv_state # when above happens, we don't shift-left to keep any records in conv_state
assert dim == conv_state.size(1) assert dim == conv_state.size(1)
......
...@@ -17,6 +17,7 @@ from vllm.model_executor.layers.mamba.abstract import MambaBase ...@@ -17,6 +17,7 @@ from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba_utils import ( from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateDtypeCalculator,
MambaStateShapeCalculator, MambaStateShapeCalculator,
is_conv_state_dim_first,
) )
from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_fn,
...@@ -117,8 +118,11 @@ class ShortConv(MambaBase, CustomOp): ...@@ -117,8 +118,11 @@ class ShortConv(MambaBase, CustomOp):
assert isinstance(attn_metadata, dict) assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix] attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, ShortConvAttentionMetadata) assert isinstance(attn_metadata, ShortConvAttentionMetadata)
self_kv_cache = self.kv_cache conv_state = (
conv_state = self_kv_cache[0].transpose(-1, -2) self.kv_cache[0]
if is_conv_state_dim_first()
else self.kv_cache[0].transpose(-1, -2)
)
state_indices_tensor_p = attn_metadata.state_indices_tensor_p state_indices_tensor_p = attn_metadata.state_indices_tensor_p
state_indices_tensor_d = attn_metadata.state_indices_tensor_d state_indices_tensor_d = attn_metadata.state_indices_tensor_d
has_initial_states_p = attn_metadata.has_initial_states_p has_initial_states_p = attn_metadata.has_initial_states_p
......
...@@ -68,6 +68,7 @@ from vllm.model_executor.layers.mamba.mamba_utils import ( ...@@ -68,6 +68,7 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFuncCalculator, MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator, MambaStateDtypeCalculator,
MambaStateShapeCalculator, MambaStateShapeCalculator,
is_conv_state_dim_first,
) )
from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_fn,
...@@ -429,7 +430,13 @@ class OlmoHybridGatedDeltaNet(nn.Module, MambaBase): ...@@ -429,7 +430,13 @@ class OlmoHybridGatedDeltaNet(nn.Module, MambaBase):
spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor
self_kv_cache = self.kv_cache self_kv_cache = self.kv_cache
conv_state = self_kv_cache[0].transpose(-1, -2) # conv_state must be (..., dim, width-1) for the conv kernels.
# DS layout stores it that way directly; SD layout needs a transpose.
conv_state = (
self_kv_cache[0]
if is_conv_state_dim_first()
else self_kv_cache[0].transpose(-1, -2)
)
ssm_state = self_kv_cache[1] ssm_state = self_kv_cache[1]
num_actual_tokens = attn_metadata.num_actual_tokens num_actual_tokens = attn_metadata.num_actual_tokens
num_accepted_tokens = attn_metadata.num_accepted_tokens num_accepted_tokens = attn_metadata.num_accepted_tokens
......
...@@ -32,6 +32,7 @@ from vllm.model_executor.layers.mamba.mamba_utils import ( ...@@ -32,6 +32,7 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFuncCalculator, MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator, MambaStateDtypeCalculator,
MambaStateShapeCalculator, MambaStateShapeCalculator,
is_conv_state_dim_first,
) )
from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_fn,
...@@ -266,7 +267,13 @@ class Plamo2MambaMixer(MambaBase, PluggableLayer): ...@@ -266,7 +267,13 @@ class Plamo2MambaMixer(MambaBase, PluggableLayer):
assert isinstance(attn_metadata, Mamba2AttentionMetadata) assert isinstance(attn_metadata, Mamba2AttentionMetadata)
self_kv_cache = self.kv_cache self_kv_cache = self.kv_cache
# conv_state = (..., dim, width-1) yet contiguous along 'dim' # conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state = self_kv_cache[0].transpose(-1, -2) # conv_state must be (..., dim, width-1) for the conv kernels.
# DS layout stores it that way directly; SD layout needs a transpose.
conv_state = (
self_kv_cache[0]
if is_conv_state_dim_first()
else self_kv_cache[0].transpose(-1, -2)
)
ssm_state = self_kv_cache[1] ssm_state = self_kv_cache[1]
state_indices_tensor_p = attn_metadata.state_indices_tensor_p state_indices_tensor_p = attn_metadata.state_indices_tensor_p
state_indices_tensor_d = attn_metadata.state_indices_tensor_d state_indices_tensor_d = attn_metadata.state_indices_tensor_d
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment