xformers.py 15.8 KB
Newer Older
1
"""Attention layer with xFormers and PagedAttention."""
2
3
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Type
Woosuk Kwon's avatar
Woosuk Kwon committed
4
5

import torch
6
from xformers import ops as xops
7
8
from xformers.ops.fmha.attn_bias import (AttentionBias,
                                         BlockDiagonalCausalMask,
Woosuk Kwon's avatar
Woosuk Kwon committed
9
                                         LowerTriangularMaskWithTensorBias)
Woosuk Kwon's avatar
Woosuk Kwon committed
10

11
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
12
13
                                              AttentionMetadata,
                                              AttentionMetadataPerStage)
14
15
from vllm.attention.ops.paged_attn import (PagedAttention,
                                           PagedAttentionMetadata)
16
from vllm.logger import init_logger
Woosuk Kwon's avatar
Woosuk Kwon committed
17

18
logger = init_logger(__name__)
19

20
21
22

class XFormersBackend(AttentionBackend):

23
24
25
26
    @staticmethod
    def get_name() -> str:
        return "xformers"

27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
    @staticmethod
    def get_impl_cls() -> Type["XFormersImpl"]:
        return XFormersImpl

    @staticmethod
    def make_metadata(*args, **kwargs) -> "XFormersMetadata":
        return XFormersMetadata(*args, **kwargs)

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
        return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
                                                 num_kv_heads, head_size)

    @staticmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
        src_to_dst: Dict[int, int],
    ) -> None:
        PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)

    @staticmethod
    def copy_blocks(
        kv_caches: List[torch.Tensor],
56
        src_to_dists: torch.Tensor,
57
58
59
60
61
    ) -> None:
        PagedAttention.copy_blocks(kv_caches, src_to_dists)


@dataclass
62
class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
63
64
65
66
67
68
69
70
71
72
    """Metadata for XFormersbackend.

    NOTE: Any python object stored here is not updated when it is
    cuda-graph replayed. If you have values that need to be changed
    dynamically, it should be stored in tensor. The tensor has to be
    updated from `CUDAGraphRunner.forward` API.
    """
    # Currently, input sequences can only contain all prompts
    # or all decoding. True if all sequences are prompts.
    is_prompt: bool
73
74
75
76
77
    # (batch_size,). The sequence length per sequence. Sequence length means
    # the computed tokens + new tokens None if it is a decoding.
    seq_lens: Optional[List[int]]
    # seq_lens stored as a tensor.
    seq_lens_tensor: Optional[torch.Tensor]
78
79
80
81
82

    # |---------- N-1 iteration --------|
    # |---------------- N iteration ---------------------|
    # |- tokenA -|......................|-- newTokens ---|
    # |---------- context_len ----------|
83
84
    # |-------------------- seq_len ----------------------|
    #                                   |-- query_len ---|
85

86
87
    # Maximum query length in the batch.
    max_query_len: Optional[int]
88
    # FIXME: It is for flash attn.
89
90
    # Maximum sequence length in the batch.
    max_seq_len: Optional[int]
91
92
93
94
95
96
97
98
99
    # (batch_size + 1,). The cumulative subquery lengths of the sequences in
    # the batch, used to index into subquery. E.g., if the subquery length
    # is [4, 6], it is [0, 4, 10].
    subquery_start_loc: Optional[torch.Tensor]
    # FIXME: It is for flash attn.
    # (batch_size + 1,). The cumulative sequence lengths of the sequences in
    # the batch, used to index into sequence. E.g., if the sequence length is
    # [4, 6], it is [0, 4, 10].
    seq_start_loc: Optional[torch.Tensor]
100
101
102
    # (batch_size,) A tensor of context lengths (tokens that are computed
    # so far).
    context_lens_tensor: Optional[torch.Tensor]
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118

    # Whether or not if cuda graph is enabled.
    # Cuda-graph is currently enabled for decoding only.
    # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
    use_cuda_graph: bool

    def __post_init__(self):
        # Set during the execution of the first attention op.
        # It is a list because it is needed to set per prompt
        # when alibi slopes is used. It is because of the limitation
        # from xformer API.
        # will not appear in the __repr__ and __init__
        self.attn_bias: Optional[List[AttentionBias]] = None


class XFormersImpl(AttentionImpl):
119
120
    """
    If the input tensors contain prompt tokens, the layout is as follows:
121
122
    |<--------------- num_prefill_tokens ----------------->|	
    |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
123
124

    Otherwise, the layout is as follows:	
125
126
    |<----------------- num_decode_tokens ------------------>|	
    |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
127
128
129
130
131
132

    Generation tokens can contain padding when cuda-graph is used.
    Currently, prompt tokens don't contain any padding.

    The prompts might have different lengths, while the generation tokens
    always have length 1.
133
134
135
136
137
138
139
140
141

    If chunked prefill is enabled, prefill tokens and decode tokens can be
    batched together in a flattened 1D query.

    |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
    |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|

    Currently, cuda graph is disabled for chunked prefill, meaning there's no
    padding between prefill and decode tokens.
142
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
143

Woosuk Kwon's avatar
Woosuk Kwon committed
144
145
146
147
148
149
150
151
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
        alibi_slopes: Optional[List[float]] = None,
        sliding_window: Optional[int] = None,
152
        kv_cache_dtype: str = "auto",
Woosuk Kwon's avatar
Woosuk Kwon committed
153
    ) -> None:
154
155
        self.num_heads = num_heads
        self.head_size = head_size
156
        self.scale = float(scale)
Zhuohan Li's avatar
Zhuohan Li committed
157
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
Woosuk Kwon's avatar
Woosuk Kwon committed
158
159
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
160
        self.alibi_slopes = alibi_slopes
161
162
        self.sliding_window = sliding_window
        self.kv_cache_dtype = kv_cache_dtype
Zhuohan Li's avatar
Zhuohan Li committed
163
164
165

        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
166
167

        suppored_head_sizes = PagedAttention.get_supported_head_sizes()
168
169
170
171
        if head_size not in suppored_head_sizes:
            raise ValueError(
                f"Head size {head_size} is not supported by PagedAttention. "
                f"Supported head sizes are: {suppored_head_sizes}.")
Woosuk Kwon's avatar
Woosuk Kwon committed
172
173
174

    def forward(
        self,
175
176
177
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
178
        kv_cache: Optional[torch.Tensor],
179
        attn_metadata: AttentionMetadata[XFormersMetadata],
180
        kv_scale: float = 1.0,
181
    ) -> torch.Tensor:
182
        """Forward pass with xFormers and PagedAttention.
183
184

        Args:
185
186
187
            query: shape = [num_tokens, num_heads * head_size]
            key: shape = [num_tokens, num_kv_heads * head_size]
            value: shape = [num_tokens, num_kv_heads * head_size]
188
189
            kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
            attn_metadata: Metadata for attention.
190
        Returns:
191
            shape = [num_tokens, num_heads * head_size]
192
        """
193
        query = query.view(-1, self.num_heads, self.head_size)
Zhuohan Li's avatar
Zhuohan Li committed
194
195
        key = key.view(-1, self.num_kv_heads, self.head_size)
        value = value.view(-1, self.num_kv_heads, self.head_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
196

197
198
199
200
201
202
203
204
205
206
        if kv_cache is not None:
            key_cache, value_cache = PagedAttention.split_kv_cache(
                kv_cache, self.num_kv_heads, self.head_size)

            # Reshape the input keys and values and store them in the cache.
            # If kv_cache is not provided, the new key and value tensors are
            # not cached. This happens during the initial memory profiling run.
            PagedAttention.write_to_paged_cache(key, value, key_cache,
                                                value_cache,
                                                attn_metadata.slot_mapping,
207
                                                self.kv_cache_dtype, kv_scale)
Woosuk Kwon's avatar
Woosuk Kwon committed
208

209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
        num_prefill_tokens = attn_metadata.num_prefill_tokens
        num_decode_tokens = attn_metadata.num_decode_tokens
        assert key.shape[0] == num_prefill_tokens + num_decode_tokens
        assert value.shape[0] == num_prefill_tokens + num_decode_tokens

        output = torch.empty_like(query)
        # Query for decode. KV is not needed because it is already cached.
        decode_query = query[num_prefill_tokens:]
        # QKV for prefill.
        query = query[:num_prefill_tokens]
        key = key[:num_prefill_tokens]
        value = value[:num_prefill_tokens]

        assert query.shape[0] == num_prefill_tokens
        assert decode_query.shape[0] == num_decode_tokens

        if prefill_meta := attn_metadata.prefill_metadata:
226
            # Prompt run.
227
            if kv_cache is None or prefill_meta.block_tables.numel() == 0:
228
229
230
                # normal attention.
                # block tables are empty if the prompt does not have a cached
                # prefix.
231
232
233
234
                out = self._run_memory_efficient_xformers_forward(
                    query, key, value, prefill_meta)
                assert out.shape == output[:num_prefill_tokens].shape
                output[:num_prefill_tokens] = out
Woosuk Kwon's avatar
Woosuk Kwon committed
235
            else:
236
                # prefix-enabled attention
237
238
239
                # TODO(Hai) this triton kernel has regression issue (broke) to
                # deal with different data types between KV and FP8 KV cache,
                # to be addressed separately.
240
                out = PagedAttention.forward_prefix(
241
242
243
244
245
                    query,
                    key,
                    value,
                    key_cache,
                    value_cache,
246
247
                    prefill_meta.block_tables,
                    prefill_meta.subquery_start_loc,
248
249
250
                    prefill_meta.seq_lens_tensor,
                    prefill_meta.context_lens_tensor,
                    prefill_meta.max_query_len,
251
                    self.alibi_slopes,
252
                    self.sliding_window,
253
                )
254
255
256
257
258
259
                assert output[:num_prefill_tokens].shape == out.shape
                output[:num_prefill_tokens] = out

        if decode_meta := attn_metadata.decode_metadata:
            output[num_prefill_tokens:] = PagedAttention.forward_decode(
                decode_query,
260
261
                key_cache,
                value_cache,
262
                decode_meta.block_tables,
263
264
                decode_meta.seq_lens_tensor,
                decode_meta.max_seq_len,
265
                self.kv_cache_dtype,
266
267
268
                self.num_kv_heads,
                self.scale,
                self.alibi_slopes,
269
                kv_scale,
270
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
271
272

        # Reshape the output tensor.
273
274
        return output.view(-1, self.num_heads * self.head_size)

275
    def _run_memory_efficient_xformers_forward(
276
277
278
279
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
280
        attn_metadata: XFormersMetadata,
281
282
283
284
    ) -> torch.Tensor:
        """Attention for 1D query of multiple prompts. Multiple prompt
        tokens are flattened in to `query` input.

285
286
287
        See https://facebookresearch.github.io/xformers/components/ops.html
        for API spec.

288
        Args:
289
290
291
292
            output: shape = [num_prefill_tokens, num_heads, head_size]
            query: shape = [num_prefill_tokens, num_heads, head_size]
            key: shape = [num_prefill_tokens, num_kv_heads, head_size]
            value: shape = [num_prefill_tokens, num_kv_heads, head_size]
293
            attn_metadata: Metadata for attention.
294
        """
295
        assert attn_metadata.seq_lens is not None
296
297
298
299
300
301
302
303
304
305
306
307
308
309
        original_query = query
        if self.num_kv_heads != self.num_heads:
            # GQA/MQA requires the shape [B, M, G, H, K].
            # Note that the output also has the same shape (which is different
            # from a spec from the doc).
            query = query.view(query.shape[0], self.num_kv_heads,
                               self.num_queries_per_kv, query.shape[-1])
            key = key[:, :,
                      None, :].expand(key.shape[0], self.num_kv_heads,
                                      self.num_queries_per_kv, key.shape[-1])
            value = value[:, :,
                          None, :].expand(value.shape[0], self.num_kv_heads,
                                          self.num_queries_per_kv,
                                          value.shape[-1])
310
311
312
        # Set attention bias if not provided. This typically happens at
        # the very attention layer of every iteration.
        # FIXME(woosuk): This is a hack.
313
        if attn_metadata.attn_bias is None:
314
315
            if self.alibi_slopes is None:
                attn_bias = BlockDiagonalCausalMask.from_seqlens(
316
                    attn_metadata.seq_lens)
317
318
319
                if self.sliding_window is not None:
                    attn_bias = attn_bias.make_local_attention(
                        self.sliding_window)
320
                attn_metadata.attn_bias = [attn_bias]
321
            else:
322
                attn_metadata.attn_bias = _make_alibi_bias(
323
                    self.alibi_slopes, self.num_kv_heads, query.dtype,
324
                    attn_metadata.seq_lens)
325
326
327
328
329

        # No alibi slopes.
        # TODO(woosuk): Too many view operations. Let's try to reduce
        # them in the future for code readability.
        if self.alibi_slopes is None:
330
            # Add the batch dimension.
331
332
333
334
335
336
337
            query = query.unsqueeze(0)
            key = key.unsqueeze(0)
            value = value.unsqueeze(0)
            out = xops.memory_efficient_attention_forward(
                query,
                key,
                value,
338
                attn_bias=attn_metadata.attn_bias[0],
339
                p=0.0,
340
                scale=self.scale)
341
            return out.view_as(original_query)
342
343
344
345
346

        # Attention with alibi slopes.
        # FIXME(woosuk): Because xformers does not support dynamic sequence
        # lengths with custom attention bias, we process each prompt one by
        # one. This is inefficient, especially when we have many short prompts.
347
        output = torch.empty_like(original_query)
348
        start = 0
349
350
        for i, seq_len in enumerate(attn_metadata.seq_lens):
            end = start + seq_len
351
352
353
354
            out = xops.memory_efficient_attention_forward(
                query[None, start:end],
                key[None, start:end],
                value[None, start:end],
355
                attn_bias=attn_metadata.attn_bias[i],
356
                p=0.0,
357
                scale=self.scale)
358
            # TODO(woosuk): Unnecessary copy. Optimize.
359
            output[start:end].copy_(out.view_as(original_query[start:end]))
360
            start += seq_len
361
        return output
Woosuk Kwon's avatar
Woosuk Kwon committed
362
363
364
365


def _make_alibi_bias(
    alibi_slopes: torch.Tensor,
Megha Agarwal's avatar
Megha Agarwal committed
366
    num_kv_heads: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
367
    dtype: torch.dtype,
368
    seq_lens: List[int],
Woosuk Kwon's avatar
Woosuk Kwon committed
369
) -> LowerTriangularMaskWithTensorBias:
370
    attn_biases = []
371
372
    for seq_len in seq_lens:
        bias = torch.arange(seq_len, dtype=dtype)
373
        # NOTE(zhuohan): HF uses
374
        #     `bias = bias[None, :].repeat(seq_len, 1)`
375
376
377
378
379
380
381
        # here. We find that both biases give the same results, but
        # the bias below more accurately follows the original ALiBi
        # paper.
        # Calculate a matrix where each element represents ith element- jth
        # element.
        bias = bias[None, :] - bias[:, None]

382
        padded_len = (seq_len + 7) // 8 * 8
383
384
385
386
        num_heads = alibi_slopes.shape[0]
        bias = torch.empty(
            1,  # batch size
            num_heads,
387
            seq_len,
388
389
390
            padded_len,
            device=alibi_slopes.device,
            dtype=dtype,
391
        )[:, :, :, :seq_len].copy_(bias)
392
393
394
395
396
397
        bias.mul_(alibi_slopes[:, None, None])
        if num_heads != num_kv_heads:
            bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
        attn_biases.append(LowerTriangularMaskWithTensorBias(bias))

    return attn_biases