xformers.py 32.6 KB
Newer Older
1
"""Attention layer with xFormers and PagedAttention."""
2
from dataclasses import dataclass
3
from typing import Any, Dict, List, Optional, Tuple, Type
Woosuk Kwon's avatar
Woosuk Kwon committed
4
5

import torch
6
from xformers import ops as xops
7
8
from xformers.ops.fmha.attn_bias import (AttentionBias,
                                         BlockDiagonalCausalMask,
9
                                         BlockDiagonalMask,
Woosuk Kwon's avatar
Woosuk Kwon committed
10
                                         LowerTriangularMaskWithTensorBias)
Woosuk Kwon's avatar
Woosuk Kwon committed
11

12
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
13
                                              AttentionMetadata, AttentionType)
14
15
from vllm.attention.backends.utils import (CommonAttentionState,
                                           CommonMetadataBuilder)
16
17
from vllm.attention.ops.paged_attn import (PagedAttention,
                                           PagedAttentionMetadata)
18
from vllm.logger import init_logger
Woosuk Kwon's avatar
Woosuk Kwon committed
19

20
logger = init_logger(__name__)
21

22
23
24

class XFormersBackend(AttentionBackend):

25
26
27
28
    @staticmethod
    def get_name() -> str:
        return "xformers"

29
30
31
32
33
    @staticmethod
    def get_impl_cls() -> Type["XFormersImpl"]:
        return XFormersImpl

    @staticmethod
34
35
    def get_metadata_cls() -> Type["AttentionMetadata"]:
        return XFormersMetadata
36

37
38
39
40
    @staticmethod
    def get_builder_cls() -> Type["XFormersMetadataBuilder"]:
        return XFormersMetadataBuilder

41
42
43
44
    @staticmethod
    def get_state_cls() -> Type["CommonAttentionState"]:
        return CommonAttentionState

45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
        return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
                                                 num_kv_heads, head_size)

    @staticmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
        src_to_dst: Dict[int, int],
    ) -> None:
        PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)

    @staticmethod
    def copy_blocks(
        kv_caches: List[torch.Tensor],
66
        src_to_dists: torch.Tensor,
67
68
69
70
71
    ) -> None:
        PagedAttention.copy_blocks(kv_caches, src_to_dists)


@dataclass
72
class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
73
74
75
76
77
78
79
80
81
82
83
84
    """Metadata for XFormersbackend.

    NOTE: Any python object stored here is not updated when it is
    cuda-graph replayed. If you have values that need to be changed
    dynamically, it should be stored in tensor. The tensor has to be
    updated from `CUDAGraphRunner.forward` API.
    """

    # |---------- N-1 iteration --------|
    # |---------------- N iteration ---------------------|
    # |- tokenA -|......................|-- newTokens ---|
    # |---------- context_len ----------|
85
86
    # |-------------------- seq_len ----------------------|
    #                                   |-- query_len ---|
87

88
89
90
    # seq_lens stored as a tensor.
    seq_lens_tensor: Optional[torch.Tensor]

91
    # FIXME: It is for flash attn.
92
93
94
95
96
97
    # Maximum sequence length among prefill batch. 0 if there are decoding
    # requests only.
    max_prefill_seq_len: int
    # Maximum sequence length among decode batch. 0 if there are prefill
    # requests only.
    max_decode_seq_len: int
98
99
100
101
102
103
104
105
106
107

    # Whether or not if cuda graph is enabled.
    # Cuda-graph is currently enabled for decoding only.
    # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
    use_cuda_graph: bool

    # (batch_size,). The sequence length per sequence. Sequence length means
    # the computed tokens + new tokens None if it is a decoding.
    seq_lens: Optional[List[int]] = None

108
109
110
111
    # FIXME: It is for flash attn.
    # (batch_size + 1,). The cumulative sequence lengths of the sequences in
    # the batch, used to index into sequence. E.g., if the sequence length is
    # [4, 6], it is [0, 4, 10].
112
113
    seq_start_loc: Optional[torch.Tensor] = None

114
115
    # (batch_size,) A tensor of context lengths (tokens that are computed
    # so far).
116
    context_lens_tensor: Optional[torch.Tensor] = None
117

118
119
120
121
122
123
124
125
126
    # Maximum query length in the batch. None for decoding.
    max_query_len: Optional[int] = None

    # (batch_size + 1,). The cumulative subquery lengths of the sequences in
    # the batch, used to index into subquery. E.g., if the subquery length
    # is [4, 6], it is [0, 4, 10].
    query_start_loc: Optional[torch.Tensor] = None

    # Self-attention prefill/decode metadata cache
127
128
    _cached_prefill_metadata: Optional["XFormersMetadata"] = None
    _cached_decode_metadata: Optional["XFormersMetadata"] = None
129

130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
    # Begin encoder attn & enc/dec cross-attn fields...

    # Encoder sequence lengths representation
    encoder_seq_lens: Optional[List[int]] = None
    encoder_seq_lens_tensor: Optional[torch.Tensor] = None

    # Maximum sequence length among encoder sequences
    max_encoder_seq_len: Optional[int] = None

    # Number of tokens input to encoder
    num_encoder_tokens: Optional[int] = None

    # Cross-attention memory-mapping data structures: slot mapping
    # and block tables
    cross_slot_mapping: Optional[torch.Tensor] = None
    cross_block_tables: Optional[torch.Tensor] = None

147
148
149
150
151
152
153
    def __post_init__(self):
        # Set during the execution of the first attention op.
        # It is a list because it is needed to set per prompt
        # when alibi slopes is used. It is because of the limitation
        # from xformer API.
        # will not appear in the __repr__ and __init__
        self.attn_bias: Optional[List[AttentionBias]] = None
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
        self.encoder_attn_bias: Optional[List[AttentionBias]] = None
        self.cross_attn_bias: Optional[List[AttentionBias]] = None

    @property
    def is_all_encoder_attn_metadata_set(self):
        '''
        All attention metadata required for encoder attention is set.
        '''
        return ((self.encoder_seq_lens is not None)
                and (self.encoder_seq_lens_tensor is not None)
                and (self.max_encoder_seq_len is not None))

    @property
    def is_all_cross_attn_metadata_set(self):
        '''
        All attention metadata required for enc/dec cross-attention is set.

        Superset of encoder attention required metadata.
        '''
        return (self.is_all_encoder_attn_metadata_set
                and (self.cross_slot_mapping is not None)
                and (self.cross_block_tables is not None))
176

177
178
179
180
181
182
    @property
    def prefill_metadata(self) -> Optional["XFormersMetadata"]:
        if self.num_prefills == 0:
            return None

        if self._cached_prefill_metadata is not None:
183
184
            # Recover cached prefill-phase attention
            # metadata structure
185
186
            return self._cached_prefill_metadata

187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
        assert ((self.seq_lens is not None)
                or (self.encoder_seq_lens is not None))
        assert ((self.seq_lens_tensor is not None)
                or (self.encoder_seq_lens_tensor is not None))

        # Compute some attn_metadata fields which default to None
        query_start_loc = (None if self.query_start_loc is None else
                           self.query_start_loc[:self.num_prefills + 1])
        slot_mapping = (None if self.slot_mapping is None else
                        self.slot_mapping[:self.num_prefill_tokens])
        seq_lens = (None if self.seq_lens is None else
                    self.seq_lens[:self.num_prefills])
        seq_lens_tensor = (None if self.seq_lens_tensor is None else
                           self.seq_lens_tensor[:self.num_prefills])
        context_lens_tensor = (None if self.context_lens_tensor is None else
                               self.context_lens_tensor[:self.num_prefills])
        block_tables = (None if self.block_tables is None else
                        self.block_tables[:self.num_prefills])

        # Construct & cache prefill-phase attention metadata structure
207
208
209
210
        self._cached_prefill_metadata = XFormersMetadata(
            num_prefills=self.num_prefills,
            num_prefill_tokens=self.num_prefill_tokens,
            num_decode_tokens=0,
211
212
213
            slot_mapping=slot_mapping,
            seq_lens=seq_lens,
            seq_lens_tensor=seq_lens_tensor,
214
215
216
            max_query_len=self.max_query_len,
            max_prefill_seq_len=self.max_prefill_seq_len,
            max_decode_seq_len=0,
217
218
219
            query_start_loc=query_start_loc,
            context_lens_tensor=context_lens_tensor,
            block_tables=block_tables,
220
            use_cuda_graph=False,
221
222
223
224
225
226
            # Begin encoder & cross attn fields below...
            encoder_seq_lens=self.encoder_seq_lens,
            encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
            max_encoder_seq_len=self.max_encoder_seq_len,
            cross_slot_mapping=self.cross_slot_mapping,
            cross_block_tables=self.cross_block_tables)
227
228
229
230
231
232
233
234
        return self._cached_prefill_metadata

    @property
    def decode_metadata(self) -> Optional["XFormersMetadata"]:
        if self.num_decode_tokens == 0:
            return None

        if self._cached_decode_metadata is not None:
235
236
            # Recover cached decode-phase attention
            # metadata structure
237
            return self._cached_decode_metadata
238
239
240
241
242
243
244
245
246
247
248
249
        assert ((self.seq_lens_tensor is not None)
                or (self.encoder_seq_lens_tensor is not None))

        # Compute some attn_metadata fields which default to None
        slot_mapping = (None if self.slot_mapping is None else
                        self.slot_mapping[self.num_prefill_tokens:])
        seq_lens_tensor = (None if self.seq_lens_tensor is None else
                           self.seq_lens_tensor[self.num_prefills:])
        block_tables = (None if self.block_tables is None else
                        self.block_tables[self.num_prefills:])

        # Construct & cache decode-phase attention metadata structure
250
251
252
253
        self._cached_decode_metadata = XFormersMetadata(
            num_prefills=0,
            num_prefill_tokens=0,
            num_decode_tokens=self.num_decode_tokens,
254
255
            slot_mapping=slot_mapping,
            seq_lens_tensor=seq_lens_tensor,
256
257
            max_prefill_seq_len=0,
            max_decode_seq_len=self.max_decode_seq_len,
258
            block_tables=block_tables,
259
            use_cuda_graph=self.use_cuda_graph,
260
261
262
263
264
265
            # Begin encoder & cross attn fields below...
            encoder_seq_lens=self.encoder_seq_lens,
            encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
            max_encoder_seq_len=self.max_encoder_seq_len,
            cross_slot_mapping=self.cross_slot_mapping,
            cross_block_tables=self.cross_block_tables)
266
267
268
        return self._cached_decode_metadata


269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
def _get_attn_bias(
    attn_metadata: XFormersMetadata,
    attn_type: AttentionType,
) -> Optional[AttentionBias]:
    '''
    Extract appropriate attention bias from attention metadata
    according to attention type.

    Arguments:

    * attn_metadata: Attention metadata structure associated with attention
    * attn_type: encoder attention, decoder self-attention,
                 encoder/decoder cross-attention

    Returns:
    * Appropriate attention bias value given the attention type
    '''

    if attn_type == AttentionType.DECODER:
        return attn_metadata.attn_bias
    elif attn_type == AttentionType.ENCODER:
        return attn_metadata.encoder_attn_bias
    else:
        # attn_type == AttentionType.ENCODER_DECODER
        return attn_metadata.cross_attn_bias


def _set_attn_bias(
    attn_metadata: XFormersMetadata,
    attn_bias: List[Optional[AttentionBias]],
    attn_type: AttentionType,
) -> None:
    '''
    Update appropriate attention bias field of attention metadata,
    according to attention type.

    Arguments:

    * attn_metadata: Attention metadata structure associated with attention
    * attn_bias: The desired attention bias value
    * attn_type: encoder attention, decoder self-attention,
                 encoder/decoder cross-attention
    '''

    if attn_type == AttentionType.DECODER:
        attn_metadata.attn_bias = attn_bias
    elif attn_type == AttentionType.ENCODER:
        attn_metadata.encoder_attn_bias = attn_bias
    elif attn_type == AttentionType.ENCODER_DECODER:
        attn_metadata.cross_attn_bias = attn_bias
    else:
        raise AttributeError(f"Invalid attention type {str(attn_type)}")


def _get_seq_len_block_table_args(
    attn_metadata: XFormersMetadata,
    is_prompt: bool,
    attn_type: AttentionType,
) -> tuple:
    '''
    The particular choice of sequence-length- and block-table-related
    attributes which should be extracted from attn_metadata is dependent
    on the type of attention operation.

    Decoder attn -> select entirely decoder self-attention-related fields
    Encoder/decoder cross-attn -> select encoder sequence lengths & 
                                  cross-attn block-tables fields
    Encoder attn -> select encoder sequence lengths fields & no block tables
    
    Arguments:

    * attn_metadata: Attention metadata structure associated with attention op
    * is_prompt: True if prefill, False otherwise
    * attn_type: encoder attention, decoder self-attention,
                 encoder/decoder cross-attention

    Returns:

    * Appropriate sequence-lengths tensor
    * Appropriate max sequence-length scalar
    * Appropriate block tables (or None)
    '''

    if attn_type == AttentionType.DECODER:
        # Decoder self-attention
        # Choose max_seq_len based on whether we are in prompt_run
        if is_prompt:
            max_seq_len = attn_metadata.max_prefill_seq_len
        else:
            max_seq_len = attn_metadata.max_decode_seq_len
        return (attn_metadata.seq_lens_tensor, max_seq_len,
                attn_metadata.block_tables)
    elif attn_type == AttentionType.ENCODER_DECODER:
        # Enc/dec cross-attention KVs match encoder sequence length;
        # cross-attention utilizes special "cross" block tables
        return (attn_metadata.encoder_seq_lens_tensor,
                attn_metadata.max_encoder_seq_len,
                attn_metadata.cross_block_tables)
    elif attn_type == AttentionType.ENCODER:
        # No block tables associated with encoder attention
        return (attn_metadata.encoder_seq_lens_tensor,
                attn_metadata.max_encoder_seq_len, None)
    else:
        raise AttributeError(f"Invalid attention type {str(attn_type)}")


375
376
377
378
379
class XFormersMetadataBuilder(CommonMetadataBuilder[XFormersMetadata]):

    _metadata_cls = XFormersMetadata


380
class XFormersImpl(AttentionImpl[XFormersMetadata]):
381
382
    """
    If the input tensors contain prompt tokens, the layout is as follows:
383
384
    |<--------------- num_prefill_tokens ----------------->|	
    |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
385
386

    Otherwise, the layout is as follows:	
387
388
    |<----------------- num_decode_tokens ------------------>|	
    |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
389
390
391
392
393
394

    Generation tokens can contain padding when cuda-graph is used.
    Currently, prompt tokens don't contain any padding.

    The prompts might have different lengths, while the generation tokens
    always have length 1.
395
396
397
398
399
400
401
402
403

    If chunked prefill is enabled, prefill tokens and decode tokens can be
    batched together in a flattened 1D query.

    |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
    |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|

    Currently, cuda graph is disabled for chunked prefill, meaning there's no
    padding between prefill and decode tokens.
404
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
405

Woosuk Kwon's avatar
Woosuk Kwon committed
406
407
408
409
410
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
411
412
413
414
        num_kv_heads: int,
        alibi_slopes: Optional[List[float]],
        sliding_window: Optional[int],
        kv_cache_dtype: str,
415
        blocksparse_params: Optional[Dict[str, Any]] = None,
416
        logits_soft_cap: Optional[float] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
417
    ) -> None:
418
419
420
421
422
423
        if blocksparse_params is not None:
            raise ValueError(
                "XFormers does not support block-sparse attention.")
        if logits_soft_cap is not None:
            raise ValueError(
                "XFormers does not support attention logits soft capping.")
424
425
        self.num_heads = num_heads
        self.head_size = head_size
426
        self.scale = float(scale)
427
        self.num_kv_heads = num_kv_heads
Woosuk Kwon's avatar
Woosuk Kwon committed
428
429
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
430
        self.alibi_slopes = alibi_slopes
431
432
        self.sliding_window = sliding_window
        self.kv_cache_dtype = kv_cache_dtype
Zhuohan Li's avatar
Zhuohan Li committed
433
434
435

        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
436
437

        suppored_head_sizes = PagedAttention.get_supported_head_sizes()
438
439
440
441
        if head_size not in suppored_head_sizes:
            raise ValueError(
                f"Head size {head_size} is not supported by PagedAttention. "
                f"Supported head sizes are: {suppored_head_sizes}.")
Woosuk Kwon's avatar
Woosuk Kwon committed
442
443
444

    def forward(
        self,
445
        query: torch.Tensor,
446
447
        key: Optional[torch.Tensor],
        value: Optional[torch.Tensor],
448
        kv_cache: Optional[torch.Tensor],
449
        attn_metadata: "XFormersMetadata",
450
451
        k_scale: float = 1.0,
        v_scale: float = 1.0,
452
        attn_type: AttentionType = AttentionType.DECODER,
453
    ) -> torch.Tensor:
454
        """Forward pass with xFormers and PagedAttention.
455

456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
        For decoder-only models: query, key and value must be non-None.

        For encoder/decoder models:
        * XFormersImpl.forward() may be invoked for both self- and cross-
          attention layers.
        * For self-attention: query, key and value must be non-None.
        * For cross-attention:
            * Query must be non-None
            * During prefill, key and value must be non-None; key and value
              get cached for use during decode.
            * During decode, key and value may be None, since:
              (1) key and value tensors were cached during prefill, and
              (2) cross-attention key and value tensors do not grow during
                  decode
        
        A note on how the attn_type (attention type enum) argument impacts
        attention forward() behavior:
    
            * DECODER: normal decoder-only behavior;
                use decoder self-attention block table
            * ENCODER: no KV caching; pass encoder sequence
                attributes (encoder_seq_lens/encoder_seq_lens_tensor/
                max_encoder_seq_len) to kernel, in lieu of decoder
                sequence attributes (seq_lens/seq_lens_tensor/max_seq_len)
            * ENCODER_DECODER: cross-attention behavior;
                use cross-attention block table for caching KVs derived
                from encoder hidden states; since KV sequence lengths
                will match encoder sequence lengths, pass encoder sequence
                attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/
                max_encoder_seq_len)
    
487
        Args:
488
489
490
            query: shape = [num_tokens, num_heads * head_size]
            key: shape = [num_tokens, num_kv_heads * head_size]
            value: shape = [num_tokens, num_kv_heads * head_size]
491
492
            kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
            attn_metadata: Metadata for attention.
493
494
495
496
            attn_type: Select attention type, between encoder attention,
                       decoder self-attention, or encoder/decoder cross-
                       attention. Defaults to decoder self-attention,
                       which is the vLLM default generally
497
        Returns:
498
            shape = [num_tokens, num_heads * head_size]
499
        """
Woosuk Kwon's avatar
Woosuk Kwon committed
500

501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
        # Check that appropriate attention metadata attributes are
        # selected for the desired attention type
        if (attn_type == AttentionType.ENCODER
                and (not attn_metadata.is_all_encoder_attn_metadata_set)):
            raise AttributeError("Encoder attention requires setting "
                                 "encoder metadata attributes.")
        elif (attn_type == AttentionType.ENCODER_DECODER
              and (not attn_metadata.is_all_cross_attn_metadata_set)):
            raise AttributeError("Encoder/decoder cross-attention "
                                 "requires setting cross-attention "
                                 "metadata attributes.")

        query = query.view(-1, self.num_heads, self.head_size)
        if key is not None:
            assert value is not None
            key = key.view(-1, self.num_kv_heads, self.head_size)
            value = value.view(-1, self.num_kv_heads, self.head_size)
        else:
            assert value is None

        # Self-attention vs. cross-attention will impact
        # which KV cache memory-mapping & which
        # seqlen datastructures we utilize

        if (attn_type != AttentionType.ENCODER and kv_cache is not None):
            # KV-cache during decoder-self- or
            # encoder-decoder-cross-attention, but not
            # during encoder attention.
            #
            # Even if there are no new key/value pairs to cache,
            # we still need to break out key_cache and value_cache
            # i.e. for later use by paged attention
533
534
535
            key_cache, value_cache = PagedAttention.split_kv_cache(
                kv_cache, self.num_kv_heads, self.head_size)

536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
            if (key is not None) and (value is not None):

                if attn_type == AttentionType.ENCODER_DECODER:
                    # Update cross-attention KV cache (prefill-only)
                    # During cross-attention decode, key & value will be None,
                    # preventing this IF-statement branch from running
                    updated_slot_mapping = attn_metadata.cross_slot_mapping
                else:
                    # Update self-attention KV cache (prefill/decode)
                    updated_slot_mapping = attn_metadata.slot_mapping

                # Reshape the input keys and values and store them in the cache.
                # If kv_cache is not provided, the new key and value tensors are
                # not cached. This happens during the initial memory
                # profiling run.
                PagedAttention.write_to_paged_cache(key, value, key_cache,
                                                    value_cache,
                                                    updated_slot_mapping,
                                                    self.kv_cache_dtype,
555
                                                    k_scale, v_scale)
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575

        if attn_type != AttentionType.ENCODER:
            # Decoder self-attention supports chunked prefill.
            # Encoder/decoder cross-attention requires no chunked
            # prefill (100% prefill or 100% decode tokens, no mix)
            num_prefill_tokens = attn_metadata.num_prefill_tokens
            num_decode_tokens = attn_metadata.num_decode_tokens
        else:
            # Encoder attention - chunked prefill is not applicable;
            # derive token-count from query shape & and treat them
            # as 100% prefill tokens
            assert attn_metadata.num_encoder_tokens is not None
            num_prefill_tokens = attn_metadata.num_encoder_tokens
            num_decode_tokens = 0

        if attn_type == AttentionType.DECODER:
            # Only enforce this shape-constraint for decoder
            # self-attention
            assert key.shape[0] == num_prefill_tokens + num_decode_tokens
            assert value.shape[0] == num_prefill_tokens + num_decode_tokens
576
577
578
579
580
581

        output = torch.empty_like(query)
        # Query for decode. KV is not needed because it is already cached.
        decode_query = query[num_prefill_tokens:]
        # QKV for prefill.
        query = query[:num_prefill_tokens]
582
583
584
        if key is not None and value is not None:
            key = key[:num_prefill_tokens]
            value = value[:num_prefill_tokens]
585
586
587
588
589

        assert query.shape[0] == num_prefill_tokens
        assert decode_query.shape[0] == num_decode_tokens

        if prefill_meta := attn_metadata.prefill_metadata:
590
            # Prompt run.
591
            if kv_cache is None or prefill_meta.block_tables.numel() == 0:
592
593
594
                # normal attention.
                # block tables are empty if the prompt does not have a cached
                # prefix.
595
                out = self._run_memory_efficient_xformers_forward(
596
                    query, key, value, prefill_meta, attn_type=attn_type)
597
598
                assert out.shape == output[:num_prefill_tokens].shape
                output[:num_prefill_tokens] = out
Woosuk Kwon's avatar
Woosuk Kwon committed
599
            else:
600
601
602
603

                assert prefill_meta.query_start_loc is not None
                assert prefill_meta.max_query_len is not None

604
                # prefix-enabled attention
605
606
607
                # TODO(Hai) this triton kernel has regression issue (broke) to
                # deal with different data types between KV and FP8 KV cache,
                # to be addressed separately.
608
                out = PagedAttention.forward_prefix(
609
610
611
                    query,
                    key,
                    value,
612
                    self.kv_cache_dtype,
613
614
                    key_cache,
                    value_cache,
615
                    prefill_meta.block_tables,
616
                    prefill_meta.query_start_loc,
617
618
619
                    prefill_meta.seq_lens_tensor,
                    prefill_meta.context_lens_tensor,
                    prefill_meta.max_query_len,
620
                    self.alibi_slopes,
621
                    self.sliding_window,
622
623
                    k_scale,
                    v_scale,
624
                )
625
626
627
628
                assert output[:num_prefill_tokens].shape == out.shape
                output[:num_prefill_tokens] = out

        if decode_meta := attn_metadata.decode_metadata:
629
630
631
632
633
634
635

            (
                seq_lens_arg,
                max_seq_len_arg,
                block_tables_arg,
            ) = _get_seq_len_block_table_args(decode_meta, False, attn_type)

636
637
            output[num_prefill_tokens:] = PagedAttention.forward_decode(
                decode_query,
638
639
                key_cache,
                value_cache,
640
641
642
                block_tables_arg,
                seq_lens_arg,
                max_seq_len_arg,
643
                self.kv_cache_dtype,
644
645
646
                self.num_kv_heads,
                self.scale,
                self.alibi_slopes,
647
648
                k_scale,
                v_scale,
649
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
650
651

        # Reshape the output tensor.
652
653
        return output.view(-1, self.num_heads * self.head_size)

654
    def _run_memory_efficient_xformers_forward(
655
656
657
658
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
659
        attn_metadata: XFormersMetadata,
660
        attn_type: AttentionType = AttentionType.DECODER,
661
662
663
664
    ) -> torch.Tensor:
        """Attention for 1D query of multiple prompts. Multiple prompt
        tokens are flattened in to `query` input.

665
666
667
        See https://facebookresearch.github.io/xformers/components/ops.html
        for API spec.

668
        Args:
669
670
671
672
            output: shape = [num_prefill_tokens, num_heads, head_size]
            query: shape = [num_prefill_tokens, num_heads, head_size]
            key: shape = [num_prefill_tokens, num_kv_heads, head_size]
            value: shape = [num_prefill_tokens, num_kv_heads, head_size]
673
            attn_metadata: Metadata for attention.
674
675
676
677
            attn_type: Select attention type, between encoder attention,
                       decoder self-attention, or encoder/decoder cross-
                       attention. Defaults to decoder self-attention,
                       which is the vLLM default generally
678
        """
679

680
681
682
683
684
685
686
687
688
689
690
691
692
693
        original_query = query
        if self.num_kv_heads != self.num_heads:
            # GQA/MQA requires the shape [B, M, G, H, K].
            # Note that the output also has the same shape (which is different
            # from a spec from the doc).
            query = query.view(query.shape[0], self.num_kv_heads,
                               self.num_queries_per_kv, query.shape[-1])
            key = key[:, :,
                      None, :].expand(key.shape[0], self.num_kv_heads,
                                      self.num_queries_per_kv, key.shape[-1])
            value = value[:, :,
                          None, :].expand(value.shape[0], self.num_kv_heads,
                                          self.num_queries_per_kv,
                                          value.shape[-1])
694
695
696
        # Set attention bias if not provided. This typically happens at
        # the very attention layer of every iteration.
        # FIXME(woosuk): This is a hack.
697
698
        attn_bias = _get_attn_bias(attn_metadata, attn_type)
        if attn_bias is None:
699
            if self.alibi_slopes is None:
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
                if (attn_type == AttentionType.ENCODER_DECODER):
                    assert attn_metadata.seq_lens is not None
                    assert attn_metadata.encoder_seq_lens is not None

                    # Default enc/dec cross-attention mask is non-causal
                    attn_bias = BlockDiagonalMask.from_seqlens(
                        attn_metadata.seq_lens, attn_metadata.encoder_seq_lens)
                elif attn_type == AttentionType.ENCODER:
                    assert attn_metadata.encoder_seq_lens is not None

                    # Default encoder self-attention mask is non-causal
                    attn_bias = BlockDiagonalMask.from_seqlens(
                        attn_metadata.encoder_seq_lens)
                else:
                    assert attn_metadata.seq_lens is not None

                    # Default decoder self-attention mask is causal
                    attn_bias = BlockDiagonalCausalMask.from_seqlens(
                        attn_metadata.seq_lens)
719
720
721
                if self.sliding_window is not None:
                    attn_bias = attn_bias.make_local_attention(
                        self.sliding_window)
722
                attn_bias = [attn_bias]
723
            else:
724
725
726
727
728
729
                assert attn_metadata.seq_lens is not None
                attn_bias = _make_alibi_bias(self.alibi_slopes,
                                             self.num_kv_heads, query.dtype,
                                             attn_metadata.seq_lens)

            _set_attn_bias(attn_metadata, attn_bias, attn_type)
730
731
732
733
734

        # No alibi slopes.
        # TODO(woosuk): Too many view operations. Let's try to reduce
        # them in the future for code readability.
        if self.alibi_slopes is None:
735
            # Add the batch dimension.
736
737
738
739
740
741
742
            query = query.unsqueeze(0)
            key = key.unsqueeze(0)
            value = value.unsqueeze(0)
            out = xops.memory_efficient_attention_forward(
                query,
                key,
                value,
743
                attn_bias=attn_bias[0],
744
                p=0.0,
745
                scale=self.scale)
746
            return out.view_as(original_query)
747
748
749
750
751

        # Attention with alibi slopes.
        # FIXME(woosuk): Because xformers does not support dynamic sequence
        # lengths with custom attention bias, we process each prompt one by
        # one. This is inefficient, especially when we have many short prompts.
752
        assert attn_metadata.seq_lens is not None
753
        output = torch.empty_like(original_query)
754
        start = 0
755
756
        for i, seq_len in enumerate(attn_metadata.seq_lens):
            end = start + seq_len
757
758
759
760
            out = xops.memory_efficient_attention_forward(
                query[None, start:end],
                key[None, start:end],
                value[None, start:end],
761
                attn_bias=attn_bias[i],
762
                p=0.0,
763
                scale=self.scale)
764
            # TODO(woosuk): Unnecessary copy. Optimize.
765
            output[start:end].copy_(out.view_as(original_query[start:end]))
766
            start += seq_len
767
        return output
Woosuk Kwon's avatar
Woosuk Kwon committed
768
769
770
771


def _make_alibi_bias(
    alibi_slopes: torch.Tensor,
Megha Agarwal's avatar
Megha Agarwal committed
772
    num_kv_heads: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
773
    dtype: torch.dtype,
774
    seq_lens: List[int],
775
776
) -> List[AttentionBias]:
    attn_biases: List[AttentionBias] = []
777
778
    for seq_len in seq_lens:
        bias = torch.arange(seq_len, dtype=dtype)
779
        # NOTE(zhuohan): HF uses
780
        #     `bias = bias[None, :].repeat(seq_len, 1)`
781
782
783
784
785
786
787
        # here. We find that both biases give the same results, but
        # the bias below more accurately follows the original ALiBi
        # paper.
        # Calculate a matrix where each element represents ith element- jth
        # element.
        bias = bias[None, :] - bias[:, None]

788
        padded_len = (seq_len + 7) // 8 * 8
789
790
791
792
        num_heads = alibi_slopes.shape[0]
        bias = torch.empty(
            1,  # batch size
            num_heads,
793
            seq_len,
794
795
796
            padded_len,
            device=alibi_slopes.device,
            dtype=dtype,
797
        )[:, :, :, :seq_len].copy_(bias)
798
799
800
801
802
803
        bias.mul_(alibi_slopes[:, None, None])
        if num_heads != num_kv_heads:
            bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
        attn_biases.append(LowerTriangularMaskWithTensorBias(bias))

    return attn_biases