mamba_cache.py 7.23 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from dataclasses import dataclass
4
from typing import Dict, List, Tuple
5
6
7

import torch

8
from vllm.attention.backends.utils import PAD_SLOT_ID
9
from vllm.config import VllmConfig
10
11
12
13
14
15
16
17
18
19
20
21


@dataclass
class MambaCacheParams:
    conv_state: torch.Tensor = torch.Tensor()
    ssm_state: torch.Tensor = torch.Tensor()
    state_indices_tensor: torch.Tensor = torch.Tensor()

    def at_layer_idx(self, layer_idx):
        return MambaCacheParams(self.conv_state[layer_idx],
                                self.ssm_state[layer_idx],
                                self.state_indices_tensor)
22
23
24
25


class MambaCacheManager:

26
27
28
29
30
31
32
33
    def __init__(self, vllm_config: VllmConfig, dtype: torch.dtype,
                 num_mamba_layers: int, conv_state_shape: Tuple[int, int],
                 temporal_state_shape: Tuple[int, int]):

        # Determine max batch size to set size of MambaCache
        max_batch_size = vllm_config.scheduler_config.max_num_seqs
        if not vllm_config.model_config.enforce_eager:
            max_batch_size = vllm_config.pad_for_cudagraph(max_batch_size)
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

        conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
                                 conv_state_shape,
                                 dtype=dtype,
                                 device="cuda")
        temporal_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
                                     temporal_state_shape,
                                     dtype=dtype,
                                     device="cuda")

        self.mamba_cache = (conv_state, temporal_state)

        # Maps between the request id and a dict that maps between the seq_id
        # and its index inside the self.mamba_cache
        self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {}
49
        self.free_cache_indices = list(range(max_batch_size))
50

Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
51
    def current_run_tensors(self, **kwargs) -> MambaCacheParams:
52
53
54
55
56
57
58
59
60
        """
        Return the tensors for the current run's conv and ssm state.
        """
        if "seqlen_agnostic_capture_inputs" not in kwargs:
            # We get here only on Prefill/Eager mode runs
            request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
            finished_requests_ids = kwargs["finished_requests_ids"]

            self._release_finished_requests(finished_requests_ids)
61
            state_indices = self._prepare_current_run_mamba_cache(
62
63
                request_ids_to_seq_ids, finished_requests_ids)

64
65
66
67
68
            state_indices_tensor = torch.as_tensor(state_indices,
                                                   dtype=torch.int32,
                                                   device="cuda")
            mamba_cache_tensors = self.mamba_cache

69
70
        else:
            # CUDA graph capturing runs
71
72
            (mamba_cache_tensors,
             state_indices_tensor) = kwargs["seqlen_agnostic_capture_inputs"]
73

Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
74
75
        return MambaCacheParams(mamba_cache_tensors[0], mamba_cache_tensors[1],
                                state_indices_tensor)
76
77
78

    def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
        """
79
        Copy the relevant state_indices into the CUDA graph input buffer 
80
81
82
83
84
85
        """
        assert all(
            key in kwargs
            for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
        finished_requests_ids = kwargs["finished_requests_ids"]
        request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
86
87
88
        assert "seqlen_agnostic_capture_inputs" in input_buffers
        _, input_state_indices_buffer = input_buffers[
            "seqlen_agnostic_capture_inputs"]
89
90

        self._release_finished_requests(finished_requests_ids)
91
92
93
94
95
96
97
98
        state_indices = self._prepare_current_run_mamba_cache(
            request_ids_to_seq_ids, finished_requests_ids)
        cuda_graph_pad_len = input_state_indices_buffer.shape[0] - len(
            state_indices)
        state_indices.extend([PAD_SLOT_ID] * cuda_graph_pad_len)

        input_state_indices_buffer.copy_(
            torch.as_tensor(state_indices, dtype=torch.int32, device="cuda"))
99
100
101
102
103
104
105

    def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
        """
        Provide the CUDA graph capture runs with a buffer in adjusted size.
        The buffer is used to maintain the Mamba Cache during the CUDA graph
        replay runs.
        """
106
107
108
109
        state_indices_tensor = torch.as_tensor([PAD_SLOT_ID] * batch_size,
                                               dtype=torch.int32,
                                               device="cuda")
        return (self.mamba_cache, state_indices_tensor)
110
111
112
113
114
115
116

    def _copy_mamba_cache(self, from_index: int, to_index: int):
        assert len(self.mamba_cache) > 0
        for cache_t in self.mamba_cache:
            cache_t[:, to_index].copy_(cache_t[:, from_index],
                                       non_blocking=True)

117
118
    def _assign_seq_id_to_cache_index(self, cur_rid: str, seq_id: int,
                                      finished_requests_ids) -> int:
119
120
121
122
        """
        Assign (req_id,seq_id) pair to a `destination_index` index, if
        already occupied, move the occupying index to a free index.
        """
123
124
125
126
127
        if cur_rid in finished_requests_ids:
            # set as pad, do not allocate destination index
            return PAD_SLOT_ID
        elif cur_rid not in self.mamba_cache_indices_mapping:
            destination_index = self.free_cache_indices.pop()
128
129
130
            self.mamba_cache_indices_mapping[cur_rid] = {
                seq_id: destination_index
            }
131
            return destination_index
132
133
134
        elif seq_id not in (seq_ids2indices :=
                            self.mamba_cache_indices_mapping[cur_rid]):
            # parallel sampling , where n > 1, assume prefill have
135
            # already happened, so we copy the
136
            # existing cache into the siblings seq_ids caches
137
            index_exists = next(iter(seq_ids2indices.values()))
138
            # case of decoding n>1, copy prefill cache to decoding indices
139
            destination_index = self.free_cache_indices.pop()
140
141
142
143
            self._copy_mamba_cache(from_index=index_exists,
                                   to_index=destination_index)
            self.mamba_cache_indices_mapping[cur_rid][
                seq_id] = destination_index
144
            return destination_index
145
146
        else:
            # already exists
147
            return self.mamba_cache_indices_mapping[cur_rid][seq_id]
148
149
150

    def _prepare_current_run_mamba_cache(
            self, request_ids_to_seq_ids: Dict[str, list[int]],
151
152
153
154
            finished_requests_ids: List[str]) -> List[int]:
        return [
            self._assign_seq_id_to_cache_index(req_id, seq_id,
                                               finished_requests_ids)
155
156
157
158
159
160
161
162
            for req_id, seq_ids in request_ids_to_seq_ids.items()
            for seq_id in seq_ids
        ]

    def _release_finished_requests(self,
                                   finished_seq_groups_req_ids: List[str]):
        for req_id in finished_seq_groups_req_ids:
            if req_id in self.mamba_cache_indices_mapping:
163
164
165
                for seq_id in self.mamba_cache_indices_mapping[req_id]:
                    self.free_cache_indices.append(
                        self.mamba_cache_indices_mapping[req_id][seq_id])
166
                self.mamba_cache_indices_mapping.pop(req_id)