placeholder_attn.py 15.9 KB
Newer Older
1
from collections import defaultdict
2
from dataclasses import dataclass
3
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type
4
5
6
7
8
9
10

import torch

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
                                              AttentionMetadata,
                                              AttentionMetadataBuilder)
from vllm.attention.backends.utils import CommonAttentionState
11
from vllm.multimodal import MultiModalPlaceholderMap
12
13

if TYPE_CHECKING:
14
15
    from vllm.worker.model_runner import (ModelInputForGPUBuilder,
                                          ModelInputForGPUWithSamplingMetadata)
16

17
# Placeholder attention backend for models like Mamba and pooling models that
18
19
20
21
22
23
24
25
# lack attention.


class PlaceholderAttentionBackend(AttentionBackend):
    """Placeholder backend for when no attention is needed."""

    @staticmethod
    def get_name() -> str:
26
        return "NO_ATTENTION"
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80

    @staticmethod
    def get_impl_cls() -> Type["PlaceholderAttentionImpl"]:
        return PlaceholderAttentionImpl

    @staticmethod
    def get_builder_cls() -> Type["PlaceholderAttentionMetadataBuilder"]:
        return PlaceholderAttentionMetadataBuilder

    @staticmethod
    def get_metadata_cls() -> Type["PlaceholderAttentionMetadata"]:
        return PlaceholderAttentionMetadata

    @staticmethod
    def get_state_cls() -> Type["CommonAttentionState"]:
        return CommonAttentionState

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
        return (1, 1, 1, 1, 1)

    @staticmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
        src_to_dst: torch.Tensor,
    ) -> None:
        return

    @staticmethod
    def copy_blocks(
        kv_caches: List[torch.Tensor],
        src_to_dists: torch.Tensor,
    ) -> None:
        return


@dataclass
class PlaceholderAttentionMetadata(AttentionMetadata):
    """Attention metadata for prefill and decode batched together."""
    # (batch_size,). The sequence length per sequence. Sequence length means
    # the computed tokens + new tokens None if it is a decoding.
    seq_lens: Optional[List[int]]
    # seq_lens stored as a tensor.
    seq_lens_tensor: Optional[torch.Tensor]

    # Maximum query length in the batch.
    max_query_len: Optional[int]

Lily Liu's avatar
Lily Liu committed
81
82
    # Max number of query tokens among request in the batch.
    max_decode_query_len: Optional[int]
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140

    # Maximum sequence length among prefill batch. 0 if there are decoding
    # requests only.
    max_prefill_seq_len: int
    # Maximum sequence length among decode batch. 0 if there are prefill
    # requests only.
    max_decode_seq_len: int
    # (batch_size + 1,). The cumulative subquery lengths of the sequences in
    # the batch, used to index into subquery. E.g., if the subquery length
    # is [4, 6], it is [0, 4, 10].
    query_start_loc: Optional[torch.Tensor]
    # (batch_size + 1,). The cumulative sequence lengths of the sequences in
    # the batch, used to index into sequence. E.g., if the sequence length is
    # [4, 6], it is [0, 4, 10].
    seq_start_loc: Optional[torch.Tensor]
    # (batch_size,) A tensor of context lengths (tokens that are computed
    # so far).
    context_lens_tensor: Optional[torch.Tensor]

    # (batch_size, max_blocks_per_seq).
    # Block addresses per sequence. (Seq id -> list of physical block)
    # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
    # in the kv cache. Each block can contain up to block_size tokens.
    # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
    # captured.
    block_tables: Optional[torch.Tensor]

    # Whether or not if cuda graph is enabled.
    # Cuda-graph is currently enabled for decoding only.
    # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
    use_cuda_graph: bool

    _cached_prefill_metadata: Optional["PlaceholderAttentionMetadata"] = None
    _cached_decode_metadata: Optional["PlaceholderAttentionMetadata"] = None

    @property
    def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
        if self.num_prefills == 0:
            return None

        if self._cached_prefill_metadata is not None:
            return self._cached_prefill_metadata

        assert self.seq_lens is not None
        assert self.seq_lens_tensor is not None
        assert self.query_start_loc is not None
        assert self.context_lens_tensor is not None
        assert self.seq_start_loc is not None

        # Placeholders
        slot_mapping = torch.empty(0)
        block_tables = torch.empty(0)

        self._cached_prefill_metadata = PlaceholderAttentionMetadata(
            num_prefills=self.num_prefills,
            num_prefill_tokens=self.num_prefill_tokens,
            num_decode_tokens=0,
            slot_mapping=slot_mapping,
141
142
            multi_modal_placeholder_index_maps=self.
            multi_modal_placeholder_index_maps,
143
            enable_kv_scales_calculation=self.enable_kv_scales_calculation,
144
145
            seq_lens=self.seq_lens[:self.num_prefills],
            seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
Lily Liu's avatar
Lily Liu committed
146
            max_decode_query_len=0,
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
            max_query_len=self.max_query_len,
            max_prefill_seq_len=self.max_prefill_seq_len,
            max_decode_seq_len=0,
            query_start_loc=self.query_start_loc[:self.num_prefills + 1],
            seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
            context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
            block_tables=block_tables,
            use_cuda_graph=False,
        )
        return self._cached_prefill_metadata

    @property
    def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
        if self.num_decode_tokens == 0:
            return None

        if self._cached_decode_metadata is not None:
            return self._cached_decode_metadata
        assert self.seq_lens_tensor is not None

        # Placeholders
        slot_mapping = torch.empty(0)
        block_tables = torch.empty(0)

        self._cached_decode_metadata = PlaceholderAttentionMetadata(
            num_prefills=0,
            num_prefill_tokens=0,
            num_decode_tokens=self.num_decode_tokens,
            slot_mapping=slot_mapping,
176
            multi_modal_placeholder_index_maps=None,
177
            enable_kv_scales_calculation=True,
178
179
            seq_lens=None,
            seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
Lily Liu's avatar
Lily Liu committed
180
            max_decode_query_len=self.max_decode_query_len,
181
182
183
184
185
186
187
188
189
190
191
            max_query_len=None,
            max_prefill_seq_len=0,
            max_decode_seq_len=self.max_decode_seq_len,
            query_start_loc=None,
            seq_start_loc=None,
            context_lens_tensor=None,
            block_tables=block_tables,
            use_cuda_graph=self.use_cuda_graph,
        )
        return self._cached_decode_metadata

192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
    def advance_step(self,
                     model_input: "ModelInputForGPUWithSamplingMetadata",
                     sampled_token_ids: Optional[torch.Tensor],
                     block_size: int,
                     num_seqs: int,
                     num_queries: int,
                     turn_prefills_into_decodes: bool = False):
        """
        Update metadata in-place to advance one decode step.
        """
        # When using cudagraph, the num_seqs is padded to the next captured
        # batch sized, but num_queries tracks the actual number of requests in
        # the batch. For --enforce-eager mode, num_seqs == num_queries
        if num_seqs != num_queries:
            assert num_seqs > num_queries
            assert self.use_cuda_graph

        assert not turn_prefills_into_decodes, \
            ("Multi-Step + Chunked-Prefill is not supported for attention-free"
             "models. turn_prefills_into_decodes is a "
             "Multi-Step + Chunked-Prefill specific parameter.")

        assert self.seq_lens is not None
        assert self.max_decode_seq_len == max(self.seq_lens)

        assert self.num_prefills == 0
        assert self.num_prefill_tokens == 0
        assert self.num_decode_tokens == num_seqs

        assert self.seq_lens is not None
        assert len(self.seq_lens) == num_seqs
        assert self.seq_lens_tensor is not None
        assert self.seq_lens_tensor.shape == (num_seqs, )
        assert self.max_query_len == 1
        assert self.max_prefill_seq_len == 0

        assert self.query_start_loc is not None
        assert self.query_start_loc.shape == (num_queries + 1, )
        assert self.seq_start_loc is not None
        assert self.seq_start_loc.shape == (num_seqs + 1, )

        assert self.context_lens_tensor is not None
        assert self.context_lens_tensor.shape == (num_queries, )

        assert self.block_tables is not None

        # Update query lengths. Note that we update only queries and not seqs,
        # since tensors may be padded due to captured cuda graph batch size
        for i in range(num_queries):
            self.seq_lens[i] += 1
        self.max_decode_seq_len = max(self.seq_lens)

        # Update sequences, masking off entries greater than num_queries
        device = self.seq_lens_tensor.device
        mask = torch.arange(self.seq_lens_tensor.size(0),
                            device=device) < num_queries
        self.seq_lens_tensor += mask.to(self.seq_lens_tensor.dtype)
        if sampled_token_ids is not None:
            model_input.input_tokens.masked_scatter_(
                mask, sampled_token_ids[:num_queries])

253
254
255
256
257

class PlaceholderAttentionMetadataBuilder(
        AttentionMetadataBuilder[PlaceholderAttentionMetadata]):

    def __init__(self, input_builder: "ModelInputForGPUBuilder"):
258
259
260
261
262

        self.input_builder = input_builder
        self.runner = input_builder.runner

    def prepare(self):
263
264
265
        self.prefill_seq_lens: List[int] = []
        self.context_lens: List[int] = []
        self.curr_seq_lens: List[int] = []
266
267
268
        self.multimodal_placeholder_maps: Dict[
            str,
            MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
        self.num_prefills = 0
        self.num_prefill_tokens = 0
        self.num_decode_tokens = 0

    def _add_seq_group(
            self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
            chunked_prefill_enabled: bool):
        """Add a sequence group to the metadata. Specifically update/append
        1. context length.
        """
        is_prompt = inter_data.is_prompt

        for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
             curr_sliding_window_block) in zip(
                 inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
                 inter_data.orig_seq_lens, inter_data.seq_lens,
                 inter_data.query_lens, inter_data.context_lens,
                 inter_data.curr_sliding_window_blocks):
            self.context_lens.append(context_len)

            if is_prompt:
290
291
292
293
294
295
                mm_maps = inter_data.multi_modal_placeholder_maps
                if mm_maps:
                    for modality, placeholders in mm_maps.items():
                        self.multimodal_placeholder_maps[modality].extend(
                            placeholders)

296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
                self.num_prefills += 1
                self.num_prefill_tokens += token_len
                self.prefill_seq_lens.append(seq_len)
            else:
                assert query_len == 1, (
                    "seq_len: {}, context_len: {}, query_len: {}".format(
                        seq_len, context_len, query_len))
                self.num_decode_tokens += query_len
                self.curr_seq_lens.append(curr_seq_len)

    def build(self, seq_lens: List[int], query_lens: List[int],
              cuda_graph_pad_size: int, batch_size: int):
        """Build attention metadata with on-device tensors.

        Args:
            seq_lens: The maybe padded sequence lengths of the input sequences.
            query_lens: The query lengths of the input sequences.
            cuda_graph_pad_size: The padding size for cuda graph.
                                 -1 if cuda graph is not used.
            batch_size: The maybe padded batch size.
        """
        for inter_data in self.input_builder.inter_data_list:
            self._add_seq_group(inter_data,
                                self.input_builder.chunked_prefill_enabled)

        device = self.runner.device
        use_captured_graph = cuda_graph_pad_size != -1

        logits_soft_cap = getattr(self.runner.model_config.hf_config,
                                  "attn_logit_softcapping", None)
        if logits_soft_cap is not None:
            raise ValueError(
                "Please use Flashinfer backend for models with logits_soft_cap"
                " (i.e., Gemma-2). Otherwise, the output might be wrong."
                " Set Flashinfer backend by "
                "export VLLM_ATTENTION_BACKEND=FLASHINFER.")

        max_query_len = max(query_lens)
        decode_query_lens = query_lens[self.num_prefills:]
        if len(decode_query_lens) > 0:
Lily Liu's avatar
Lily Liu committed
336
            max_decode_query_len = max(decode_query_lens)
337
        else:
Lily Liu's avatar
Lily Liu committed
338
            max_decode_query_len = 1
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
        max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
        max_decode_seq_len = max(self.curr_seq_lens, default=0)
        num_decode_tokens = self.num_decode_tokens

        if use_captured_graph:
            num_decode_tokens = batch_size

        assert max_query_len > 0, ("query_lens: {}".format(query_lens))

        context_lens_tensor = torch.tensor(self.context_lens,
                                           dtype=torch.int,
                                           device=device)
        seq_lens_tensor = torch.tensor(seq_lens,
                                       dtype=torch.int,
                                       device=device)
        query_lens_tensor = torch.tensor(query_lens,
                                         dtype=torch.long,
                                         device=device)
        query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
                                      dtype=torch.int32,
                                      device=device)
        seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
                                    dtype=torch.int32,
                                    device=device)
363
364
365
366
367
        placeholder_index_maps = {
            modality: placeholder_map.index_map()
            for modality, placeholder_map in
            self.multimodal_placeholder_maps.items()
        }
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
        torch.cumsum(seq_lens_tensor,
                     dim=0,
                     dtype=seq_start_loc.dtype,
                     out=seq_start_loc[1:])
        torch.cumsum(query_lens_tensor,
                     dim=0,
                     dtype=query_start_loc.dtype,
                     out=query_start_loc[1:])

        # Placeholders
        slot_mapping = torch.empty(0)
        block_tables = torch.empty(0)

        return PlaceholderAttentionMetadata(
            num_prefills=self.num_prefills,
            slot_mapping=slot_mapping,
384
            multi_modal_placeholder_index_maps=placeholder_index_maps,
385
            enable_kv_scales_calculation=True,
386
387
388
389
390
            num_prefill_tokens=self.num_prefill_tokens,
            num_decode_tokens=num_decode_tokens,
            seq_lens=seq_lens,
            seq_lens_tensor=seq_lens_tensor,
            max_query_len=max_query_len,
Lily Liu's avatar
Lily Liu committed
391
            max_decode_query_len=max_decode_query_len,
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
            max_prefill_seq_len=max_prefill_seq_len,
            max_decode_seq_len=max_decode_seq_len,
            query_start_loc=query_start_loc,
            seq_start_loc=seq_start_loc,
            context_lens_tensor=context_lens_tensor,
            block_tables=block_tables,
            use_cuda_graph=use_captured_graph,
        )


class PlaceholderAttentionImpl(AttentionImpl):

    def __init__(self, *args, **kwargs) -> None:
        return

    def forward(self, *args, **kwargs) -> torch.Tensor:
        raise NotImplementedError