abstract.py 2.38 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from abc import abstractmethod
4
5
6
7
from collections.abc import Iterable

import torch

8
from vllm.attention.backends.abstract import AttentionBackend
9
from vllm.attention.selector import get_mamba_attn_backend
10
from vllm.config import VllmConfig
11
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
12
from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec
13

14
15

class MambaBase(AttentionLayerBase):
16
17
18
19
20
21
22
    """
    Base class for Mamba-like layers which support the v1 engine.
    Inherit from this class if you implement a custom layer.
    """

    # Contains the KV cache (mamba state) for the layer
    # in the shape specified by `self.get_state_shape`.
23
    kv_cache: tuple[torch.Tensor, ...]
24
25
26
27
28
29
30
31
32

    @abstractmethod
    def get_state_shape(self) -> Iterable[tuple[int, ...]]:
        """
        Defines the shape of the state.
        For mamba layers this is usually a (conv_state, ssm_state) tuple.
        In this case, returns (conv_state_shape, ssm_state_shape).
        """
        pass
33
34
35
36
37

    @property
    @abstractmethod
    def mamba_type(self) -> str:
        pass
38

39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
    @abstractmethod
    def get_state_dtype(self) -> tuple[torch.dtype, ...]:
        pass

    def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None:
        if (
            vllm_config.speculative_config is not None
            and vllm_config.model_config.hf_config.model_type not in ["qwen3_next"]
        ):
            raise NotImplementedError(
                "Mamba with speculative decoding is not supported yet."
            )
        mamba_block_size = vllm_config.cache_config.mamba_block_size
        page_size_padded = vllm_config.cache_config.mamba_page_size_padded
        return MambaSpec(
            shapes=self.get_state_shape(),
            dtypes=self.get_state_dtype(),
            block_size=mamba_block_size,
            page_size_padded=page_size_padded,
            mamba_type=self.mamba_type,
            num_speculative_blocks=(
                vllm_config.speculative_config.num_speculative_tokens
                if vllm_config.speculative_config
                else 0
            ),
        )
65

66
    def get_attn_backend(self) -> type[AttentionBackend]:
67
68
        """Get the attention backend class for this Mamba layer."""
        return get_mamba_attn_backend(self.mamba_type)