pallas.py 13.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type

import torch
import torch_xla.experimental.custom_kernel  # Required to register custom ops.

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
11
                                              AttentionLayer,
12
13
                                              AttentionMetadata, AttentionType,
                                              is_quantized_kv_cache)
14
from vllm.attention.backends.utils import CommonAttentionState
15
16
17
from vllm.logger import init_logger

logger = init_logger(__name__)
18
19
20
21


class PallasAttentionBackend(AttentionBackend):

22
23
24
25
    @staticmethod
    def get_name() -> str:
        return "PALLAS"

26
27
28
29
30
    @staticmethod
    def get_impl_cls() -> Type["PallasAttentionBackendImpl"]:
        return PallasAttentionBackendImpl

    @staticmethod
31
32
    def get_metadata_cls() -> Type["PallasMetadata"]:
        return PallasMetadata
33

34
35
36
37
    @staticmethod
    def get_state_cls() -> Type["CommonAttentionState"]:
        return CommonAttentionState

38
39
40
41
42
43
44
45
46
47
48
    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
        return (num_kv_heads, num_blocks, block_size, head_size)

    @staticmethod
    def swap_blocks(
49
50
51
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
        src_to_dst: torch.Tensor,
52
    ) -> None:
53
        raise RuntimeError("swap_blocks is not used for the TPU backend.")
54
55

    @torch.compile(backend="openxla")
56
57
    @staticmethod
    def copy_blocks(
58
59
        kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
        src_to_dists: Tuple[torch.Tensor, torch.Tensor],
60
    ) -> None:
61
62
63
64
65
66
        src_indices, dst_indices = src_to_dists
        for k_cache, v_cache in kv_caches:
            torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True)
            k_cache[:, dst_indices] = k_cache[:, src_indices]
            torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True)
            v_cache[:, dst_indices] = v_cache[:, src_indices]
67
68
69
70
71
72
73


@dataclass
class PallasMetadata(AttentionMetadata):

    # Currently, input sequences can only contain all prefills
    # or all decoding.
74
75
    block_tables: Optional[torch.Tensor] = None
    context_lens: Optional[torch.Tensor] = None
76
    effective_query_lens: Optional[torch.Tensor] = None
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109

    @property
    def prefill_metadata(self) -> Optional["PallasMetadata"]:
        if self.num_prefills == 0:
            return None

        assert self.num_decode_tokens == 0
        return self

    @property
    def decode_metadata(self) -> Optional["PallasMetadata"]:
        if self.num_decode_tokens == 0:
            return None

        assert self.num_prefills == 0
        assert self.num_prefill_tokens == 0
        assert self.block_tables is not None
        assert self.context_lens is not None
        return self


class PallasAttentionBackendImpl(AttentionImpl):

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
        alibi_slopes: Optional[List[float]],
        sliding_window: Optional[int],
        kv_cache_dtype: str,
        blocksparse_params: Optional[Dict[str, Any]] = None,
110
        logits_soft_cap: Optional[float] = None,
111
        attn_type: str = AttentionType.DECODER,
112
        kv_sharing_target_layer_name: Optional[str] = None,
113
        use_irope: bool = False,
114
    ) -> None:
115
116
        if kv_sharing_target_layer_name is not None:
            raise NotImplementedError("KV sharing is not supported in V0.")
117
118
119
120
        if use_irope:
            logger.warning_once(
                "Using irope in Pallas is not supported yet, it will fall back "
                "to global attention for long context.")
121
122
123
124
125
126
127
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads

        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
128
        self.logits_soft_cap = logits_soft_cap
129
        if head_size % 128 != 0:
130
131
            raise NotImplementedError(
                f"Head size must be a multiple of 128, found {head_size}.")
132
133
134
135
        if alibi_slopes is not None:
            raise NotImplementedError("Alibi slopes is not supported.")
        if sliding_window is not None:
            raise NotImplementedError("Sliding window is not supported.")
136
        if is_quantized_kv_cache(kv_cache_dtype):
137
138
139
140
141
142
143
144
            raise NotImplementedError("FP8 KV cache dtype is not supported.")
        if blocksparse_params is not None:
            raise NotImplementedError("Blocksparse is not supported.")

        if torch_xla.tpu.version() < 4:
            raise NotImplementedError("TPU version must be 4 or higher.")

        self.megacore_mode = None
145
        tpu_env = torch_xla.tpu.get_tpu_env()
146
147
148
149
        tpu_type = (tpu_env.get("ACCELERATOR_TYPE", None)
                    or tpu_env.get("TYPE", None)
                    or tpu_env.get("TPU_ACCELERATOR_TYPE", None))
        assert tpu_type is not None
150
151
        tpu_type = tpu_type.lower()

152
        if (("lite" not in tpu_type) and ("v6" not in tpu_type)):
153
154
155
156
157
158
159
            if self.num_kv_heads % 2 == 0:
                self.megacore_mode = "kv_head"
            else:
                # NOTE(woosuk): If the batch size is not a multiple of 2, the
                # megacore mode will be None.
                self.megacore_mode = "batch"

160
161
162
163
164
165
        if attn_type != AttentionType.DECODER:
            raise NotImplementedError("Encoder self-attention and "
                                      "encoder/decoder cross-attention "
                                      "are not implemented for "
                                      "PallasAttentionBackendImpl")

166
167
    def forward(
        self,
168
        layer: AttentionLayer,
169
170
171
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
172
        kv_cache: Tuple[torch.Tensor, torch.Tensor],
173
        attn_metadata: PallasMetadata,
174
        output: Optional[torch.Tensor] = None,
175
        output_scale: Optional[torch.Tensor] = None,
176
177
178
179
180
181
182
    ) -> torch.Tensor:
        """Forward pass with Pallas attention.

        Args:
            query: shape = [batch_size, seq_len, num_heads * head_size]
            key: shape = [batch_size, seq_len, num_kv_heads * head_size]
            value: shape = [batch_size, seq_len, num_kv_heads * head_size]
183
184
185
186
            kv_cache[0] = [num_kv_heads, num_blocks, block_size, head_size]
            kv_cache[1] = [num_kv_heads, num_blocks, block_size, head_size]
                NOTE: kv_cache[0] and kv_cache[1] will be an empty tensor 
                with shape [0] for profiling run.
187
188
189
190
            attn_metadata: Metadata for attention.
        Returns:
            shape = [batch_size, seq_len, num_heads * head_size]
        """
191
192
193
194
195
        if output_scale is not None:
            raise NotImplementedError(
                "fused output quantization is not yet supported"
                " for PallasAttentionImpl")

196
        assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
197
198
199
200
201
202
        batch_size, seq_len, hidden_size = query.shape
        query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
        key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
        value = value.view(batch_size, seq_len, self.num_kv_heads,
                           self.head_size)

203
        if kv_cache[0].numel() > 0:
204
205
206
207
208
209
            slot_mapping = attn_metadata.slot_mapping
            key_cache, value_cache = kv_cache
            write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping)

        query = query * self.scale
        if attn_metadata.num_prefills > 0:
210
211
212
213
214
215
216
217
218
            if attn_metadata.block_tables is None:
                # Prefill without paged KV cache.
                assert seq_len % 16 == 0, (
                    "Pallas FlashAttention kernel requires seq_len to be a "
                    f"multiple of 16 but got {seq_len}")

                # Handle GQA/MQA.
                if self.num_kv_heads != self.num_heads:
                    key = key.repeat_interleave(self.num_queries_per_kv,
219
                                                dim=-2)
220
                    key = key.view(batch_size, seq_len, self.num_heads,
221
                                   self.head_size)
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
                    value = value.repeat_interleave(self.num_queries_per_kv,
                                                    dim=-2)
                    value = value.view(batch_size, seq_len, self.num_heads,
                                       self.head_size)
                # FlashAttention kernel requires the input shape to be
                # [batch_size, num_heads, seq_len, d_model]
                # while the input is [batch_size, seq_len, num_heads, d_model].
                # Permute the input to match the required format.
                output = torch.ops.xla.flash_attention(
                    query.permute(0, 2, 1, 3),
                    key.permute(0, 2, 1, 3),
                    value.permute(0, 2, 1, 3),
                    True,
                )
                output = output.permute(0, 2, 1, 3)
            else:
                # Prefill with paged KV cache.
                # TODO(woosuk): Tune the below knobs.
                num_kv_pages_per_compute_block = 16
                num_queries_per_compute_block = 16
                assert seq_len % num_queries_per_compute_block == 0
                output = torch.ops.xla.multi_queries_paged_attention(
                    query,
                    key_cache,
                    value_cache,
                    attn_metadata.context_lens,
                    attn_metadata.block_tables,
                    attn_metadata.effective_query_lens,
                    num_kv_pages_per_compute_block,
                    num_queries_per_compute_block,
                    use_kernel=True,
253
                    attn_logits_soft_cap=self.logits_soft_cap,
254
                )
255
256
        else:
            # Decoding run.
257
            assert kv_cache[0].numel() > 0
258
            query = query.squeeze(dim=1)
259
            pages_per_compute_block = 16  # TODO(woosuk): Tune this value.
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274

            assert attn_metadata.block_tables is not None
            assert attn_metadata.context_lens is not None
            # NOTE(woosuk): The PagedAttention Pallas kernel stores the entire
            # block table in SMEM. Therefore, if the block table is too large,
            # the kernel compilation will fail. To avoid this, we split the
            # batch dimension into smaller chunks and run the kernel multiple
            # times.
            MAX_SMEM_USAGE = 512 * 1024
            size_per_seq = 4 * attn_metadata.block_tables.shape[1]
            max_num_seq = MAX_SMEM_USAGE // size_per_seq

            if batch_size <= max_num_seq:
                output = paged_attention(
                    query,
275
276
277
278
279
                    key_cache,
                    value_cache,
                    attn_metadata.context_lens,
                    attn_metadata.block_tables,
                    pages_per_compute_block,
280
                    self.megacore_mode,
281
                    attn_logits_soft_cap=self.logits_soft_cap,
282
283
                )
            else:
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
                chunk_size = max_num_seq
                # Make sure the chunk size is a multiple of 2.
                chunk_size = chunk_size // 2 * 2
                num_chunks = (batch_size + chunk_size - 1) // chunk_size

                output = torch.empty_like(query)
                for chunk_idx in range(num_chunks):
                    chunk_start = chunk_idx * chunk_size
                    chunk_end = chunk_start + chunk_size
                    # NOTE(woosuk): We skip this line because it causes Dynamo
                    # compilation error. Instead, we rely on the slice operation
                    # to handle the out-of-bound case.
                    # chunk_end = min(chunk_end, batch_size)
                    chunk_output = paged_attention(
                        query[chunk_start:chunk_end],
                        key_cache,
                        value_cache,
                        attn_metadata.context_lens[chunk_start:chunk_end],
                        attn_metadata.block_tables[chunk_start:chunk_end],
                        pages_per_compute_block,
                        self.megacore_mode,
305
                        attn_logits_soft_cap=self.logits_soft_cap,
306
307
                    )
                    output[chunk_start:chunk_end] = chunk_output
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328

        # Reshape the output tensor.
        return output.reshape(batch_size, seq_len, hidden_size)


def write_to_kv_cache(
    key: torch.Tensor,
    value: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    slot_mapping: torch.Tensor,
) -> None:
    torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True)
    torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True)

    key = key.flatten(0, 2)
    value = value.flatten(0, 2)
    key_cache = key_cache.flatten(0, 2)
    value_cache = value_cache.flatten(0, 2)
    key_cache.index_copy_(0, slot_mapping, key)
    value_cache.index_copy_(0, slot_mapping, value)
329
330
331
332
333
334
335
336
337
338


def paged_attention(
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    context_lens: torch.Tensor,
    block_tables: torch.Tensor,
    pages_per_compute_block: int,
    megacore_mode: Optional[str],
339
340
    *,
    attn_logits_soft_cap: Optional[float],
341
342
343
344
345
346
347
) -> torch.Tensor:
    batch_size = query.shape[0]
    if megacore_mode == "batch" and batch_size % 2 != 0:
        megacore_mode = None
    else:
        megacore_mode = megacore_mode

348
349
350
351
352
353
354
355
356
357
    return torch.ops.xla.paged_attention(
        query,
        key_cache,
        value_cache,
        context_lens,
        block_tables,
        pages_per_compute_block,
        megacore_mode=megacore_mode,
        attn_logits_soft_cap=attn_logits_soft_cap,
    )