# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod from collections.abc import Iterable import torch class MambaBase(ABC): """ Base class for Mamba-like layers which support the v1 engine. Inherit from this class if you implement a custom layer. """ # Contains the KV cache (mamba state) for the layer # in the shape specified by `self.get_state_shape`. # The outer list is for v0 PP virtual engine. Though this code path # only runs for v1, we have to do this to unify with the interface # of Attention + v0 PP. kv_cache: list[Iterable[torch.Tensor]] @abstractmethod def get_state_shape(self) -> Iterable[tuple[int, ...]]: """ Defines the shape of the state. For mamba layers this is usually a (conv_state, ssm_state) tuple. In this case, returns (conv_state_shape, ssm_state_shape). """ pass