pallas.py 13.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type

import torch
import torch_xla.experimental.custom_kernel  # Required to register custom ops.

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
11
                                              AttentionLayer,
12
13
                                              AttentionMetadata, AttentionType,
                                              is_quantized_kv_cache)
14
from vllm.attention.backends.utils import CommonAttentionState
15
16
17
from vllm.logger import init_logger

logger = init_logger(__name__)
18
19
20
21


class PallasAttentionBackend(AttentionBackend):

22
23
24
25
    @staticmethod
    def get_name() -> str:
        return "PALLAS"

26
27
28
29
30
    @staticmethod
    def get_impl_cls() -> Type["PallasAttentionBackendImpl"]:
        return PallasAttentionBackendImpl

    @staticmethod
31
32
    def get_metadata_cls() -> Type["PallasMetadata"]:
        return PallasMetadata
33

34
35
36
37
    @staticmethod
    def get_state_cls() -> Type["CommonAttentionState"]:
        return CommonAttentionState

38
39
40
41
42
43
44
45
46
47
48
    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
        return (num_kv_heads, num_blocks, block_size, head_size)

    @staticmethod
    def swap_blocks(
49
50
51
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
        src_to_dst: torch.Tensor,
52
    ) -> None:
53
        raise RuntimeError("swap_blocks is not used for the TPU backend.")
54
55

    @torch.compile(backend="openxla")
56
57
    @staticmethod
    def copy_blocks(
58
59
        kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
        src_to_dists: Tuple[torch.Tensor, torch.Tensor],
60
    ) -> None:
61
62
63
64
65
66
        src_indices, dst_indices = src_to_dists
        for k_cache, v_cache in kv_caches:
            torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True)
            k_cache[:, dst_indices] = k_cache[:, src_indices]
            torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True)
            v_cache[:, dst_indices] = v_cache[:, src_indices]
67
68
69
70
71
72
73


@dataclass
class PallasMetadata(AttentionMetadata):

    # Currently, input sequences can only contain all prefills
    # or all decoding.
74
75
    block_tables: Optional[torch.Tensor] = None
    context_lens: Optional[torch.Tensor] = None
76
    effective_query_lens: Optional[torch.Tensor] = None
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109

    @property
    def prefill_metadata(self) -> Optional["PallasMetadata"]:
        if self.num_prefills == 0:
            return None

        assert self.num_decode_tokens == 0
        return self

    @property
    def decode_metadata(self) -> Optional["PallasMetadata"]:
        if self.num_decode_tokens == 0:
            return None

        assert self.num_prefills == 0
        assert self.num_prefill_tokens == 0
        assert self.block_tables is not None
        assert self.context_lens is not None
        return self


class PallasAttentionBackendImpl(AttentionImpl):

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
        alibi_slopes: Optional[List[float]],
        sliding_window: Optional[int],
        kv_cache_dtype: str,
        blocksparse_params: Optional[Dict[str, Any]] = None,
110
        logits_soft_cap: Optional[float] = None,
111
        attn_type: str = AttentionType.DECODER,
112
        use_irope: bool = False,
113
    ) -> None:
114
115
116
117
        if use_irope:
            logger.warning_once(
                "Using irope in Pallas is not supported yet, it will fall back "
                "to global attention for long context.")
118
119
120
121
122
123
124
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads

        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
125
        self.logits_soft_cap = logits_soft_cap
126
        if head_size % 128 != 0:
127
128
            raise NotImplementedError(
                f"Head size must be a multiple of 128, found {head_size}.")
129
130
131
132
        if alibi_slopes is not None:
            raise NotImplementedError("Alibi slopes is not supported.")
        if sliding_window is not None:
            raise NotImplementedError("Sliding window is not supported.")
133
        if is_quantized_kv_cache(kv_cache_dtype):
134
135
136
137
138
139
140
141
            raise NotImplementedError("FP8 KV cache dtype is not supported.")
        if blocksparse_params is not None:
            raise NotImplementedError("Blocksparse is not supported.")

        if torch_xla.tpu.version() < 4:
            raise NotImplementedError("TPU version must be 4 or higher.")

        self.megacore_mode = None
142
        tpu_env = torch_xla.tpu.get_tpu_env()
143
144
145
146
        tpu_type = (tpu_env.get("ACCELERATOR_TYPE", None)
                    or tpu_env.get("TYPE", None)
                    or tpu_env.get("TPU_ACCELERATOR_TYPE", None))
        assert tpu_type is not None
147
148
        tpu_type = tpu_type.lower()

149
        if (("lite" not in tpu_type) and ("v6" not in tpu_type)):
150
151
152
153
154
155
156
            if self.num_kv_heads % 2 == 0:
                self.megacore_mode = "kv_head"
            else:
                # NOTE(woosuk): If the batch size is not a multiple of 2, the
                # megacore mode will be None.
                self.megacore_mode = "batch"

157
158
159
160
161
162
        if attn_type != AttentionType.DECODER:
            raise NotImplementedError("Encoder self-attention and "
                                      "encoder/decoder cross-attention "
                                      "are not implemented for "
                                      "PallasAttentionBackendImpl")

163
164
    def forward(
        self,
165
        layer: AttentionLayer,
166
167
168
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
169
        kv_cache: Tuple[torch.Tensor, torch.Tensor],
170
        attn_metadata: PallasMetadata,
171
        output: Optional[torch.Tensor] = None,
172
173
174
175
176
177
178
    ) -> torch.Tensor:
        """Forward pass with Pallas attention.

        Args:
            query: shape = [batch_size, seq_len, num_heads * head_size]
            key: shape = [batch_size, seq_len, num_kv_heads * head_size]
            value: shape = [batch_size, seq_len, num_kv_heads * head_size]
179
180
181
182
            kv_cache[0] = [num_kv_heads, num_blocks, block_size, head_size]
            kv_cache[1] = [num_kv_heads, num_blocks, block_size, head_size]
                NOTE: kv_cache[0] and kv_cache[1] will be an empty tensor 
                with shape [0] for profiling run.
183
184
185
186
            attn_metadata: Metadata for attention.
        Returns:
            shape = [batch_size, seq_len, num_heads * head_size]
        """
187
        assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
188
189
190
191
192
193
        batch_size, seq_len, hidden_size = query.shape
        query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
        key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
        value = value.view(batch_size, seq_len, self.num_kv_heads,
                           self.head_size)

194
        if kv_cache[0].numel() > 0:
195
196
197
198
199
200
            slot_mapping = attn_metadata.slot_mapping
            key_cache, value_cache = kv_cache
            write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping)

        query = query * self.scale
        if attn_metadata.num_prefills > 0:
201
202
203
204
205
206
207
208
209
            if attn_metadata.block_tables is None:
                # Prefill without paged KV cache.
                assert seq_len % 16 == 0, (
                    "Pallas FlashAttention kernel requires seq_len to be a "
                    f"multiple of 16 but got {seq_len}")

                # Handle GQA/MQA.
                if self.num_kv_heads != self.num_heads:
                    key = key.repeat_interleave(self.num_queries_per_kv,
210
                                                dim=-2)
211
                    key = key.view(batch_size, seq_len, self.num_heads,
212
                                   self.head_size)
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
                    value = value.repeat_interleave(self.num_queries_per_kv,
                                                    dim=-2)
                    value = value.view(batch_size, seq_len, self.num_heads,
                                       self.head_size)
                # FlashAttention kernel requires the input shape to be
                # [batch_size, num_heads, seq_len, d_model]
                # while the input is [batch_size, seq_len, num_heads, d_model].
                # Permute the input to match the required format.
                output = torch.ops.xla.flash_attention(
                    query.permute(0, 2, 1, 3),
                    key.permute(0, 2, 1, 3),
                    value.permute(0, 2, 1, 3),
                    True,
                )
                output = output.permute(0, 2, 1, 3)
            else:
                # Prefill with paged KV cache.
                # TODO(woosuk): Tune the below knobs.
                num_kv_pages_per_compute_block = 16
                num_queries_per_compute_block = 16
                assert seq_len % num_queries_per_compute_block == 0
                output = torch.ops.xla.multi_queries_paged_attention(
                    query,
                    key_cache,
                    value_cache,
                    attn_metadata.context_lens,
                    attn_metadata.block_tables,
                    attn_metadata.effective_query_lens,
                    num_kv_pages_per_compute_block,
                    num_queries_per_compute_block,
                    use_kernel=True,
244
                    attn_logits_soft_cap=self.logits_soft_cap,
245
                )
246
247
        else:
            # Decoding run.
248
            assert kv_cache[0].numel() > 0
249
            query = query.squeeze(dim=1)
250
            pages_per_compute_block = 16  # TODO(woosuk): Tune this value.
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265

            assert attn_metadata.block_tables is not None
            assert attn_metadata.context_lens is not None
            # NOTE(woosuk): The PagedAttention Pallas kernel stores the entire
            # block table in SMEM. Therefore, if the block table is too large,
            # the kernel compilation will fail. To avoid this, we split the
            # batch dimension into smaller chunks and run the kernel multiple
            # times.
            MAX_SMEM_USAGE = 512 * 1024
            size_per_seq = 4 * attn_metadata.block_tables.shape[1]
            max_num_seq = MAX_SMEM_USAGE // size_per_seq

            if batch_size <= max_num_seq:
                output = paged_attention(
                    query,
266
267
268
269
270
                    key_cache,
                    value_cache,
                    attn_metadata.context_lens,
                    attn_metadata.block_tables,
                    pages_per_compute_block,
271
                    self.megacore_mode,
272
                    attn_logits_soft_cap=self.logits_soft_cap,
273
274
                )
            else:
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
                chunk_size = max_num_seq
                # Make sure the chunk size is a multiple of 2.
                chunk_size = chunk_size // 2 * 2
                num_chunks = (batch_size + chunk_size - 1) // chunk_size

                output = torch.empty_like(query)
                for chunk_idx in range(num_chunks):
                    chunk_start = chunk_idx * chunk_size
                    chunk_end = chunk_start + chunk_size
                    # NOTE(woosuk): We skip this line because it causes Dynamo
                    # compilation error. Instead, we rely on the slice operation
                    # to handle the out-of-bound case.
                    # chunk_end = min(chunk_end, batch_size)
                    chunk_output = paged_attention(
                        query[chunk_start:chunk_end],
                        key_cache,
                        value_cache,
                        attn_metadata.context_lens[chunk_start:chunk_end],
                        attn_metadata.block_tables[chunk_start:chunk_end],
                        pages_per_compute_block,
                        self.megacore_mode,
296
                        attn_logits_soft_cap=self.logits_soft_cap,
297
298
                    )
                    output[chunk_start:chunk_end] = chunk_output
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319

        # Reshape the output tensor.
        return output.reshape(batch_size, seq_len, hidden_size)


def write_to_kv_cache(
    key: torch.Tensor,
    value: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    slot_mapping: torch.Tensor,
) -> None:
    torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True)
    torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True)

    key = key.flatten(0, 2)
    value = value.flatten(0, 2)
    key_cache = key_cache.flatten(0, 2)
    value_cache = value_cache.flatten(0, 2)
    key_cache.index_copy_(0, slot_mapping, key)
    value_cache.index_copy_(0, slot_mapping, value)
320
321
322
323
324
325
326
327
328
329


def paged_attention(
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    context_lens: torch.Tensor,
    block_tables: torch.Tensor,
    pages_per_compute_block: int,
    megacore_mode: Optional[str],
330
331
    *,
    attn_logits_soft_cap: Optional[float],
332
333
334
335
336
337
338
) -> torch.Tensor:
    batch_size = query.shape[0]
    if megacore_mode == "batch" and batch_size % 2 != 0:
        megacore_mode = None
    else:
        megacore_mode = megacore_mode

339
340
341
342
343
344
345
346
347
348
    return torch.ops.xla.paged_attention(
        query,
        key_cache,
        value_cache,
        context_lens,
        block_tables,
        pages_per_compute_block,
        megacore_mode=megacore_mode,
        attn_logits_soft_cap=attn_logits_soft_cap,
    )