abstract.py 2.35 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from abc import abstractmethod
4
from collections.abc import Iterable
5
from typing import TYPE_CHECKING
6
7
8

import torch

9
from vllm.config import VllmConfig
10
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
11
from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec
12

13
14
15
16
17
if TYPE_CHECKING:
    from vllm.attention.backends.abstract import AttentionBackend


class MambaBase(AttentionLayerBase):
18
19
20
21
22
23
24
    """
    Base class for Mamba-like layers which support the v1 engine.
    Inherit from this class if you implement a custom layer.
    """

    # Contains the KV cache (mamba state) for the layer
    # in the shape specified by `self.get_state_shape`.
25
    kv_cache: tuple[torch.Tensor, ...]
26
27
28
29
30
31
32
33
34

    @abstractmethod
    def get_state_shape(self) -> Iterable[tuple[int, ...]]:
        """
        Defines the shape of the state.
        For mamba layers this is usually a (conv_state, ssm_state) tuple.
        In this case, returns (conv_state_shape, ssm_state_shape).
        """
        pass
35
36
37
38
39

    @property
    @abstractmethod
    def mamba_type(self) -> str:
        pass
40
41
42
43
44

    @abstractmethod
    def get_attn_backend(self) -> type["AttentionBackend"]:
        """Get the attention backend class for this Mamba layer."""
        pass
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71

    @abstractmethod
    def get_state_dtype(self) -> tuple[torch.dtype, ...]:
        pass

    def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None:
        if (
            vllm_config.speculative_config is not None
            and vllm_config.model_config.hf_config.model_type not in ["qwen3_next"]
        ):
            raise NotImplementedError(
                "Mamba with speculative decoding is not supported yet."
            )
        mamba_block_size = vllm_config.cache_config.mamba_block_size
        page_size_padded = vllm_config.cache_config.mamba_page_size_padded
        return MambaSpec(
            shapes=self.get_state_shape(),
            dtypes=self.get_state_dtype(),
            block_size=mamba_block_size,
            page_size_padded=page_size_padded,
            mamba_type=self.mamba_type,
            num_speculative_blocks=(
                vllm_config.speculative_config.num_speculative_tokens
                if vllm_config.speculative_config
                else 0
            ),
        )