"tests/planner/scaling/disagg_planner_throughput.yaml" did not exist on "ee3a8e42838108a92d813b1a21037ad69e4d1530"
torch_sdpa.py 10.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
""" Attention layer with torch scaled_dot_product_attention
    and PagedAttention."""
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Type

import torch
from torch.nn.functional import scaled_dot_product_attention

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
10
11
                                              AttentionMetadata,
                                              AttentionMetadataPerStage)
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from vllm.attention.ops.paged_attn import (PagedAttention,
                                           PagedAttentionMetadata)


class TorchSDPABackend(AttentionBackend):

    @staticmethod
    def get_impl_cls() -> Type["TorchSDPABackendImpl"]:
        return TorchSDPABackendImpl

    @staticmethod
    def make_metadata(*args, **kwargs) -> "TorchSDPAMetadata":
        return TorchSDPAMetadata(*args, **kwargs)

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
        return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
                                                 num_kv_heads, head_size)

    @staticmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
        src_to_dst: Dict[int, int],
    ) -> None:
        PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)

    @staticmethod
    def copy_blocks(
        kv_caches: List[torch.Tensor],
        src_to_dists: Dict[int, List[int]],
    ) -> None:
        PagedAttention.copy_blocks(kv_caches, src_to_dists)


@dataclass
53
class TorchSDPAMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
    """Metadata for TorchSDPABackend.
    """
    # Currently, input sequences can only contain all prompts
    # or all decoding. True if all sequences are prompts.
    is_prompt: bool
    prompt_lens: Optional[List[int]]
    prompt_lens_tensor: Optional[torch.Tensor]

    max_subquery_len: Optional[int] = None
    max_prompt_len: Optional[int] = None
    subquery_start_loc: Optional[torch.Tensor] = None
    seq_start_loc: Optional[torch.Tensor] = None
    use_cuda_graph: bool = False

    def __post_init__(self):
        # Set during the execution of the first attention op.
        # It is a list because it is needed to set per prompt
        # when alibi slopes is used. It is because of the limitation
        # from xformer API.
        # will not appear in the __repr__ and __init__
        self.attn_bias: Optional[List[torch.Tensor]] = None


class TorchSDPABackendImpl(AttentionImpl):

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
        alibi_slopes: Optional[List[float]] = None,
        sliding_window: Optional[int] = None,
    ) -> None:
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
        self.sliding_window = sliding_window
        if alibi_slopes is not None:
            assert len(alibi_slopes) == num_heads
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
        self.alibi_slopes = alibi_slopes
        self.need_mask = (self.alibi_slopes is not None
                          or self.sliding_window is not None)

        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
        suppored_head_sizes = PagedAttention.get_supported_head_sizes()
        if head_size not in suppored_head_sizes:
            raise ValueError(
                f"Head size {head_size} is not supported by PagedAttention. "
                f"Supported head sizes are: {suppored_head_sizes}.")

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: Optional[torch.Tensor],
114
        attn_metadata: AttentionMetadata[TorchSDPAMetadata],
115
        kv_scale: float,
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
    ) -> torch.Tensor:
        """Forward pass with torch SDPA and PagedAttention.

        Args:
            query: shape = [num_tokens, num_heads * head_size]
            key: shape = [num_tokens, num_kv_heads * head_size]
            value: shape = [num_tokens, num_kv_heads * head_size]
            kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
            attn_metadata: Metadata for attention.
        Returns:
            shape = [num_tokens, num_heads * head_size]
        """
        num_tokens, hidden_size = query.shape
        # Reshape the query, key, and value tensors.
        query = query.view(-1, self.num_heads, self.head_size)
        key = key.view(-1, self.num_kv_heads, self.head_size)
        value = value.view(-1, self.num_kv_heads, self.head_size)

        if kv_cache is not None:
            key_cache, value_cache = PagedAttention.split_kv_cache(
                kv_cache, self.num_kv_heads, self.head_size)
            PagedAttention.write_to_paged_cache(key, value, key_cache,
                                                value_cache,
                                                attn_metadata.slot_mapping,
140
141
                                                attn_metadata.kv_cache_dtype,
                                                kv_scale)
142

143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
        num_prefill_tokens = attn_metadata.num_prefill_tokens
        num_decode_tokens = attn_metadata.num_decode_tokens
        assert key.shape[0] == num_prefill_tokens + num_decode_tokens
        assert value.shape[0] == num_prefill_tokens + num_decode_tokens

        output = torch.empty_like(query)
        # Query for decode. KV is not needed because it is already cached.
        decode_query = query[num_prefill_tokens:]
        # QKV for prefill.
        query = query[:num_prefill_tokens]
        key = key[:num_prefill_tokens]
        value = value[:num_prefill_tokens]

        assert query.shape[0] == num_prefill_tokens
        assert decode_query.shape[0] == num_decode_tokens

        if prefill_meta := attn_metadata.prefill_metadata:
            if (kv_cache is None or prefill_meta.block_tables.numel() == 0):
161
162
163
164
165
                if self.num_kv_heads != self.num_heads:
                    key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
                    value = value.repeat_interleave(self.num_queries_per_kv,
                                                    dim=1)

166
                if prefill_meta.attn_bias is None:
167
168
169
                    if self.alibi_slopes is not None:
                        att_masks = _make_alibi_bias(
                            self.alibi_slopes, query.dtype,
170
                            prefill_meta.prompt_lens)  # type: ignore
171
172
                    elif self.sliding_window is not None:
                        att_masks = _make_sliding_window_bias(
173
                            prefill_meta.prompt_lens, self.sliding_window,
174
175
                            query.dtype)  # type: ignore
                    else:
176
177
                        att_masks = [None] * len(prefill_meta.prompt_lens)
                    prefill_meta.attn_bias = att_masks
178
179
180
181
182
183

                query = query.movedim(0, query.dim() - 2)
                key = key.movedim(0, key.dim() - 2)
                value = value.movedim(0, value.dim() - 2)

                start = 0
184
185
186
187
                out = torch.empty((num_tokens, self.num_heads, self.head_size),
                                  dtype=query.dtype)
                for prompt_len, mask in zip(prefill_meta.prompt_lens,
                                            prefill_meta.attn_bias):
188
189
190
191
192
193
194
195
196
                    end = start + prompt_len
                    sub_out = scaled_dot_product_attention(
                        query[:, start:end, :],
                        key[:, start:end, :],
                        value[:, start:end, :],
                        attn_mask=mask,
                        dropout_p=0.0,
                        is_causal=not self.need_mask,
                        scale=self.scale).movedim(query.dim() - 2, 0)
197
                    out[start:end, :, :] = sub_out
198
                    start = end
199
200
                assert out.shape == output[:num_prefill_tokens].shape
                output[:num_prefill_tokens] = out
201
202
203
204
205
            else:
                # prefix-enabled attention
                raise RuntimeError(
                    "Torch SDPA backend doesn't support prefix decoding.")

206
        if decode_meta := attn_metadata.decode_metadata:
207
            # Decoding run.
208
209
            out = PagedAttention.forward_decode(
                decode_query,
210
211
                key_cache,
                value_cache,
212
213
214
                decode_meta.block_tables,
                decode_meta.context_lens,
                decode_meta.max_context_len,
215
216
217
218
                attn_metadata.kv_cache_dtype,
                self.num_kv_heads,
                self.scale,
                self.alibi_slopes,
219
                kv_scale,
220
            )
221
222
            assert out.shape == output[num_prefill_tokens:].shape
            output[num_prefill_tokens:]
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273

        # Reshape the output tensor.
        return output.view(-1, self.num_heads * self.head_size)


def _make_alibi_bias(
    alibi_slopes: torch.Tensor,
    dtype: torch.dtype,
    prompt_lens: List[int],
) -> List[torch.Tensor]:
    attn_biases = []
    for prompt_len in prompt_lens:
        bias = torch.arange(prompt_len, dtype=dtype)
        # NOTE(zhuohan): HF uses
        #     `bias = bias[None, :].repeat(prompt_len, 1)`
        # here. We find that both biases give the same results, but
        # the bias below more accurately follows the original ALiBi
        # paper.
        bias = bias[None, :] - bias[:, None]

        num_heads = alibi_slopes.shape[0]
        bias = bias[None, :].expand(num_heads, prompt_len, prompt_len)
        bias.mul_(alibi_slopes[:, None, None])
        inf_mask = torch.empty(
            (1, prompt_len, prompt_len),
            dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1)
        attn_biases.append((bias + inf_mask).to(dtype))

    return attn_biases


def _make_sliding_window_bias(
    prompt_lens: List[int],
    window_size: Optional[int],
    dtype: torch.dtype,
) -> List[torch.Tensor]:
    attn_biases = []
    for prompt_len in prompt_lens:
        tensor = torch.full(
            (1, prompt_len, prompt_len),
            dtype=dtype,
            fill_value=1,
        )
        shift = 0
        mask = torch.tril(tensor, diagonal=shift).to(dtype)  # type: ignore
        if window_size is not None:
            mask = torch.triu(mask, diagonal=shift - window_size + 1)
        mask = torch.log(mask)
        attn_biases.append(mask.to(dtype))

    return attn_biases