mamba_cache.py 3.26 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from dataclasses import dataclass
5
6
7

import torch

8
from vllm.attention.backends.utils import PAD_SLOT_ID
9
from vllm.config import VllmConfig
10
from vllm.model_executor.models.constant_size_cache import ConstantSizeCache
11
12
13
14
15
16
17
18
19
20
21
22


@dataclass
class MambaCacheParams:
    conv_state: torch.Tensor = torch.Tensor()
    ssm_state: torch.Tensor = torch.Tensor()
    state_indices_tensor: torch.Tensor = torch.Tensor()

    def at_layer_idx(self, layer_idx):
        return MambaCacheParams(self.conv_state[layer_idx],
                                self.ssm_state[layer_idx],
                                self.state_indices_tensor)
23
24


25
class MambaCacheManager(ConstantSizeCache):
26

27
28
29
30
31
32
33
34
    def __init__(self, vllm_config: VllmConfig, num_mamba_layers: int,
                 conv_state_shape: tuple[int, int],
                 temporal_state_shape: tuple[int, int],
                 conv_state_dtype: torch.dtype,
                 temporal_state_dtype: torch.dtype):

        self.conv_state_dtype = conv_state_dtype
        self.temporal_state_dtype = temporal_state_dtype
35
36
37
38
39

        # Determine max batch size to set size of MambaCache
        max_batch_size = vllm_config.scheduler_config.max_num_seqs
        if not vllm_config.model_config.enforce_eager:
            max_batch_size = vllm_config.pad_for_cudagraph(max_batch_size)
40

41
42
43
        # Initialize parent class
        super().__init__(max_batch_size)

44
45
        # assume conv_state = (dim, state_len)
        assert conv_state_shape[0] > conv_state_shape[1]
46
        conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
47
                                 (conv_state_shape[1], conv_state_shape[0]),
48
                                 dtype=self.conv_state_dtype,
49
                                 device="cuda").transpose(-1, -2)
50
51
        temporal_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
                                     temporal_state_shape,
52
                                     dtype=self.temporal_state_dtype,
53
54
                                     device="cuda")

55
56
57
58
59
        self._mamba_cache = (conv_state, temporal_state)

    @property
    def cache(self):
        return self._mamba_cache
60

61
62
63
64
    def _copy_cache(self, from_index: int, to_index: int):
        for cache_t in self.cache:
            cache_t[:, to_index].copy_(cache_t[:, from_index],
                                       non_blocking=True)
65

Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
66
    def current_run_tensors(self, **kwargs) -> MambaCacheParams:
67
68
69
        """
        Return the tensors for the current run's conv and ssm state.
        """
70
71
72
        cache_tensors, state_indices_tensor = super().current_run_tensors(
            **kwargs)
        return MambaCacheParams(cache_tensors[0], cache_tensors[1],
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
73
                                state_indices_tensor)
74
75
76
77
78
79
80

    def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
        """
        Provide the CUDA graph capture runs with a buffer in adjusted size.
        The buffer is used to maintain the Mamba Cache during the CUDA graph
        replay runs.
        """
81
82
83
        return self._mamba_cache, torch.as_tensor([PAD_SLOT_ID] * batch_size,
                                                  dtype=torch.int32,
                                                  device="cuda")