blocksparse_attn.py 17.7 KB
Newer Older
1
2
3
4
5
6
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple, Type

import torch

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
7
                                              AttentionLayer,
8
                                              AttentionMetadata, AttentionType)
9
10
from vllm.attention.backends.utils import (CommonAttentionState,
                                           CommonMetadataBuilder)
11
12
13
from vllm.attention.ops.blocksparse_attention.interface import (
    LocalStridedBlockSparseAttn, get_head_sliding_step)
from vllm.attention.ops.paged_attn import PagedAttention
14
from vllm import _custom_ops as ops
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from vllm.distributed import (get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size)


@dataclass
class BlocksparseParams:
    max_seqlen: int

    # Num q heads per tensor-parallel rank/partition
    num_heads: int  # per TP partition
    # Num kv heads per tensor-parallel rank/partition
    num_kv_heads: int

    # block size used for blocksparse attention.
    # This is the block_size used in `local_blocks`, `vert_stride`.
    block_size: int

    # Number of blocks for local attention, i.e., number of
    # local attended tokens / `sparse_block_size`
    local_blocks: int

    # Attend to one block per every `vert_stride` blocks.
    # Controlling the sparsity
    vert_stride: int
    """
    If to use the same vertical stride offset for all heads, 
    i.e., attend to the same block of tokens on all heads.
    By default, it is False, i.e., attention on the non-local 
    blocks depends on the `head_idx`, that is on
    blocks satisfying 
    `(block_idx + head_idx * head_sliding_step + 1) % vert_stride == 0`
    where `head_sliding_step=max(1, int(vert_stride / num_total_heads))`,
            `block_idx = position_id // sparse_block_size`.
    See `..ops.blocksparse_attention.utils:get_sparse_attn_mask`
    for more detail.
    """
    homo_head: bool = False

    # If within a group, the kv offsets that each q attends is the same or no.
    homo_head_group: bool = False

    # Decided by homo_head and homo_head group
    head_sliding_step: int = field(init=False)

    # range of q heads to for a TP rank
    active_head_range: Tuple = field(init=False)

    def __post_init__(self):
        assert self.block_size > 0
        assert self.local_blocks >= 0
        assert self.vert_stride >= 1
        assert self.num_heads % self.num_kv_heads == 0

        tp_size = get_tensor_model_parallel_world_size()
        tp_rank = get_tensor_model_parallel_rank()
        total_heads = tp_size * self.num_heads
        total_kv_heads = tp_size * self.num_kv_heads

        if self.homo_head:
            self.head_sliding_step = 0
        elif self.homo_head_group:
            head_sliding_step = get_head_sliding_step(total_kv_heads,
                                                      self.vert_stride)
            # negative indicates sliding along kv heads, i.e., homo q group
            self.head_sliding_step = -head_sliding_step
        else:
            self.head_sliding_step = get_head_sliding_step(
                total_heads, self.vert_stride)

        self.active_head_range = (
            tp_rank * self.num_heads,
            (tp_rank + 1) * self.num_heads,
        )


class BlocksparseFlashAttentionBackend(AttentionBackend):

92
93
    @staticmethod
    def get_name() -> str:
94
        return "BLOCK_SPARSE_FLASH_ATTN"
95

96
97
98
99
100
    @staticmethod
    def get_impl_cls() -> Type["BlocksparseFlashAttentionImpl"]:
        return BlocksparseFlashAttentionImpl

    @staticmethod
101
102
    def get_metadata_cls() -> Type["AttentionMetadata"]:
        return BlocksparseFlashAttentionMetadata
103

104
105
106
107
    @staticmethod
    def get_builder_cls() -> Type["BlocksparseFlashAttentionMetadataBuilder"]:
        return BlocksparseFlashAttentionMetadataBuilder

108
109
110
111
    @staticmethod
    def get_state_cls() -> Type["CommonAttentionState"]:
        return CommonAttentionState

112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
        return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
                                                 num_kv_heads, head_size)

    @staticmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
        src_to_dst: Dict[int, int],
    ) -> None:
        PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)

    @staticmethod
    def copy_blocks(
        kv_caches: List[torch.Tensor],
        src_to_dists: Dict[int, List[int]],
    ) -> None:
        PagedAttention.copy_blocks(kv_caches, src_to_dists)


@dataclass
class BlocksparseFlashAttentionMetadata(AttentionMetadata):
    """A copy of Metadata for FlashAttentionBackend,
    to avoid having to install flash_attn.

    NOTE: Any python object stored here is not updated when it is
    cuda-graph replayed. If you have values that need to be changed
    dynamically, it should be stored in tensor. The tensor has to be
    updated from `CUDAGraphRunner.forward` API.
    """
    # (batch_size,). The sequence length per sequence. Sequence length means
    # the computed tokens + new tokens None if it is a decoding.
    seq_lens: Optional[List[int]]
    # seq_lens stored as a tensor.
    seq_lens_tensor: Optional[torch.Tensor]

    # NOTE(sang): Definition of context_len, query_len, and seq_len.
    # |---------- N-1 iteration --------|
    # |---------------- N iteration ---------------------|
    # |- tokenA -|......................|-- newTokens ---|
    # |---------- context_len ----------|
    # |-------------------- seq_len ----------------------|
    #                                   |-- query_len ---|

    # Maximum query length in the batch. None for decoding.
    max_query_len: Optional[int]
    # Maximum sequence length among prefill batch. 0 if there are decoding
    # requests only.
    max_prefill_seq_len: int
    # Maximum sequence length among decode batch. 0 if there are prefill
    # requests only.
    max_decode_seq_len: int
    # (batch_size + 1,). The cumulative subquery lengths of the sequences in
    # the batch, used to index into subquery. E.g., if the subquery length
    # is [4, 6], it is [0, 4, 10].
    query_start_loc: Optional[torch.Tensor]
    # (batch_size + 1,). The cumulative sequence lengths of the sequences in
    # the batch, used to index into sequence. E.g., if the sequence length is
    # [4, 6], it is [0, 4, 10].
    seq_start_loc: Optional[torch.Tensor]
    # (batch_size,) A tensor of context lengths (tokens that are computed
    # so far).
    context_lens_tensor: Optional[torch.Tensor]

    # (batch_size, max_blocks_per_seq).
    # Block addresses per sequence. (Seq id -> list of physical block)
    # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
    # in the kv cache. Each block can contain up to block_size tokens.
    # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
    # captured.
    block_tables: Optional[torch.Tensor]

    # Whether or not if cuda graph is enabled.
    # Cuda-graph is currently enabled for decoding only.
    # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
    use_cuda_graph: bool

195
196
    # Max number of query tokens for among request in the batch.
    max_decode_query_len: Optional[int] = None
197

198
199
200
201
202
    _cached_prefill_metadata: Optional[
        "BlocksparseFlashAttentionMetadata"] = None
    _cached_decode_metadata: Optional[
        "BlocksparseFlashAttentionMetadata"] = None

203
204
    block_tables_list: Optional[List[int]] = None

205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
    @property
    def prefill_metadata(
            self) -> Optional["BlocksparseFlashAttentionMetadata"]:
        if self.num_prefills == 0:
            return None

        if self._cached_prefill_metadata is not None:
            return self._cached_prefill_metadata

        assert self.seq_lens is not None
        assert self.seq_lens_tensor is not None
        assert self.query_start_loc is not None
        assert self.context_lens_tensor is not None
        assert self.block_tables is not None
        assert self.seq_start_loc is not None

        self._cached_prefill_metadata = BlocksparseFlashAttentionMetadata(
            num_prefills=self.num_prefills,
            num_prefill_tokens=self.num_prefill_tokens,
            num_decode_tokens=0,
            slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
226
227
            multi_modal_placeholder_index_maps=self.
            multi_modal_placeholder_index_maps,
228
            enable_kv_scales_calculation=self.enable_kv_scales_calculation,
229
230
231
232
233
234
235
236
237
238
            seq_lens=self.seq_lens[:self.num_prefills],
            seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
            max_query_len=self.max_query_len,
            max_prefill_seq_len=self.max_prefill_seq_len,
            max_decode_seq_len=0,
            query_start_loc=self.query_start_loc[:self.num_prefills + 1],
            seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
            context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
            block_tables=self.block_tables[:self.num_prefills],
            use_cuda_graph=False,
239
            block_tables_list=self.block_tables_list
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
        )
        return self._cached_prefill_metadata

    @property
    def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]:
        if self.num_decode_tokens == 0:
            return None

        if self._cached_decode_metadata is not None:
            return self._cached_decode_metadata
        assert self.block_tables is not None
        assert self.seq_lens_tensor is not None

        self._cached_decode_metadata = BlocksparseFlashAttentionMetadata(
            num_prefills=0,
            num_prefill_tokens=0,
            num_decode_tokens=self.num_decode_tokens,
            slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
258
            multi_modal_placeholder_index_maps=None,
259
            enable_kv_scales_calculation=False,
260
261
262
263
264
265
266
267
268
269
            seq_lens=None,
            seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
            max_query_len=None,
            max_prefill_seq_len=0,
            max_decode_seq_len=self.max_decode_seq_len,
            query_start_loc=None,
            seq_start_loc=None,
            context_lens_tensor=None,
            block_tables=self.block_tables[self.num_prefills:],
            use_cuda_graph=self.use_cuda_graph,
270
            block_tables_list=self.block_tables_list
271
272
273
274
        )
        return self._cached_decode_metadata


275
276
277
278
279
280
class BlocksparseFlashAttentionMetadataBuilder(
        CommonMetadataBuilder[BlocksparseFlashAttentionMetadata]):

    _metadata_cls = BlocksparseFlashAttentionMetadata


281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
class BlocksparseFlashAttentionImpl(AttentionImpl):
    """
    If the input tensors contain prompt tokens, the layout is as follows:
    |<--------------- num_prompt_tokens -------------->|
    |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|

    Otherwise, the layout is as follows:
    |<------------------ num_generation_tokens (M) ----------------->|
    |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|

    Generation tokens can contain padding when cuda-graph is used.
    Currently, prompt tokens don't contain any padding.

    The prompts might have different lengths, while the generation tokens
    always have length 1.

    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
        alibi_slopes: Optional[List[float]],
        sliding_window: Optional[int],
        kv_cache_dtype: str,
        blocksparse_params: Optional[Dict[str, Any]] = None,
309
        logits_soft_cap: Optional[float] = None,
310
        attn_type: str = AttentionType.DECODER,
311
312
313
314
315
316
    ) -> None:
        assert blocksparse_params is not None
        assert alibi_slopes is None, ValueError(
            "Alibi not support for blocksparse flash attention.")
        assert sliding_window is None, ValueError(
            "sliding_window is invalid for blocksparse attention.")
317
318
        assert logits_soft_cap is None, ValueError(
            "logits_soft_cap is invalid for blocksparse attention.")
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360

        if "num_heads" not in blocksparse_params:
            blocksparse_params["num_heads"] = num_heads
        if "num_kv_heads" not in blocksparse_params:
            blocksparse_params["num_kv_heads"] = num_kv_heads or num_heads
        self.blocksparse_params = BlocksparseParams(**blocksparse_params)
        self.kv_cache_dtype = kv_cache_dtype

        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
        self.alibi_slopes = alibi_slopes
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads

        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

        self.local_blocks = self.blocksparse_params.local_blocks
        self.vert_stride = self.blocksparse_params.vert_stride
        self.sparse_block_size = self.blocksparse_params.block_size
        self.head_sliding_step = self.blocksparse_params.head_sliding_step

        suppored_head_sizes = PagedAttention.get_supported_head_sizes()
        if head_size not in suppored_head_sizes:
            raise ValueError(
                f"Head size {head_size} is not supported by PagedAttention. "
                f"Supported head sizes are: {suppored_head_sizes}.")

        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()

        total_num_heads = num_heads * self.tp_size
        self.bs_attn = LocalStridedBlockSparseAttn(
            total_num_heads,
            self.blocksparse_params.max_seqlen,
            self.blocksparse_params.local_blocks,
            self.blocksparse_params.vert_stride,
            self.blocksparse_params.block_size,
            homo_head=self.blocksparse_params.homo_head,
            active_head_range=self.blocksparse_params.active_head_range,
        )

361
362
363
364
365
366
        if attn_type != AttentionType.DECODER:
            raise NotImplementedError("Encoder self-attention and "
                                      "encoder/decoder cross-attention "
                                      "are not implemented for "
                                      "BlocksparseFlashAttentionImpl")

367
368
    def forward(
        self,
369
        layer: AttentionLayer,
370
371
372
373
374
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: BlocksparseFlashAttentionMetadata,
375
        output: Optional[torch.Tensor] = None,
376
377
378
379
380
381
382
383
    ) -> torch.Tensor:
        """Forward pass with FlashAttention and PagedAttention.

        Args:
            query: shape = [num_tokens, num_heads * head_size]
            key: shape = [num_tokens, num_kv_heads * head_size]
            value: shape = [num_tokens, num_kv_heads * head_size]
            kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
384
385
                NOTE: kv_cache will be an empty tensor with shape [0]
                for profiling run.
386
387
388
389
390
391
392
393
394
395
            attn_metadata: Metadata for attention.
        Returns:
            shape = [num_tokens, num_heads * head_size]
        """
        num_tokens, hidden_size = query.shape
        # Reshape the query, key, and value tensors.
        query = query.view(-1, self.num_heads, self.head_size)
        key = key.view(-1, self.num_kv_heads, self.head_size)
        value = value.view(-1, self.num_kv_heads, self.head_size)

396
        if kv_cache.numel() > 0:
397
398
399
400
401
402
403
404
405
406
407
408
409
410
            key_cache, value_cache = PagedAttention.split_kv_cache(
                kv_cache, self.num_kv_heads, self.head_size)

            # Reshape the input keys and values and store them in the cache.
            # If kv_cache is not provided, the new key and value tensors are
            # not cached. This happens during the initial memory profiling run.

            PagedAttention.write_to_paged_cache(
                key,
                value,
                key_cache,
                value_cache,
                attn_metadata.slot_mapping,
                self.kv_cache_dtype,
411
412
                layer._k_scale,
                layer._v_scale,
413
414
415
416
417
418
419
420
421
            )

        if prefill_meta := attn_metadata.prefill_metadata:

            # Prompt run.
            # normal attention
            # When block_tables are not filled, it means q and k are the
            # prompt, and they have the same length.

422
            assert kv_cache.numel() == 0 \
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
                    or prefill_meta.block_tables is None \
                    or prefill_meta.block_tables.numel() == 0, \
                "Does not support prefix-enabled attention."

            output = self.bs_attn(
                q=query,
                k=key,
                v=value,
                cu_seqlens_q=prefill_meta.seq_start_loc,
                cu_seqlens_k=prefill_meta.seq_start_loc,
                sm_scale=self.scale,
            )

        if decode_meta := attn_metadata.decode_metadata:
            # Decoding run.
            output = PagedAttention.forward_decode(
                query,
                key_cache,
                value_cache,
                decode_meta.block_tables,
                decode_meta.seq_lens_tensor,
                self.blocksparse_params.max_seqlen,
                self.kv_cache_dtype,
                self.num_kv_heads,
                self.scale,
                self.alibi_slopes,
449
450
                layer._k_scale,
                layer._v_scale,
451
452
453
454
455
456
457
                tp_rank=self.tp_rank,
                blocksparse_local_blocks=self.local_blocks,
                blocksparse_vert_stride=self.vert_stride,
                blocksparse_block_size=self.sparse_block_size,
                blocksparse_head_sliding_step=self.head_sliding_step,
            )

458
        assert output is not None
459
460
        # Reshape the output tensor.
        return output.view(num_tokens, hidden_size)