mamba_cache.py 6.86 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
from dataclasses import dataclass
from typing import Dict, List
5
6
7

import torch

8
9
10
11
12
13
14
15
16
17
18
19
20
from vllm.attention.backends.utils import PAD_SLOT_ID


@dataclass
class MambaCacheParams:
    conv_state: torch.Tensor = torch.Tensor()
    ssm_state: torch.Tensor = torch.Tensor()
    state_indices_tensor: torch.Tensor = torch.Tensor()

    def at_layer_idx(self, layer_idx):
        return MambaCacheParams(self.conv_state[layer_idx],
                                self.ssm_state[layer_idx],
                                self.state_indices_tensor)
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41


class MambaCacheManager:

    def __init__(self, dtype, num_mamba_layers, max_batch_size,
                 conv_state_shape, temporal_state_shape):

        conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
                                 conv_state_shape,
                                 dtype=dtype,
                                 device="cuda")
        temporal_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
                                     temporal_state_shape,
                                     dtype=dtype,
                                     device="cuda")

        self.mamba_cache = (conv_state, temporal_state)

        # Maps between the request id and a dict that maps between the seq_id
        # and its index inside the self.mamba_cache
        self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {}
42
        self.free_cache_indices = list(range(max_batch_size))
43

Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
44
    def current_run_tensors(self, **kwargs) -> MambaCacheParams:
45
46
47
48
49
50
51
52
53
        """
        Return the tensors for the current run's conv and ssm state.
        """
        if "seqlen_agnostic_capture_inputs" not in kwargs:
            # We get here only on Prefill/Eager mode runs
            request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
            finished_requests_ids = kwargs["finished_requests_ids"]

            self._release_finished_requests(finished_requests_ids)
54
            state_indices = self._prepare_current_run_mamba_cache(
55
56
                request_ids_to_seq_ids, finished_requests_ids)

57
58
59
60
61
            state_indices_tensor = torch.as_tensor(state_indices,
                                                   dtype=torch.int32,
                                                   device="cuda")
            mamba_cache_tensors = self.mamba_cache

62
63
        else:
            # CUDA graph capturing runs
64
65
            (mamba_cache_tensors,
             state_indices_tensor) = kwargs["seqlen_agnostic_capture_inputs"]
66

Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
67
68
        return MambaCacheParams(mamba_cache_tensors[0], mamba_cache_tensors[1],
                                state_indices_tensor)
69
70
71

    def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
        """
72
        Copy the relevant state_indices into the CUDA graph input buffer 
73
74
75
76
77
78
        """
        assert all(
            key in kwargs
            for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
        finished_requests_ids = kwargs["finished_requests_ids"]
        request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
79
80
81
        assert "seqlen_agnostic_capture_inputs" in input_buffers
        _, input_state_indices_buffer = input_buffers[
            "seqlen_agnostic_capture_inputs"]
82
83

        self._release_finished_requests(finished_requests_ids)
84
85
86
87
88
89
90
91
        state_indices = self._prepare_current_run_mamba_cache(
            request_ids_to_seq_ids, finished_requests_ids)
        cuda_graph_pad_len = input_state_indices_buffer.shape[0] - len(
            state_indices)
        state_indices.extend([PAD_SLOT_ID] * cuda_graph_pad_len)

        input_state_indices_buffer.copy_(
            torch.as_tensor(state_indices, dtype=torch.int32, device="cuda"))
92
93
94
95
96
97
98

    def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
        """
        Provide the CUDA graph capture runs with a buffer in adjusted size.
        The buffer is used to maintain the Mamba Cache during the CUDA graph
        replay runs.
        """
99
100
101
102
        state_indices_tensor = torch.as_tensor([PAD_SLOT_ID] * batch_size,
                                               dtype=torch.int32,
                                               device="cuda")
        return (self.mamba_cache, state_indices_tensor)
103
104
105
106
107
108
109

    def _copy_mamba_cache(self, from_index: int, to_index: int):
        assert len(self.mamba_cache) > 0
        for cache_t in self.mamba_cache:
            cache_t[:, to_index].copy_(cache_t[:, from_index],
                                       non_blocking=True)

110
111
    def _assign_seq_id_to_cache_index(self, cur_rid: str, seq_id: int,
                                      finished_requests_ids) -> int:
112
113
114
115
        """
        Assign (req_id,seq_id) pair to a `destination_index` index, if
        already occupied, move the occupying index to a free index.
        """
116
117
118
119
120
        if cur_rid in finished_requests_ids:
            # set as pad, do not allocate destination index
            return PAD_SLOT_ID
        elif cur_rid not in self.mamba_cache_indices_mapping:
            destination_index = self.free_cache_indices.pop()
121
122
123
            self.mamba_cache_indices_mapping[cur_rid] = {
                seq_id: destination_index
            }
124
            return destination_index
125
126
127
        elif seq_id not in (seq_ids2indices :=
                            self.mamba_cache_indices_mapping[cur_rid]):
            # parallel sampling , where n > 1, assume prefill have
128
            # already happened, so we copy the
129
            # existing cache into the siblings seq_ids caches
130
            index_exists = next(iter(seq_ids2indices.values()))
131
            # case of decoding n>1, copy prefill cache to decoding indices
132
            destination_index = self.free_cache_indices.pop()
133
134
135
136
            self._copy_mamba_cache(from_index=index_exists,
                                   to_index=destination_index)
            self.mamba_cache_indices_mapping[cur_rid][
                seq_id] = destination_index
137
            return destination_index
138
139
        else:
            # already exists
140
            return self.mamba_cache_indices_mapping[cur_rid][seq_id]
141
142
143

    def _prepare_current_run_mamba_cache(
            self, request_ids_to_seq_ids: Dict[str, list[int]],
144
145
146
147
            finished_requests_ids: List[str]) -> List[int]:
        return [
            self._assign_seq_id_to_cache_index(req_id, seq_id,
                                               finished_requests_ids)
148
149
150
151
152
153
154
155
            for req_id, seq_ids in request_ids_to_seq_ids.items()
            for seq_id in seq_ids
        ]

    def _release_finished_requests(self,
                                   finished_seq_groups_req_ids: List[str]):
        for req_id in finished_seq_groups_req_ids:
            if req_id in self.mamba_cache_indices_mapping:
156
157
158
                for seq_id in self.mamba_cache_indices_mapping[req_id]:
                    self.free_cache_indices.append(
                        self.mamba_cache_indices_mapping[req_id][seq_id])
159
                self.mamba_cache_indices_mapping.pop(req_id)