xformers.py 32.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
"""Attention layer with xFormers and PagedAttention."""
3
from dataclasses import dataclass
4
from typing import Any, Dict, List, Optional, Tuple, Type
Woosuk Kwon's avatar
Woosuk Kwon committed
5
6

import torch
7
from xformers import ops as xops
8
9
from xformers.ops.fmha.attn_bias import (AttentionBias,
                                         BlockDiagonalCausalMask,
10
                                         BlockDiagonalMask,
Woosuk Kwon's avatar
Woosuk Kwon committed
11
                                         LowerTriangularMaskWithTensorBias)
Woosuk Kwon's avatar
Woosuk Kwon committed
12

13
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
14
                                              AttentionLayer,
15
                                              AttentionMetadata, AttentionType)
16
17
18
19
from vllm.attention.backends.utils import (
    CommonAttentionState, CommonMetadataBuilder,
    get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args,
    is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set)
20
21
from vllm.attention.ops.paged_attn import (PagedAttention,
                                           PagedAttentionMetadata)
22
23
24
from vllm.logger import init_logger

logger = init_logger(__name__)
25

26
27
28

class XFormersBackend(AttentionBackend):

29
30
    @staticmethod
    def get_name() -> str:
31
        return "XFORMERS"
32

33
34
35
36
37
    @staticmethod
    def get_impl_cls() -> Type["XFormersImpl"]:
        return XFormersImpl

    @staticmethod
38
39
    def get_metadata_cls() -> Type["AttentionMetadata"]:
        return XFormersMetadata
40

41
42
43
44
    @staticmethod
    def get_builder_cls() -> Type["XFormersMetadataBuilder"]:
        return XFormersMetadataBuilder

45
46
47
48
    @staticmethod
    def get_state_cls() -> Type["CommonAttentionState"]:
        return CommonAttentionState

49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
        return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
                                                 num_kv_heads, head_size)

    @staticmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
        src_to_dst: Dict[int, int],
    ) -> None:
        PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)

    @staticmethod
    def copy_blocks(
        kv_caches: List[torch.Tensor],
70
        src_to_dists: torch.Tensor,
71
72
73
74
75
    ) -> None:
        PagedAttention.copy_blocks(kv_caches, src_to_dists)


@dataclass
76
class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
77
78
79
80
81
82
83
84
85
86
87
88
    """Metadata for XFormersbackend.

    NOTE: Any python object stored here is not updated when it is
    cuda-graph replayed. If you have values that need to be changed
    dynamically, it should be stored in tensor. The tensor has to be
    updated from `CUDAGraphRunner.forward` API.
    """

    # |---------- N-1 iteration --------|
    # |---------------- N iteration ---------------------|
    # |- tokenA -|......................|-- newTokens ---|
    # |---------- context_len ----------|
89
90
    # |-------------------- seq_len ----------------------|
    #                                   |-- query_len ---|
91

92
93
94
    # seq_lens stored as a tensor.
    seq_lens_tensor: Optional[torch.Tensor]

95
    # FIXME: It is for flash attn.
96
97
98
99
100
101
    # Maximum sequence length among prefill batch. 0 if there are decoding
    # requests only.
    max_prefill_seq_len: int
    # Maximum sequence length among decode batch. 0 if there are prefill
    # requests only.
    max_decode_seq_len: int
102
103
104
105
106
107
108
109
110
111

    # Whether or not if cuda graph is enabled.
    # Cuda-graph is currently enabled for decoding only.
    # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
    use_cuda_graph: bool

    # (batch_size,). The sequence length per sequence. Sequence length means
    # the computed tokens + new tokens None if it is a decoding.
    seq_lens: Optional[List[int]] = None

112
113
114
115
    # FIXME: It is for flash attn.
    # (batch_size + 1,). The cumulative sequence lengths of the sequences in
    # the batch, used to index into sequence. E.g., if the sequence length is
    # [4, 6], it is [0, 4, 10].
116
117
    seq_start_loc: Optional[torch.Tensor] = None

118
119
    # (batch_size,) A tensor of context lengths (tokens that are computed
    # so far).
120
    context_lens_tensor: Optional[torch.Tensor] = None
121

122
123
124
    # Maximum query length in the batch. None for decoding.
    max_query_len: Optional[int] = None

125
126
    # Max number of query tokens among request in the batch.
    max_decode_query_len: Optional[int] = None
127

128
129
130
131
132
133
    # (batch_size + 1,). The cumulative subquery lengths of the sequences in
    # the batch, used to index into subquery. E.g., if the subquery length
    # is [4, 6], it is [0, 4, 10].
    query_start_loc: Optional[torch.Tensor] = None

    # Self-attention prefill/decode metadata cache
134
135
    _cached_prefill_metadata: Optional["XFormersMetadata"] = None
    _cached_decode_metadata: Optional["XFormersMetadata"] = None
136

137
138
139
140
141
    # Begin encoder attn & enc/dec cross-attn fields...

    # Encoder sequence lengths representation
    encoder_seq_lens: Optional[List[int]] = None
    encoder_seq_lens_tensor: Optional[torch.Tensor] = None
142
143
144
145
146
    # FIXME: It is for flash attn.
    # (batch_size + 1,). The cumulative sequence lengths of the sequences in
    # the batch, used to index into sequence. E.g., if the sequence length is
    # [4, 6], it is [0, 4, 10].
    encoder_seq_start_loc: Optional[torch.Tensor] = None
147
148
149
150
151
152
153
154
155
156
157
158

    # Maximum sequence length among encoder sequences
    max_encoder_seq_len: Optional[int] = None

    # Number of tokens input to encoder
    num_encoder_tokens: Optional[int] = None

    # Cross-attention memory-mapping data structures: slot mapping
    # and block tables
    cross_slot_mapping: Optional[torch.Tensor] = None
    cross_block_tables: Optional[torch.Tensor] = None

159
160
161
162
163
164
165
    def __post_init__(self):
        # Set during the execution of the first attention op.
        # It is a list because it is needed to set per prompt
        # when alibi slopes is used. It is because of the limitation
        # from xformer API.
        # will not appear in the __repr__ and __init__
        self.attn_bias: Optional[List[AttentionBias]] = None
166
167
168
169
170
171
172
173
        self.encoder_attn_bias: Optional[List[AttentionBias]] = None
        self.cross_attn_bias: Optional[List[AttentionBias]] = None

    @property
    def is_all_encoder_attn_metadata_set(self):
        '''
        All attention metadata required for encoder attention is set.
        '''
174
        return is_all_encoder_attn_metadata_set(self)
175
176
177
178
179
180
181
182

    @property
    def is_all_cross_attn_metadata_set(self):
        '''
        All attention metadata required for enc/dec cross-attention is set.

        Superset of encoder attention required metadata.
        '''
183
        return is_all_cross_attn_metadata_set(self)
184

185
186
187
188
189
190
    @property
    def prefill_metadata(self) -> Optional["XFormersMetadata"]:
        if self.num_prefills == 0:
            return None

        if self._cached_prefill_metadata is not None:
191
192
            # Recover cached prefill-phase attention
            # metadata structure
193
194
            return self._cached_prefill_metadata

195
196
197
198
199
200
201
202
        assert ((self.seq_lens is not None)
                or (self.encoder_seq_lens is not None))
        assert ((self.seq_lens_tensor is not None)
                or (self.encoder_seq_lens_tensor is not None))

        # Compute some attn_metadata fields which default to None
        query_start_loc = (None if self.query_start_loc is None else
                           self.query_start_loc[:self.num_prefills + 1])
203
204
        seq_start_loc = (None if self.seq_start_loc is None else
                         self.seq_start_loc[:self.num_prefills + 1])
205
206
207
208
209
210
211
212
213
214
215
216
        slot_mapping = (None if self.slot_mapping is None else
                        self.slot_mapping[:self.num_prefill_tokens])
        seq_lens = (None if self.seq_lens is None else
                    self.seq_lens[:self.num_prefills])
        seq_lens_tensor = (None if self.seq_lens_tensor is None else
                           self.seq_lens_tensor[:self.num_prefills])
        context_lens_tensor = (None if self.context_lens_tensor is None else
                               self.context_lens_tensor[:self.num_prefills])
        block_tables = (None if self.block_tables is None else
                        self.block_tables[:self.num_prefills])

        # Construct & cache prefill-phase attention metadata structure
217
218
219
220
        self._cached_prefill_metadata = XFormersMetadata(
            num_prefills=self.num_prefills,
            num_prefill_tokens=self.num_prefill_tokens,
            num_decode_tokens=0,
221
            slot_mapping=slot_mapping,
222
223
            multi_modal_placeholder_index_maps=self.
            multi_modal_placeholder_index_maps,
224
            enable_kv_scales_calculation=self.enable_kv_scales_calculation,
225
226
            seq_lens=seq_lens,
            seq_lens_tensor=seq_lens_tensor,
227
228
229
            max_query_len=self.max_query_len,
            max_prefill_seq_len=self.max_prefill_seq_len,
            max_decode_seq_len=0,
230
            query_start_loc=query_start_loc,
231
            seq_start_loc=seq_start_loc,
232
233
            context_lens_tensor=context_lens_tensor,
            block_tables=block_tables,
234
            use_cuda_graph=False,
235
236
237
238
239
240
            # Begin encoder & cross attn fields below...
            encoder_seq_lens=self.encoder_seq_lens,
            encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
            max_encoder_seq_len=self.max_encoder_seq_len,
            cross_slot_mapping=self.cross_slot_mapping,
            cross_block_tables=self.cross_block_tables)
241
242
243
244
245
246
247
248
        return self._cached_prefill_metadata

    @property
    def decode_metadata(self) -> Optional["XFormersMetadata"]:
        if self.num_decode_tokens == 0:
            return None

        if self._cached_decode_metadata is not None:
249
250
            # Recover cached decode-phase attention
            # metadata structure
251
            return self._cached_decode_metadata
252
253
254
255
256
257
258
259
260
261
262
263
        assert ((self.seq_lens_tensor is not None)
                or (self.encoder_seq_lens_tensor is not None))

        # Compute some attn_metadata fields which default to None
        slot_mapping = (None if self.slot_mapping is None else
                        self.slot_mapping[self.num_prefill_tokens:])
        seq_lens_tensor = (None if self.seq_lens_tensor is None else
                           self.seq_lens_tensor[self.num_prefills:])
        block_tables = (None if self.block_tables is None else
                        self.block_tables[self.num_prefills:])

        # Construct & cache decode-phase attention metadata structure
264
265
266
267
        self._cached_decode_metadata = XFormersMetadata(
            num_prefills=0,
            num_prefill_tokens=0,
            num_decode_tokens=self.num_decode_tokens,
268
            slot_mapping=slot_mapping,
269
            multi_modal_placeholder_index_maps=None,
270
            enable_kv_scales_calculation=True,
271
            seq_lens_tensor=seq_lens_tensor,
272
273
            max_prefill_seq_len=0,
            max_decode_seq_len=self.max_decode_seq_len,
274
            block_tables=block_tables,
275
            use_cuda_graph=self.use_cuda_graph,
276
277
278
279
280
281
            # Begin encoder & cross attn fields below...
            encoder_seq_lens=self.encoder_seq_lens,
            encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
            max_encoder_seq_len=self.max_encoder_seq_len,
            cross_slot_mapping=self.cross_slot_mapping,
            cross_block_tables=self.cross_block_tables)
282
283
284
285
286
287
288

        # Batch may be composed of prefill|decodes, adjust query start indices
        # to refer to the start of decodes when the two are split apart.
        # E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
        if self._cached_decode_metadata.query_start_loc is not None:
            qs = self._cached_decode_metadata.query_start_loc
            self._cached_decode_metadata.query_start_loc = qs - qs[0]
289
290
291
        return self._cached_decode_metadata


292
293
def _get_attn_bias(
    attn_metadata: XFormersMetadata,
294
    attn_type: str,
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
) -> Optional[AttentionBias]:
    '''
    Extract appropriate attention bias from attention metadata
    according to attention type.

    Arguments:

    * attn_metadata: Attention metadata structure associated with attention
    * attn_type: encoder attention, decoder self-attention,
                 encoder/decoder cross-attention

    Returns:
    * Appropriate attention bias value given the attention type
    '''

310
311
    if (attn_type == AttentionType.DECODER
            or attn_type == AttentionType.ENCODER_ONLY):
312
313
314
        return attn_metadata.attn_bias
    elif attn_type == AttentionType.ENCODER:
        return attn_metadata.encoder_attn_bias
315
    elif attn_type == AttentionType.ENCODER_DECODER:
316
        return attn_metadata.cross_attn_bias
317
318
    else:
        raise AttributeError(f"Invalid attention type {str(attn_type)}")
319
320
321
322
323


def _set_attn_bias(
    attn_metadata: XFormersMetadata,
    attn_bias: List[Optional[AttentionBias]],
324
    attn_type: str,
325
326
327
328
329
330
331
332
333
334
335
336
337
) -> None:
    '''
    Update appropriate attention bias field of attention metadata,
    according to attention type.

    Arguments:

    * attn_metadata: Attention metadata structure associated with attention
    * attn_bias: The desired attention bias value
    * attn_type: encoder attention, decoder self-attention,
                 encoder/decoder cross-attention
    '''

338
339
    if (attn_type == AttentionType.DECODER
            or attn_type == AttentionType.ENCODER_ONLY):
340
341
342
343
344
345
346
347
348
        attn_metadata.attn_bias = attn_bias
    elif attn_type == AttentionType.ENCODER:
        attn_metadata.encoder_attn_bias = attn_bias
    elif attn_type == AttentionType.ENCODER_DECODER:
        attn_metadata.cross_attn_bias = attn_bias
    else:
        raise AttributeError(f"Invalid attention type {str(attn_type)}")


349
350
351
352
353
class XFormersMetadataBuilder(CommonMetadataBuilder[XFormersMetadata]):

    _metadata_cls = XFormersMetadata


354
class XFormersImpl(AttentionImpl[XFormersMetadata]):
355
356
    """
    If the input tensors contain prompt tokens, the layout is as follows:
357
358
    |<--------------- num_prefill_tokens ----------------->|	
    |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
359
360

    Otherwise, the layout is as follows:	
361
362
    |<----------------- num_decode_tokens ------------------>|	
    |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
363
364
365
366
367
368

    Generation tokens can contain padding when cuda-graph is used.
    Currently, prompt tokens don't contain any padding.

    The prompts might have different lengths, while the generation tokens
    always have length 1.
369
370
371
372
373
374
375
376
377

    If chunked prefill is enabled, prefill tokens and decode tokens can be
    batched together in a flattened 1D query.

    |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
    |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|

    Currently, cuda graph is disabled for chunked prefill, meaning there's no
    padding between prefill and decode tokens.
378
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
379

Woosuk Kwon's avatar
Woosuk Kwon committed
380
381
382
383
384
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
385
386
387
388
        num_kv_heads: int,
        alibi_slopes: Optional[List[float]],
        sliding_window: Optional[int],
        kv_cache_dtype: str,
389
        blocksparse_params: Optional[Dict[str, Any]] = None,
390
        logits_soft_cap: Optional[float] = None,
391
        attn_type: str = AttentionType.DECODER,
Woosuk Kwon's avatar
Woosuk Kwon committed
392
    ) -> None:
393
394
395
396
        if blocksparse_params is not None:
            raise ValueError(
                "XFormers does not support block-sparse attention.")
        if logits_soft_cap is not None:
397
398
            logger.warning_once("XFormers does not support logits soft cap. "
                                "Outputs may be slightly off.")
399
400
        self.num_heads = num_heads
        self.head_size = head_size
401
        self.scale = float(scale)
402
        self.num_kv_heads = num_kv_heads
Woosuk Kwon's avatar
Woosuk Kwon committed
403
404
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
405
        self.alibi_slopes = alibi_slopes
406
407
        self.sliding_window = sliding_window
        self.kv_cache_dtype = kv_cache_dtype
Zhuohan Li's avatar
Zhuohan Li committed
408
409
410

        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
411
412

        suppored_head_sizes = PagedAttention.get_supported_head_sizes()
413
414
415
416
        if head_size not in suppored_head_sizes:
            raise ValueError(
                f"Head size {head_size} is not supported by PagedAttention. "
                f"Supported head sizes are: {suppored_head_sizes}.")
Woosuk Kwon's avatar
Woosuk Kwon committed
417

418
419
        self.attn_type = attn_type

Woosuk Kwon's avatar
Woosuk Kwon committed
420
421
    def forward(
        self,
422
        layer: AttentionLayer,
423
        query: torch.Tensor,
424
425
        key: Optional[torch.Tensor],
        value: Optional[torch.Tensor],
426
        kv_cache: torch.Tensor,
427
        attn_metadata: "XFormersMetadata",
428
        output: Optional[torch.Tensor] = None,
429
    ) -> torch.Tensor:
430
        """Forward pass with xFormers and PagedAttention.
431

432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
        For decoder-only models: query, key and value must be non-None.

        For encoder/decoder models:
        * XFormersImpl.forward() may be invoked for both self- and cross-
          attention layers.
        * For self-attention: query, key and value must be non-None.
        * For cross-attention:
            * Query must be non-None
            * During prefill, key and value must be non-None; key and value
              get cached for use during decode.
            * During decode, key and value may be None, since:
              (1) key and value tensors were cached during prefill, and
              (2) cross-attention key and value tensors do not grow during
                  decode
        
        A note on how the attn_type (attention type enum) argument impacts
        attention forward() behavior:
    
            * DECODER: normal decoder-only behavior;
                use decoder self-attention block table
            * ENCODER: no KV caching; pass encoder sequence
                attributes (encoder_seq_lens/encoder_seq_lens_tensor/
                max_encoder_seq_len) to kernel, in lieu of decoder
455
456
457
458
                sequence attributes (seq_lens/seq_lens_tensor/max_seq_len).
                Used for encoder branch of encoder-decoder models.
            * ENCODER_ONLY: no kv_caching, uses the normal attention 
                attributes (seq_lens/seq_lens_tensor/max_seq_len).
459
460
461
462
463
464
465
            * ENCODER_DECODER: cross-attention behavior;
                use cross-attention block table for caching KVs derived
                from encoder hidden states; since KV sequence lengths
                will match encoder sequence lengths, pass encoder sequence
                attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/
                max_encoder_seq_len)
    
466
        Args:
467
468
469
            query: shape = [num_tokens, num_heads * head_size]
            key: shape = [num_tokens, num_kv_heads * head_size]
            value: shape = [num_tokens, num_kv_heads * head_size]
470
            kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
471
472
                NOTE: kv_cache will be an empty tensor with shape [0]
                for profiling run.
473
            attn_metadata: Metadata for attention.
474
475
476
477
            attn_type: Select attention type, between encoder attention,
                       decoder self-attention, or encoder/decoder cross-
                       attention. Defaults to decoder self-attention,
                       which is the vLLM default generally
478
        Returns:
479
            shape = [num_tokens, num_heads * head_size]
480
        """
481
        attn_type = self.attn_type
482
483
484
485
486
487
        # Check that appropriate attention metadata attributes are
        # selected for the desired attention type
        if (attn_type == AttentionType.ENCODER
                and (not attn_metadata.is_all_encoder_attn_metadata_set)):
            raise AttributeError("Encoder attention requires setting "
                                 "encoder metadata attributes.")
488

489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
        elif (attn_type == AttentionType.ENCODER_DECODER
              and (not attn_metadata.is_all_cross_attn_metadata_set)):
            raise AttributeError("Encoder/decoder cross-attention "
                                 "requires setting cross-attention "
                                 "metadata attributes.")

        query = query.view(-1, self.num_heads, self.head_size)
        if key is not None:
            assert value is not None
            key = key.view(-1, self.num_kv_heads, self.head_size)
            value = value.view(-1, self.num_kv_heads, self.head_size)
        else:
            assert value is None

        # Self-attention vs. cross-attention will impact
        # which KV cache memory-mapping & which
        # seqlen datastructures we utilize

507
        if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0):
508
509
510
511
512
513
514
            # KV-cache during decoder-self- or
            # encoder-decoder-cross-attention, but not
            # during encoder attention.
            #
            # Even if there are no new key/value pairs to cache,
            # we still need to break out key_cache and value_cache
            # i.e. for later use by paged attention
515
516
517
            key_cache, value_cache = PagedAttention.split_kv_cache(
                kv_cache, self.num_kv_heads, self.head_size)

518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
            if (key is not None) and (value is not None):

                if attn_type == AttentionType.ENCODER_DECODER:
                    # Update cross-attention KV cache (prefill-only)
                    # During cross-attention decode, key & value will be None,
                    # preventing this IF-statement branch from running
                    updated_slot_mapping = attn_metadata.cross_slot_mapping
                else:
                    # Update self-attention KV cache (prefill/decode)
                    updated_slot_mapping = attn_metadata.slot_mapping

                # Reshape the input keys and values and store them in the cache.
                # If kv_cache is not provided, the new key and value tensors are
                # not cached. This happens during the initial memory
                # profiling run.
533
534
535
                PagedAttention.write_to_paged_cache(
                    key, value, key_cache, value_cache, updated_slot_mapping,
                    self.kv_cache_dtype, layer._k_scale, layer._v_scale)
536
537
538
        (num_prefill_query_tokens, num_prefill_kv_tokens,
        num_decode_query_tokens) = \
            get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
539
540
541

        output = torch.empty_like(query)
        # Query for decode. KV is not needed because it is already cached.
542
        decode_query = query[num_prefill_query_tokens:]
543
        # QKV for prefill.
544
        query = query[:num_prefill_query_tokens]
545
        if key is not None and value is not None:
546
547
            key = key[:num_prefill_kv_tokens]
            value = value[:num_prefill_kv_tokens]
548

549
550
        assert query.shape[0] == num_prefill_query_tokens
        assert decode_query.shape[0] == num_decode_query_tokens
551
552

        if prefill_meta := attn_metadata.prefill_metadata:
553
            # Prompt run.
554
            if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0:
555
556
557
                # normal attention.
                # block tables are empty if the prompt does not have a cached
                # prefix.
558
                out = self._run_memory_efficient_xformers_forward(
559
                    query, key, value, prefill_meta, attn_type=attn_type)
560
561
                assert out.shape == output[:num_prefill_query_tokens].shape
                output[:num_prefill_query_tokens] = out
Woosuk Kwon's avatar
Woosuk Kwon committed
562
            else:
563
564
                assert attn_type != AttentionType.ENCODER_ONLY, (
                    "Encoder-only models should not have prefix attention.")
565
566
567
568

                assert prefill_meta.query_start_loc is not None
                assert prefill_meta.max_query_len is not None

569
                # prefix-enabled attention
570
571
572
                # TODO(Hai) this triton kernel has regression issue (broke) to
                # deal with different data types between KV and FP8 KV cache,
                # to be addressed separately.
573
                out = PagedAttention.forward_prefix(
574
575
576
                    query,
                    key,
                    value,
577
                    self.kv_cache_dtype,
578
579
                    key_cache,
                    value_cache,
580
                    prefill_meta.block_tables,
581
                    prefill_meta.query_start_loc,
582
583
                    prefill_meta.seq_lens_tensor,
                    prefill_meta.max_query_len,
584
                    self.alibi_slopes,
585
                    self.sliding_window,
586
587
                    layer._k_scale,
                    layer._v_scale,
588
                )
589
590
                assert output[:num_prefill_query_tokens].shape == out.shape
                output[:num_prefill_query_tokens] = out
591
592

        if decode_meta := attn_metadata.decode_metadata:
593
594
            assert attn_type != AttentionType.ENCODER_ONLY, (
                "Encoder-only models should not have decode metadata.")
595
596
597
598
599

            (
                seq_lens_arg,
                max_seq_len_arg,
                block_tables_arg,
600
            ) = get_seq_len_block_table_args(decode_meta, False, attn_type)
601

602
            output[num_prefill_query_tokens:] = PagedAttention.forward_decode(
603
                decode_query,
604
605
                key_cache,
                value_cache,
606
607
608
                block_tables_arg,
                seq_lens_arg,
                max_seq_len_arg,
609
                self.kv_cache_dtype,
610
611
612
                self.num_kv_heads,
                self.scale,
                self.alibi_slopes,
613
614
                layer._k_scale,
                layer._v_scale,
615
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
616
617

        # Reshape the output tensor.
618
619
        return output.view(-1, self.num_heads * self.head_size)

620
    def _run_memory_efficient_xformers_forward(
621
622
623
624
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
625
        attn_metadata: XFormersMetadata,
626
        attn_type: str = AttentionType.DECODER,
627
628
629
630
    ) -> torch.Tensor:
        """Attention for 1D query of multiple prompts. Multiple prompt
        tokens are flattened in to `query` input.

631
632
633
        See https://facebookresearch.github.io/xformers/components/ops.html
        for API spec.

634
        Args:
635
636
637
638
            output: shape = [num_prefill_tokens, num_heads, head_size]
            query: shape = [num_prefill_tokens, num_heads, head_size]
            key: shape = [num_prefill_tokens, num_kv_heads, head_size]
            value: shape = [num_prefill_tokens, num_kv_heads, head_size]
639
            attn_metadata: Metadata for attention.
640
641
642
643
            attn_type: Select attention type, between encoder attention,
                       decoder self-attention, or encoder/decoder cross-
                       attention. Defaults to decoder self-attention,
                       which is the vLLM default generally
644
        """
645

646
647
648
649
650
651
652
653
654
655
656
657
658
659
        original_query = query
        if self.num_kv_heads != self.num_heads:
            # GQA/MQA requires the shape [B, M, G, H, K].
            # Note that the output also has the same shape (which is different
            # from a spec from the doc).
            query = query.view(query.shape[0], self.num_kv_heads,
                               self.num_queries_per_kv, query.shape[-1])
            key = key[:, :,
                      None, :].expand(key.shape[0], self.num_kv_heads,
                                      self.num_queries_per_kv, key.shape[-1])
            value = value[:, :,
                          None, :].expand(value.shape[0], self.num_kv_heads,
                                          self.num_queries_per_kv,
                                          value.shape[-1])
660

661
662
663
        # Set attention bias if not provided. This typically happens at
        # the very attention layer of every iteration.
        # FIXME(woosuk): This is a hack.
664
665
        attn_bias = _get_attn_bias(attn_metadata, attn_type)
        if attn_bias is None:
666
            if self.alibi_slopes is None:
667
668
669

                # Cross attention block of decoder branch of encoder-decoder
                # model uses seq_lens for dec / encoder_seq_lens for enc
670
671
672
673
                if (attn_type == AttentionType.ENCODER_DECODER):
                    assert attn_metadata.seq_lens is not None
                    assert attn_metadata.encoder_seq_lens is not None

674
                    # Cross-attention mask is non-causal
675
676
                    attn_bias = BlockDiagonalMask.from_seqlens(
                        attn_metadata.seq_lens, attn_metadata.encoder_seq_lens)
677
678
679

                # Encoder branch of encoder-decoder model uses
                # attn_metadata.encoder_seq_lens
680
                elif attn_type == AttentionType.ENCODER:
681

682
683
                    assert attn_metadata.encoder_seq_lens is not None

684
                    # Encoder self-attention mask is non-causal
685
686
                    attn_bias = BlockDiagonalMask.from_seqlens(
                        attn_metadata.encoder_seq_lens)
687
688
689
690

                # Self-attention block of encoder-only model just
                # uses the seq_lens directly.
                elif attn_type == AttentionType.ENCODER_ONLY:
691
692
                    assert attn_metadata.seq_lens is not None

693
694
695
696
697
698
699
700
701
702
                    # Encoder self-attention mask is non-causal
                    attn_bias = BlockDiagonalMask.from_seqlens(
                        attn_metadata.seq_lens)

                # Self-attention block of decoder branch just
                # uses the seq_lens directly
                elif attn_type == AttentionType.DECODER:
                    assert attn_metadata.seq_lens is not None

                    # Decoder self-attention mask is causal
703
704
                    attn_bias = BlockDiagonalCausalMask.from_seqlens(
                        attn_metadata.seq_lens)
705
706
707
                else:
                    raise ValueError("Unknown AttentionType: %s", attn_type)

708
709
710
                if self.sliding_window is not None:
                    attn_bias = attn_bias.make_local_attention(
                        self.sliding_window)
711
                attn_bias = [attn_bias]
712
            else:
713
                assert attn_type == AttentionType.DECODER
714
715
716
717
718
719
                assert attn_metadata.seq_lens is not None
                attn_bias = _make_alibi_bias(self.alibi_slopes,
                                             self.num_kv_heads, query.dtype,
                                             attn_metadata.seq_lens)

            _set_attn_bias(attn_metadata, attn_bias, attn_type)
720
721
722
723
724

        # No alibi slopes.
        # TODO(woosuk): Too many view operations. Let's try to reduce
        # them in the future for code readability.
        if self.alibi_slopes is None:
725
            # Add the batch dimension.
726
727
728
729
730
731
732
            query = query.unsqueeze(0)
            key = key.unsqueeze(0)
            value = value.unsqueeze(0)
            out = xops.memory_efficient_attention_forward(
                query,
                key,
                value,
733
                attn_bias=attn_bias[0],
734
                p=0.0,
735
                scale=self.scale)
736
            return out.view_as(original_query)
737
738
739
740
741

        # Attention with alibi slopes.
        # FIXME(woosuk): Because xformers does not support dynamic sequence
        # lengths with custom attention bias, we process each prompt one by
        # one. This is inefficient, especially when we have many short prompts.
742
        assert attn_metadata.seq_lens is not None
743
        output = torch.empty_like(original_query)
744
        start = 0
745
746
        for i, seq_len in enumerate(attn_metadata.seq_lens):
            end = start + seq_len
747
748
749
750
            out = xops.memory_efficient_attention_forward(
                query[None, start:end],
                key[None, start:end],
                value[None, start:end],
751
                attn_bias=attn_bias[i],
752
                p=0.0,
753
                scale=self.scale)
754
            # TODO(woosuk): Unnecessary copy. Optimize.
755
            output[start:end].copy_(out.view_as(original_query[start:end]))
756
            start += seq_len
757
        return output
Woosuk Kwon's avatar
Woosuk Kwon committed
758
759
760
761


def _make_alibi_bias(
    alibi_slopes: torch.Tensor,
Megha Agarwal's avatar
Megha Agarwal committed
762
    num_kv_heads: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
763
    dtype: torch.dtype,
764
    seq_lens: List[int],
765
766
) -> List[AttentionBias]:
    attn_biases: List[AttentionBias] = []
767
768
    for seq_len in seq_lens:
        bias = torch.arange(seq_len, dtype=dtype)
769
        # NOTE(zhuohan): HF uses
770
        #     `bias = bias[None, :].repeat(seq_len, 1)`
771
772
773
774
775
776
777
        # here. We find that both biases give the same results, but
        # the bias below more accurately follows the original ALiBi
        # paper.
        # Calculate a matrix where each element represents ith element- jth
        # element.
        bias = bias[None, :] - bias[:, None]

778
        padded_len = (seq_len + 7) // 8 * 8
779
780
781
782
        num_heads = alibi_slopes.shape[0]
        bias = torch.empty(
            1,  # batch size
            num_heads,
783
            seq_len,
784
785
786
            padded_len,
            device=alibi_slopes.device,
            dtype=dtype,
787
        )[:, :, :, :seq_len].copy_(bias)
788
789
790
791
792
793
        bias.mul_(alibi_slopes[:, None, None])
        if num_heads != num_kv_heads:
            bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
        attn_biases.append(LowerTriangularMaskWithTensorBias(bias))

    return attn_biases