xformers.py 32.5 KB
Newer Older
1
"""Attention layer with xFormers and PagedAttention."""
2
from dataclasses import dataclass
3
from typing import Any, Dict, List, Optional, Tuple, Type
Woosuk Kwon's avatar
Woosuk Kwon committed
4
5

import torch
6
from xformers import ops as xops
7
8
from xformers.ops.fmha.attn_bias import (AttentionBias,
                                         BlockDiagonalCausalMask,
9
                                         BlockDiagonalMask,
Woosuk Kwon's avatar
Woosuk Kwon committed
10
                                         LowerTriangularMaskWithTensorBias)
Woosuk Kwon's avatar
Woosuk Kwon committed
11

12
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
13
                                              AttentionLayer,
14
                                              AttentionMetadata, AttentionType)
15
16
17
18
from vllm.attention.backends.utils import (
    CommonAttentionState, CommonMetadataBuilder,
    get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args,
    is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set)
19
20
from vllm.attention.ops.paged_attn import (PagedAttention,
                                           PagedAttentionMetadata)
21
22
23
from vllm.logger import init_logger

logger = init_logger(__name__)
24

25
26
27

class XFormersBackend(AttentionBackend):

28
29
    @staticmethod
    def get_name() -> str:
30
        return "XFORMERS"
31

32
33
34
35
36
    @staticmethod
    def get_impl_cls() -> Type["XFormersImpl"]:
        return XFormersImpl

    @staticmethod
37
38
    def get_metadata_cls() -> Type["AttentionMetadata"]:
        return XFormersMetadata
39

40
41
42
43
    @staticmethod
    def get_builder_cls() -> Type["XFormersMetadataBuilder"]:
        return XFormersMetadataBuilder

44
45
46
47
    @staticmethod
    def get_state_cls() -> Type["CommonAttentionState"]:
        return CommonAttentionState

48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
        return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
                                                 num_kv_heads, head_size)

    @staticmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
        src_to_dst: Dict[int, int],
    ) -> None:
        PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)

    @staticmethod
    def copy_blocks(
        kv_caches: List[torch.Tensor],
69
        src_to_dists: torch.Tensor,
70
71
72
73
74
    ) -> None:
        PagedAttention.copy_blocks(kv_caches, src_to_dists)


@dataclass
75
class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
76
77
78
79
80
81
82
83
84
85
86
87
    """Metadata for XFormersbackend.

    NOTE: Any python object stored here is not updated when it is
    cuda-graph replayed. If you have values that need to be changed
    dynamically, it should be stored in tensor. The tensor has to be
    updated from `CUDAGraphRunner.forward` API.
    """

    # |---------- N-1 iteration --------|
    # |---------------- N iteration ---------------------|
    # |- tokenA -|......................|-- newTokens ---|
    # |---------- context_len ----------|
88
89
    # |-------------------- seq_len ----------------------|
    #                                   |-- query_len ---|
90

91
92
93
    # seq_lens stored as a tensor.
    seq_lens_tensor: Optional[torch.Tensor]

94
    # FIXME: It is for flash attn.
95
96
97
98
99
100
    # Maximum sequence length among prefill batch. 0 if there are decoding
    # requests only.
    max_prefill_seq_len: int
    # Maximum sequence length among decode batch. 0 if there are prefill
    # requests only.
    max_decode_seq_len: int
101
102
103
104
105
106
107
108
109
110

    # Whether or not if cuda graph is enabled.
    # Cuda-graph is currently enabled for decoding only.
    # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
    use_cuda_graph: bool

    # (batch_size,). The sequence length per sequence. Sequence length means
    # the computed tokens + new tokens None if it is a decoding.
    seq_lens: Optional[List[int]] = None

111
112
113
114
    # FIXME: It is for flash attn.
    # (batch_size + 1,). The cumulative sequence lengths of the sequences in
    # the batch, used to index into sequence. E.g., if the sequence length is
    # [4, 6], it is [0, 4, 10].
115
116
    seq_start_loc: Optional[torch.Tensor] = None

117
118
    # (batch_size,) A tensor of context lengths (tokens that are computed
    # so far).
119
    context_lens_tensor: Optional[torch.Tensor] = None
120

121
122
123
    # Maximum query length in the batch. None for decoding.
    max_query_len: Optional[int] = None

124
125
    # Max number of query tokens among request in the batch.
    max_decode_query_len: Optional[int] = None
126

127
128
129
130
131
132
    # (batch_size + 1,). The cumulative subquery lengths of the sequences in
    # the batch, used to index into subquery. E.g., if the subquery length
    # is [4, 6], it is [0, 4, 10].
    query_start_loc: Optional[torch.Tensor] = None

    # Self-attention prefill/decode metadata cache
133
134
    _cached_prefill_metadata: Optional["XFormersMetadata"] = None
    _cached_decode_metadata: Optional["XFormersMetadata"] = None
135

136
137
138
139
140
    # Begin encoder attn & enc/dec cross-attn fields...

    # Encoder sequence lengths representation
    encoder_seq_lens: Optional[List[int]] = None
    encoder_seq_lens_tensor: Optional[torch.Tensor] = None
141
142
143
144
145
    # FIXME: It is for flash attn.
    # (batch_size + 1,). The cumulative sequence lengths of the sequences in
    # the batch, used to index into sequence. E.g., if the sequence length is
    # [4, 6], it is [0, 4, 10].
    encoder_seq_start_loc: Optional[torch.Tensor] = None
146
147
148
149
150
151
152
153
154
155
156
157

    # Maximum sequence length among encoder sequences
    max_encoder_seq_len: Optional[int] = None

    # Number of tokens input to encoder
    num_encoder_tokens: Optional[int] = None

    # Cross-attention memory-mapping data structures: slot mapping
    # and block tables
    cross_slot_mapping: Optional[torch.Tensor] = None
    cross_block_tables: Optional[torch.Tensor] = None

158
159
160
161
162
163
164
    def __post_init__(self):
        # Set during the execution of the first attention op.
        # It is a list because it is needed to set per prompt
        # when alibi slopes is used. It is because of the limitation
        # from xformer API.
        # will not appear in the __repr__ and __init__
        self.attn_bias: Optional[List[AttentionBias]] = None
165
166
167
168
169
170
171
172
        self.encoder_attn_bias: Optional[List[AttentionBias]] = None
        self.cross_attn_bias: Optional[List[AttentionBias]] = None

    @property
    def is_all_encoder_attn_metadata_set(self):
        '''
        All attention metadata required for encoder attention is set.
        '''
173
        return is_all_encoder_attn_metadata_set(self)
174
175
176
177
178
179
180
181

    @property
    def is_all_cross_attn_metadata_set(self):
        '''
        All attention metadata required for enc/dec cross-attention is set.

        Superset of encoder attention required metadata.
        '''
182
        return is_all_cross_attn_metadata_set(self)
183

184
185
186
187
188
189
    @property
    def prefill_metadata(self) -> Optional["XFormersMetadata"]:
        if self.num_prefills == 0:
            return None

        if self._cached_prefill_metadata is not None:
190
191
            # Recover cached prefill-phase attention
            # metadata structure
192
193
            return self._cached_prefill_metadata

194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
        assert ((self.seq_lens is not None)
                or (self.encoder_seq_lens is not None))
        assert ((self.seq_lens_tensor is not None)
                or (self.encoder_seq_lens_tensor is not None))

        # Compute some attn_metadata fields which default to None
        query_start_loc = (None if self.query_start_loc is None else
                           self.query_start_loc[:self.num_prefills + 1])
        slot_mapping = (None if self.slot_mapping is None else
                        self.slot_mapping[:self.num_prefill_tokens])
        seq_lens = (None if self.seq_lens is None else
                    self.seq_lens[:self.num_prefills])
        seq_lens_tensor = (None if self.seq_lens_tensor is None else
                           self.seq_lens_tensor[:self.num_prefills])
        context_lens_tensor = (None if self.context_lens_tensor is None else
                               self.context_lens_tensor[:self.num_prefills])
        block_tables = (None if self.block_tables is None else
                        self.block_tables[:self.num_prefills])

        # Construct & cache prefill-phase attention metadata structure
214
215
216
217
        self._cached_prefill_metadata = XFormersMetadata(
            num_prefills=self.num_prefills,
            num_prefill_tokens=self.num_prefill_tokens,
            num_decode_tokens=0,
218
            slot_mapping=slot_mapping,
219
220
            multi_modal_placeholder_index_maps=self.
            multi_modal_placeholder_index_maps,
221
            enable_kv_scales_calculation=self.enable_kv_scales_calculation,
222
223
            seq_lens=seq_lens,
            seq_lens_tensor=seq_lens_tensor,
224
225
226
            max_query_len=self.max_query_len,
            max_prefill_seq_len=self.max_prefill_seq_len,
            max_decode_seq_len=0,
227
228
229
            query_start_loc=query_start_loc,
            context_lens_tensor=context_lens_tensor,
            block_tables=block_tables,
230
            use_cuda_graph=False,
231
232
233
234
235
236
            # Begin encoder & cross attn fields below...
            encoder_seq_lens=self.encoder_seq_lens,
            encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
            max_encoder_seq_len=self.max_encoder_seq_len,
            cross_slot_mapping=self.cross_slot_mapping,
            cross_block_tables=self.cross_block_tables)
237
238
239
240
241
242
243
244
        return self._cached_prefill_metadata

    @property
    def decode_metadata(self) -> Optional["XFormersMetadata"]:
        if self.num_decode_tokens == 0:
            return None

        if self._cached_decode_metadata is not None:
245
246
            # Recover cached decode-phase attention
            # metadata structure
247
            return self._cached_decode_metadata
248
249
250
251
252
253
254
255
256
257
258
259
        assert ((self.seq_lens_tensor is not None)
                or (self.encoder_seq_lens_tensor is not None))

        # Compute some attn_metadata fields which default to None
        slot_mapping = (None if self.slot_mapping is None else
                        self.slot_mapping[self.num_prefill_tokens:])
        seq_lens_tensor = (None if self.seq_lens_tensor is None else
                           self.seq_lens_tensor[self.num_prefills:])
        block_tables = (None if self.block_tables is None else
                        self.block_tables[self.num_prefills:])

        # Construct & cache decode-phase attention metadata structure
260
261
262
263
        self._cached_decode_metadata = XFormersMetadata(
            num_prefills=0,
            num_prefill_tokens=0,
            num_decode_tokens=self.num_decode_tokens,
264
            slot_mapping=slot_mapping,
265
            multi_modal_placeholder_index_maps=None,
266
            enable_kv_scales_calculation=True,
267
            seq_lens_tensor=seq_lens_tensor,
268
269
            max_prefill_seq_len=0,
            max_decode_seq_len=self.max_decode_seq_len,
270
            block_tables=block_tables,
271
            use_cuda_graph=self.use_cuda_graph,
272
273
274
275
276
277
            # Begin encoder & cross attn fields below...
            encoder_seq_lens=self.encoder_seq_lens,
            encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
            max_encoder_seq_len=self.max_encoder_seq_len,
            cross_slot_mapping=self.cross_slot_mapping,
            cross_block_tables=self.cross_block_tables)
278
279
280
281
282
283
284

        # Batch may be composed of prefill|decodes, adjust query start indices
        # to refer to the start of decodes when the two are split apart.
        # E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
        if self._cached_decode_metadata.query_start_loc is not None:
            qs = self._cached_decode_metadata.query_start_loc
            self._cached_decode_metadata.query_start_loc = qs - qs[0]
285
286
287
        return self._cached_decode_metadata


288
289
def _get_attn_bias(
    attn_metadata: XFormersMetadata,
290
    attn_type: str,
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
) -> Optional[AttentionBias]:
    '''
    Extract appropriate attention bias from attention metadata
    according to attention type.

    Arguments:

    * attn_metadata: Attention metadata structure associated with attention
    * attn_type: encoder attention, decoder self-attention,
                 encoder/decoder cross-attention

    Returns:
    * Appropriate attention bias value given the attention type
    '''

306
307
    if (attn_type == AttentionType.DECODER
            or attn_type == AttentionType.ENCODER_ONLY):
308
309
310
        return attn_metadata.attn_bias
    elif attn_type == AttentionType.ENCODER:
        return attn_metadata.encoder_attn_bias
311
    elif attn_type == AttentionType.ENCODER_DECODER:
312
        return attn_metadata.cross_attn_bias
313
314
    else:
        raise AttributeError(f"Invalid attention type {str(attn_type)}")
315
316
317
318
319


def _set_attn_bias(
    attn_metadata: XFormersMetadata,
    attn_bias: List[Optional[AttentionBias]],
320
    attn_type: str,
321
322
323
324
325
326
327
328
329
330
331
332
333
) -> None:
    '''
    Update appropriate attention bias field of attention metadata,
    according to attention type.

    Arguments:

    * attn_metadata: Attention metadata structure associated with attention
    * attn_bias: The desired attention bias value
    * attn_type: encoder attention, decoder self-attention,
                 encoder/decoder cross-attention
    '''

334
335
    if (attn_type == AttentionType.DECODER
            or attn_type == AttentionType.ENCODER_ONLY):
336
337
338
339
340
341
342
343
344
        attn_metadata.attn_bias = attn_bias
    elif attn_type == AttentionType.ENCODER:
        attn_metadata.encoder_attn_bias = attn_bias
    elif attn_type == AttentionType.ENCODER_DECODER:
        attn_metadata.cross_attn_bias = attn_bias
    else:
        raise AttributeError(f"Invalid attention type {str(attn_type)}")


345
346
347
348
349
class XFormersMetadataBuilder(CommonMetadataBuilder[XFormersMetadata]):

    _metadata_cls = XFormersMetadata


350
class XFormersImpl(AttentionImpl[XFormersMetadata]):
351
352
    """
    If the input tensors contain prompt tokens, the layout is as follows:
353
354
    |<--------------- num_prefill_tokens ----------------->|	
    |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
355
356

    Otherwise, the layout is as follows:	
357
358
    |<----------------- num_decode_tokens ------------------>|	
    |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
359
360
361
362
363
364

    Generation tokens can contain padding when cuda-graph is used.
    Currently, prompt tokens don't contain any padding.

    The prompts might have different lengths, while the generation tokens
    always have length 1.
365
366
367
368
369
370
371
372
373

    If chunked prefill is enabled, prefill tokens and decode tokens can be
    batched together in a flattened 1D query.

    |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
    |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|

    Currently, cuda graph is disabled for chunked prefill, meaning there's no
    padding between prefill and decode tokens.
374
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
375

Woosuk Kwon's avatar
Woosuk Kwon committed
376
377
378
379
380
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
381
382
383
384
        num_kv_heads: int,
        alibi_slopes: Optional[List[float]],
        sliding_window: Optional[int],
        kv_cache_dtype: str,
385
        blocksparse_params: Optional[Dict[str, Any]] = None,
386
        logits_soft_cap: Optional[float] = None,
387
        attn_type: str = AttentionType.DECODER,
Woosuk Kwon's avatar
Woosuk Kwon committed
388
    ) -> None:
389
390
391
392
        if blocksparse_params is not None:
            raise ValueError(
                "XFormers does not support block-sparse attention.")
        if logits_soft_cap is not None:
393
394
            logger.warning_once("XFormers does not support logits soft cap. "
                                "Outputs may be slightly off.")
395
396
        self.num_heads = num_heads
        self.head_size = head_size
397
        self.scale = float(scale)
398
        self.num_kv_heads = num_kv_heads
Woosuk Kwon's avatar
Woosuk Kwon committed
399
400
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
401
        self.alibi_slopes = alibi_slopes
402
403
        self.sliding_window = sliding_window
        self.kv_cache_dtype = kv_cache_dtype
Zhuohan Li's avatar
Zhuohan Li committed
404
405
406

        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
407
408

        suppored_head_sizes = PagedAttention.get_supported_head_sizes()
409
410
411
412
        if head_size not in suppored_head_sizes:
            raise ValueError(
                f"Head size {head_size} is not supported by PagedAttention. "
                f"Supported head sizes are: {suppored_head_sizes}.")
Woosuk Kwon's avatar
Woosuk Kwon committed
413

414
415
        self.attn_type = attn_type

Woosuk Kwon's avatar
Woosuk Kwon committed
416
417
    def forward(
        self,
418
        layer: AttentionLayer,
419
        query: torch.Tensor,
420
421
        key: Optional[torch.Tensor],
        value: Optional[torch.Tensor],
422
        kv_cache: torch.Tensor,
423
        attn_metadata: "XFormersMetadata",
424
        output: Optional[torch.Tensor] = None,
425
    ) -> torch.Tensor:
426
        """Forward pass with xFormers and PagedAttention.
427

428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
        For decoder-only models: query, key and value must be non-None.

        For encoder/decoder models:
        * XFormersImpl.forward() may be invoked for both self- and cross-
          attention layers.
        * For self-attention: query, key and value must be non-None.
        * For cross-attention:
            * Query must be non-None
            * During prefill, key and value must be non-None; key and value
              get cached for use during decode.
            * During decode, key and value may be None, since:
              (1) key and value tensors were cached during prefill, and
              (2) cross-attention key and value tensors do not grow during
                  decode
        
        A note on how the attn_type (attention type enum) argument impacts
        attention forward() behavior:
    
            * DECODER: normal decoder-only behavior;
                use decoder self-attention block table
            * ENCODER: no KV caching; pass encoder sequence
                attributes (encoder_seq_lens/encoder_seq_lens_tensor/
                max_encoder_seq_len) to kernel, in lieu of decoder
451
452
453
454
                sequence attributes (seq_lens/seq_lens_tensor/max_seq_len).
                Used for encoder branch of encoder-decoder models.
            * ENCODER_ONLY: no kv_caching, uses the normal attention 
                attributes (seq_lens/seq_lens_tensor/max_seq_len).
455
456
457
458
459
460
461
            * ENCODER_DECODER: cross-attention behavior;
                use cross-attention block table for caching KVs derived
                from encoder hidden states; since KV sequence lengths
                will match encoder sequence lengths, pass encoder sequence
                attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/
                max_encoder_seq_len)
    
462
        Args:
463
464
465
            query: shape = [num_tokens, num_heads * head_size]
            key: shape = [num_tokens, num_kv_heads * head_size]
            value: shape = [num_tokens, num_kv_heads * head_size]
466
            kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
467
468
                NOTE: kv_cache will be an empty tensor with shape [0]
                for profiling run.
469
            attn_metadata: Metadata for attention.
470
471
472
473
            attn_type: Select attention type, between encoder attention,
                       decoder self-attention, or encoder/decoder cross-
                       attention. Defaults to decoder self-attention,
                       which is the vLLM default generally
474
        Returns:
475
            shape = [num_tokens, num_heads * head_size]
476
        """
477
        attn_type = self.attn_type
478
479
480
481
482
483
        # Check that appropriate attention metadata attributes are
        # selected for the desired attention type
        if (attn_type == AttentionType.ENCODER
                and (not attn_metadata.is_all_encoder_attn_metadata_set)):
            raise AttributeError("Encoder attention requires setting "
                                 "encoder metadata attributes.")
484

485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
        elif (attn_type == AttentionType.ENCODER_DECODER
              and (not attn_metadata.is_all_cross_attn_metadata_set)):
            raise AttributeError("Encoder/decoder cross-attention "
                                 "requires setting cross-attention "
                                 "metadata attributes.")

        query = query.view(-1, self.num_heads, self.head_size)
        if key is not None:
            assert value is not None
            key = key.view(-1, self.num_kv_heads, self.head_size)
            value = value.view(-1, self.num_kv_heads, self.head_size)
        else:
            assert value is None

        # Self-attention vs. cross-attention will impact
        # which KV cache memory-mapping & which
        # seqlen datastructures we utilize

503
        if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0):
504
505
506
507
508
509
510
            # KV-cache during decoder-self- or
            # encoder-decoder-cross-attention, but not
            # during encoder attention.
            #
            # Even if there are no new key/value pairs to cache,
            # we still need to break out key_cache and value_cache
            # i.e. for later use by paged attention
511
512
513
            key_cache, value_cache = PagedAttention.split_kv_cache(
                kv_cache, self.num_kv_heads, self.head_size)

514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
            if (key is not None) and (value is not None):

                if attn_type == AttentionType.ENCODER_DECODER:
                    # Update cross-attention KV cache (prefill-only)
                    # During cross-attention decode, key & value will be None,
                    # preventing this IF-statement branch from running
                    updated_slot_mapping = attn_metadata.cross_slot_mapping
                else:
                    # Update self-attention KV cache (prefill/decode)
                    updated_slot_mapping = attn_metadata.slot_mapping

                # Reshape the input keys and values and store them in the cache.
                # If kv_cache is not provided, the new key and value tensors are
                # not cached. This happens during the initial memory
                # profiling run.
529
530
531
                PagedAttention.write_to_paged_cache(
                    key, value, key_cache, value_cache, updated_slot_mapping,
                    self.kv_cache_dtype, layer._k_scale, layer._v_scale)
532
533
534
        (num_prefill_query_tokens, num_prefill_kv_tokens,
        num_decode_query_tokens) = \
            get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
535
536
537

        output = torch.empty_like(query)
        # Query for decode. KV is not needed because it is already cached.
538
        decode_query = query[num_prefill_query_tokens:]
539
        # QKV for prefill.
540
        query = query[:num_prefill_query_tokens]
541
        if key is not None and value is not None:
542
543
            key = key[:num_prefill_kv_tokens]
            value = value[:num_prefill_kv_tokens]
544

545
546
        assert query.shape[0] == num_prefill_query_tokens
        assert decode_query.shape[0] == num_decode_query_tokens
547
548

        if prefill_meta := attn_metadata.prefill_metadata:
549
            # Prompt run.
550
            if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0:
551
552
553
                # normal attention.
                # block tables are empty if the prompt does not have a cached
                # prefix.
554
                out = self._run_memory_efficient_xformers_forward(
555
                    query, key, value, prefill_meta, attn_type=attn_type)
556
557
                assert out.shape == output[:num_prefill_query_tokens].shape
                output[:num_prefill_query_tokens] = out
Woosuk Kwon's avatar
Woosuk Kwon committed
558
            else:
559
560
                assert attn_type != AttentionType.ENCODER_ONLY, (
                    "Encoder-only models should not have prefix attention.")
561
562
563
564

                assert prefill_meta.query_start_loc is not None
                assert prefill_meta.max_query_len is not None

565
                # prefix-enabled attention
566
567
568
                # TODO(Hai) this triton kernel has regression issue (broke) to
                # deal with different data types between KV and FP8 KV cache,
                # to be addressed separately.
569
                out = PagedAttention.forward_prefix(
570
571
572
                    query,
                    key,
                    value,
573
                    self.kv_cache_dtype,
574
575
                    key_cache,
                    value_cache,
576
                    prefill_meta.block_tables,
577
                    prefill_meta.query_start_loc,
578
579
580
                    prefill_meta.seq_lens_tensor,
                    prefill_meta.context_lens_tensor,
                    prefill_meta.max_query_len,
581
                    self.alibi_slopes,
582
                    self.sliding_window,
583
584
                    layer._k_scale,
                    layer._v_scale,
585
                )
586
587
                assert output[:num_prefill_query_tokens].shape == out.shape
                output[:num_prefill_query_tokens] = out
588
589

        if decode_meta := attn_metadata.decode_metadata:
590
591
            assert attn_type != AttentionType.ENCODER_ONLY, (
                "Encoder-only models should not have decode metadata.")
592
593
594
595
596

            (
                seq_lens_arg,
                max_seq_len_arg,
                block_tables_arg,
597
            ) = get_seq_len_block_table_args(decode_meta, False, attn_type)
598

599
            output[num_prefill_query_tokens:] = PagedAttention.forward_decode(
600
                decode_query,
601
602
                key_cache,
                value_cache,
603
604
605
                block_tables_arg,
                seq_lens_arg,
                max_seq_len_arg,
606
                self.kv_cache_dtype,
607
608
609
                self.num_kv_heads,
                self.scale,
                self.alibi_slopes,
610
611
                layer._k_scale,
                layer._v_scale,
612
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
613
614

        # Reshape the output tensor.
615
616
        return output.view(-1, self.num_heads * self.head_size)

617
    def _run_memory_efficient_xformers_forward(
618
619
620
621
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
622
        attn_metadata: XFormersMetadata,
623
        attn_type: str = AttentionType.DECODER,
624
625
626
627
    ) -> torch.Tensor:
        """Attention for 1D query of multiple prompts. Multiple prompt
        tokens are flattened in to `query` input.

628
629
630
        See https://facebookresearch.github.io/xformers/components/ops.html
        for API spec.

631
        Args:
632
633
634
635
            output: shape = [num_prefill_tokens, num_heads, head_size]
            query: shape = [num_prefill_tokens, num_heads, head_size]
            key: shape = [num_prefill_tokens, num_kv_heads, head_size]
            value: shape = [num_prefill_tokens, num_kv_heads, head_size]
636
            attn_metadata: Metadata for attention.
637
638
639
640
            attn_type: Select attention type, between encoder attention,
                       decoder self-attention, or encoder/decoder cross-
                       attention. Defaults to decoder self-attention,
                       which is the vLLM default generally
641
        """
642

643
644
645
646
647
648
649
650
651
652
653
654
655
656
        original_query = query
        if self.num_kv_heads != self.num_heads:
            # GQA/MQA requires the shape [B, M, G, H, K].
            # Note that the output also has the same shape (which is different
            # from a spec from the doc).
            query = query.view(query.shape[0], self.num_kv_heads,
                               self.num_queries_per_kv, query.shape[-1])
            key = key[:, :,
                      None, :].expand(key.shape[0], self.num_kv_heads,
                                      self.num_queries_per_kv, key.shape[-1])
            value = value[:, :,
                          None, :].expand(value.shape[0], self.num_kv_heads,
                                          self.num_queries_per_kv,
                                          value.shape[-1])
657

658
659
660
        # Set attention bias if not provided. This typically happens at
        # the very attention layer of every iteration.
        # FIXME(woosuk): This is a hack.
661
662
        attn_bias = _get_attn_bias(attn_metadata, attn_type)
        if attn_bias is None:
663
            if self.alibi_slopes is None:
664
665
666

                # Cross attention block of decoder branch of encoder-decoder
                # model uses seq_lens for dec / encoder_seq_lens for enc
667
668
669
670
                if (attn_type == AttentionType.ENCODER_DECODER):
                    assert attn_metadata.seq_lens is not None
                    assert attn_metadata.encoder_seq_lens is not None

671
                    # Cross-attention mask is non-causal
672
673
                    attn_bias = BlockDiagonalMask.from_seqlens(
                        attn_metadata.seq_lens, attn_metadata.encoder_seq_lens)
674
675
676

                # Encoder branch of encoder-decoder model uses
                # attn_metadata.encoder_seq_lens
677
                elif attn_type == AttentionType.ENCODER:
678

679
680
                    assert attn_metadata.encoder_seq_lens is not None

681
                    # Encoder self-attention mask is non-causal
682
683
                    attn_bias = BlockDiagonalMask.from_seqlens(
                        attn_metadata.encoder_seq_lens)
684
685
686
687

                # Self-attention block of encoder-only model just
                # uses the seq_lens directly.
                elif attn_type == AttentionType.ENCODER_ONLY:
688
689
                    assert attn_metadata.seq_lens is not None

690
691
692
693
694
695
696
697
698
699
                    # Encoder self-attention mask is non-causal
                    attn_bias = BlockDiagonalMask.from_seqlens(
                        attn_metadata.seq_lens)

                # Self-attention block of decoder branch just
                # uses the seq_lens directly
                elif attn_type == AttentionType.DECODER:
                    assert attn_metadata.seq_lens is not None

                    # Decoder self-attention mask is causal
700
701
                    attn_bias = BlockDiagonalCausalMask.from_seqlens(
                        attn_metadata.seq_lens)
702
703
704
                else:
                    raise ValueError("Unknown AttentionType: %s", attn_type)

705
706
707
                if self.sliding_window is not None:
                    attn_bias = attn_bias.make_local_attention(
                        self.sliding_window)
708
                attn_bias = [attn_bias]
709
            else:
710
                assert attn_type == AttentionType.DECODER
711
712
713
714
715
716
                assert attn_metadata.seq_lens is not None
                attn_bias = _make_alibi_bias(self.alibi_slopes,
                                             self.num_kv_heads, query.dtype,
                                             attn_metadata.seq_lens)

            _set_attn_bias(attn_metadata, attn_bias, attn_type)
717
718
719
720
721

        # No alibi slopes.
        # TODO(woosuk): Too many view operations. Let's try to reduce
        # them in the future for code readability.
        if self.alibi_slopes is None:
722
            # Add the batch dimension.
723
724
725
726
727
728
729
            query = query.unsqueeze(0)
            key = key.unsqueeze(0)
            value = value.unsqueeze(0)
            out = xops.memory_efficient_attention_forward(
                query,
                key,
                value,
730
                attn_bias=attn_bias[0],
731
                p=0.0,
732
                scale=self.scale)
733
            return out.view_as(original_query)
734
735
736
737
738

        # Attention with alibi slopes.
        # FIXME(woosuk): Because xformers does not support dynamic sequence
        # lengths with custom attention bias, we process each prompt one by
        # one. This is inefficient, especially when we have many short prompts.
739
        assert attn_metadata.seq_lens is not None
740
        output = torch.empty_like(original_query)
741
        start = 0
742
743
        for i, seq_len in enumerate(attn_metadata.seq_lens):
            end = start + seq_len
744
745
746
747
            out = xops.memory_efficient_attention_forward(
                query[None, start:end],
                key[None, start:end],
                value[None, start:end],
748
                attn_bias=attn_bias[i],
749
                p=0.0,
750
                scale=self.scale)
751
            # TODO(woosuk): Unnecessary copy. Optimize.
752
            output[start:end].copy_(out.view_as(original_query[start:end]))
753
            start += seq_len
754
        return output
Woosuk Kwon's avatar
Woosuk Kwon committed
755
756
757
758


def _make_alibi_bias(
    alibi_slopes: torch.Tensor,
Megha Agarwal's avatar
Megha Agarwal committed
759
    num_kv_heads: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
760
    dtype: torch.dtype,
761
    seq_lens: List[int],
762
763
) -> List[AttentionBias]:
    attn_biases: List[AttentionBias] = []
764
765
    for seq_len in seq_lens:
        bias = torch.arange(seq_len, dtype=dtype)
766
        # NOTE(zhuohan): HF uses
767
        #     `bias = bias[None, :].repeat(seq_len, 1)`
768
769
770
771
772
773
774
        # here. We find that both biases give the same results, but
        # the bias below more accurately follows the original ALiBi
        # paper.
        # Calculate a matrix where each element represents ith element- jth
        # element.
        bias = bias[None, :] - bias[:, None]

775
        padded_len = (seq_len + 7) // 8 * 8
776
777
778
779
        num_heads = alibi_slopes.shape[0]
        bias = torch.empty(
            1,  # batch size
            num_heads,
780
            seq_len,
781
782
783
            padded_len,
            device=alibi_slopes.device,
            dtype=dtype,
784
        )[:, :, :, :seq_len].copy_(bias)
785
786
787
788
789
790
        bias.mul_(alibi_slopes[:, None, None])
        if num_heads != num_kv_heads:
            bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
        attn_biases.append(LowerTriangularMaskWithTensorBias(bias))

    return attn_biases