xformers.py 17.6 KB
Newer Older
1
2
"""Attention layer with xFormers and PagedAttention."""
import importlib
3
4
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Type
Woosuk Kwon's avatar
Woosuk Kwon committed
5
6

import torch
7
from xformers import ops as xops
8
9
from xformers.ops.fmha.attn_bias import (AttentionBias,
                                         BlockDiagonalCausalMask,
Woosuk Kwon's avatar
Woosuk Kwon committed
10
                                         LowerTriangularMaskWithTensorBias)
Woosuk Kwon's avatar
Woosuk Kwon committed
11

12
13
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
                                              AttentionMetadata)
14
15
from vllm.attention.ops.paged_attn import (PagedAttention,
                                           PagedAttentionMetadata)
16
from vllm.logger import init_logger
17
from vllm.utils import is_hip
Woosuk Kwon's avatar
Woosuk Kwon committed
18

19
logger = init_logger(__name__)
20

21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125

class XFormersBackend(AttentionBackend):

    @staticmethod
    def get_impl_cls() -> Type["XFormersImpl"]:
        return XFormersImpl

    @staticmethod
    def make_metadata(*args, **kwargs) -> "XFormersMetadata":
        return XFormersMetadata(*args, **kwargs)

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
        return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
                                                 num_kv_heads, head_size)

    @staticmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
        src_to_dst: Dict[int, int],
    ) -> None:
        PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)

    @staticmethod
    def copy_blocks(
        kv_caches: List[torch.Tensor],
        src_to_dists: Dict[int, List[int]],
    ) -> None:
        PagedAttention.copy_blocks(kv_caches, src_to_dists)


@dataclass
class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
    """Metadata for XFormersbackend.

    NOTE: Any python object stored here is not updated when it is
    cuda-graph replayed. If you have values that need to be changed
    dynamically, it should be stored in tensor. The tensor has to be
    updated from `CUDAGraphRunner.forward` API.
    """
    # Currently, input sequences can only contain all prompts
    # or all decoding. True if all sequences are prompts.
    is_prompt: bool
    # (num_tokens,). The indices of the token slots that input tokens will be
    # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
    # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
    # in block 0, and 1st slot in block 1, respectively.
    slot_mapping: torch.Tensor
    # (batch_size,). The prompt length per sequence. None if it is a decoding.
    prompt_lens: Optional[List[int]]
    # prompt_lens stored as a tensor.
    prompt_lens_tensor: Optional[torch.Tensor]
    # The number of prompt tokens. Doesn't include padding.
    num_prompt_tokens: int
    # The number of generation tokens. Doesn't include padding.
    num_generation_tokens: int

    # NOTE(sang): Definition of context_len, subquery_len, and seqlen.
    # |---------- N-1 iteration --------|
    # |---------------- N iteration ---------------------|
    # |- tokenA -|......................|-- newTokens ---|
    # |---------- context_len ----------|
    # |-------------------- seqlen ----------------------|
    #                                   |- subquery_len -|

    # WARNING(sang): context_len has different definition depending on if it is
    # prefill vs decoding. When it is prefill, it doesn't include new tokens.
    # When it is for decoding, it includes a new token.

    # Maximum subquery length in the batch.
    max_subquery_len: Optional[int]
    # FIXME: It is for flash attn.
    # Maximum prompt length in the batch.
    max_prompt_len: Optional[int]
    # (batch_size + 1,). The cumulative subquery lengths of the sequences in
    # the batch, used to index into subquery. E.g., if the subquery length
    # is [4, 6], it is [0, 4, 10].
    subquery_start_loc: Optional[torch.Tensor]
    # FIXME: It is for flash attn.
    # (batch_size + 1,). The cumulative sequence lengths of the sequences in
    # the batch, used to index into sequence. E.g., if the sequence length is
    # [4, 6], it is [0, 4, 10].
    seq_start_loc: Optional[torch.Tensor]

    # Whether or not if cuda graph is enabled.
    # Cuda-graph is currently enabled for decoding only.
    # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
    use_cuda_graph: bool

    def __post_init__(self):
        # Set during the execution of the first attention op.
        # It is a list because it is needed to set per prompt
        # when alibi slopes is used. It is because of the limitation
        # from xformer API.
        # will not appear in the __repr__ and __init__
        self.attn_bias: Optional[List[AttentionBias]] = None


class XFormersImpl(AttentionImpl):
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
    """
    If the input tensors contain prompt tokens, the layout is as follows:
    |<--------------- num_prompt_tokens --------------->|	
    |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1--->|

    Otherwise, the layout is as follows:	
    |<------------------ num_generation_tokens (M) ----------------->|	
    |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|

    Generation tokens can contain padding when cuda-graph is used.
    Currently, prompt tokens don't contain any padding.

    The prompts might have different lengths, while the generation tokens
    always have length 1.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
141

Woosuk Kwon's avatar
Woosuk Kwon committed
142
143
144
145
146
147
148
149
150
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
        alibi_slopes: Optional[List[float]] = None,
        sliding_window: Optional[int] = None,
    ) -> None:
151
152
        self.num_heads = num_heads
        self.head_size = head_size
153
        self.scale = float(scale)
Zhuohan Li's avatar
Zhuohan Li committed
154
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
155
        self.sliding_window = sliding_window
Woosuk Kwon's avatar
Woosuk Kwon committed
156
157
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
158
        self.alibi_slopes = alibi_slopes
Zhuohan Li's avatar
Zhuohan Li committed
159
160
161

        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
162
163

        suppored_head_sizes = PagedAttention.get_supported_head_sizes()
164
165
166
167
        if head_size not in suppored_head_sizes:
            raise ValueError(
                f"Head size {head_size} is not supported by PagedAttention. "
                f"Supported head sizes are: {suppored_head_sizes}.")
Woosuk Kwon's avatar
Woosuk Kwon committed
168

169
170
171
172
        # AMD Radeon 7900 series (gfx1100) currently does not support xFormers
        # nor FlashAttention. As a temporary workaround, we use naive PyTorch
        # implementation of attention.
        self.use_naive_attention = _check_use_naive_attention()
173

Woosuk Kwon's avatar
Woosuk Kwon committed
174
175
    def forward(
        self,
176
177
178
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
179
180
        kv_cache: Optional[torch.Tensor],
        attn_metadata: XFormersMetadata,
181
    ) -> torch.Tensor:
182
        """Forward pass with xFormers and PagedAttention.
183
184

        Args:
185
186
187
            query: shape = [num_tokens, num_heads * head_size]
            key: shape = [num_tokens, num_kv_heads * head_size]
            value: shape = [num_tokens, num_kv_heads * head_size]
188
189
            kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
            attn_metadata: Metadata for attention.
190
        Returns:
191
            shape = [num_tokens, num_heads * head_size]
192
        """
193
        num_tokens, hidden_size = query.shape
194
        query = query.view(-1, self.num_heads, self.head_size)
Zhuohan Li's avatar
Zhuohan Li committed
195
196
        key = key.view(-1, self.num_kv_heads, self.head_size)
        value = value.view(-1, self.num_kv_heads, self.head_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
197

198
199
200
201
202
203
204
205
206
207
208
        if kv_cache is not None:
            key_cache, value_cache = PagedAttention.split_kv_cache(
                kv_cache, self.num_kv_heads, self.head_size)

            # Reshape the input keys and values and store them in the cache.
            # If kv_cache is not provided, the new key and value tensors are
            # not cached. This happens during the initial memory profiling run.
            PagedAttention.write_to_paged_cache(key, value, key_cache,
                                                value_cache,
                                                attn_metadata.slot_mapping,
                                                attn_metadata.kv_cache_dtype)
Woosuk Kwon's avatar
Woosuk Kwon committed
209

210
        if attn_metadata.is_prompt:
211
            # Prompt run.
212
213
214
215
            if kv_cache is None or attn_metadata.block_tables.numel() == 0:
                # normal attention.
                # block tables are empty if the prompt does not have a cached
                # prefix.
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
                if self.num_kv_heads != self.num_heads:
                    # As of Nov 2023, xformers only supports MHA. For MQA/GQA,
                    # project the key and value tensors to the desired number of
                    # heads.
                    # TODO(woosuk): Use MQA/GQA kernels for higher performance.
                    query = query.view(query.shape[0], self.num_kv_heads,
                                       self.num_queries_per_kv,
                                       query.shape[-1])
                    key = key[:, :,
                              None, :].expand(key.shape[0], self.num_kv_heads,
                                              self.num_queries_per_kv,
                                              key.shape[-1])
                    value = value[:, :,
                                  None, :].expand(value.shape[0],
                                                  self.num_kv_heads,
                                                  self.num_queries_per_kv,
                                                  value.shape[-1])

234
                if self.use_naive_attention:
235
236
                    output = torch.empty_like(query)
                    start = 0
237
                    for _, prompt_len in enumerate(attn_metadata.prompt_lens):
238
                        end = start + prompt_len
239
                        out = _naive_masked_attention(
240
241
242
243
244
245
246
247
248
249
250
251
                            query[None, start:end],
                            key[None, start:end],
                            value[None, start:end],
                            self.num_heads,
                            self.num_kv_heads,
                            self.head_size,
                            self.scale,
                        )
                        # TODO(woosuk): Unnecessary copy. Optimize.
                        output[start:end].copy_(out)
                        start += prompt_len

252
253
254
255
                    # Using view got RuntimeError: view size is not compatible
                    # with input tensor's size and stride (at least one
                    # dimension spans across two contiguous subspaces).
                    # Use reshape instead.
256
                    return output.reshape(num_tokens, hidden_size)
257

258
259
                output = self._run_memory_efficient_xformers_forward(
                    query, key, value, attn_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
260
            else:
261
                # prefix-enabled attention
262
                output = PagedAttention.forward_prefix(
263
264
265
266
267
                    query,
                    key,
                    value,
                    key_cache,
                    value_cache,
268
269
270
271
272
                    attn_metadata.block_tables,
                    attn_metadata.subquery_start_loc,
                    attn_metadata.prompt_lens_tensor,
                    attn_metadata.context_lens,
                    attn_metadata.max_subquery_len,
273
                    self.alibi_slopes,
274
                )
Woosuk Kwon's avatar
Woosuk Kwon committed
275
        else:
Woosuk Kwon's avatar
Woosuk Kwon committed
276
            # Decoding run.
277
            output = PagedAttention.forward_decode(
278
279
280
                query,
                key_cache,
                value_cache,
281
282
283
284
                attn_metadata.block_tables,
                attn_metadata.context_lens,
                attn_metadata.max_context_len,
                attn_metadata.kv_cache_dtype,
285
286
287
288
                self.num_kv_heads,
                self.scale,
                self.alibi_slopes,
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
289
290

        # Reshape the output tensor.
291
292
        return output.view(-1, self.num_heads * self.head_size)

293
    def _run_memory_efficient_xformers_forward(
294
295
296
297
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
298
        attn_metadata: XFormersMetadata,
299
300
301
302
303
304
305
306
307
    ) -> torch.Tensor:
        """Attention for 1D query of multiple prompts. Multiple prompt
        tokens are flattened in to `query` input.

        Args:
            output: shape = [num_prompt_tokens, num_heads, head_size]
            query: shape = [num_prompt_tokens, num_heads, head_size]
            key: shape = [num_prompt_tokens, num_kv_heads, head_size]
            value: shape = [num_prompt_tokens, num_kv_heads, head_size]
308
            attn_metadata: Metadata for attention.
309
310
311
312
        """
        # Set attention bias if not provided. This typically happens at
        # the very attention layer of every iteration.
        # FIXME(woosuk): This is a hack.
313
        if attn_metadata.attn_bias is None:
314
315
            if self.alibi_slopes is None:
                attn_bias = BlockDiagonalCausalMask.from_seqlens(
316
                    attn_metadata.prompt_lens)
317
318
319
                if self.sliding_window is not None:
                    attn_bias = attn_bias.make_local_attention(
                        self.sliding_window)
320
                attn_metadata.attn_bias = [attn_bias]
321
            else:
322
                attn_metadata.attn_bias = _make_alibi_bias(
323
                    self.alibi_slopes, self.num_kv_heads, query.dtype,
324
                    attn_metadata.prompt_lens)
325
326
327
328
329
330
331
332
333
334
335
336
337
338

        op = xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if (
            is_hip()) else None
        # No alibi slopes.
        # TODO(woosuk): Too many view operations. Let's try to reduce
        # them in the future for code readability.
        if self.alibi_slopes is None:
            query = query.unsqueeze(0)
            key = key.unsqueeze(0)
            value = value.unsqueeze(0)
            out = xops.memory_efficient_attention_forward(
                query,
                key,
                value,
339
                attn_bias=attn_metadata.attn_bias[0],
340
341
342
343
344
345
346
347
348
349
350
351
                p=0.0,
                scale=self.scale,
                op=op)

            return out.view_as(query)

        # Attention with alibi slopes.
        # FIXME(woosuk): Because xformers does not support dynamic sequence
        # lengths with custom attention bias, we process each prompt one by
        # one. This is inefficient, especially when we have many short prompts.
        output = torch.empty_like(query)
        start = 0
352
        for i, prompt_len in enumerate(attn_metadata.prompt_lens):
353
354
355
356
357
            end = start + prompt_len
            out = xops.memory_efficient_attention_forward(
                query[None, start:end],
                key[None, start:end],
                value[None, start:end],
358
                attn_bias=attn_metadata.attn_bias[i],
359
360
361
362
363
364
365
                p=0.0,
                scale=self.scale,
                op=op)
            # TODO(woosuk): Unnecessary copy. Optimize.
            output[start:end].copy_(out.squeeze(0))
            start += prompt_len
        return output
Woosuk Kwon's avatar
Woosuk Kwon committed
366
367
368
369


def _make_alibi_bias(
    alibi_slopes: torch.Tensor,
Megha Agarwal's avatar
Megha Agarwal committed
370
    num_kv_heads: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
371
    dtype: torch.dtype,
372
    prompt_lens: List[int],
Woosuk Kwon's avatar
Woosuk Kwon committed
373
) -> LowerTriangularMaskWithTensorBias:
374
    attn_biases = []
375
    for prompt_len in prompt_lens:
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
        bias = torch.arange(prompt_len, dtype=dtype)
        # NOTE(zhuohan): HF uses
        #     `bias = bias[None, :].repeat(prompt_len, 1)`
        # here. We find that both biases give the same results, but
        # the bias below more accurately follows the original ALiBi
        # paper.
        # Calculate a matrix where each element represents ith element- jth
        # element.
        bias = bias[None, :] - bias[:, None]

        padded_len = (prompt_len + 7) // 8 * 8
        num_heads = alibi_slopes.shape[0]
        bias = torch.empty(
            1,  # batch size
            num_heads,
            prompt_len,
            padded_len,
            device=alibi_slopes.device,
            dtype=dtype,
        )[:, :, :, :prompt_len].copy_(bias)
        bias.mul_(alibi_slopes[:, None, None])
        if num_heads != num_kv_heads:
            bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
        attn_biases.append(LowerTriangularMaskWithTensorBias(bias))

    return attn_biases
Woosuk Kwon's avatar
Woosuk Kwon committed
402
403


404
def _check_use_naive_attention() -> bool:
405
406
407
    if not is_hip():
        return False
    # For ROCm, check whether flash attention is installed or not.
408
409
    use_naive_attention = importlib.util.find_spec("flash_attn") is None
    if use_naive_attention:
410
411
412
413
        logger.warning("flash_attn is not installed. Using naive attention. "
                       "This will take significantly more GPU memory.")
        return True
    return False
414
415


416
def _naive_masked_attention(
Woosuk Kwon's avatar
Woosuk Kwon committed
417
    query: torch.Tensor,
418
419
420
    key: torch.Tensor,
    value: torch.Tensor,
    num_heads: int,
421
    num_kv_heads: int,
422
    head_size: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
423
424
    scale: float,
) -> torch.Tensor:
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
    query = query.view(-1, num_heads, head_size)
    key = key.view(-1, num_kv_heads, head_size)
    value = value.view(-1, num_kv_heads, head_size)
    seq_len, _, _ = query.shape
    attn_mask = torch.triu(torch.ones(seq_len,
                                      seq_len,
                                      dtype=query.dtype,
                                      device=query.device),
                           diagonal=1)
    attn_mask = attn_mask * torch.finfo(query.dtype).min

    attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
    attn_weights = attn_weights + attn_mask.float()
    attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
    out = torch.einsum("hqk,khd->qhd", attn_weights, value)
    return out