torch_sdpa.py 26.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
""" Attention layer with torch scaled_dot_product_attention
    and PagedAttention."""
from dataclasses import dataclass
6
from typing import Any, Dict, List, Optional, Tuple, Type
7
8
9
10

import torch
from torch.nn.functional import scaled_dot_product_attention

11
12
# yapf conflicts with isort for this block
# yapf: disable
13
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
14
                                              AttentionLayer,
15
16
                                              AttentionMetadata,
                                              AttentionMetadataBuilder,
17
18
19
                                              AttentionType,
                                              is_quantized_kv_cache)
# yapf: enable
20
from vllm.attention.backends.utils import CommonAttentionState
21
from vllm.attention.ops.ipex_attn import PagedAttention, _use_ipex
22
from vllm.attention.ops.paged_attn import PagedAttentionMetadata
23
24
from vllm.logger import init_logger
from vllm.utils import make_tensor_with_pad
25
from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder
26

27
28
logger = init_logger(__name__)

29
30
31

class TorchSDPABackend(AttentionBackend):

32
33
    @staticmethod
    def get_name() -> str:
34
        return "TORCH_SDPA"
35

36
37
38
39
40
    @staticmethod
    def get_impl_cls() -> Type["TorchSDPABackendImpl"]:
        return TorchSDPABackendImpl

    @staticmethod
41
42
    def get_metadata_cls() -> Type["AttentionMetadata"]:
        return TorchSDPAMetadata
43

44
45
46
47
    @staticmethod
    def get_state_cls() -> Type["CommonAttentionState"]:
        return CommonAttentionState

48
49
50
51
    @staticmethod
    def get_builder_cls() -> Type["TorchSDPAMetadataBuilder"]:
        return TorchSDPAMetadataBuilder

52
53
54
55
56
57
58
59
60
61
62
63
64
65
    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
        return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
                                                 num_kv_heads, head_size)

    @staticmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
66
        src_to_dst: torch.Tensor,
67
68
69
70
71
72
    ) -> None:
        PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)

    @staticmethod
    def copy_blocks(
        kv_caches: List[torch.Tensor],
73
        src_to_dists: torch.Tensor,
74
75
76
77
78
    ) -> None:
        PagedAttention.copy_blocks(kv_caches, src_to_dists)


@dataclass
79
class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
80
81
82
83
    """Metadata for TorchSDPABackend.
    """
    # Currently, input sequences can only contain all prompts
    # or all decoding. True if all sequences are prompts.
84
85
86
87
88
89
90
91
92
    chunked_prefill: bool
    seq_lens: Optional[List[int]] = None  # For non-chunked prefill

    # For chunked prefill only
    max_query_len: Optional[int] = None
    max_kv_len: Optional[int] = None
    query_start_loc: Optional[torch.Tensor] = None
    kv_start_loc: Optional[torch.Tensor] = None
    prefill_block_tables: Optional[torch.Tensor] = None
93

94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
    # Begin encoder attn & enc/dec cross-attn fields...
    # Encoder sequence lengths representation
    encoder_seq_lens: Optional[List[int]] = None
    encoder_seq_lens_tensor: Optional[torch.Tensor] = None

    # Maximum sequence length among encoder sequences
    max_encoder_seq_len: Optional[int] = None

    # Number of tokens input to encoder
    num_encoder_tokens: Optional[int] = None

    # Cross-attention memory-mapping data structures: slot mapping
    # and block tables
    cross_slot_mapping: Optional[torch.Tensor] = None
    cross_block_tables: Optional[torch.Tensor] = None

110
111
112
113
114
115
116
    def __post_init__(self):
        # Set during the execution of the first attention op.
        # It is a list because it is needed to set per prompt
        # when alibi slopes is used. It is because of the limitation
        # from xformer API.
        # will not appear in the __repr__ and __init__
        self.attn_bias: Optional[List[torch.Tensor]] = None
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
        self.encoder_attn_bias: Optional[List[torch.Tensor]] = None
        self.cross_attn_bias: Optional[List[torch.Tensor]] = None

    @property
    def is_all_encoder_attn_metadata_set(self):
        '''
        All attention metadata required for encoder attention is set.
        '''
        return ((self.encoder_seq_lens is not None)
                and (self.encoder_seq_lens_tensor is not None)
                and (self.max_encoder_seq_len is not None))

    @property
    def is_all_cross_attn_metadata_set(self):
        '''
        All attention metadata required for enc/dec cross-attention is set.

        Superset of encoder attention required metadata.
        '''
        return (self.is_all_encoder_attn_metadata_set
                and (self.cross_slot_mapping is not None)
                and (self.cross_block_tables is not None))
139

140
141
    @property
    def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]:
142
143
144
        if self.num_prefill_tokens == 0:
            return None
        return self
145
146
147

    @property
    def decode_metadata(self) -> Optional["TorchSDPAMetadata"]:
148
        if self.num_decode_tokens == 0:
149
150
151
            return None
        return self

152
153
    def get_seq_lens(
        self,
154
        attn_type: str,
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
    ):
        '''
        Extract appropriate sequence lengths from attention metadata
        according to attention type.

        Arguments:

        * attn_metadata: Attention metadata structure associated with attention
        * attn_type: encoder attention, decoder self-attention,
                    encoder/decoder cross-attention

        Returns:
        * Appropriate sequence lengths tensor for query
        * Appropriate sequence lengths tensor for key & value
        '''

171
172
        if (attn_type == AttentionType.DECODER
                or attn_type == AttentionType.ENCODER_ONLY):
173
174
175
176
177
178
179
180
181
182
183
184
185
186
            seq_lens_q = self.seq_lens
            seq_lens_kv = self.seq_lens
        elif attn_type == AttentionType.ENCODER:
            seq_lens_q = self.encoder_seq_lens
            seq_lens_kv = self.encoder_seq_lens
        elif attn_type == AttentionType.ENCODER_DECODER:
            seq_lens_q = self.seq_lens
            seq_lens_kv = self.encoder_seq_lens
        else:
            raise AttributeError(f"Invalid attention type {str(attn_type)}")
        return seq_lens_q, seq_lens_kv

    def get_attn_bias(
        self,
187
        attn_type: str,
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
    ) -> Optional[List[torch.Tensor]]:
        '''
        Extract appropriate attention bias from attention metadata
        according to attention type.

        Arguments:

        * attn_metadata: Attention metadata structure associated with attention
        * attn_type: encoder attention, decoder self-attention,
                    encoder/decoder cross-attention

        Returns:
        * Appropriate attention bias value given the attention type
        '''

203
204
        if (attn_type == AttentionType.DECODER
                or attn_type == AttentionType.ENCODER_ONLY):
205
206
207
208
209
210
211
212
213
214
215
            return self.attn_bias
        elif attn_type == AttentionType.ENCODER:
            return self.encoder_attn_bias
        elif attn_type == AttentionType.ENCODER_DECODER:
            return self.cross_attn_bias
        else:
            raise AttributeError(f"Invalid attention type {str(attn_type)}")

    def set_attn_bias(
        self,
        attn_bias: List[torch.Tensor],
216
        attn_type: str,
217
218
219
220
221
222
223
224
225
226
227
228
229
    ) -> None:
        '''
        Update appropriate attention bias field of attention metadata,
        according to attention type.

        Arguments:

        * attn_metadata: Attention metadata structure associated with attention
        * attn_bias: The desired attention bias value
        * attn_type: encoder attention, decoder self-attention,
                    encoder/decoder cross-attention
        '''

230
231
        if (attn_type == AttentionType.DECODER
                or attn_type == AttentionType.ENCODER_ONLY):
232
233
234
235
236
237
238
239
240
241
            self.attn_bias = attn_bias
        elif attn_type == AttentionType.ENCODER:
            self.encoder_attn_bias = attn_bias
        elif attn_type == AttentionType.ENCODER_DECODER:
            self.cross_attn_bias = attn_bias
        else:
            raise AttributeError(f"Invalid attention type {str(attn_type)}")

    def get_seq_len_block_table_args(
        self,
242
        attn_type: str,
243
244
245
246
247
248
249
    ) -> tuple:
        '''
        The particular choice of sequence-length- and block-table-related
        attributes which should be extracted from attn_metadata is dependent
        on the type of attention operation.

        Decoder attn -> select entirely decoder self-attention-related fields
250
        Encoder/decoder cross-attn -> select encoder sequence lengths &
251
252
                                    cross-attn block-tables fields
        Encoder attn -> select encoder sequence lengths fields & no block tables
253

254
255
256
257
258
259
260
261
262
263
264
265
266
267
        Arguments:

        * attn_metadata: Attention metadata structure associated with attention
        * is_prompt: True if prefill, False otherwise
        * attn_type: encoder attention, decoder self-attention,
                    encoder/decoder cross-attention

        Returns:

        * Appropriate sequence-lengths tensor
        * Appropriate max sequence-length scalar
        * Appropriate block tables (or None)
        '''

268
269
        if (attn_type == AttentionType.DECODER
                or attn_type == AttentionType.ENCODER_ONLY):
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
            # Decoder self-attention
            # Choose max_seq_len based on whether we are in prompt_run
            return (self.seq_lens_tensor, self.max_decode_seq_len,
                    self.block_tables)
        elif attn_type == AttentionType.ENCODER_DECODER:
            # Enc/dec cross-attention KVs match encoder sequence length;
            # cross-attention utilizes special "cross" block tables
            return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len,
                    self.cross_block_tables)
        elif attn_type == AttentionType.ENCODER:
            # No block tables associated with encoder attention
            return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len,
                    None)
        else:
            raise AttributeError(f"Invalid attention type {str(attn_type)}")

286

287
288
289
290
class TorchSDPAMetadataBuilder(AttentionMetadataBuilder[TorchSDPAMetadata]):

    def __init__(self, input_builder: ModelInputForCPUBuilder) -> None:
        self.chunked_prefill = input_builder.chunked_prefill
291
292
293
294
        self.input_builder = input_builder

    def prepare(self):
        self.input_data = self.input_builder.input_data
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356

    def build(self, seq_lens: List[int], query_lens: List[int],
              cuda_graph_pad_size: int, batch_size: int) -> TorchSDPAMetadata:
        input_data = self.input_data
        prefill_seq_lens = seq_lens[0:input_data.num_prefills]
        prefill_query_lens = query_lens[0:input_data.num_prefills]
        slot_mapping = torch.tensor(input_data.slot_mapping,
                                    dtype=torch.long,
                                    device="cpu")

        # For chunked-prefill
        if self.chunked_prefill and input_data.num_prefill_tokens != 0:
            prefill_block_tables = make_tensor_with_pad(
                self.input_data.prefill_block_tables,
                pad=0,
                dtype=torch.int32,
                device="cpu",
            )
            query_lens_tensor = torch.tensor(prefill_query_lens,
                                             dtype=torch.int32,
                                             device="cpu")
            kv_lens_tensor = torch.tensor(prefill_seq_lens,
                                          dtype=torch.int32,
                                          device="cpu")
            query_start_loc = torch.zeros(input_data.num_prefills + 1,
                                          dtype=torch.int32,
                                          device="cpu")
            kv_start_loc = torch.zeros(input_data.num_prefills + 1,
                                       dtype=torch.int32,
                                       device="cpu")
            torch.cumsum(query_lens_tensor,
                         dim=0,
                         dtype=torch.int32,
                         out=query_start_loc[1:])
            torch.cumsum(kv_lens_tensor,
                         dim=0,
                         dtype=torch.int32,
                         out=kv_start_loc[1:])
            max_query_len = max(prefill_query_lens)
            max_kv_len = max(prefill_seq_lens)
        else:
            prefill_block_tables = None
            query_start_loc = None
            kv_start_loc = None
            max_query_len = None
            max_kv_len = None

        # For paged attention
        if input_data.num_decode_tokens != 0:
            seq_lens_tensor = torch.tensor(
                input_data.seq_lens[input_data.num_prefills:],
                dtype=torch.int32,
                device="cpu",
            )
            block_tables = make_tensor_with_pad(
                self.input_data.decode_block_tables,
                pad=0,
                dtype=torch.int32,
                device="cpu",
            )
        else:
            block_tables = torch.tensor([])
357
358
359
360
361
            seq_lens_tensor = torch.tensor(
                input_data.seq_lens[:input_data.num_prefills],
                dtype=torch.int32,
                device="cpu",
            )
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387

        # For multi-modal models
        placeholder_index_maps = None
        if len(input_data.multi_modal_inputs_list) != 0:
            placeholder_index_maps = {
                modality: placeholder_map.index_map()
                for modality, placeholder_map in
                input_data.multi_modal_placeholder_maps.items()
            }

        attn_metadata = TorchSDPAMetadata(
            chunked_prefill=self.chunked_prefill,
            seq_lens=prefill_seq_lens,
            seq_lens_tensor=seq_lens_tensor,
            max_query_len=max_query_len,
            max_kv_len=max_kv_len,
            query_start_loc=query_start_loc,
            kv_start_loc=kv_start_loc,
            max_decode_seq_len=input_data.max_decode_seq_len,
            num_prefills=input_data.num_prefills,
            num_prefill_tokens=input_data.num_prefill_tokens,
            num_decode_tokens=input_data.num_decode_tokens,
            block_tables=block_tables,
            prefill_block_tables=prefill_block_tables,
            slot_mapping=slot_mapping,
            multi_modal_placeholder_index_maps=placeholder_index_maps,
388
            enable_kv_scales_calculation=False,
389
390
391
392
393
        )

        return attn_metadata


394
class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
395
396
397
398
399
400

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
401
402
403
404
        num_kv_heads: int,
        alibi_slopes: Optional[List[float]],
        sliding_window: Optional[int],
        kv_cache_dtype: str,
405
        blocksparse_params: Optional[Dict[str, Any]] = None,
406
        logits_soft_cap: Optional[float] = None,
407
        attn_type: str = AttentionType.DECODER,
408
        use_irope: bool = False,
409
    ) -> None:
410
411
412
413
        if blocksparse_params is not None:
            raise ValueError(
                "Torch SPDA does not support block-sparse attention.")
        if logits_soft_cap is not None:
414
415
            logger.warning_once("Torch SPDA does not support logits soft cap. "
                                "Outputs may be slightly off.")
416
417
418
419
        if use_irope:
            logger.warning_once(
                "Using irope in Torch SPDA is not supported yet, it will fall"
                " back to global attention for long context.")
420
421
422
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
423
        self.num_kv_heads = num_kv_heads
424
425
426
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
        self.alibi_slopes = alibi_slopes
427
428
        self.sliding_window = sliding_window
        self.kv_cache_dtype = kv_cache_dtype
429
430
431

        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
432
433
434
435
436
        self.need_mask = (self.alibi_slopes is not None
                          or self.sliding_window is not None)

        supported_head_sizes = PagedAttention.get_supported_head_sizes()
        if head_size not in supported_head_sizes:
437
438
            raise ValueError(
                f"Head size {head_size} is not supported by PagedAttention. "
439
                f"Supported head sizes are: {supported_head_sizes}.")
440
441

        if is_quantized_kv_cache(kv_cache_dtype) and not _use_ipex:
442
            raise NotImplementedError(
443
444
                "Torch SDPA backend FP8 KV cache requires "
                "intel_extension_for_pytorch support.")
445
        self.attn_type = attn_type
446
447
448

    def forward(
        self,
449
        layer: AttentionLayer,
450
451
452
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
453
        kv_cache: torch.Tensor,
454
        attn_metadata: TorchSDPAMetadata,  # type: ignore
455
        output: Optional[torch.Tensor] = None,
456
457
458
459
460
461
462
463
    ) -> torch.Tensor:
        """Forward pass with torch SDPA and PagedAttention.

        Args:
            query: shape = [num_tokens, num_heads * head_size]
            key: shape = [num_tokens, num_kv_heads * head_size]
            value: shape = [num_tokens, num_kv_heads * head_size]
            kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
464
465
                NOTE: kv_cache will be an empty tensor with shape [0]
                for profiling run.
466
467
468
469
            attn_metadata: Metadata for attention.
        Returns:
            shape = [num_tokens, num_heads * head_size]
        """
470
        attn_type = self.attn_type
471
472
473
474
475
476
477
478
479
480
        if (attn_type == AttentionType.ENCODER
                and (not attn_metadata.is_all_encoder_attn_metadata_set)):
            raise AttributeError("Encoder attention requires setting "
                                 "encoder metadata attributes.")
        elif (attn_type == AttentionType.ENCODER_DECODER
              and (not attn_metadata.is_all_cross_attn_metadata_set)):
            raise AttributeError("Encoder/decoder cross-attention "
                                 "requires setting cross-attention "
                                 "metadata attributes.")

481
482
        # Reshape the query, key, and value tensors.
        query = query.view(-1, self.num_heads, self.head_size)
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
        if key is not None:
            assert value is not None
            key = key.view(-1, self.num_kv_heads, self.head_size)
            value = value.view(-1, self.num_kv_heads, self.head_size)
        else:
            assert value is None

        if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0):
            # KV-cache during decoder-self- or
            # encoder-decoder-cross-attention, but not
            # during encoder attention.
            #
            # Even if there are no new key/value pairs to cache,
            # we still need to break out key_cache and value_cache
            # i.e. for later use by paged attention
498
499
500
            key_cache, value_cache = PagedAttention.split_kv_cache(
                kv_cache, self.num_kv_heads, self.head_size)

501
502
503
504
505
506
507
508
509
510
            if (key is not None) and (value is not None):
                if attn_type == AttentionType.ENCODER_DECODER:
                    # Update cross-attention KV cache (prefill-only)
                    # During cross-attention decode, key & value will be None,
                    # preventing this IF-statement branch from running
                    updated_slot_mapping = attn_metadata.cross_slot_mapping
                else:
                    # Update self-attention KV cache (prefill/decode)
                    updated_slot_mapping = attn_metadata.slot_mapping

511
512
513
                PagedAttention.write_to_paged_cache(
                    key, value, key_cache, value_cache, updated_slot_mapping,
                    self.kv_cache_dtype, layer._k_scale, layer._v_scale)
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534

        if attn_type != AttentionType.ENCODER:
            # Decoder self-attention supports chunked prefill.
            # Encoder/decoder cross-attention requires no chunked
            # prefill (100% prefill or 100% decode tokens, no mix)
            num_prefill_tokens = attn_metadata.num_prefill_tokens
            num_decode_tokens = attn_metadata.num_decode_tokens
        else:
            # Encoder attention - chunked prefill is not applicable;
            # derive token-count from query shape & and treat them
            # as 100% prefill tokens
            assert attn_metadata.num_encoder_tokens is not None
            num_prefill_tokens = attn_metadata.num_encoder_tokens
            num_decode_tokens = 0

        if attn_type == AttentionType.DECODER:
            # Only enforce this shape-constraint for decoder
            # self-attention
            assert key.shape[0] == num_prefill_tokens + num_decode_tokens
            assert value.shape[0] == num_prefill_tokens + num_decode_tokens

535
        output = torch.empty_like(query)
536
        if prefill_meta := attn_metadata.prefill_metadata:
537
            assert attn_metadata.seq_lens is not None
538
539
540
541
542
543
544
            if not prefill_meta.prefill_metadata.chunked_prefill:  # type: ignore
                self._run_sdpa_forward(output,
                                       query,
                                       key,
                                       value,
                                       prefill_meta,
                                       attn_type=attn_type)
545
546
            else:
                # prefix-enabled attention
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
                assert not self.need_mask
                import intel_extension_for_pytorch.llm.modules as ipex_modules
                output = torch.empty_like(query)
                ipex_modules.PagedAttention.flash_attn_varlen_func(
                    output[:prefill_meta.num_prefill_tokens, :, :],
                    query[:prefill_meta.num_prefill_tokens, :, :],
                    key_cache,
                    value_cache,
                    prefill_meta.query_start_loc,
                    prefill_meta.kv_start_loc,
                    prefill_meta.max_query_len,
                    prefill_meta.max_kv_len,
                    self.scale,
                    True,
                    prefill_meta.prefill_block_tables,
                    self.alibi_slopes,
                )
564

565
        if decode_meta := attn_metadata.decode_metadata:
566
567
            assert attn_type != AttentionType.ENCODER_ONLY, (
                "Encoder-only models should not have decode metadata.")
568
            # Decoding run.
569
570
571
572
573
574
            (
                seq_lens_arg,
                max_seq_len_arg,
                block_tables_arg,
            ) = decode_meta.get_seq_len_block_table_args(attn_type)

575
576
577
            PagedAttention.forward_decode(
                output[attn_metadata.num_prefill_tokens:, :, :],
                query[attn_metadata.num_prefill_tokens:, :, :],
578
579
                key_cache,
                value_cache,
580
581
582
                block_tables_arg,
                seq_lens_arg,
                max_seq_len_arg,
583
                self.kv_cache_dtype,
584
585
586
                self.num_kv_heads,
                self.scale,
                self.alibi_slopes,
587
588
                layer._k_scale,
                layer._v_scale,
589
590
591
592
593
            )

        # Reshape the output tensor.
        return output.view(-1, self.num_heads * self.head_size)

594
595
    def _run_sdpa_forward(
        self,
596
        output: torch.Tensor,
597
598
599
600
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        attn_metadata: TorchSDPAMetadata,
601
        attn_type: str = AttentionType.DECODER,
602
    ) -> None:
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
        if self.num_kv_heads != self.num_heads:
            key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
            value = value.repeat_interleave(self.num_queries_per_kv, dim=1)

        attn_masks = attn_metadata.get_attn_bias(attn_type)
        if attn_masks is None:
            if self.alibi_slopes is not None:
                attn_masks = _make_alibi_bias(
                    self.alibi_slopes, query.dtype,
                    attn_metadata.seq_lens)  # type: ignore
            elif self.sliding_window is not None:
                assert attn_metadata.seq_lens is not None
                attn_masks = _make_sliding_window_bias(
                    attn_metadata.seq_lens, self.sliding_window,
                    query.dtype)  # type: ignore
            else:
                seq_lens, _ = attn_metadata.get_seq_lens(attn_type)
                attn_masks = [None] * len(seq_lens)
            attn_metadata.set_attn_bias(attn_masks, attn_type)

        query = query.movedim(0, query.dim() - 2)
        key = key.movedim(0, key.dim() - 2)
        value = value.movedim(0, value.dim() - 2)

        causal_attn = (attn_type == AttentionType.DECODER)

        seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type)
        start_q, start_kv = 0, 0
        for seq_len_q, seq_len_kv, mask in zip(seq_lens_q, seq_lens_kv,
                                               attn_masks):
            end_q = start_q + seq_len_q
            end_kv = start_kv + seq_len_kv
            sub_out = scaled_dot_product_attention(
                query[None, :, start_q:end_q, :],
                key[None, :, start_kv:end_kv, :],
                value[None, :, start_kv:end_kv, :],
                attn_mask=mask,
                dropout_p=0.0,
641
                is_causal=causal_attn and mask is None,
642
643
644
645
                scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0)
            output[start_q:end_q, :, :] = sub_out
            start_q, start_kv = end_q, end_kv

646
647
648
649

def _make_alibi_bias(
    alibi_slopes: torch.Tensor,
    dtype: torch.dtype,
650
    seq_lens: List[int],
651
) -> List[torch.Tensor]:
652
    attn_biases: List[torch.Tensor] = []
653
654
    for seq_len in seq_lens:
        bias = torch.arange(seq_len, dtype=dtype)
655
        # NOTE(zhuohan): HF uses
656
        #     `bias = bias[None, :].repeat(seq_len, 1)`
657
658
659
660
661
662
        # here. We find that both biases give the same results, but
        # the bias below more accurately follows the original ALiBi
        # paper.
        bias = bias[None, :] - bias[:, None]

        num_heads = alibi_slopes.shape[0]
663
        bias = bias[None, :].repeat((num_heads, 1, 1))
664
        bias.mul_(alibi_slopes[:, None, None]).unsqueeze_(0)
665
        inf_mask = torch.empty(
666
            (1, seq_len, seq_len),
667
668
669
670
671
672
673
            dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1)
        attn_biases.append((bias + inf_mask).to(dtype))

    return attn_biases


def _make_sliding_window_bias(
674
    seq_lens: List[int],
675
676
677
    window_size: Optional[int],
    dtype: torch.dtype,
) -> List[torch.Tensor]:
678
    attn_biases: List[torch.Tensor] = []
679
    for seq_len in seq_lens:
680
        tensor = torch.full(
681
            (1, seq_len, seq_len),
682
683
684
685
686
687
688
689
690
691
692
            dtype=dtype,
            fill_value=1,
        )
        shift = 0
        mask = torch.tril(tensor, diagonal=shift).to(dtype)  # type: ignore
        if window_size is not None:
            mask = torch.triu(mask, diagonal=shift - window_size + 1)
        mask = torch.log(mask)
        attn_biases.append(mask.to(dtype))

    return attn_biases