mamba2_attn.py 9.53 KB
Newer Older
Chen Zhang's avatar
Chen Zhang committed
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import itertools
4
from dataclasses import dataclass, replace
Chen Zhang's avatar
Chen Zhang committed
5
6
7

import torch

8
from vllm.config import VllmConfig
9
from vllm.utils.math_utils import cdiv
10
11
12
13
from vllm.v1.attention.backend import (
    AttentionBackend,
    CommonAttentionMetadata,
)
14
15
16
from vllm.v1.attention.backends.mamba_attn import (
    BaseMambaAttentionMetadata,
    BaseMambaAttentionMetadataBuilder,
17
)
18
from vllm.v1.kv_cache_interface import AttentionSpec
Chen Zhang's avatar
Chen Zhang committed
19
20


21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def compute_varlen_chunk_metadata(
    query_start_loc: torch.Tensor,
    chunk_size: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Build chunk-aligned, variable-length metadata used by Mamba2 SSD kernels.

    Given per-sequence cumulative token starts `query_start_loc` of shape [B+1]
    and a physical `chunk_size`, returns three tensors on the same device:
      - cu_chunk_seqlens:  (nchunks+1,) int32   exclusive prefix-sum of
        logical-chunk lengths (each logical chunk never crosses a sequence or
        physical-chunk boundary).
      - last_chunk_indices: (B,)       int32   index of the last logical chunk
        for each sequence (=-1 for empty sequences).
      - seq_idx_chunks:     (nchunks,) int32   sequence index for each logical
        chunk in order.

    This is intentionally lightweight and CPU-side; it mirrors the metadata
    produced by the V1 Mamba2 meta-data builder and is exported so tests
    (and other callers) can avoid duplicating the logic.
    """
    assert query_start_loc.ndim == 1, "query_start_loc must be 1-D [B+1]"
    assert int(query_start_loc[0].item()) == 0, "query_start_loc[0] must be 0"
    device = query_start_loc.device

    qsl64 = query_start_loc.to(torch.int64)
    starts = qsl64[:-1].tolist()
    ends = qsl64[1:].tolist()
    total = int(qsl64[-1].item())

    chunk_lens: list[int] = []
    seq_idx_chunks: list[int] = []
    last_chunk_indices: list[int] = [-1] * len(starts)

    for b, (s, e) in enumerate(zip(starts, ends)):
        if e <= s:
            # empty sequence
            continue
        pos = s
        while pos < e:
            # split at both sequence boundaries and physical chunk boundaries
            room = chunk_size - (pos % chunk_size)
            take = min(room, e - pos)
            chunk_lens.append(int(take))
            seq_idx_chunks.append(b)
            last_chunk_indices[b] = len(chunk_lens) - 1
            pos += take

    # Exclusive prefix sum over logical-chunk lengths
    if chunk_lens:
71
72
73
74
75
        cu_chunk_seqlens = torch.tensor(
            [0] + list(itertools.accumulate(chunk_lens)),
            device=device,
            dtype=torch.int32,
        )
76
77
78
79
80
        # Final boundary must equal total tokens
        assert int(cu_chunk_seqlens[-1].item()) == total
    else:
        cu_chunk_seqlens = torch.tensor([0], device=device, dtype=torch.int32)

81
82
83
84
85
86
    last_chunk_indices_t = (
        torch.tensor(last_chunk_indices, device=device, dtype=torch.int32)
        if len(starts) > 0
        else torch.empty((0,), device=device, dtype=torch.int32)
    )
    seq_idx_chunks_t = torch.tensor(seq_idx_chunks, device=device, dtype=torch.int32)
87
88
89
    return cu_chunk_seqlens, last_chunk_indices_t, seq_idx_chunks_t


Chen Zhang's avatar
Chen Zhang committed
90
class Mamba2AttentionBackend(AttentionBackend):
91
92
93
94
    @staticmethod
    def get_name() -> str:
        return "MAMBA2_ATTN"

Chen Zhang's avatar
Chen Zhang committed
95
96
97
98
99
100
    @staticmethod
    def get_builder_cls() -> type["Mamba2AttentionMetadataBuilder"]:
        return Mamba2AttentionMetadataBuilder


@dataclass
101
102
103
class Mamba2AttentionMetadata(BaseMambaAttentionMetadata):
    prep_initial_states: bool = False
    chunk_size: int = 0
104

105
106
    # Chunk-related metadata (only for prefill)
    seq_idx_p: torch.Tensor | None = None
107
    # cu_chunk_seqlen_p is a tensor of shape (nchunks+1,) that contains, for
108
    # each chunk, its offsets into the varlen sequence dimension. It is defined
109
110
    # such that the i-th chunk contains tokens from cu_chunk_seqlen_p[i] to
    # cu_chunk_seqlen_p[i+1].
111
    cu_chunk_seqlen_p: torch.Tensor | None = None
112
113
    # last_chunk_indices_p is a tensor of shape (batch,) that contains the
    # index of the last chunk for every sequence in the (prefill) batch.
114
    last_chunk_indices_p: torch.Tensor | None = None
Chen Zhang's avatar
Chen Zhang committed
115
116
117


class Mamba2AttentionMetadataBuilder(
118
119
    BaseMambaAttentionMetadataBuilder[Mamba2AttentionMetadata]
):
120
    metadata_cls = Mamba2AttentionMetadata
121

122
123
124
125
126
127
128
    def __init__(
        self,
        kv_cache_spec: AttentionSpec,
        layer_names: list[str],
        vllm_config: VllmConfig,
        device: torch.device,
    ):
129
        super().__init__(kv_cache_spec, layer_names, vllm_config, device)
130
131
        chunk_size = vllm_config.model_config.get_mamba_chunk_size()
        assert chunk_size is not None, (
132
133
            "chunk_size needs to be set in the model config for Mamba2 models"
        )
134
        self.chunk_size: int = chunk_size
Chen Zhang's avatar
Chen Zhang committed
135

136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
    def _compute_chunk_metadata(
        self,
        num_prefills: int,
        num_computed_tokens_p_cpu: torch.Tensor,
        query_start_loc_p_cpu: torch.Tensor,
    ) -> tuple[list[int], list[int], list[int]]:
        """
        Compute chunk-specific metadata for Mamba2.

        The code below carefully constructs the chunks such that:
        1. Chunks contain tokens from a *single* sequence only.
        2. For every sequence, we are guaranteed that we can
           retrieve the mamba state *every* chunk_size tokens.
        Constraint (1) dramatically simplifies the mamba2 kernels.
        Constraint (2) dramatically simplifies the implementation
        of prefix caching for mamba2 (wip). We need to take care
        of the interaction with chunked prefill in order to
        satisfy constraint (2).
        """
        # TODO (tdoublep): This code could probably be optimized.
        cu_chunk_seqlen = []
        seq_idx = []
        last_chunk_indices = []
        seqlen_pos = 0

        for req_idx in range(num_prefills):
            this_num_computed = num_computed_tokens_p_cpu[req_idx].item()
            this_new_tokens = (
                query_start_loc_p_cpu[req_idx + 1].item()
                - query_start_loc_p_cpu[req_idx].item()
            )

            # if computed tokens are not chunk-aligned, use the first
            # chunk to finish it off
            if this_num_computed % self.chunk_size != 0:
                seq_idx.append(req_idx)
                cu_chunk_seqlen.append(seqlen_pos)
                # how many tokens to finish the chunk?
                chunk_len = (
                    cdiv(this_num_computed, self.chunk_size) * self.chunk_size
                    - this_num_computed
                )
                # we can only use at most this_new_tokens
                chunk_len = min(chunk_len, this_new_tokens)
                seqlen_pos += chunk_len
                this_new_tokens -= chunk_len

            n_chunks = cdiv(this_new_tokens, self.chunk_size)
            for chunk in range(n_chunks):
                seq_idx.append(req_idx)
                cu_chunk_seqlen.append(seqlen_pos)
                chunk_len = min(self.chunk_size, this_new_tokens)
                seqlen_pos += chunk_len
                this_new_tokens -= chunk_len

            assert this_new_tokens == 0
            last_chunk_indices.append(len(cu_chunk_seqlen) - 1)

        cu_chunk_seqlen.append(seqlen_pos)

        return cu_chunk_seqlen, seq_idx, last_chunk_indices

198
199
200
201
202
203
    def build(
        self,
        common_prefix_len: int,
        common_attn_metadata: CommonAttentionMetadata,
        fast_build: bool = False,
    ) -> Mamba2AttentionMetadata:
204
        common = self._compute_common_metadata(common_attn_metadata)
Chen Zhang's avatar
Chen Zhang committed
205

206
        seq_idx_p = None
207
208
        cu_chunk_seqlen_p = None
        last_chunk_indices_p = None
Chen Zhang's avatar
Chen Zhang committed
209
210
        prep_initial_states = False

211
        # Compute seq_idx for prefill only
212
213
214
215
216
        if common.num_prefills > 0:
            prep_initial_states = (
                torch.any(common.has_initial_states_p).item()
                if common.has_initial_states_p is not None
                else False
217
            )
Chen Zhang's avatar
Chen Zhang committed
218

219
220
221
            num_reqs = common.num_reqs
            num_prefills = common.num_prefills
            num_decode_tokens = common.num_decode_tokens
222

223
224
225
226
            num_computed_tokens_cpu = (
                common_attn_metadata.compute_num_computed_tokens().cpu()
            )
            num_computed_tokens_p_cpu = num_computed_tokens_cpu[
227
228
229
230
231
232
                num_reqs - num_prefills : num_reqs
            ]
            query_start_loc_p_cpu = (
                common_attn_metadata.query_start_loc_cpu[-num_prefills - 1 :]
                - num_decode_tokens
            )
233

234
235
236
237
238
            cu_chunk_seqlen, seq_idx, last_chunk_indices = self._compute_chunk_metadata(
                num_prefills,
                num_computed_tokens_p_cpu,
                query_start_loc_p_cpu,
            )
239

240
            seq_idx_p = torch.as_tensor(
241
242
243
                seq_idx,
                device=common_attn_metadata.query_start_loc.device,
                dtype=torch.int32,
244
            )
245
            cu_chunk_seqlen_p = torch.as_tensor(
246
247
248
                cu_chunk_seqlen,
                device=common_attn_metadata.query_start_loc.device,
                dtype=torch.int32,
249
            )
250
            last_chunk_indices_p = torch.as_tensor(
251
252
253
                last_chunk_indices,
                device=common_attn_metadata.query_start_loc.device,
                dtype=torch.int32,
254
            )
255

256
257
        return replace(
            common,
Chen Zhang's avatar
Chen Zhang committed
258
259
            prep_initial_states=prep_initial_states,
            chunk_size=self.chunk_size,
260
            seq_idx_p=seq_idx_p,
261
262
            cu_chunk_seqlen_p=cu_chunk_seqlen_p,
            last_chunk_indices_p=last_chunk_indices_p,
Chen Zhang's avatar
Chen Zhang committed
263
        )