abstract.py 1.23 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from abc import abstractmethod
4
from collections.abc import Iterable
5
from typing import TYPE_CHECKING
6
7
8

import torch

9
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
10

11
12
13
14
15
if TYPE_CHECKING:
    from vllm.attention.backends.abstract import AttentionBackend


class MambaBase(AttentionLayerBase):
16
17
18
19
20
21
22
    """
    Base class for Mamba-like layers which support the v1 engine.
    Inherit from this class if you implement a custom layer.
    """

    # Contains the KV cache (mamba state) for the layer
    # in the shape specified by `self.get_state_shape`.
23
    kv_cache: tuple[torch.Tensor, ...]
24
25
26
27
28
29
30
31
32

    @abstractmethod
    def get_state_shape(self) -> Iterable[tuple[int, ...]]:
        """
        Defines the shape of the state.
        For mamba layers this is usually a (conv_state, ssm_state) tuple.
        In this case, returns (conv_state_shape, ssm_state_shape).
        """
        pass
33
34
35
36
37

    @property
    @abstractmethod
    def mamba_type(self) -> str:
        pass
38
39
40
41
42

    @abstractmethod
    def get_attn_backend(self) -> type["AttentionBackend"]:
        """Get the attention backend class for this Mamba layer."""
        pass