xformers.py 32.4 KB
Newer Older
1
"""Attention layer with xFormers and PagedAttention."""
2
from dataclasses import dataclass
3
from typing import Any, Dict, List, Optional, Tuple, Type
Woosuk Kwon's avatar
Woosuk Kwon committed
4
5

import torch
6
from xformers import ops as xops
7
8
from xformers.ops.fmha.attn_bias import (AttentionBias,
                                         BlockDiagonalCausalMask,
9
                                         BlockDiagonalMask,
Woosuk Kwon's avatar
Woosuk Kwon committed
10
                                         LowerTriangularMaskWithTensorBias)
Woosuk Kwon's avatar
Woosuk Kwon committed
11

12
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
13
                                              AttentionMetadata, AttentionType)
14
from vllm.attention.backends.utils import CommonMetadataBuilder
15
16
from vllm.attention.ops.paged_attn import (PagedAttention,
                                           PagedAttentionMetadata)
17
from vllm.logger import init_logger
Woosuk Kwon's avatar
Woosuk Kwon committed
18

19
logger = init_logger(__name__)
20

21
22
23

class XFormersBackend(AttentionBackend):

24
25
26
27
    @staticmethod
    def get_name() -> str:
        return "xformers"

28
29
30
31
32
    @staticmethod
    def get_impl_cls() -> Type["XFormersImpl"]:
        return XFormersImpl

    @staticmethod
33
34
    def get_metadata_cls() -> Type["AttentionMetadata"]:
        return XFormersMetadata
35

36
37
38
39
    @staticmethod
    def get_builder_cls() -> Type["XFormersMetadataBuilder"]:
        return XFormersMetadataBuilder

40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
        return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
                                                 num_kv_heads, head_size)

    @staticmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
        src_to_dst: Dict[int, int],
    ) -> None:
        PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)

    @staticmethod
    def copy_blocks(
        kv_caches: List[torch.Tensor],
61
        src_to_dists: torch.Tensor,
62
63
64
65
66
    ) -> None:
        PagedAttention.copy_blocks(kv_caches, src_to_dists)


@dataclass
67
class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
68
69
70
71
72
73
74
75
76
77
78
79
    """Metadata for XFormersbackend.

    NOTE: Any python object stored here is not updated when it is
    cuda-graph replayed. If you have values that need to be changed
    dynamically, it should be stored in tensor. The tensor has to be
    updated from `CUDAGraphRunner.forward` API.
    """

    # |---------- N-1 iteration --------|
    # |---------------- N iteration ---------------------|
    # |- tokenA -|......................|-- newTokens ---|
    # |---------- context_len ----------|
80
81
    # |-------------------- seq_len ----------------------|
    #                                   |-- query_len ---|
82

83
84
85
    # seq_lens stored as a tensor.
    seq_lens_tensor: Optional[torch.Tensor]

86
    # FIXME: It is for flash attn.
87
88
89
90
91
92
    # Maximum sequence length among prefill batch. 0 if there are decoding
    # requests only.
    max_prefill_seq_len: int
    # Maximum sequence length among decode batch. 0 if there are prefill
    # requests only.
    max_decode_seq_len: int
93
94
95
96
97
98
99
100
101
102

    # Whether or not if cuda graph is enabled.
    # Cuda-graph is currently enabled for decoding only.
    # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
    use_cuda_graph: bool

    # (batch_size,). The sequence length per sequence. Sequence length means
    # the computed tokens + new tokens None if it is a decoding.
    seq_lens: Optional[List[int]] = None

103
104
105
106
    # FIXME: It is for flash attn.
    # (batch_size + 1,). The cumulative sequence lengths of the sequences in
    # the batch, used to index into sequence. E.g., if the sequence length is
    # [4, 6], it is [0, 4, 10].
107
108
    seq_start_loc: Optional[torch.Tensor] = None

109
110
    # (batch_size,) A tensor of context lengths (tokens that are computed
    # so far).
111
    context_lens_tensor: Optional[torch.Tensor] = None
112

113
114
115
116
117
118
119
120
121
    # Maximum query length in the batch. None for decoding.
    max_query_len: Optional[int] = None

    # (batch_size + 1,). The cumulative subquery lengths of the sequences in
    # the batch, used to index into subquery. E.g., if the subquery length
    # is [4, 6], it is [0, 4, 10].
    query_start_loc: Optional[torch.Tensor] = None

    # Self-attention prefill/decode metadata cache
122
123
    _cached_prefill_metadata: Optional["XFormersMetadata"] = None
    _cached_decode_metadata: Optional["XFormersMetadata"] = None
124

125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
    # Begin encoder attn & enc/dec cross-attn fields...

    # Encoder sequence lengths representation
    encoder_seq_lens: Optional[List[int]] = None
    encoder_seq_lens_tensor: Optional[torch.Tensor] = None

    # Maximum sequence length among encoder sequences
    max_encoder_seq_len: Optional[int] = None

    # Number of tokens input to encoder
    num_encoder_tokens: Optional[int] = None

    # Cross-attention memory-mapping data structures: slot mapping
    # and block tables
    cross_slot_mapping: Optional[torch.Tensor] = None
    cross_block_tables: Optional[torch.Tensor] = None

142
143
144
145
146
147
148
    def __post_init__(self):
        # Set during the execution of the first attention op.
        # It is a list because it is needed to set per prompt
        # when alibi slopes is used. It is because of the limitation
        # from xformer API.
        # will not appear in the __repr__ and __init__
        self.attn_bias: Optional[List[AttentionBias]] = None
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
        self.encoder_attn_bias: Optional[List[AttentionBias]] = None
        self.cross_attn_bias: Optional[List[AttentionBias]] = None

    @property
    def is_all_encoder_attn_metadata_set(self):
        '''
        All attention metadata required for encoder attention is set.
        '''
        return ((self.encoder_seq_lens is not None)
                and (self.encoder_seq_lens_tensor is not None)
                and (self.max_encoder_seq_len is not None))

    @property
    def is_all_cross_attn_metadata_set(self):
        '''
        All attention metadata required for enc/dec cross-attention is set.

        Superset of encoder attention required metadata.
        '''
        return (self.is_all_encoder_attn_metadata_set
                and (self.cross_slot_mapping is not None)
                and (self.cross_block_tables is not None))
171

172
173
174
175
176
177
    @property
    def prefill_metadata(self) -> Optional["XFormersMetadata"]:
        if self.num_prefills == 0:
            return None

        if self._cached_prefill_metadata is not None:
178
179
            # Recover cached prefill-phase attention
            # metadata structure
180
181
            return self._cached_prefill_metadata

182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
        assert ((self.seq_lens is not None)
                or (self.encoder_seq_lens is not None))
        assert ((self.seq_lens_tensor is not None)
                or (self.encoder_seq_lens_tensor is not None))

        # Compute some attn_metadata fields which default to None
        query_start_loc = (None if self.query_start_loc is None else
                           self.query_start_loc[:self.num_prefills + 1])
        slot_mapping = (None if self.slot_mapping is None else
                        self.slot_mapping[:self.num_prefill_tokens])
        seq_lens = (None if self.seq_lens is None else
                    self.seq_lens[:self.num_prefills])
        seq_lens_tensor = (None if self.seq_lens_tensor is None else
                           self.seq_lens_tensor[:self.num_prefills])
        context_lens_tensor = (None if self.context_lens_tensor is None else
                               self.context_lens_tensor[:self.num_prefills])
        block_tables = (None if self.block_tables is None else
                        self.block_tables[:self.num_prefills])

        # Construct & cache prefill-phase attention metadata structure
202
203
204
205
        self._cached_prefill_metadata = XFormersMetadata(
            num_prefills=self.num_prefills,
            num_prefill_tokens=self.num_prefill_tokens,
            num_decode_tokens=0,
206
207
208
            slot_mapping=slot_mapping,
            seq_lens=seq_lens,
            seq_lens_tensor=seq_lens_tensor,
209
210
211
            max_query_len=self.max_query_len,
            max_prefill_seq_len=self.max_prefill_seq_len,
            max_decode_seq_len=0,
212
213
214
            query_start_loc=query_start_loc,
            context_lens_tensor=context_lens_tensor,
            block_tables=block_tables,
215
            use_cuda_graph=False,
216
217
218
219
220
221
            # Begin encoder & cross attn fields below...
            encoder_seq_lens=self.encoder_seq_lens,
            encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
            max_encoder_seq_len=self.max_encoder_seq_len,
            cross_slot_mapping=self.cross_slot_mapping,
            cross_block_tables=self.cross_block_tables)
222
223
224
225
226
227
228
229
        return self._cached_prefill_metadata

    @property
    def decode_metadata(self) -> Optional["XFormersMetadata"]:
        if self.num_decode_tokens == 0:
            return None

        if self._cached_decode_metadata is not None:
230
231
            # Recover cached decode-phase attention
            # metadata structure
232
            return self._cached_decode_metadata
233
234
235
236
237
238
239
240
241
242
243
244
        assert ((self.seq_lens_tensor is not None)
                or (self.encoder_seq_lens_tensor is not None))

        # Compute some attn_metadata fields which default to None
        slot_mapping = (None if self.slot_mapping is None else
                        self.slot_mapping[self.num_prefill_tokens:])
        seq_lens_tensor = (None if self.seq_lens_tensor is None else
                           self.seq_lens_tensor[self.num_prefills:])
        block_tables = (None if self.block_tables is None else
                        self.block_tables[self.num_prefills:])

        # Construct & cache decode-phase attention metadata structure
245
246
247
248
        self._cached_decode_metadata = XFormersMetadata(
            num_prefills=0,
            num_prefill_tokens=0,
            num_decode_tokens=self.num_decode_tokens,
249
250
            slot_mapping=slot_mapping,
            seq_lens_tensor=seq_lens_tensor,
251
252
            max_prefill_seq_len=0,
            max_decode_seq_len=self.max_decode_seq_len,
253
            block_tables=block_tables,
254
            use_cuda_graph=self.use_cuda_graph,
255
256
257
258
259
260
            # Begin encoder & cross attn fields below...
            encoder_seq_lens=self.encoder_seq_lens,
            encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
            max_encoder_seq_len=self.max_encoder_seq_len,
            cross_slot_mapping=self.cross_slot_mapping,
            cross_block_tables=self.cross_block_tables)
261
262
263
        return self._cached_decode_metadata


264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
def _get_attn_bias(
    attn_metadata: XFormersMetadata,
    attn_type: AttentionType,
) -> Optional[AttentionBias]:
    '''
    Extract appropriate attention bias from attention metadata
    according to attention type.

    Arguments:

    * attn_metadata: Attention metadata structure associated with attention
    * attn_type: encoder attention, decoder self-attention,
                 encoder/decoder cross-attention

    Returns:
    * Appropriate attention bias value given the attention type
    '''

    if attn_type == AttentionType.DECODER:
        return attn_metadata.attn_bias
    elif attn_type == AttentionType.ENCODER:
        return attn_metadata.encoder_attn_bias
    else:
        # attn_type == AttentionType.ENCODER_DECODER
        return attn_metadata.cross_attn_bias


def _set_attn_bias(
    attn_metadata: XFormersMetadata,
    attn_bias: List[Optional[AttentionBias]],
    attn_type: AttentionType,
) -> None:
    '''
    Update appropriate attention bias field of attention metadata,
    according to attention type.

    Arguments:

    * attn_metadata: Attention metadata structure associated with attention
    * attn_bias: The desired attention bias value
    * attn_type: encoder attention, decoder self-attention,
                 encoder/decoder cross-attention
    '''

    if attn_type == AttentionType.DECODER:
        attn_metadata.attn_bias = attn_bias
    elif attn_type == AttentionType.ENCODER:
        attn_metadata.encoder_attn_bias = attn_bias
    elif attn_type == AttentionType.ENCODER_DECODER:
        attn_metadata.cross_attn_bias = attn_bias
    else:
        raise AttributeError(f"Invalid attention type {str(attn_type)}")


def _get_seq_len_block_table_args(
    attn_metadata: XFormersMetadata,
    is_prompt: bool,
    attn_type: AttentionType,
) -> tuple:
    '''
    The particular choice of sequence-length- and block-table-related
    attributes which should be extracted from attn_metadata is dependent
    on the type of attention operation.

    Decoder attn -> select entirely decoder self-attention-related fields
    Encoder/decoder cross-attn -> select encoder sequence lengths & 
                                  cross-attn block-tables fields
    Encoder attn -> select encoder sequence lengths fields & no block tables
    
    Arguments:

    * attn_metadata: Attention metadata structure associated with attention op
    * is_prompt: True if prefill, False otherwise
    * attn_type: encoder attention, decoder self-attention,
                 encoder/decoder cross-attention

    Returns:

    * Appropriate sequence-lengths tensor
    * Appropriate max sequence-length scalar
    * Appropriate block tables (or None)
    '''

    if attn_type == AttentionType.DECODER:
        # Decoder self-attention
        # Choose max_seq_len based on whether we are in prompt_run
        if is_prompt:
            max_seq_len = attn_metadata.max_prefill_seq_len
        else:
            max_seq_len = attn_metadata.max_decode_seq_len
        return (attn_metadata.seq_lens_tensor, max_seq_len,
                attn_metadata.block_tables)
    elif attn_type == AttentionType.ENCODER_DECODER:
        # Enc/dec cross-attention KVs match encoder sequence length;
        # cross-attention utilizes special "cross" block tables
        return (attn_metadata.encoder_seq_lens_tensor,
                attn_metadata.max_encoder_seq_len,
                attn_metadata.cross_block_tables)
    elif attn_type == AttentionType.ENCODER:
        # No block tables associated with encoder attention
        return (attn_metadata.encoder_seq_lens_tensor,
                attn_metadata.max_encoder_seq_len, None)
    else:
        raise AttributeError(f"Invalid attention type {str(attn_type)}")


370
371
372
373
374
class XFormersMetadataBuilder(CommonMetadataBuilder[XFormersMetadata]):

    _metadata_cls = XFormersMetadata


375
class XFormersImpl(AttentionImpl[XFormersMetadata]):
376
377
    """
    If the input tensors contain prompt tokens, the layout is as follows:
378
379
    |<--------------- num_prefill_tokens ----------------->|	
    |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
380
381

    Otherwise, the layout is as follows:	
382
383
    |<----------------- num_decode_tokens ------------------>|	
    |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
384
385
386
387
388
389

    Generation tokens can contain padding when cuda-graph is used.
    Currently, prompt tokens don't contain any padding.

    The prompts might have different lengths, while the generation tokens
    always have length 1.
390
391
392
393
394
395
396
397
398

    If chunked prefill is enabled, prefill tokens and decode tokens can be
    batched together in a flattened 1D query.

    |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
    |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|

    Currently, cuda graph is disabled for chunked prefill, meaning there's no
    padding between prefill and decode tokens.
399
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
400

Woosuk Kwon's avatar
Woosuk Kwon committed
401
402
403
404
405
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
406
407
408
409
        num_kv_heads: int,
        alibi_slopes: Optional[List[float]],
        sliding_window: Optional[int],
        kv_cache_dtype: str,
410
        blocksparse_params: Optional[Dict[str, Any]] = None,
411
        logits_soft_cap: Optional[float] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
412
    ) -> None:
413
414
415
416
417
418
        if blocksparse_params is not None:
            raise ValueError(
                "XFormers does not support block-sparse attention.")
        if logits_soft_cap is not None:
            raise ValueError(
                "XFormers does not support attention logits soft capping.")
419
420
        self.num_heads = num_heads
        self.head_size = head_size
421
        self.scale = float(scale)
422
        self.num_kv_heads = num_kv_heads
Woosuk Kwon's avatar
Woosuk Kwon committed
423
424
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
425
        self.alibi_slopes = alibi_slopes
426
427
        self.sliding_window = sliding_window
        self.kv_cache_dtype = kv_cache_dtype
Zhuohan Li's avatar
Zhuohan Li committed
428
429
430

        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
431
432

        suppored_head_sizes = PagedAttention.get_supported_head_sizes()
433
434
435
436
        if head_size not in suppored_head_sizes:
            raise ValueError(
                f"Head size {head_size} is not supported by PagedAttention. "
                f"Supported head sizes are: {suppored_head_sizes}.")
Woosuk Kwon's avatar
Woosuk Kwon committed
437
438
439

    def forward(
        self,
440
        query: torch.Tensor,
441
442
        key: Optional[torch.Tensor],
        value: Optional[torch.Tensor],
443
        kv_cache: Optional[torch.Tensor],
444
        attn_metadata: "XFormersMetadata",
445
446
        k_scale: float = 1.0,
        v_scale: float = 1.0,
447
        attn_type: AttentionType = AttentionType.DECODER,
448
    ) -> torch.Tensor:
449
        """Forward pass with xFormers and PagedAttention.
450

451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
        For decoder-only models: query, key and value must be non-None.

        For encoder/decoder models:
        * XFormersImpl.forward() may be invoked for both self- and cross-
          attention layers.
        * For self-attention: query, key and value must be non-None.
        * For cross-attention:
            * Query must be non-None
            * During prefill, key and value must be non-None; key and value
              get cached for use during decode.
            * During decode, key and value may be None, since:
              (1) key and value tensors were cached during prefill, and
              (2) cross-attention key and value tensors do not grow during
                  decode
        
        A note on how the attn_type (attention type enum) argument impacts
        attention forward() behavior:
    
            * DECODER: normal decoder-only behavior;
                use decoder self-attention block table
            * ENCODER: no KV caching; pass encoder sequence
                attributes (encoder_seq_lens/encoder_seq_lens_tensor/
                max_encoder_seq_len) to kernel, in lieu of decoder
                sequence attributes (seq_lens/seq_lens_tensor/max_seq_len)
            * ENCODER_DECODER: cross-attention behavior;
                use cross-attention block table for caching KVs derived
                from encoder hidden states; since KV sequence lengths
                will match encoder sequence lengths, pass encoder sequence
                attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/
                max_encoder_seq_len)
    
482
        Args:
483
484
485
            query: shape = [num_tokens, num_heads * head_size]
            key: shape = [num_tokens, num_kv_heads * head_size]
            value: shape = [num_tokens, num_kv_heads * head_size]
486
487
            kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
            attn_metadata: Metadata for attention.
488
489
490
491
            attn_type: Select attention type, between encoder attention,
                       decoder self-attention, or encoder/decoder cross-
                       attention. Defaults to decoder self-attention,
                       which is the vLLM default generally
492
        Returns:
493
            shape = [num_tokens, num_heads * head_size]
494
        """
Woosuk Kwon's avatar
Woosuk Kwon committed
495

496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
        # Check that appropriate attention metadata attributes are
        # selected for the desired attention type
        if (attn_type == AttentionType.ENCODER
                and (not attn_metadata.is_all_encoder_attn_metadata_set)):
            raise AttributeError("Encoder attention requires setting "
                                 "encoder metadata attributes.")
        elif (attn_type == AttentionType.ENCODER_DECODER
              and (not attn_metadata.is_all_cross_attn_metadata_set)):
            raise AttributeError("Encoder/decoder cross-attention "
                                 "requires setting cross-attention "
                                 "metadata attributes.")

        query = query.view(-1, self.num_heads, self.head_size)
        if key is not None:
            assert value is not None
            key = key.view(-1, self.num_kv_heads, self.head_size)
            value = value.view(-1, self.num_kv_heads, self.head_size)
        else:
            assert value is None

        # Self-attention vs. cross-attention will impact
        # which KV cache memory-mapping & which
        # seqlen datastructures we utilize

        if (attn_type != AttentionType.ENCODER and kv_cache is not None):
            # KV-cache during decoder-self- or
            # encoder-decoder-cross-attention, but not
            # during encoder attention.
            #
            # Even if there are no new key/value pairs to cache,
            # we still need to break out key_cache and value_cache
            # i.e. for later use by paged attention
528
529
530
            key_cache, value_cache = PagedAttention.split_kv_cache(
                kv_cache, self.num_kv_heads, self.head_size)

531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
            if (key is not None) and (value is not None):

                if attn_type == AttentionType.ENCODER_DECODER:
                    # Update cross-attention KV cache (prefill-only)
                    # During cross-attention decode, key & value will be None,
                    # preventing this IF-statement branch from running
                    updated_slot_mapping = attn_metadata.cross_slot_mapping
                else:
                    # Update self-attention KV cache (prefill/decode)
                    updated_slot_mapping = attn_metadata.slot_mapping

                # Reshape the input keys and values and store them in the cache.
                # If kv_cache is not provided, the new key and value tensors are
                # not cached. This happens during the initial memory
                # profiling run.
                PagedAttention.write_to_paged_cache(key, value, key_cache,
                                                    value_cache,
                                                    updated_slot_mapping,
                                                    self.kv_cache_dtype,
550
                                                    k_scale, v_scale)
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570

        if attn_type != AttentionType.ENCODER:
            # Decoder self-attention supports chunked prefill.
            # Encoder/decoder cross-attention requires no chunked
            # prefill (100% prefill or 100% decode tokens, no mix)
            num_prefill_tokens = attn_metadata.num_prefill_tokens
            num_decode_tokens = attn_metadata.num_decode_tokens
        else:
            # Encoder attention - chunked prefill is not applicable;
            # derive token-count from query shape & and treat them
            # as 100% prefill tokens
            assert attn_metadata.num_encoder_tokens is not None
            num_prefill_tokens = attn_metadata.num_encoder_tokens
            num_decode_tokens = 0

        if attn_type == AttentionType.DECODER:
            # Only enforce this shape-constraint for decoder
            # self-attention
            assert key.shape[0] == num_prefill_tokens + num_decode_tokens
            assert value.shape[0] == num_prefill_tokens + num_decode_tokens
571
572
573
574
575
576

        output = torch.empty_like(query)
        # Query for decode. KV is not needed because it is already cached.
        decode_query = query[num_prefill_tokens:]
        # QKV for prefill.
        query = query[:num_prefill_tokens]
577
578
579
        if key is not None and value is not None:
            key = key[:num_prefill_tokens]
            value = value[:num_prefill_tokens]
580
581
582
583
584

        assert query.shape[0] == num_prefill_tokens
        assert decode_query.shape[0] == num_decode_tokens

        if prefill_meta := attn_metadata.prefill_metadata:
585
            # Prompt run.
586
            if kv_cache is None or prefill_meta.block_tables.numel() == 0:
587
588
589
                # normal attention.
                # block tables are empty if the prompt does not have a cached
                # prefix.
590
                out = self._run_memory_efficient_xformers_forward(
591
                    query, key, value, prefill_meta, attn_type=attn_type)
592
593
                assert out.shape == output[:num_prefill_tokens].shape
                output[:num_prefill_tokens] = out
Woosuk Kwon's avatar
Woosuk Kwon committed
594
            else:
595
596
597
598

                assert prefill_meta.query_start_loc is not None
                assert prefill_meta.max_query_len is not None

599
                # prefix-enabled attention
600
601
602
                # TODO(Hai) this triton kernel has regression issue (broke) to
                # deal with different data types between KV and FP8 KV cache,
                # to be addressed separately.
603
                out = PagedAttention.forward_prefix(
604
605
606
                    query,
                    key,
                    value,
607
                    self.kv_cache_dtype,
608
609
                    key_cache,
                    value_cache,
610
                    prefill_meta.block_tables,
611
                    prefill_meta.query_start_loc,
612
613
614
                    prefill_meta.seq_lens_tensor,
                    prefill_meta.context_lens_tensor,
                    prefill_meta.max_query_len,
615
                    self.alibi_slopes,
616
                    self.sliding_window,
617
618
                    k_scale,
                    v_scale,
619
                )
620
621
622
623
                assert output[:num_prefill_tokens].shape == out.shape
                output[:num_prefill_tokens] = out

        if decode_meta := attn_metadata.decode_metadata:
624
625
626
627
628
629
630

            (
                seq_lens_arg,
                max_seq_len_arg,
                block_tables_arg,
            ) = _get_seq_len_block_table_args(decode_meta, False, attn_type)

631
632
            output[num_prefill_tokens:] = PagedAttention.forward_decode(
                decode_query,
633
634
                key_cache,
                value_cache,
635
636
637
                block_tables_arg,
                seq_lens_arg,
                max_seq_len_arg,
638
                self.kv_cache_dtype,
639
640
641
                self.num_kv_heads,
                self.scale,
                self.alibi_slopes,
642
643
                k_scale,
                v_scale,
644
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
645
646

        # Reshape the output tensor.
647
648
        return output.view(-1, self.num_heads * self.head_size)

649
    def _run_memory_efficient_xformers_forward(
650
651
652
653
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
654
        attn_metadata: XFormersMetadata,
655
        attn_type: AttentionType = AttentionType.DECODER,
656
657
658
659
    ) -> torch.Tensor:
        """Attention for 1D query of multiple prompts. Multiple prompt
        tokens are flattened in to `query` input.

660
661
662
        See https://facebookresearch.github.io/xformers/components/ops.html
        for API spec.

663
        Args:
664
665
666
667
            output: shape = [num_prefill_tokens, num_heads, head_size]
            query: shape = [num_prefill_tokens, num_heads, head_size]
            key: shape = [num_prefill_tokens, num_kv_heads, head_size]
            value: shape = [num_prefill_tokens, num_kv_heads, head_size]
668
            attn_metadata: Metadata for attention.
669
670
671
672
            attn_type: Select attention type, between encoder attention,
                       decoder self-attention, or encoder/decoder cross-
                       attention. Defaults to decoder self-attention,
                       which is the vLLM default generally
673
        """
674

675
676
677
678
679
680
681
682
683
684
685
686
687
688
        original_query = query
        if self.num_kv_heads != self.num_heads:
            # GQA/MQA requires the shape [B, M, G, H, K].
            # Note that the output also has the same shape (which is different
            # from a spec from the doc).
            query = query.view(query.shape[0], self.num_kv_heads,
                               self.num_queries_per_kv, query.shape[-1])
            key = key[:, :,
                      None, :].expand(key.shape[0], self.num_kv_heads,
                                      self.num_queries_per_kv, key.shape[-1])
            value = value[:, :,
                          None, :].expand(value.shape[0], self.num_kv_heads,
                                          self.num_queries_per_kv,
                                          value.shape[-1])
689
690
691
        # Set attention bias if not provided. This typically happens at
        # the very attention layer of every iteration.
        # FIXME(woosuk): This is a hack.
692
693
        attn_bias = _get_attn_bias(attn_metadata, attn_type)
        if attn_bias is None:
694
            if self.alibi_slopes is None:
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
                if (attn_type == AttentionType.ENCODER_DECODER):
                    assert attn_metadata.seq_lens is not None
                    assert attn_metadata.encoder_seq_lens is not None

                    # Default enc/dec cross-attention mask is non-causal
                    attn_bias = BlockDiagonalMask.from_seqlens(
                        attn_metadata.seq_lens, attn_metadata.encoder_seq_lens)
                elif attn_type == AttentionType.ENCODER:
                    assert attn_metadata.encoder_seq_lens is not None

                    # Default encoder self-attention mask is non-causal
                    attn_bias = BlockDiagonalMask.from_seqlens(
                        attn_metadata.encoder_seq_lens)
                else:
                    assert attn_metadata.seq_lens is not None

                    # Default decoder self-attention mask is causal
                    attn_bias = BlockDiagonalCausalMask.from_seqlens(
                        attn_metadata.seq_lens)
714
715
716
                if self.sliding_window is not None:
                    attn_bias = attn_bias.make_local_attention(
                        self.sliding_window)
717
                attn_bias = [attn_bias]
718
            else:
719
720
721
722
723
724
                assert attn_metadata.seq_lens is not None
                attn_bias = _make_alibi_bias(self.alibi_slopes,
                                             self.num_kv_heads, query.dtype,
                                             attn_metadata.seq_lens)

            _set_attn_bias(attn_metadata, attn_bias, attn_type)
725
726
727
728
729

        # No alibi slopes.
        # TODO(woosuk): Too many view operations. Let's try to reduce
        # them in the future for code readability.
        if self.alibi_slopes is None:
730
            # Add the batch dimension.
731
732
733
734
735
736
737
            query = query.unsqueeze(0)
            key = key.unsqueeze(0)
            value = value.unsqueeze(0)
            out = xops.memory_efficient_attention_forward(
                query,
                key,
                value,
738
                attn_bias=attn_bias[0],
739
                p=0.0,
740
                scale=self.scale)
741
            return out.view_as(original_query)
742
743
744
745
746

        # Attention with alibi slopes.
        # FIXME(woosuk): Because xformers does not support dynamic sequence
        # lengths with custom attention bias, we process each prompt one by
        # one. This is inefficient, especially when we have many short prompts.
747
        assert attn_metadata.seq_lens is not None
748
        output = torch.empty_like(original_query)
749
        start = 0
750
751
        for i, seq_len in enumerate(attn_metadata.seq_lens):
            end = start + seq_len
752
753
754
755
            out = xops.memory_efficient_attention_forward(
                query[None, start:end],
                key[None, start:end],
                value[None, start:end],
756
                attn_bias=attn_bias[i],
757
                p=0.0,
758
                scale=self.scale)
759
            # TODO(woosuk): Unnecessary copy. Optimize.
760
            output[start:end].copy_(out.view_as(original_query[start:end]))
761
            start += seq_len
762
        return output
Woosuk Kwon's avatar
Woosuk Kwon committed
763
764
765
766


def _make_alibi_bias(
    alibi_slopes: torch.Tensor,
Megha Agarwal's avatar
Megha Agarwal committed
767
    num_kv_heads: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
768
    dtype: torch.dtype,
769
    seq_lens: List[int],
770
771
) -> List[AttentionBias]:
    attn_biases: List[AttentionBias] = []
772
773
    for seq_len in seq_lens:
        bias = torch.arange(seq_len, dtype=dtype)
774
        # NOTE(zhuohan): HF uses
775
        #     `bias = bias[None, :].repeat(seq_len, 1)`
776
777
778
779
780
781
782
        # here. We find that both biases give the same results, but
        # the bias below more accurately follows the original ALiBi
        # paper.
        # Calculate a matrix where each element represents ith element- jth
        # element.
        bias = bias[None, :] - bias[:, None]

783
        padded_len = (seq_len + 7) // 8 * 8
784
785
786
787
        num_heads = alibi_slopes.shape[0]
        bias = torch.empty(
            1,  # batch size
            num_heads,
788
            seq_len,
789
790
791
            padded_len,
            device=alibi_slopes.device,
            dtype=dtype,
792
        )[:, :, :, :seq_len].copy_(bias)
793
794
795
796
797
798
        bias.mul_(alibi_slopes[:, None, None])
        if num_heads != num_kv_heads:
            bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
        attn_biases.append(LowerTriangularMaskWithTensorBias(bias))

    return attn_biases