xformers.py 33.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Attention layer with xFormers and PagedAttention."""
4
from dataclasses import dataclass
5
from typing import Any, Dict, List, Optional, Tuple, Type
Woosuk Kwon's avatar
Woosuk Kwon committed
6
7

import torch
8
from xformers import ops as xops
9
10
from xformers.ops.fmha.attn_bias import (AttentionBias,
                                         BlockDiagonalCausalMask,
11
                                         BlockDiagonalMask,
Woosuk Kwon's avatar
Woosuk Kwon committed
12
                                         LowerTriangularMaskWithTensorBias)
Woosuk Kwon's avatar
Woosuk Kwon committed
13

14
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
15
                                              AttentionLayer,
16
                                              AttentionMetadata, AttentionType)
17
18
19
20
from vllm.attention.backends.utils import (
    CommonAttentionState, CommonMetadataBuilder,
    get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args,
    is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set)
21
22
from vllm.attention.ops.paged_attn import (PagedAttention,
                                           PagedAttentionMetadata)
23
24
25
from vllm.logger import init_logger

logger = init_logger(__name__)
26

27
28
29

class XFormersBackend(AttentionBackend):

30
31
    @staticmethod
    def get_name() -> str:
32
        return "XFORMERS"
33

34
35
36
37
38
    @staticmethod
    def get_impl_cls() -> Type["XFormersImpl"]:
        return XFormersImpl

    @staticmethod
39
40
    def get_metadata_cls() -> Type["AttentionMetadata"]:
        return XFormersMetadata
41

42
43
44
45
    @staticmethod
    def get_builder_cls() -> Type["XFormersMetadataBuilder"]:
        return XFormersMetadataBuilder

46
47
48
49
    @staticmethod
    def get_state_cls() -> Type["CommonAttentionState"]:
        return CommonAttentionState

50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
        return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
                                                 num_kv_heads, head_size)

    @staticmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
        src_to_dst: Dict[int, int],
    ) -> None:
        PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)

    @staticmethod
    def copy_blocks(
        kv_caches: List[torch.Tensor],
71
        src_to_dists: torch.Tensor,
72
73
74
75
76
    ) -> None:
        PagedAttention.copy_blocks(kv_caches, src_to_dists)


@dataclass
77
class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
78
79
80
81
82
83
84
85
86
87
88
89
    """Metadata for XFormersbackend.

    NOTE: Any python object stored here is not updated when it is
    cuda-graph replayed. If you have values that need to be changed
    dynamically, it should be stored in tensor. The tensor has to be
    updated from `CUDAGraphRunner.forward` API.
    """

    # |---------- N-1 iteration --------|
    # |---------------- N iteration ---------------------|
    # |- tokenA -|......................|-- newTokens ---|
    # |---------- context_len ----------|
90
91
    # |-------------------- seq_len ----------------------|
    #                                   |-- query_len ---|
92

93
94
95
    # seq_lens stored as a tensor.
    seq_lens_tensor: Optional[torch.Tensor]

96
    # FIXME: It is for flash attn.
97
98
99
100
101
102
    # Maximum sequence length among prefill batch. 0 if there are decoding
    # requests only.
    max_prefill_seq_len: int
    # Maximum sequence length among decode batch. 0 if there are prefill
    # requests only.
    max_decode_seq_len: int
103
104
105
106
107
108
109
110
111
112

    # Whether or not if cuda graph is enabled.
    # Cuda-graph is currently enabled for decoding only.
    # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
    use_cuda_graph: bool

    # (batch_size,). The sequence length per sequence. Sequence length means
    # the computed tokens + new tokens None if it is a decoding.
    seq_lens: Optional[List[int]] = None

113
114
115
116
    # FIXME: It is for flash attn.
    # (batch_size + 1,). The cumulative sequence lengths of the sequences in
    # the batch, used to index into sequence. E.g., if the sequence length is
    # [4, 6], it is [0, 4, 10].
117
118
    seq_start_loc: Optional[torch.Tensor] = None

119
120
    # (batch_size,) A tensor of context lengths (tokens that are computed
    # so far).
121
    context_lens_tensor: Optional[torch.Tensor] = None
122

123
124
125
    # Maximum query length in the batch. None for decoding.
    max_query_len: Optional[int] = None

126
127
    # Max number of query tokens among request in the batch.
    max_decode_query_len: Optional[int] = None
128

129
130
131
132
133
134
    # (batch_size + 1,). The cumulative subquery lengths of the sequences in
    # the batch, used to index into subquery. E.g., if the subquery length
    # is [4, 6], it is [0, 4, 10].
    query_start_loc: Optional[torch.Tensor] = None

    # Self-attention prefill/decode metadata cache
135
136
    _cached_prefill_metadata: Optional["XFormersMetadata"] = None
    _cached_decode_metadata: Optional["XFormersMetadata"] = None
137

138
139
140
141
142
    # Begin encoder attn & enc/dec cross-attn fields...

    # Encoder sequence lengths representation
    encoder_seq_lens: Optional[List[int]] = None
    encoder_seq_lens_tensor: Optional[torch.Tensor] = None
143
144
145
146
147
    # FIXME: It is for flash attn.
    # (batch_size + 1,). The cumulative sequence lengths of the sequences in
    # the batch, used to index into sequence. E.g., if the sequence length is
    # [4, 6], it is [0, 4, 10].
    encoder_seq_start_loc: Optional[torch.Tensor] = None
148
149
150
151
152
153
154
155
156
157
158
159

    # Maximum sequence length among encoder sequences
    max_encoder_seq_len: Optional[int] = None

    # Number of tokens input to encoder
    num_encoder_tokens: Optional[int] = None

    # Cross-attention memory-mapping data structures: slot mapping
    # and block tables
    cross_slot_mapping: Optional[torch.Tensor] = None
    cross_block_tables: Optional[torch.Tensor] = None

160
161
162
163
164
165
166
    def __post_init__(self):
        # Set during the execution of the first attention op.
        # It is a list because it is needed to set per prompt
        # when alibi slopes is used. It is because of the limitation
        # from xformer API.
        # will not appear in the __repr__ and __init__
        self.attn_bias: Optional[List[AttentionBias]] = None
167
168
169
170
171
172
173
174
        self.encoder_attn_bias: Optional[List[AttentionBias]] = None
        self.cross_attn_bias: Optional[List[AttentionBias]] = None

    @property
    def is_all_encoder_attn_metadata_set(self):
        '''
        All attention metadata required for encoder attention is set.
        '''
175
        return is_all_encoder_attn_metadata_set(self)
176
177
178
179
180
181
182
183

    @property
    def is_all_cross_attn_metadata_set(self):
        '''
        All attention metadata required for enc/dec cross-attention is set.

        Superset of encoder attention required metadata.
        '''
184
        return is_all_cross_attn_metadata_set(self)
185

186
187
188
189
190
191
    @property
    def prefill_metadata(self) -> Optional["XFormersMetadata"]:
        if self.num_prefills == 0:
            return None

        if self._cached_prefill_metadata is not None:
192
193
            # Recover cached prefill-phase attention
            # metadata structure
194
195
            return self._cached_prefill_metadata

196
197
198
199
200
201
202
203
        assert ((self.seq_lens is not None)
                or (self.encoder_seq_lens is not None))
        assert ((self.seq_lens_tensor is not None)
                or (self.encoder_seq_lens_tensor is not None))

        # Compute some attn_metadata fields which default to None
        query_start_loc = (None if self.query_start_loc is None else
                           self.query_start_loc[:self.num_prefills + 1])
204
205
        seq_start_loc = (None if self.seq_start_loc is None else
                         self.seq_start_loc[:self.num_prefills + 1])
206
207
208
209
210
211
212
213
214
215
216
217
        slot_mapping = (None if self.slot_mapping is None else
                        self.slot_mapping[:self.num_prefill_tokens])
        seq_lens = (None if self.seq_lens is None else
                    self.seq_lens[:self.num_prefills])
        seq_lens_tensor = (None if self.seq_lens_tensor is None else
                           self.seq_lens_tensor[:self.num_prefills])
        context_lens_tensor = (None if self.context_lens_tensor is None else
                               self.context_lens_tensor[:self.num_prefills])
        block_tables = (None if self.block_tables is None else
                        self.block_tables[:self.num_prefills])

        # Construct & cache prefill-phase attention metadata structure
218
219
220
221
        self._cached_prefill_metadata = XFormersMetadata(
            num_prefills=self.num_prefills,
            num_prefill_tokens=self.num_prefill_tokens,
            num_decode_tokens=0,
222
            slot_mapping=slot_mapping,
223
224
            multi_modal_placeholder_index_maps=self.
            multi_modal_placeholder_index_maps,
225
            enable_kv_scales_calculation=self.enable_kv_scales_calculation,
226
227
            seq_lens=seq_lens,
            seq_lens_tensor=seq_lens_tensor,
228
229
230
            max_query_len=self.max_query_len,
            max_prefill_seq_len=self.max_prefill_seq_len,
            max_decode_seq_len=0,
231
            query_start_loc=query_start_loc,
232
            seq_start_loc=seq_start_loc,
233
234
            context_lens_tensor=context_lens_tensor,
            block_tables=block_tables,
235
            use_cuda_graph=False,
236
237
238
239
240
241
            # Begin encoder & cross attn fields below...
            encoder_seq_lens=self.encoder_seq_lens,
            encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
            max_encoder_seq_len=self.max_encoder_seq_len,
            cross_slot_mapping=self.cross_slot_mapping,
            cross_block_tables=self.cross_block_tables)
242
243
244
245
246
247
248
249
        return self._cached_prefill_metadata

    @property
    def decode_metadata(self) -> Optional["XFormersMetadata"]:
        if self.num_decode_tokens == 0:
            return None

        if self._cached_decode_metadata is not None:
250
251
            # Recover cached decode-phase attention
            # metadata structure
252
            return self._cached_decode_metadata
253
254
255
256
257
258
259
260
261
262
263
264
        assert ((self.seq_lens_tensor is not None)
                or (self.encoder_seq_lens_tensor is not None))

        # Compute some attn_metadata fields which default to None
        slot_mapping = (None if self.slot_mapping is None else
                        self.slot_mapping[self.num_prefill_tokens:])
        seq_lens_tensor = (None if self.seq_lens_tensor is None else
                           self.seq_lens_tensor[self.num_prefills:])
        block_tables = (None if self.block_tables is None else
                        self.block_tables[self.num_prefills:])

        # Construct & cache decode-phase attention metadata structure
265
266
267
268
        self._cached_decode_metadata = XFormersMetadata(
            num_prefills=0,
            num_prefill_tokens=0,
            num_decode_tokens=self.num_decode_tokens,
269
            slot_mapping=slot_mapping,
270
            multi_modal_placeholder_index_maps=None,
271
            enable_kv_scales_calculation=True,
272
            seq_lens_tensor=seq_lens_tensor,
273
274
            max_prefill_seq_len=0,
            max_decode_seq_len=self.max_decode_seq_len,
275
            block_tables=block_tables,
276
            use_cuda_graph=self.use_cuda_graph,
277
278
279
280
281
282
            # Begin encoder & cross attn fields below...
            encoder_seq_lens=self.encoder_seq_lens,
            encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
            max_encoder_seq_len=self.max_encoder_seq_len,
            cross_slot_mapping=self.cross_slot_mapping,
            cross_block_tables=self.cross_block_tables)
283
284
285
286
287
288
289

        # Batch may be composed of prefill|decodes, adjust query start indices
        # to refer to the start of decodes when the two are split apart.
        # E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
        if self._cached_decode_metadata.query_start_loc is not None:
            qs = self._cached_decode_metadata.query_start_loc
            self._cached_decode_metadata.query_start_loc = qs - qs[0]
290
291
292
        return self._cached_decode_metadata


293
294
def _get_attn_bias(
    attn_metadata: XFormersMetadata,
295
    attn_type: str,
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
) -> Optional[AttentionBias]:
    '''
    Extract appropriate attention bias from attention metadata
    according to attention type.

    Arguments:

    * attn_metadata: Attention metadata structure associated with attention
    * attn_type: encoder attention, decoder self-attention,
                 encoder/decoder cross-attention

    Returns:
    * Appropriate attention bias value given the attention type
    '''

311
312
    if (attn_type == AttentionType.DECODER
            or attn_type == AttentionType.ENCODER_ONLY):
313
314
315
        return attn_metadata.attn_bias
    elif attn_type == AttentionType.ENCODER:
        return attn_metadata.encoder_attn_bias
316
    elif attn_type == AttentionType.ENCODER_DECODER:
317
        return attn_metadata.cross_attn_bias
318
319
    else:
        raise AttributeError(f"Invalid attention type {str(attn_type)}")
320
321
322
323
324


def _set_attn_bias(
    attn_metadata: XFormersMetadata,
    attn_bias: List[Optional[AttentionBias]],
325
    attn_type: str,
326
327
328
329
330
331
332
333
334
335
336
337
338
) -> None:
    '''
    Update appropriate attention bias field of attention metadata,
    according to attention type.

    Arguments:

    * attn_metadata: Attention metadata structure associated with attention
    * attn_bias: The desired attention bias value
    * attn_type: encoder attention, decoder self-attention,
                 encoder/decoder cross-attention
    '''

339
340
    if (attn_type == AttentionType.DECODER
            or attn_type == AttentionType.ENCODER_ONLY):
341
342
343
344
345
346
347
348
349
        attn_metadata.attn_bias = attn_bias
    elif attn_type == AttentionType.ENCODER:
        attn_metadata.encoder_attn_bias = attn_bias
    elif attn_type == AttentionType.ENCODER_DECODER:
        attn_metadata.cross_attn_bias = attn_bias
    else:
        raise AttributeError(f"Invalid attention type {str(attn_type)}")


350
351
352
353
354
class XFormersMetadataBuilder(CommonMetadataBuilder[XFormersMetadata]):

    _metadata_cls = XFormersMetadata


355
class XFormersImpl(AttentionImpl[XFormersMetadata]):
356
357
    """
    If the input tensors contain prompt tokens, the layout is as follows:
358
359
    |<--------------- num_prefill_tokens ----------------->|	
    |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
360
361

    Otherwise, the layout is as follows:	
362
363
    |<----------------- num_decode_tokens ------------------>|	
    |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
364
365
366
367
368
369

    Generation tokens can contain padding when cuda-graph is used.
    Currently, prompt tokens don't contain any padding.

    The prompts might have different lengths, while the generation tokens
    always have length 1.
370
371
372
373
374
375
376
377
378

    If chunked prefill is enabled, prefill tokens and decode tokens can be
    batched together in a flattened 1D query.

    |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
    |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|

    Currently, cuda graph is disabled for chunked prefill, meaning there's no
    padding between prefill and decode tokens.
379
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
380

Woosuk Kwon's avatar
Woosuk Kwon committed
381
382
383
384
385
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
386
387
388
389
        num_kv_heads: int,
        alibi_slopes: Optional[List[float]],
        sliding_window: Optional[int],
        kv_cache_dtype: str,
390
        blocksparse_params: Optional[Dict[str, Any]] = None,
391
        logits_soft_cap: Optional[float] = None,
392
        attn_type: str = AttentionType.DECODER,
393
        kv_sharing_target_layer_name: Optional[str] = None,
394
        use_irope: bool = False,
Woosuk Kwon's avatar
Woosuk Kwon committed
395
    ) -> None:
396
397
        if kv_sharing_target_layer_name is not None:
            raise NotImplementedError("KV sharing is not supported in V0.")
398
399
400
401
        if blocksparse_params is not None:
            raise ValueError(
                "XFormers does not support block-sparse attention.")
        if logits_soft_cap is not None:
402
403
            logger.warning_once("XFormers does not support logits soft cap. "
                                "Outputs may be slightly off.")
404
405
406
407
        if use_irope:
            logger.warning_once(
                "Using irope in XFormers is not supported yet, it will fall"
                " back to global attention for long context.")
408
409
        self.num_heads = num_heads
        self.head_size = head_size
410
        self.scale = float(scale)
411
        self.num_kv_heads = num_kv_heads
Woosuk Kwon's avatar
Woosuk Kwon committed
412
413
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
414
        self.alibi_slopes = alibi_slopes
415
416
        self.sliding_window = sliding_window
        self.kv_cache_dtype = kv_cache_dtype
Zhuohan Li's avatar
Zhuohan Li committed
417
418
419

        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
420

421
422
        supported_head_sizes = PagedAttention.get_supported_head_sizes()
        if head_size not in supported_head_sizes:
423
424
            raise ValueError(
                f"Head size {head_size} is not supported by PagedAttention. "
425
                f"Supported head sizes are: {supported_head_sizes}.")
Woosuk Kwon's avatar
Woosuk Kwon committed
426

427
428
        self.attn_type = attn_type

Woosuk Kwon's avatar
Woosuk Kwon committed
429
430
    def forward(
        self,
431
        layer: AttentionLayer,
432
        query: torch.Tensor,
433
434
        key: Optional[torch.Tensor],
        value: Optional[torch.Tensor],
435
        kv_cache: torch.Tensor,
436
        attn_metadata: "XFormersMetadata",
437
        output: Optional[torch.Tensor] = None,
438
        output_scale: Optional[torch.Tensor] = None,
439
    ) -> torch.Tensor:
440
        """Forward pass with xFormers and PagedAttention.
441

442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
        For decoder-only models: query, key and value must be non-None.

        For encoder/decoder models:
        * XFormersImpl.forward() may be invoked for both self- and cross-
          attention layers.
        * For self-attention: query, key and value must be non-None.
        * For cross-attention:
            * Query must be non-None
            * During prefill, key and value must be non-None; key and value
              get cached for use during decode.
            * During decode, key and value may be None, since:
              (1) key and value tensors were cached during prefill, and
              (2) cross-attention key and value tensors do not grow during
                  decode
        
        A note on how the attn_type (attention type enum) argument impacts
        attention forward() behavior:
    
            * DECODER: normal decoder-only behavior;
                use decoder self-attention block table
            * ENCODER: no KV caching; pass encoder sequence
                attributes (encoder_seq_lens/encoder_seq_lens_tensor/
                max_encoder_seq_len) to kernel, in lieu of decoder
465
466
467
468
                sequence attributes (seq_lens/seq_lens_tensor/max_seq_len).
                Used for encoder branch of encoder-decoder models.
            * ENCODER_ONLY: no kv_caching, uses the normal attention 
                attributes (seq_lens/seq_lens_tensor/max_seq_len).
469
470
471
472
473
474
475
            * ENCODER_DECODER: cross-attention behavior;
                use cross-attention block table for caching KVs derived
                from encoder hidden states; since KV sequence lengths
                will match encoder sequence lengths, pass encoder sequence
                attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/
                max_encoder_seq_len)
    
476
        Args:
477
478
479
            query: shape = [num_tokens, num_heads * head_size]
            key: shape = [num_tokens, num_kv_heads * head_size]
            value: shape = [num_tokens, num_kv_heads * head_size]
480
            kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
481
482
                NOTE: kv_cache will be an empty tensor with shape [0]
                for profiling run.
483
            attn_metadata: Metadata for attention.
484
485
486
487
            attn_type: Select attention type, between encoder attention,
                       decoder self-attention, or encoder/decoder cross-
                       attention. Defaults to decoder self-attention,
                       which is the vLLM default generally
488
        Returns:
489
            shape = [num_tokens, num_heads * head_size]
490
        """
491
492
493
494
495
        if output_scale is not None:
            raise NotImplementedError(
                "fused output quantization is not yet supported"
                " for XFormersImpl")

496
        attn_type = self.attn_type
497
498
499
500
501
502
        # Check that appropriate attention metadata attributes are
        # selected for the desired attention type
        if (attn_type == AttentionType.ENCODER
                and (not attn_metadata.is_all_encoder_attn_metadata_set)):
            raise AttributeError("Encoder attention requires setting "
                                 "encoder metadata attributes.")
503

504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
        elif (attn_type == AttentionType.ENCODER_DECODER
              and (not attn_metadata.is_all_cross_attn_metadata_set)):
            raise AttributeError("Encoder/decoder cross-attention "
                                 "requires setting cross-attention "
                                 "metadata attributes.")

        query = query.view(-1, self.num_heads, self.head_size)
        if key is not None:
            assert value is not None
            key = key.view(-1, self.num_kv_heads, self.head_size)
            value = value.view(-1, self.num_kv_heads, self.head_size)
        else:
            assert value is None

        # Self-attention vs. cross-attention will impact
        # which KV cache memory-mapping & which
        # seqlen datastructures we utilize

522
        if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0):
523
524
525
526
527
528
529
            # KV-cache during decoder-self- or
            # encoder-decoder-cross-attention, but not
            # during encoder attention.
            #
            # Even if there are no new key/value pairs to cache,
            # we still need to break out key_cache and value_cache
            # i.e. for later use by paged attention
530
531
532
            key_cache, value_cache = PagedAttention.split_kv_cache(
                kv_cache, self.num_kv_heads, self.head_size)

533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
            if (key is not None) and (value is not None):

                if attn_type == AttentionType.ENCODER_DECODER:
                    # Update cross-attention KV cache (prefill-only)
                    # During cross-attention decode, key & value will be None,
                    # preventing this IF-statement branch from running
                    updated_slot_mapping = attn_metadata.cross_slot_mapping
                else:
                    # Update self-attention KV cache (prefill/decode)
                    updated_slot_mapping = attn_metadata.slot_mapping

                # Reshape the input keys and values and store them in the cache.
                # If kv_cache is not provided, the new key and value tensors are
                # not cached. This happens during the initial memory
                # profiling run.
548
549
550
                PagedAttention.write_to_paged_cache(
                    key, value, key_cache, value_cache, updated_slot_mapping,
                    self.kv_cache_dtype, layer._k_scale, layer._v_scale)
551
552
553
        (num_prefill_query_tokens, num_prefill_kv_tokens,
        num_decode_query_tokens) = \
            get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
554
555
556

        output = torch.empty_like(query)
        # Query for decode. KV is not needed because it is already cached.
557
        decode_query = query[num_prefill_query_tokens:]
558
        # QKV for prefill.
559
        query = query[:num_prefill_query_tokens]
560
        if key is not None and value is not None:
561
562
            key = key[:num_prefill_kv_tokens]
            value = value[:num_prefill_kv_tokens]
563

564
565
        assert query.shape[0] == num_prefill_query_tokens
        assert decode_query.shape[0] == num_decode_query_tokens
566
567

        if prefill_meta := attn_metadata.prefill_metadata:
568
            # Prompt run.
569
            if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0:
570
571
572
                # normal attention.
                # block tables are empty if the prompt does not have a cached
                # prefix.
573
                out = self._run_memory_efficient_xformers_forward(
574
                    query, key, value, prefill_meta, attn_type=attn_type)
575
576
                assert out.shape == output[:num_prefill_query_tokens].shape
                output[:num_prefill_query_tokens] = out
Woosuk Kwon's avatar
Woosuk Kwon committed
577
            else:
578
579
                assert attn_type != AttentionType.ENCODER_ONLY, (
                    "Encoder-only models should not have prefix attention.")
580
581
582
583

                assert prefill_meta.query_start_loc is not None
                assert prefill_meta.max_query_len is not None

584
                # prefix-enabled attention
585
586
587
                # TODO(Hai) this triton kernel has regression issue (broke) to
                # deal with different data types between KV and FP8 KV cache,
                # to be addressed separately.
588
                out = PagedAttention.forward_prefix(
589
590
591
                    query,
                    key,
                    value,
592
                    self.kv_cache_dtype,
593
594
                    key_cache,
                    value_cache,
595
                    prefill_meta.block_tables,
596
                    prefill_meta.query_start_loc,
597
598
                    prefill_meta.seq_lens_tensor,
                    prefill_meta.max_query_len,
599
                    self.alibi_slopes,
600
                    self.sliding_window,
601
602
                    layer._k_scale,
                    layer._v_scale,
603
                )
604
605
                assert output[:num_prefill_query_tokens].shape == out.shape
                output[:num_prefill_query_tokens] = out
606
607

        if decode_meta := attn_metadata.decode_metadata:
608
609
            assert attn_type != AttentionType.ENCODER_ONLY, (
                "Encoder-only models should not have decode metadata.")
610
611
612
613
614

            (
                seq_lens_arg,
                max_seq_len_arg,
                block_tables_arg,
615
            ) = get_seq_len_block_table_args(decode_meta, False, attn_type)
616

617
            output[num_prefill_query_tokens:] = PagedAttention.forward_decode(
618
                decode_query,
619
620
                key_cache,
                value_cache,
621
622
623
                block_tables_arg,
                seq_lens_arg,
                max_seq_len_arg,
624
                self.kv_cache_dtype,
625
626
627
                self.num_kv_heads,
                self.scale,
                self.alibi_slopes,
628
629
                layer._k_scale,
                layer._v_scale,
630
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
631
632

        # Reshape the output tensor.
633
634
        return output.view(-1, self.num_heads * self.head_size)

635
    def _run_memory_efficient_xformers_forward(
636
637
638
639
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
640
        attn_metadata: XFormersMetadata,
641
        attn_type: str = AttentionType.DECODER,
642
643
644
645
    ) -> torch.Tensor:
        """Attention for 1D query of multiple prompts. Multiple prompt
        tokens are flattened in to `query` input.

646
647
648
        See https://facebookresearch.github.io/xformers/components/ops.html
        for API spec.

649
        Args:
650
651
652
653
            output: shape = [num_prefill_tokens, num_heads, head_size]
            query: shape = [num_prefill_tokens, num_heads, head_size]
            key: shape = [num_prefill_tokens, num_kv_heads, head_size]
            value: shape = [num_prefill_tokens, num_kv_heads, head_size]
654
            attn_metadata: Metadata for attention.
655
656
657
658
            attn_type: Select attention type, between encoder attention,
                       decoder self-attention, or encoder/decoder cross-
                       attention. Defaults to decoder self-attention,
                       which is the vLLM default generally
659
        """
660

661
662
663
664
665
666
667
668
669
670
671
672
673
674
        original_query = query
        if self.num_kv_heads != self.num_heads:
            # GQA/MQA requires the shape [B, M, G, H, K].
            # Note that the output also has the same shape (which is different
            # from a spec from the doc).
            query = query.view(query.shape[0], self.num_kv_heads,
                               self.num_queries_per_kv, query.shape[-1])
            key = key[:, :,
                      None, :].expand(key.shape[0], self.num_kv_heads,
                                      self.num_queries_per_kv, key.shape[-1])
            value = value[:, :,
                          None, :].expand(value.shape[0], self.num_kv_heads,
                                          self.num_queries_per_kv,
                                          value.shape[-1])
675

676
677
678
        # Set attention bias if not provided. This typically happens at
        # the very attention layer of every iteration.
        # FIXME(woosuk): This is a hack.
679
680
        attn_bias = _get_attn_bias(attn_metadata, attn_type)
        if attn_bias is None:
681
            if self.alibi_slopes is None:
682
683
684

                # Cross attention block of decoder branch of encoder-decoder
                # model uses seq_lens for dec / encoder_seq_lens for enc
685
686
687
688
                if (attn_type == AttentionType.ENCODER_DECODER):
                    assert attn_metadata.seq_lens is not None
                    assert attn_metadata.encoder_seq_lens is not None

689
                    # Cross-attention mask is non-causal
690
                    attn_bias = BlockDiagonalMask.from_seqlens(
691
692
693
                        attn_metadata.seq_lens,
                        attn_metadata.encoder_seq_lens,
                        device=query.device)
694
695
696

                # Encoder branch of encoder-decoder model uses
                # attn_metadata.encoder_seq_lens
697
                elif attn_type == AttentionType.ENCODER:
698

699
700
                    assert attn_metadata.encoder_seq_lens is not None

701
                    # Encoder self-attention mask is non-causal
702
                    attn_bias = BlockDiagonalMask.from_seqlens(
703
                        attn_metadata.encoder_seq_lens, device=query.device)
704
705
706
707

                # Self-attention block of encoder-only model just
                # uses the seq_lens directly.
                elif attn_type == AttentionType.ENCODER_ONLY:
708
709
                    assert attn_metadata.seq_lens is not None

710
711
                    # Encoder self-attention mask is non-causal
                    attn_bias = BlockDiagonalMask.from_seqlens(
712
                        attn_metadata.seq_lens, device=query.device)
713
714
715
716
717
718
719

                # Self-attention block of decoder branch just
                # uses the seq_lens directly
                elif attn_type == AttentionType.DECODER:
                    assert attn_metadata.seq_lens is not None

                    # Decoder self-attention mask is causal
720
                    attn_bias = BlockDiagonalCausalMask.from_seqlens(
721
                        attn_metadata.seq_lens, device=query.device)
722
723
724
                else:
                    raise ValueError("Unknown AttentionType: %s", attn_type)

725
726
727
                if self.sliding_window is not None:
                    attn_bias = attn_bias.make_local_attention(
                        self.sliding_window)
728
                attn_bias = [attn_bias]
729
            else:
730
                assert attn_type == AttentionType.DECODER
731
732
733
734
735
736
                assert attn_metadata.seq_lens is not None
                attn_bias = _make_alibi_bias(self.alibi_slopes,
                                             self.num_kv_heads, query.dtype,
                                             attn_metadata.seq_lens)

            _set_attn_bias(attn_metadata, attn_bias, attn_type)
737
738
739
740
741

        # No alibi slopes.
        # TODO(woosuk): Too many view operations. Let's try to reduce
        # them in the future for code readability.
        if self.alibi_slopes is None:
742
            # Add the batch dimension.
743
744
745
746
747
748
749
            query = query.unsqueeze(0)
            key = key.unsqueeze(0)
            value = value.unsqueeze(0)
            out = xops.memory_efficient_attention_forward(
                query,
                key,
                value,
750
                attn_bias=attn_bias[0],
751
                p=0.0,
752
                scale=self.scale)
753
            return out.view_as(original_query)
754
755
756
757
758

        # Attention with alibi slopes.
        # FIXME(woosuk): Because xformers does not support dynamic sequence
        # lengths with custom attention bias, we process each prompt one by
        # one. This is inefficient, especially when we have many short prompts.
759
        assert attn_metadata.seq_lens is not None
760
        output = torch.empty_like(original_query)
761
        start = 0
762
763
        for i, seq_len in enumerate(attn_metadata.seq_lens):
            end = start + seq_len
764
765
766
767
            out = xops.memory_efficient_attention_forward(
                query[None, start:end],
                key[None, start:end],
                value[None, start:end],
768
                attn_bias=attn_bias[i],
769
                p=0.0,
770
                scale=self.scale)
771
            # TODO(woosuk): Unnecessary copy. Optimize.
772
            output[start:end].copy_(out.view_as(original_query[start:end]))
773
            start += seq_len
774
        return output
Woosuk Kwon's avatar
Woosuk Kwon committed
775
776
777
778


def _make_alibi_bias(
    alibi_slopes: torch.Tensor,
Megha Agarwal's avatar
Megha Agarwal committed
779
    num_kv_heads: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
780
    dtype: torch.dtype,
781
    seq_lens: List[int],
782
783
) -> List[AttentionBias]:
    attn_biases: List[AttentionBias] = []
784
785
    for seq_len in seq_lens:
        bias = torch.arange(seq_len, dtype=dtype)
786
        # NOTE(zhuohan): HF uses
787
        #     `bias = bias[None, :].repeat(seq_len, 1)`
788
789
790
791
792
793
794
        # here. We find that both biases give the same results, but
        # the bias below more accurately follows the original ALiBi
        # paper.
        # Calculate a matrix where each element represents ith element- jth
        # element.
        bias = bias[None, :] - bias[:, None]

795
        padded_len = (seq_len + 7) // 8 * 8
796
797
798
799
        num_heads = alibi_slopes.shape[0]
        bias = torch.empty(
            1,  # batch size
            num_heads,
800
            seq_len,
801
802
803
            padded_len,
            device=alibi_slopes.device,
            dtype=dtype,
804
        )[:, :, :, :seq_len].copy_(bias)
805
806
807
808
        bias.mul_(alibi_slopes[:, None, None])
        attn_biases.append(LowerTriangularMaskWithTensorBias(bias))

    return attn_biases