Unverified Commit 66e86f1d authored by Nicolò Lucchesi's avatar Nicolò Lucchesi Committed by GitHub
Browse files

[Kernel] Mamba support different layout for Conv state (#37416)

parent bb39382b
......@@ -60,6 +60,14 @@ MAX_NUM_SEQS = 4
ATTN_BACKEND = "TRITON_ATTN" if current_platform.is_rocm() else "auto"
def _set_conv_state_layout(monkeypatch, layout: str) -> None:
"""Set conv state layout env var and clear cache to pick up new value."""
from vllm.model_executor.layers.mamba import mamba_utils
monkeypatch.setenv("VLLM_SSM_CONV_STATE_LAYOUT", layout)
mamba_utils.get_conv_state_layout.cache_clear()
@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS)
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
......@@ -102,12 +110,15 @@ def test_models(
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("conv_state_layout", ["SD", "DS"])
def test_batching(
vllm_runner,
example_prompts,
monkeypatch,
model: str,
max_tokens: int,
num_logprobs: int,
conv_state_layout: str,
) -> None:
try:
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
......@@ -116,6 +127,8 @@ def test_batching(
except ValueError:
pass
_set_conv_state_layout(monkeypatch, conv_state_layout)
for_loop_outputs = []
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
for prompt in example_prompts:
......@@ -138,11 +151,14 @@ def test_batching(
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
@pytest.mark.parametrize("max_tokens", [10])
@pytest.mark.parametrize("conv_state_layout", ["SD", "DS"])
def test_chunked_prefill_with_parallel_sampling(
vllm_runner,
example_prompts,
monkeypatch,
model: str,
max_tokens: int,
conv_state_layout: str,
) -> None:
"""
Tests chunked prefill in conjunction with n > 1.
......@@ -154,6 +170,8 @@ def test_chunked_prefill_with_parallel_sampling(
decoding steps inside a chunked prefill forward pass
(where we have both prefill and decode together)
"""
_set_conv_state_layout(monkeypatch, conv_state_layout)
sampling_params = SamplingParams(n=3, temperature=1, seed=0, max_tokens=max_tokens)
with vllm_runner(
model,
......@@ -168,17 +186,22 @@ def test_chunked_prefill_with_parallel_sampling(
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
@pytest.mark.parametrize("max_tokens", [20])
@pytest.mark.parametrize("conv_state_layout", ["SD", "DS"])
def test_mamba_cache_cg_padding(
vllm_runner,
example_prompts,
monkeypatch,
model: str,
max_tokens: int,
conv_state_layout: str,
) -> None:
"""
This test is for verifying that mamba cache is padded to CG captured
batch size. If it's not, a torch RuntimeError will be raised because
tensor dimensions aren't compatible.
"""
_set_conv_state_layout(monkeypatch, conv_state_layout)
vllm_config = EngineArgs(model=model, trust_remote_code=True).create_engine_config()
cudagraph_dispatcher = CudagraphDispatcher(vllm_config)
cudagraph_dispatcher.initialize_cudagraph_keys(
......
......@@ -191,6 +191,7 @@ if TYPE_CHECKING:
VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16
VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS: int = 300
VLLM_KV_CACHE_LAYOUT: Literal["NHD", "HND"] | None = None
VLLM_SSM_CONV_STATE_LAYOUT: Literal["SD", "DS"] | None = None
VLLM_COMPUTE_NANS_IN_LOGITS: bool = False
VLLM_USE_NVFP4_CT_EMULATIONS: bool = False
VLLM_ROCM_QUICK_REDUCE_QUANTIZATION: Literal[
......@@ -1409,6 +1410,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_KV_CACHE_LAYOUT": env_with_choices(
"VLLM_KV_CACHE_LAYOUT", None, ["NHD", "HND"]
),
# SSM conv state layout used for Mamba models.
# - SD: (state_len, dim) — dim contiguous (default)
# - DS: (dim, state_len) — TP-sharded dim on dim1,
# consistent with SSM temporal state and HND KV cache layout.
"VLLM_SSM_CONV_STATE_LAYOUT": env_with_choices(
"VLLM_SSM_CONV_STATE_LAYOUT", None, ["SD", "DS"]
),
# Enable checking whether the generated logits contain NaNs,
# indicating corrupted output. Useful for debugging low level bugs
# or bad hardware but it may add compute overhead.
......
......@@ -31,7 +31,11 @@ from .linear import (
RowParallelLinear,
)
from .mamba.abstract import MambaBase
from .mamba.mamba_utils import MambaStateDtypeCalculator, MambaStateShapeCalculator
from .mamba.mamba_utils import (
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
is_conv_state_dim_first,
)
from .mamba.ops.causal_conv1d import causal_conv1d_fn, causal_conv1d_update
from .quantization.base_config import QuantizationConfig
......@@ -315,10 +319,12 @@ class KimiDeltaAttention(nn.Module, MambaBase):
beta = beta[:num_actual_tokens]
(conv_state_q, conv_state_k, conv_state_v, recurrent_state) = constant_caches
# deal with strides
conv_state_q = conv_state_q.transpose(-1, -2)
conv_state_k = conv_state_k.transpose(-1, -2)
conv_state_v = conv_state_v.transpose(-1, -2)
# conv_state must be (..., dim, width-1) for the conv kernels.
# DS layout stores it that way directly; SD layout needs a transpose.
if not is_conv_state_dim_first():
conv_state_q = conv_state_q.transpose(-1, -2)
conv_state_k = conv_state_k.transpose(-1, -2)
conv_state_v = conv_state_v.transpose(-1, -2)
q_conv_weights = self.q_conv1d.weight.view(
self.q_conv1d.weight.size(0), self.q_conv1d.weight.size(2)
......
......@@ -41,6 +41,7 @@ from vllm.model_executor.layers.mamba.mamba_mixer2 import mamba_v2_sharded_weigh
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
is_conv_state_dim_first,
)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn,
......@@ -699,7 +700,13 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
self_kv_cache = self.kv_cache
conv_state = self_kv_cache[0].transpose(-1, -2)
# conv_state must be (..., dim, width-1) for the conv kernels.
# DS layout stores it that way directly; SD layout needs a transpose.
conv_state = (
self_kv_cache[0]
if is_conv_state_dim_first()
else self_kv_cache[0].transpose(-1, -2)
)
ssm_state = self_kv_cache[1]
num_actual_tokens = attn_metadata.num_actual_tokens
num_accepted_tokens = attn_metadata.num_accepted_tokens
......@@ -914,7 +921,13 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
"""
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
self_kv_cache = self.kv_cache
conv_state = self_kv_cache[0].transpose(-1, -2)
# conv_state must be (..., dim, width-1) for the conv kernels.
# DS layout stores it that way directly; SD layout needs a transpose.
conv_state = (
self_kv_cache[0]
if is_conv_state_dim_first()
else self_kv_cache[0].transpose(-1, -2)
)
ssm_state = self_kv_cache[1]
num_actual_tokens = attn_metadata.num_actual_tokens
......
......@@ -24,6 +24,7 @@ from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
is_conv_state_dim_first,
)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn,
......@@ -267,9 +268,12 @@ class MambaMixer(MambaBase, PluggableLayer):
query_start_loc_p = attn_metadata.query_start_loc_p
state_indices_tensor_p = attn_metadata.state_indices_tensor_p
state_indices_tensor_d = attn_metadata.state_indices_tensor_d
self_kv_cache = self.kv_cache
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
conv_state = (
self.kv_cache[0]
if is_conv_state_dim_first()
else self.kv_cache[0].transpose(-1, -2)
)
ssm_state = self.kv_cache[1]
has_initial_states_p = attn_metadata.has_initial_states_p
cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p
last_chunk_indices_p = attn_metadata.last_chunk_indices_p
......
......@@ -24,6 +24,7 @@ from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
is_conv_state_dim_first,
)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn,
......@@ -575,10 +576,15 @@ class MambaMixer2(MambaBase, PluggableLayer):
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
self_kv_cache = self.kv_cache
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
# conv_state must be (..., dim, width-1) for the conv kernels.
# DS layout stores it that way directly; SD layout needs a
# transpose (which keeps dim contiguous via stride tricks).
conv_state = (
self.kv_cache[0]
if is_conv_state_dim_first()
else self.kv_cache[0].transpose(-1, -2)
)
ssm_state = self.kv_cache[1]
has_initial_states_p = attn_metadata.has_initial_states_p
prep_initial_states = attn_metadata.prep_initial_states
chunk_size = attn_metadata.chunk_size
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from collections.abc import Callable
from dataclasses import dataclass
from typing import TypeAlias
from typing import Literal, TypeAlias
import torch
import vllm.envs as envs
from vllm.config.cache import MambaDType
from vllm.config.model import ModelDType
from vllm.distributed import divide
from vllm.logger import init_logger
from vllm.utils.torch_utils import (
STR_DTYPE_TO_TORCH_DTYPE,
get_kv_cache_torch_dtype,
)
logger = init_logger(__name__)
ConvStateLayoutType = Literal["SD", "DS"]
@functools.lru_cache
def get_conv_state_layout() -> ConvStateLayoutType:
"""Return the SSM conv state layout.
SD = (state_len, dim) — dim is the innermost contiguous dimension.
DS = (dim, state_len) — TP-sharded dim is on dim-1 (like HND for KV
cache), consistent with SSM temporal state layout.
"""
layout: ConvStateLayoutType | None = envs.VLLM_SSM_CONV_STATE_LAYOUT
if layout is not None:
logger.info_once(
"VLLM_SSM_CONV_STATE_LAYOUT env detected. "
"Setting SSM conv state layout to %s.",
layout,
)
return layout
return "SD"
def is_conv_state_dim_first() -> bool:
"""True when the conv state is stored as (dim, state_len) per block."""
return get_conv_state_layout() == "DS"
class MambaStateDtypeCalculator:
@classmethod
......@@ -107,6 +139,13 @@ class MambaStateShapeCalculator:
state_shape = (num_heads // tp_size, head_dim, head_dim)
return (state_shape,)
@staticmethod
def _orient_conv_shape(dim: int, state_len: int) -> tuple[int, int]:
"""Return (dim, state_len) for DS layout, (state_len, dim) for SD."""
if is_conv_state_dim_first():
return (dim, state_len)
return (state_len, dim)
@classmethod
def mamba1_state_shape(
cls,
......@@ -115,12 +154,11 @@ class MambaStateShapeCalculator:
state_size: int,
conv_kernel: int,
) -> tuple[tuple[int, int], tuple[int, int]]:
conv_state_shape = (divide(intermediate_size, tp_world_size), conv_kernel - 1)
conv_dim = divide(intermediate_size, tp_world_size)
conv_state_shape = cls._orient_conv_shape(conv_dim, conv_kernel - 1)
temporal_state_shape = (divide(intermediate_size, tp_world_size), state_size)
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
return conv_state_shape, temporal_state_shape
@classmethod
......@@ -141,8 +179,9 @@ class MambaStateShapeCalculator:
# heads and n_groups are TP-ed
conv_dim = intermediate_size + 2 * n_groups * state_size
# contiguous along 'dim' axis
conv_state_shape = (conv_kernel - 1 + num_spec, divide(conv_dim, tp_world_size))
conv_state_shape = cls._orient_conv_shape(
divide(conv_dim, tp_world_size), conv_kernel - 1 + num_spec
)
# These are not TP-ed as they depend on A, dt_bias, D
# - they are typically small
......@@ -158,7 +197,7 @@ class MambaStateShapeCalculator:
conv_kernel: int,
) -> tuple[tuple[int, int]]:
conv_dim = divide(intermediate_size, tp_world_size)
conv_state_shape = (conv_kernel - 1, conv_dim)
conv_state_shape = cls._orient_conv_shape(conv_dim, conv_kernel - 1)
return (conv_state_shape,)
@classmethod
......@@ -185,13 +224,11 @@ class MambaStateShapeCalculator:
num_spec: int = 0,
):
conv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads
conv_state_shape = (
conv_state_shape = cls._orient_conv_shape(
divide(conv_dim, tp_world_size),
conv_kernel_size - 1 + num_spec,
)
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
temporal_state_shape = (
divide(num_v_heads, tp_world_size),
head_v_dim,
......@@ -218,12 +255,13 @@ class MambaStateShapeCalculator:
proj_size = num_heads * head_dim
proj_k_size = num_k_heads * head_k_dim
conv_state_shape = (divide(proj_size, tp_world_size), conv_kernel_size - 1)
conv_state_k_shape = (divide(proj_k_size, tp_world_size), conv_kernel_size - 1)
conv_state_shape = cls._orient_conv_shape(
divide(proj_size, tp_world_size), conv_kernel_size - 1
)
conv_state_k_shape = cls._orient_conv_shape(
divide(proj_k_size, tp_world_size), conv_kernel_size - 1
)
recurrent_state_shape = (divide(num_heads, tp_world_size), head_dim, head_dim)
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
conv_state_k_shape = conv_state_k_shape[1], conv_state_k_shape[0]
return (
conv_state_shape,
conv_state_k_shape,
......@@ -267,9 +305,27 @@ def get_conv_copy_spec(
cur_block_idx: int,
num_accepted_tokens: int,
) -> MambaCopySpec:
"""Return a MambaCopySpec for copying a convolutional state slice."""
"""Return a MambaCopySpec for copying a convolutional state slice.
Works for both SD layout ``(num_blocks, state_len, dim)`` and
DS layout ``(num_blocks, dim, state_len)``.
"""
src_block_id = block_ids[cur_block_idx]
src_state = state[src_block_id, num_accepted_tokens - 1 :]
offset = num_accepted_tokens - 1
if is_conv_state_dim_first():
# DS layout: (num_blocks, dim, state_len) — state_len is last.
if offset > 0:
# Slicing along the last dim yields a non-contiguous view
# because features (dim) are strided by state_len.
raise NotImplementedError(
"DS conv state layout does not yet support speculative "
"decoding with mamba_cache_mode='align' "
"(num_accepted_tokens > 1)."
)
src_state = state[src_block_id]
else:
# SD layout: (num_blocks, state_len, dim) — dim contiguous.
src_state = state[src_block_id, offset:]
return MambaCopySpec(
start_addr=src_state.data_ptr(), num_elements=src_state.numel()
)
......
......@@ -592,7 +592,6 @@ def causal_conv1d_fn(
stride_istate_seq = conv_states.stride(0)
stride_istate_dim = conv_states.stride(1)
stride_istate_token = conv_states.stride(2)
assert stride_istate_dim == 1
if out.dim() == 2:
stride_o_dim = out.stride(0)
stride_o_token = out.stride(1)
......@@ -1149,9 +1148,6 @@ def causal_conv1d_update(
if validate_data:
assert dim == weight.size(0)
assert conv_state.stride(-2) == 1, (
f"ERROR: expect contiguous along feat-dim of conv_state (currently stride={conv_state.stride()})"
)
assert state_len >= width - 1
# when above happens, we don't shift-left to keep any records in conv_state
assert dim == conv_state.size(1)
......
......@@ -17,6 +17,7 @@ from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
is_conv_state_dim_first,
)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn,
......@@ -117,8 +118,11 @@ class ShortConv(MambaBase, CustomOp):
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, ShortConvAttentionMetadata)
self_kv_cache = self.kv_cache
conv_state = self_kv_cache[0].transpose(-1, -2)
conv_state = (
self.kv_cache[0]
if is_conv_state_dim_first()
else self.kv_cache[0].transpose(-1, -2)
)
state_indices_tensor_p = attn_metadata.state_indices_tensor_p
state_indices_tensor_d = attn_metadata.state_indices_tensor_d
has_initial_states_p = attn_metadata.has_initial_states_p
......
......@@ -68,6 +68,7 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
is_conv_state_dim_first,
)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn,
......@@ -429,7 +430,13 @@ class OlmoHybridGatedDeltaNet(nn.Module, MambaBase):
spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor
self_kv_cache = self.kv_cache
conv_state = self_kv_cache[0].transpose(-1, -2)
# conv_state must be (..., dim, width-1) for the conv kernels.
# DS layout stores it that way directly; SD layout needs a transpose.
conv_state = (
self_kv_cache[0]
if is_conv_state_dim_first()
else self_kv_cache[0].transpose(-1, -2)
)
ssm_state = self_kv_cache[1]
num_actual_tokens = attn_metadata.num_actual_tokens
num_accepted_tokens = attn_metadata.num_accepted_tokens
......
......@@ -32,6 +32,7 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
is_conv_state_dim_first,
)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn,
......@@ -266,7 +267,13 @@ class Plamo2MambaMixer(MambaBase, PluggableLayer):
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
self_kv_cache = self.kv_cache
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state = self_kv_cache[0].transpose(-1, -2)
# conv_state must be (..., dim, width-1) for the conv kernels.
# DS layout stores it that way directly; SD layout needs a transpose.
conv_state = (
self_kv_cache[0]
if is_conv_state_dim_first()
else self_kv_cache[0].transpose(-1, -2)
)
ssm_state = self_kv_cache[1]
state_indices_tensor_p = attn_metadata.state_indices_tensor_p
state_indices_tensor_d = attn_metadata.state_indices_tensor_d
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment