torch_sdpa.py 11.3 KB
Newer Older
1
2
3
""" Attention layer with torch scaled_dot_product_attention
    and PagedAttention."""
from dataclasses import dataclass
4
from typing import Any, Dict, List, Optional, Tuple, Type
5
6
7
8
9

import torch
from torch.nn.functional import scaled_dot_product_attention

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
10
                                              AttentionMetadata, AttentionType)
11
from vllm.attention.backends.utils import CommonAttentionState
12
13
14
15
16
17
18
19
20
21
from vllm.attention.ops.paged_attn import PagedAttentionMetadata
from vllm.utils import is_cpu

if is_cpu():
    try:
        from vllm.attention.ops.ipex_attn import PagedAttention
    except ImportError:
        from vllm.attention.ops.paged_attn import PagedAttention
else:
    from vllm.attention.ops.paged_attn import PagedAttention
22
23
24
25


class TorchSDPABackend(AttentionBackend):

26
27
28
29
    @staticmethod
    def get_name() -> str:
        return "torch-sdpa"

30
31
32
33
34
    @staticmethod
    def get_impl_cls() -> Type["TorchSDPABackendImpl"]:
        return TorchSDPABackendImpl

    @staticmethod
35
36
    def get_metadata_cls() -> Type["AttentionMetadata"]:
        return TorchSDPAMetadata
37

38
39
40
41
    @staticmethod
    def get_state_cls() -> Type["CommonAttentionState"]:
        return CommonAttentionState

42
43
44
45
46
47
48
49
50
51
52
53
54
55
    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
        return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
                                                 num_kv_heads, head_size)

    @staticmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
56
        src_to_dst: torch.Tensor,
57
58
59
60
61
62
    ) -> None:
        PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)

    @staticmethod
    def copy_blocks(
        kv_caches: List[torch.Tensor],
63
        src_to_dists: torch.Tensor,
64
65
66
67
68
    ) -> None:
        PagedAttention.copy_blocks(kv_caches, src_to_dists)


@dataclass
69
class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
70
71
72
73
74
    """Metadata for TorchSDPABackend.
    """
    # Currently, input sequences can only contain all prompts
    # or all decoding. True if all sequences are prompts.
    is_prompt: bool
75
    slot_mapping: torch.Tensor
76
    seq_lens: Optional[List[int]]
77
78
79
80
81
82
83
84
85

    def __post_init__(self):
        # Set during the execution of the first attention op.
        # It is a list because it is needed to set per prompt
        # when alibi slopes is used. It is because of the limitation
        # from xformer API.
        # will not appear in the __repr__ and __init__
        self.attn_bias: Optional[List[torch.Tensor]] = None

86
87
88
89
90
91
    @property
    def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]:
        # Currently chunked prefill is not supported
        if self.num_decode_tokens == 0:
            assert self.num_prefills > 0
            return self
92

93
94
95
96
97
98
99
100
101
102
103
104
105
        return None

    @property
    def decode_metadata(self) -> Optional["TorchSDPAMetadata"]:
        # Currently chunked prefill is not supported
        if self.num_prefills > 0:
            assert self.num_decode_tokens == 0
            return None

        return self


class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
106
107
108
109
110
111

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
112
113
114
115
        num_kv_heads: int,
        alibi_slopes: Optional[List[float]],
        sliding_window: Optional[int],
        kv_cache_dtype: str,
116
        blocksparse_params: Optional[Dict[str, Any]] = None,
117
        logits_soft_cap: Optional[float] = None,
118
    ) -> None:
119
120
121
122
123
        if blocksparse_params is not None:
            raise ValueError(
                "Torch SPDA does not support block-sparse attention.")
        if logits_soft_cap is not None:
            raise ValueError("Torch SPDA does not support logits soft cap.")
124
125
126
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
127
        self.num_kv_heads = num_kv_heads
128
129
130
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
        self.alibi_slopes = alibi_slopes
131
132
        self.sliding_window = sliding_window
        self.kv_cache_dtype = kv_cache_dtype
133
134
135

        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
136
137
138
139
140
        self.need_mask = (self.alibi_slopes is not None
                          or self.sliding_window is not None)

        supported_head_sizes = PagedAttention.get_supported_head_sizes()
        if head_size not in supported_head_sizes:
141
142
            raise ValueError(
                f"Head size {head_size} is not supported by PagedAttention. "
143
144
145
146
147
                f"Supported head sizes are: {supported_head_sizes}.")
        if kv_cache_dtype != "auto":
            raise NotImplementedError(
                "Torch SDPA backend does not support FP8 KV cache. "
                "Please use xFormers backend instead.")
148
149
150
151
152
153
154

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: Optional[torch.Tensor],
155
        attn_metadata: TorchSDPAMetadata,  # type: ignore
156
157
        k_scale: float = 1.0,
        v_scale: float = 1.0,
158
        attn_type: AttentionType = AttentionType.DECODER,
159
160
161
162
163
164
165
166
167
168
169
170
    ) -> torch.Tensor:
        """Forward pass with torch SDPA and PagedAttention.

        Args:
            query: shape = [num_tokens, num_heads * head_size]
            key: shape = [num_tokens, num_kv_heads * head_size]
            value: shape = [num_tokens, num_kv_heads * head_size]
            kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
            attn_metadata: Metadata for attention.
        Returns:
            shape = [num_tokens, num_heads * head_size]
        """
171
        assert k_scale == 1.0 and v_scale == 1.0
172
173
174
175
176
        if attn_type != AttentionType.DECODER:
            raise NotImplementedError("Encoder self-attention and "
                                      "encoder/decoder cross-attention "
                                      "are not implemented for "
                                      "TorchSDPABackendImpl")
177
178
179
180
181
182
183
184
185
186
187
188
        num_tokens, hidden_size = query.shape
        # Reshape the query, key, and value tensors.
        query = query.view(-1, self.num_heads, self.head_size)
        key = key.view(-1, self.num_kv_heads, self.head_size)
        value = value.view(-1, self.num_kv_heads, self.head_size)

        if kv_cache is not None:
            key_cache, value_cache = PagedAttention.split_kv_cache(
                kv_cache, self.num_kv_heads, self.head_size)
            PagedAttention.write_to_paged_cache(key, value, key_cache,
                                                value_cache,
                                                attn_metadata.slot_mapping,
189
190
                                                self.kv_cache_dtype, k_scale,
                                                v_scale)
191

192
        if attn_metadata.is_prompt:
193
            assert attn_metadata.seq_lens is not None
194
            if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
195
196
197
198
199
                if self.num_kv_heads != self.num_heads:
                    key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
                    value = value.repeat_interleave(self.num_queries_per_kv,
                                                    dim=1)

200
                if attn_metadata.attn_bias is None:
201
202
203
                    if self.alibi_slopes is not None:
                        att_masks = _make_alibi_bias(
                            self.alibi_slopes, query.dtype,
204
                            attn_metadata.seq_lens)  # type: ignore
205
206
                    elif self.sliding_window is not None:
                        att_masks = _make_sliding_window_bias(
207
                            attn_metadata.seq_lens, self.sliding_window,
208
209
                            query.dtype)  # type: ignore
                    else:
210
                        att_masks = [None] * len(attn_metadata.seq_lens)
211
                    attn_metadata.attn_bias = att_masks
212
213
214
215
216
217

                query = query.movedim(0, query.dim() - 2)
                key = key.movedim(0, key.dim() - 2)
                value = value.movedim(0, value.dim() - 2)

                start = 0
218
219
220
                output = torch.empty(
                    (num_tokens, self.num_heads, self.head_size),
                    dtype=query.dtype)
221
222
223
                for seq_len, mask in zip(attn_metadata.seq_lens,
                                         attn_metadata.attn_bias):
                    end = start + seq_len
224
                    sub_out = scaled_dot_product_attention(
225
226
227
                        query[None, :, start:end, :],
                        key[None, :, start:end, :],
                        value[None, :, start:end, :],
228
229
230
                        attn_mask=mask,
                        dropout_p=0.0,
                        is_causal=not self.need_mask,
231
232
                        scale=self.scale).squeeze(0).movedim(
                            query.dim() - 2, 0)
233
                    output[start:end, :, :] = sub_out
234
235
236
237
238
239
                    start = end
            else:
                # prefix-enabled attention
                raise RuntimeError(
                    "Torch SDPA backend doesn't support prefix decoding.")

240
        else:
241
            # Decoding run.
242
243
            output = PagedAttention.forward_decode(
                query,
244
245
                key_cache,
                value_cache,
246
                attn_metadata.block_tables,
247
                attn_metadata.seq_lens_tensor,
248
                attn_metadata.max_decode_seq_len,
249
                self.kv_cache_dtype,
250
251
252
                self.num_kv_heads,
                self.scale,
                self.alibi_slopes,
253
254
                k_scale,
                v_scale,
255
256
257
258
259
260
261
262
263
            )

        # Reshape the output tensor.
        return output.view(-1, self.num_heads * self.head_size)


def _make_alibi_bias(
    alibi_slopes: torch.Tensor,
    dtype: torch.dtype,
264
    seq_lens: List[int],
265
) -> List[torch.Tensor]:
266
    attn_biases: List[torch.Tensor] = []
267
268
    for seq_len in seq_lens:
        bias = torch.arange(seq_len, dtype=dtype)
269
        # NOTE(zhuohan): HF uses
270
        #     `bias = bias[None, :].repeat(seq_len, 1)`
271
272
273
274
275
276
        # here. We find that both biases give the same results, but
        # the bias below more accurately follows the original ALiBi
        # paper.
        bias = bias[None, :] - bias[:, None]

        num_heads = alibi_slopes.shape[0]
277
        bias = bias[None, :].repeat((num_heads, 1, 1))
278
        bias.mul_(alibi_slopes[:, None, None]).unsqueeze_(0)
279
        inf_mask = torch.empty(
280
            (1, seq_len, seq_len),
281
282
283
284
285
286
287
            dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1)
        attn_biases.append((bias + inf_mask).to(dtype))

    return attn_biases


def _make_sliding_window_bias(
288
    seq_lens: List[int],
289
290
291
    window_size: Optional[int],
    dtype: torch.dtype,
) -> List[torch.Tensor]:
292
    attn_biases: List[torch.Tensor] = []
293
    for seq_len in seq_lens:
294
        tensor = torch.full(
295
            (1, seq_len, seq_len),
296
297
298
299
300
301
302
303
304
305
306
            dtype=dtype,
            fill_value=1,
        )
        shift = 0
        mask = torch.tril(tensor, diagonal=shift).to(dtype)  # type: ignore
        if window_size is not None:
            mask = torch.triu(mask, diagonal=shift - window_size + 1)
        mask = torch.log(mask)
        attn_biases.append(mask.to(dtype))

    return attn_biases