xformers.py 18.4 KB
Newer Older
1
"""Attention layer with xFormers and PagedAttention."""
2
from dataclasses import dataclass
3
from typing import Any, Dict, List, Optional, Tuple, Type
Woosuk Kwon's avatar
Woosuk Kwon committed
4
5

import torch
6
from xformers import ops as xops
7
8
from xformers.ops.fmha.attn_bias import (AttentionBias,
                                         BlockDiagonalCausalMask,
Woosuk Kwon's avatar
Woosuk Kwon committed
9
                                         LowerTriangularMaskWithTensorBias)
Woosuk Kwon's avatar
Woosuk Kwon committed
10

11
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
12
                                              AttentionMetadata)
13
14
from vllm.attention.ops.paged_attn import (PagedAttention,
                                           PagedAttentionMetadata)
15
from vllm.logger import init_logger
Woosuk Kwon's avatar
Woosuk Kwon committed
16

17
logger = init_logger(__name__)
18

19
20
21

class XFormersBackend(AttentionBackend):

22
23
24
25
    @staticmethod
    def get_name() -> str:
        return "xformers"

26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
    @staticmethod
    def get_impl_cls() -> Type["XFormersImpl"]:
        return XFormersImpl

    @staticmethod
    def make_metadata(*args, **kwargs) -> "XFormersMetadata":
        return XFormersMetadata(*args, **kwargs)

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
        return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
                                                 num_kv_heads, head_size)

    @staticmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
        src_to_dst: Dict[int, int],
    ) -> None:
        PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)

    @staticmethod
    def copy_blocks(
        kv_caches: List[torch.Tensor],
55
        src_to_dists: torch.Tensor,
56
57
58
59
60
    ) -> None:
        PagedAttention.copy_blocks(kv_caches, src_to_dists)


@dataclass
61
class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
62
63
64
65
66
67
68
    """Metadata for XFormersbackend.

    NOTE: Any python object stored here is not updated when it is
    cuda-graph replayed. If you have values that need to be changed
    dynamically, it should be stored in tensor. The tensor has to be
    updated from `CUDAGraphRunner.forward` API.
    """
69
70
71
72
73
    # (batch_size,). The sequence length per sequence. Sequence length means
    # the computed tokens + new tokens None if it is a decoding.
    seq_lens: Optional[List[int]]
    # seq_lens stored as a tensor.
    seq_lens_tensor: Optional[torch.Tensor]
74
75
76
77
78

    # |---------- N-1 iteration --------|
    # |---------------- N iteration ---------------------|
    # |- tokenA -|......................|-- newTokens ---|
    # |---------- context_len ----------|
79
80
    # |-------------------- seq_len ----------------------|
    #                                   |-- query_len ---|
81

82
    # Maximum query length in the batch. None for decoding.
83
    max_query_len: Optional[int]
84
    # FIXME: It is for flash attn.
85
86
87
88
89
90
    # Maximum sequence length among prefill batch. 0 if there are decoding
    # requests only.
    max_prefill_seq_len: int
    # Maximum sequence length among decode batch. 0 if there are prefill
    # requests only.
    max_decode_seq_len: int
91
92
93
    # (batch_size + 1,). The cumulative subquery lengths of the sequences in
    # the batch, used to index into subquery. E.g., if the subquery length
    # is [4, 6], it is [0, 4, 10].
94
    query_start_loc: Optional[torch.Tensor]
95
96
97
98
99
    # FIXME: It is for flash attn.
    # (batch_size + 1,). The cumulative sequence lengths of the sequences in
    # the batch, used to index into sequence. E.g., if the sequence length is
    # [4, 6], it is [0, 4, 10].
    seq_start_loc: Optional[torch.Tensor]
100
101
102
    # (batch_size,) A tensor of context lengths (tokens that are computed
    # so far).
    context_lens_tensor: Optional[torch.Tensor]
103
104
105
106
107

    # Whether or not if cuda graph is enabled.
    # Cuda-graph is currently enabled for decoding only.
    # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
    use_cuda_graph: bool
108
109
    _cached_prefill_metadata: Optional["XFormersMetadata"] = None
    _cached_decode_metadata: Optional["XFormersMetadata"] = None
110
111
112
113
114
115
116
117
118

    def __post_init__(self):
        # Set during the execution of the first attention op.
        # It is a list because it is needed to set per prompt
        # when alibi slopes is used. It is because of the limitation
        # from xformer API.
        # will not appear in the __repr__ and __init__
        self.attn_bias: Optional[List[AttentionBias]] = None

119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
    @property
    def prefill_metadata(self) -> Optional["XFormersMetadata"]:
        if self.num_prefills == 0:
            return None

        if self._cached_prefill_metadata is not None:
            return self._cached_prefill_metadata

        assert self.seq_lens is not None
        assert self.seq_lens_tensor is not None
        assert self.query_start_loc is not None
        assert self.context_lens_tensor is not None
        assert self.block_tables is not None

        self._cached_prefill_metadata = XFormersMetadata(
            num_prefills=self.num_prefills,
            num_prefill_tokens=self.num_prefill_tokens,
            num_decode_tokens=0,
            slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
            seq_lens=self.seq_lens[:self.num_prefills],
            seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
            max_query_len=self.max_query_len,
            max_prefill_seq_len=self.max_prefill_seq_len,
            max_decode_seq_len=0,
            query_start_loc=self.query_start_loc[:self.num_prefills + 1],
            seq_start_loc=None,
            context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
            block_tables=self.block_tables[:self.num_prefills],
            use_cuda_graph=False,
        )
        return self._cached_prefill_metadata

    @property
    def decode_metadata(self) -> Optional["XFormersMetadata"]:
        if self.num_decode_tokens == 0:
            return None

        if self._cached_decode_metadata is not None:
            return self._cached_decode_metadata
        assert self.block_tables is not None
        assert self.seq_lens_tensor is not None

        self._cached_decode_metadata = XFormersMetadata(
            num_prefills=0,
            num_prefill_tokens=0,
            num_decode_tokens=self.num_decode_tokens,
            slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
            seq_lens=None,
            seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
            max_query_len=None,
            max_prefill_seq_len=0,
            max_decode_seq_len=self.max_decode_seq_len,
            query_start_loc=None,
            seq_start_loc=None,
            context_lens_tensor=None,
            block_tables=self.block_tables[self.num_prefills:],
            use_cuda_graph=self.use_cuda_graph,
        )
        return self._cached_decode_metadata


class XFormersImpl(AttentionImpl[XFormersMetadata]):
181
182
    """
    If the input tensors contain prompt tokens, the layout is as follows:
183
184
    |<--------------- num_prefill_tokens ----------------->|	
    |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
185
186

    Otherwise, the layout is as follows:	
187
188
    |<----------------- num_decode_tokens ------------------>|	
    |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
189
190
191
192
193
194

    Generation tokens can contain padding when cuda-graph is used.
    Currently, prompt tokens don't contain any padding.

    The prompts might have different lengths, while the generation tokens
    always have length 1.
195
196
197
198
199
200
201
202
203

    If chunked prefill is enabled, prefill tokens and decode tokens can be
    batched together in a flattened 1D query.

    |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
    |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|

    Currently, cuda graph is disabled for chunked prefill, meaning there's no
    padding between prefill and decode tokens.
204
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
205

Woosuk Kwon's avatar
Woosuk Kwon committed
206
207
208
209
210
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
211
212
213
214
        num_kv_heads: int,
        alibi_slopes: Optional[List[float]],
        sliding_window: Optional[int],
        kv_cache_dtype: str,
215
        blocksparse_params: Optional[Dict[str, Any]] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
216
    ) -> None:
217
218
        assert blocksparse_params is None, ValueError(
            "XFormer does not support block-sparse attention.")
219
220
        self.num_heads = num_heads
        self.head_size = head_size
221
        self.scale = float(scale)
222
        self.num_kv_heads = num_kv_heads
Woosuk Kwon's avatar
Woosuk Kwon committed
223
224
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
225
        self.alibi_slopes = alibi_slopes
226
227
        self.sliding_window = sliding_window
        self.kv_cache_dtype = kv_cache_dtype
Zhuohan Li's avatar
Zhuohan Li committed
228
229
230

        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
231
232

        suppored_head_sizes = PagedAttention.get_supported_head_sizes()
233
234
235
236
        if head_size not in suppored_head_sizes:
            raise ValueError(
                f"Head size {head_size} is not supported by PagedAttention. "
                f"Supported head sizes are: {suppored_head_sizes}.")
Woosuk Kwon's avatar
Woosuk Kwon committed
237
238
239

    def forward(
        self,
240
241
242
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
243
        kv_cache: Optional[torch.Tensor],
244
        attn_metadata: "XFormersMetadata",
245
        kv_scale: float = 1.0,
246
    ) -> torch.Tensor:
247
        """Forward pass with xFormers and PagedAttention.
248
249

        Args:
250
251
252
            query: shape = [num_tokens, num_heads * head_size]
            key: shape = [num_tokens, num_kv_heads * head_size]
            value: shape = [num_tokens, num_kv_heads * head_size]
253
254
            kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
            attn_metadata: Metadata for attention.
255
        Returns:
256
            shape = [num_tokens, num_heads * head_size]
257
        """
258
        query = query.view(-1, self.num_heads, self.head_size)
Zhuohan Li's avatar
Zhuohan Li committed
259
260
        key = key.view(-1, self.num_kv_heads, self.head_size)
        value = value.view(-1, self.num_kv_heads, self.head_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
261

262
263
264
265
266
267
268
269
270
271
        if kv_cache is not None:
            key_cache, value_cache = PagedAttention.split_kv_cache(
                kv_cache, self.num_kv_heads, self.head_size)

            # Reshape the input keys and values and store them in the cache.
            # If kv_cache is not provided, the new key and value tensors are
            # not cached. This happens during the initial memory profiling run.
            PagedAttention.write_to_paged_cache(key, value, key_cache,
                                                value_cache,
                                                attn_metadata.slot_mapping,
272
                                                self.kv_cache_dtype, kv_scale)
Woosuk Kwon's avatar
Woosuk Kwon committed
273

274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
        num_prefill_tokens = attn_metadata.num_prefill_tokens
        num_decode_tokens = attn_metadata.num_decode_tokens
        assert key.shape[0] == num_prefill_tokens + num_decode_tokens
        assert value.shape[0] == num_prefill_tokens + num_decode_tokens

        output = torch.empty_like(query)
        # Query for decode. KV is not needed because it is already cached.
        decode_query = query[num_prefill_tokens:]
        # QKV for prefill.
        query = query[:num_prefill_tokens]
        key = key[:num_prefill_tokens]
        value = value[:num_prefill_tokens]

        assert query.shape[0] == num_prefill_tokens
        assert decode_query.shape[0] == num_decode_tokens

        if prefill_meta := attn_metadata.prefill_metadata:
291
            # Prompt run.
292
            if kv_cache is None or prefill_meta.block_tables.numel() == 0:
293
294
295
                # normal attention.
                # block tables are empty if the prompt does not have a cached
                # prefix.
296
297
298
299
                out = self._run_memory_efficient_xformers_forward(
                    query, key, value, prefill_meta)
                assert out.shape == output[:num_prefill_tokens].shape
                output[:num_prefill_tokens] = out
Woosuk Kwon's avatar
Woosuk Kwon committed
300
            else:
301
                # prefix-enabled attention
302
303
304
                # TODO(Hai) this triton kernel has regression issue (broke) to
                # deal with different data types between KV and FP8 KV cache,
                # to be addressed separately.
305
                out = PagedAttention.forward_prefix(
306
307
308
309
310
                    query,
                    key,
                    value,
                    key_cache,
                    value_cache,
311
                    prefill_meta.block_tables,
312
                    prefill_meta.query_start_loc,
313
314
315
                    prefill_meta.seq_lens_tensor,
                    prefill_meta.context_lens_tensor,
                    prefill_meta.max_query_len,
316
                    self.alibi_slopes,
317
                    self.sliding_window,
318
                )
319
320
321
322
323
324
                assert output[:num_prefill_tokens].shape == out.shape
                output[:num_prefill_tokens] = out

        if decode_meta := attn_metadata.decode_metadata:
            output[num_prefill_tokens:] = PagedAttention.forward_decode(
                decode_query,
325
326
                key_cache,
                value_cache,
327
                decode_meta.block_tables,
328
                decode_meta.seq_lens_tensor,
329
                decode_meta.max_decode_seq_len,
330
                self.kv_cache_dtype,
331
332
333
                self.num_kv_heads,
                self.scale,
                self.alibi_slopes,
334
                kv_scale,
335
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
336
337

        # Reshape the output tensor.
338
339
        return output.view(-1, self.num_heads * self.head_size)

340
    def _run_memory_efficient_xformers_forward(
341
342
343
344
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
345
        attn_metadata: XFormersMetadata,
346
347
348
349
    ) -> torch.Tensor:
        """Attention for 1D query of multiple prompts. Multiple prompt
        tokens are flattened in to `query` input.

350
351
352
        See https://facebookresearch.github.io/xformers/components/ops.html
        for API spec.

353
        Args:
354
355
356
357
            output: shape = [num_prefill_tokens, num_heads, head_size]
            query: shape = [num_prefill_tokens, num_heads, head_size]
            key: shape = [num_prefill_tokens, num_kv_heads, head_size]
            value: shape = [num_prefill_tokens, num_kv_heads, head_size]
358
            attn_metadata: Metadata for attention.
359
        """
360
        assert attn_metadata.seq_lens is not None
361
362
363
364
365
366
367
368
369
370
371
372
373
374
        original_query = query
        if self.num_kv_heads != self.num_heads:
            # GQA/MQA requires the shape [B, M, G, H, K].
            # Note that the output also has the same shape (which is different
            # from a spec from the doc).
            query = query.view(query.shape[0], self.num_kv_heads,
                               self.num_queries_per_kv, query.shape[-1])
            key = key[:, :,
                      None, :].expand(key.shape[0], self.num_kv_heads,
                                      self.num_queries_per_kv, key.shape[-1])
            value = value[:, :,
                          None, :].expand(value.shape[0], self.num_kv_heads,
                                          self.num_queries_per_kv,
                                          value.shape[-1])
375
376
377
        # Set attention bias if not provided. This typically happens at
        # the very attention layer of every iteration.
        # FIXME(woosuk): This is a hack.
378
        if attn_metadata.attn_bias is None:
379
380
            if self.alibi_slopes is None:
                attn_bias = BlockDiagonalCausalMask.from_seqlens(
381
                    attn_metadata.seq_lens)
382
383
384
                if self.sliding_window is not None:
                    attn_bias = attn_bias.make_local_attention(
                        self.sliding_window)
385
                attn_metadata.attn_bias = [attn_bias]
386
            else:
387
                attn_metadata.attn_bias = _make_alibi_bias(
388
                    self.alibi_slopes, self.num_kv_heads, query.dtype,
389
                    attn_metadata.seq_lens)
390
391
392
393
394

        # No alibi slopes.
        # TODO(woosuk): Too many view operations. Let's try to reduce
        # them in the future for code readability.
        if self.alibi_slopes is None:
395
            # Add the batch dimension.
396
397
398
399
400
401
402
            query = query.unsqueeze(0)
            key = key.unsqueeze(0)
            value = value.unsqueeze(0)
            out = xops.memory_efficient_attention_forward(
                query,
                key,
                value,
403
                attn_bias=attn_metadata.attn_bias[0],
404
                p=0.0,
405
                scale=self.scale)
406
            return out.view_as(original_query)
407
408
409
410
411

        # Attention with alibi slopes.
        # FIXME(woosuk): Because xformers does not support dynamic sequence
        # lengths with custom attention bias, we process each prompt one by
        # one. This is inefficient, especially when we have many short prompts.
412
        output = torch.empty_like(original_query)
413
        start = 0
414
415
        for i, seq_len in enumerate(attn_metadata.seq_lens):
            end = start + seq_len
416
417
418
419
            out = xops.memory_efficient_attention_forward(
                query[None, start:end],
                key[None, start:end],
                value[None, start:end],
420
                attn_bias=attn_metadata.attn_bias[i],
421
                p=0.0,
422
                scale=self.scale)
423
            # TODO(woosuk): Unnecessary copy. Optimize.
424
            output[start:end].copy_(out.view_as(original_query[start:end]))
425
            start += seq_len
426
        return output
Woosuk Kwon's avatar
Woosuk Kwon committed
427
428
429
430


def _make_alibi_bias(
    alibi_slopes: torch.Tensor,
Megha Agarwal's avatar
Megha Agarwal committed
431
    num_kv_heads: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
432
    dtype: torch.dtype,
433
    seq_lens: List[int],
434
435
) -> List[AttentionBias]:
    attn_biases: List[AttentionBias] = []
436
437
    for seq_len in seq_lens:
        bias = torch.arange(seq_len, dtype=dtype)
438
        # NOTE(zhuohan): HF uses
439
        #     `bias = bias[None, :].repeat(seq_len, 1)`
440
441
442
443
444
445
446
        # here. We find that both biases give the same results, but
        # the bias below more accurately follows the original ALiBi
        # paper.
        # Calculate a matrix where each element represents ith element- jth
        # element.
        bias = bias[None, :] - bias[:, None]

447
        padded_len = (seq_len + 7) // 8 * 8
448
449
450
451
        num_heads = alibi_slopes.shape[0]
        bias = torch.empty(
            1,  # batch size
            num_heads,
452
            seq_len,
453
454
455
            padded_len,
            device=alibi_slopes.device,
            dtype=dtype,
456
        )[:, :, :, :seq_len].copy_(bias)
457
458
459
460
461
462
        bias.mul_(alibi_slopes[:, None, None])
        if num_heads != num_kv_heads:
            bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
        attn_biases.append(LowerTriangularMaskWithTensorBias(bias))

    return attn_biases