torch_sdpa.py 9.64 KB
Newer Older
1
2
3
""" Attention layer with torch scaled_dot_product_attention
    and PagedAttention."""
from dataclasses import dataclass
4
from typing import List, Optional, Tuple, Type
5
6
7
8
9

import torch
from torch.nn.functional import scaled_dot_product_attention

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
10
11
                                              AttentionMetadata,
                                              AttentionMetadataPerStage)
12
13
14
15
16
17
from vllm.attention.ops.paged_attn import (PagedAttention,
                                           PagedAttentionMetadata)


class TorchSDPABackend(AttentionBackend):

18
19
20
21
    @staticmethod
    def get_name() -> str:
        return "torch-sdpa"

22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
    @staticmethod
    def get_impl_cls() -> Type["TorchSDPABackendImpl"]:
        return TorchSDPABackendImpl

    @staticmethod
    def make_metadata(*args, **kwargs) -> "TorchSDPAMetadata":
        return TorchSDPAMetadata(*args, **kwargs)

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
        return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
                                                 num_kv_heads, head_size)

    @staticmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
44
        src_to_dst: torch.Tensor,
45
46
47
48
49
50
    ) -> None:
        PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)

    @staticmethod
    def copy_blocks(
        kv_caches: List[torch.Tensor],
51
        src_to_dists: torch.Tensor,
52
53
54
55
56
    ) -> None:
        PagedAttention.copy_blocks(kv_caches, src_to_dists)


@dataclass
57
58
class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata,
                        AttentionMetadataPerStage):
59
60
61
62
63
    """Metadata for TorchSDPABackend.
    """
    # Currently, input sequences can only contain all prompts
    # or all decoding. True if all sequences are prompts.
    is_prompt: bool
64
    slot_mapping: torch.Tensor
65
    seq_lens: Optional[List[int]]
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85

    def __post_init__(self):
        # Set during the execution of the first attention op.
        # It is a list because it is needed to set per prompt
        # when alibi slopes is used. It is because of the limitation
        # from xformer API.
        # will not appear in the __repr__ and __init__
        self.attn_bias: Optional[List[torch.Tensor]] = None


class TorchSDPABackendImpl(AttentionImpl):

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
        alibi_slopes: Optional[List[float]] = None,
        sliding_window: Optional[int] = None,
86
        kv_cache_dtype: str = "auto",
87
88
89
90
91
92
93
94
    ) -> None:
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
        self.alibi_slopes = alibi_slopes
95
96
        self.sliding_window = sliding_window
        self.kv_cache_dtype = kv_cache_dtype
97
98
99

        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
100
101
102
103
104
        self.need_mask = (self.alibi_slopes is not None
                          or self.sliding_window is not None)

        supported_head_sizes = PagedAttention.get_supported_head_sizes()
        if head_size not in supported_head_sizes:
105
106
            raise ValueError(
                f"Head size {head_size} is not supported by PagedAttention. "
107
108
109
110
111
                f"Supported head sizes are: {supported_head_sizes}.")
        if kv_cache_dtype != "auto":
            raise NotImplementedError(
                "Torch SDPA backend does not support FP8 KV cache. "
                "Please use xFormers backend instead.")
112
113
114
115
116
117
118

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: Optional[torch.Tensor],
119
        attn_metadata: TorchSDPAMetadata,  # type: ignore
120
        kv_scale: float = 1.0,
121
122
123
124
125
126
127
128
129
130
131
132
    ) -> torch.Tensor:
        """Forward pass with torch SDPA and PagedAttention.

        Args:
            query: shape = [num_tokens, num_heads * head_size]
            key: shape = [num_tokens, num_kv_heads * head_size]
            value: shape = [num_tokens, num_kv_heads * head_size]
            kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
            attn_metadata: Metadata for attention.
        Returns:
            shape = [num_tokens, num_heads * head_size]
        """
133
        assert kv_scale == 1.0
134
135
136
137
138
139
140
141
142
143
144
145
        num_tokens, hidden_size = query.shape
        # Reshape the query, key, and value tensors.
        query = query.view(-1, self.num_heads, self.head_size)
        key = key.view(-1, self.num_kv_heads, self.head_size)
        value = value.view(-1, self.num_kv_heads, self.head_size)

        if kv_cache is not None:
            key_cache, value_cache = PagedAttention.split_kv_cache(
                kv_cache, self.num_kv_heads, self.head_size)
            PagedAttention.write_to_paged_cache(key, value, key_cache,
                                                value_cache,
                                                attn_metadata.slot_mapping,
146
                                                self.kv_cache_dtype, kv_scale)
147

148
        if attn_metadata.is_prompt:
149
            assert attn_metadata.seq_lens is not None
150
            if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
151
152
153
154
155
                if self.num_kv_heads != self.num_heads:
                    key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
                    value = value.repeat_interleave(self.num_queries_per_kv,
                                                    dim=1)

156
                if attn_metadata.attn_bias is None:
157
158
159
                    if self.alibi_slopes is not None:
                        att_masks = _make_alibi_bias(
                            self.alibi_slopes, query.dtype,
160
                            attn_metadata.seq_lens)  # type: ignore
161
162
                    elif self.sliding_window is not None:
                        att_masks = _make_sliding_window_bias(
163
                            attn_metadata.seq_lens, self.sliding_window,
164
165
                            query.dtype)  # type: ignore
                    else:
166
                        att_masks = [None] * len(attn_metadata.seq_lens)
167
                    attn_metadata.attn_bias = att_masks
168
169
170
171
172
173

                query = query.movedim(0, query.dim() - 2)
                key = key.movedim(0, key.dim() - 2)
                value = value.movedim(0, value.dim() - 2)

                start = 0
174
175
176
                output = torch.empty(
                    (num_tokens, self.num_heads, self.head_size),
                    dtype=query.dtype)
177
178
179
                for seq_len, mask in zip(attn_metadata.seq_lens,
                                         attn_metadata.attn_bias):
                    end = start + seq_len
180
181
182
183
184
185
186
187
                    sub_out = scaled_dot_product_attention(
                        query[:, start:end, :],
                        key[:, start:end, :],
                        value[:, start:end, :],
                        attn_mask=mask,
                        dropout_p=0.0,
                        is_causal=not self.need_mask,
                        scale=self.scale).movedim(query.dim() - 2, 0)
188
                    output[start:end, :, :] = sub_out
189
190
191
192
193
194
                    start = end
            else:
                # prefix-enabled attention
                raise RuntimeError(
                    "Torch SDPA backend doesn't support prefix decoding.")

195
        else:
196
            # Decoding run.
197
198
            output = PagedAttention.forward_decode(
                query,
199
200
                key_cache,
                value_cache,
201
                attn_metadata.block_tables,
202
203
                attn_metadata.seq_lens_tensor,
                attn_metadata.max_seq_len,
204
                self.kv_cache_dtype,
205
206
207
                self.num_kv_heads,
                self.scale,
                self.alibi_slopes,
208
                kv_scale,
209
210
211
212
213
214
215
216
217
            )

        # Reshape the output tensor.
        return output.view(-1, self.num_heads * self.head_size)


def _make_alibi_bias(
    alibi_slopes: torch.Tensor,
    dtype: torch.dtype,
218
    seq_lens: List[int],
219
220
) -> List[torch.Tensor]:
    attn_biases = []
221
222
    for seq_len in seq_lens:
        bias = torch.arange(seq_len, dtype=dtype)
223
        # NOTE(zhuohan): HF uses
224
        #     `bias = bias[None, :].repeat(seq_len, 1)`
225
226
227
228
229
230
        # here. We find that both biases give the same results, but
        # the bias below more accurately follows the original ALiBi
        # paper.
        bias = bias[None, :] - bias[:, None]

        num_heads = alibi_slopes.shape[0]
231
        bias = bias[None, :].repeat((num_heads, 1, 1))
232
233
        bias.mul_(alibi_slopes[:, None, None])
        inf_mask = torch.empty(
234
            (1, seq_len, seq_len),
235
236
237
238
239
240
241
            dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1)
        attn_biases.append((bias + inf_mask).to(dtype))

    return attn_biases


def _make_sliding_window_bias(
242
    seq_lens: List[int],
243
244
245
246
    window_size: Optional[int],
    dtype: torch.dtype,
) -> List[torch.Tensor]:
    attn_biases = []
247
    for seq_len in seq_lens:
248
        tensor = torch.full(
249
            (1, seq_len, seq_len),
250
251
252
253
254
255
256
257
258
259
260
            dtype=dtype,
            fill_value=1,
        )
        shift = 0
        mask = torch.tril(tensor, diagonal=shift).to(dtype)  # type: ignore
        if window_size is not None:
            mask = torch.triu(mask, diagonal=shift - window_size + 1)
        mask = torch.log(mask)
        attn_biases.append(mask.to(dtype))

    return attn_biases