pallas.py 9.48 KB
Newer Older
1
2
3
4
5
6
7
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type

import torch
import torch_xla.experimental.custom_kernel  # Required to register custom ops.

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
8
                                              AttentionMetadata, AttentionType)
9
from vllm.attention.backends.utils import CommonAttentionState
10
11
12
13
14
15
16
17
18


class PallasAttentionBackend(AttentionBackend):

    @staticmethod
    def get_impl_cls() -> Type["PallasAttentionBackendImpl"]:
        return PallasAttentionBackendImpl

    @staticmethod
19
20
    def get_metadata_cls() -> Type["PallasMetadata"]:
        return PallasMetadata
21

22
23
24
25
    @staticmethod
    def get_state_cls() -> Type["CommonAttentionState"]:
        return CommonAttentionState

26
27
28
29
30
31
32
33
34
35
36
    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
        return (num_kv_heads, num_blocks, block_size, head_size)

    @staticmethod
    def swap_blocks(
37
38
39
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
        src_to_dst: torch.Tensor,
40
    ) -> None:
41
        raise RuntimeError("swap_blocks is not used for the TPU backend.")
42
43

    @torch.compile(backend="openxla")
44
45
    @staticmethod
    def copy_blocks(
46
47
        kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
        src_to_dists: Tuple[torch.Tensor, torch.Tensor],
48
    ) -> None:
49
50
51
52
53
54
        src_indices, dst_indices = src_to_dists
        for k_cache, v_cache in kv_caches:
            torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True)
            k_cache[:, dst_indices] = k_cache[:, src_indices]
            torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True)
            v_cache[:, dst_indices] = v_cache[:, src_indices]
55
56
57
58
59
60
61


@dataclass
class PallasMetadata(AttentionMetadata):

    # Currently, input sequences can only contain all prefills
    # or all decoding.
62
63
    block_tables: Optional[torch.Tensor] = None
    context_lens: Optional[torch.Tensor] = None
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98

    @property
    def prefill_metadata(self) -> Optional["PallasMetadata"]:
        if self.num_prefills == 0:
            return None

        assert self.num_decode_tokens == 0
        assert self.block_tables is None
        assert self.context_lens is None
        return self

    @property
    def decode_metadata(self) -> Optional["PallasMetadata"]:
        if self.num_decode_tokens == 0:
            return None

        assert self.num_prefills == 0
        assert self.num_prefill_tokens == 0
        assert self.block_tables is not None
        assert self.context_lens is not None
        return self


class PallasAttentionBackendImpl(AttentionImpl):

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
        alibi_slopes: Optional[List[float]],
        sliding_window: Optional[int],
        kv_cache_dtype: str,
        blocksparse_params: Optional[Dict[str, Any]] = None,
99
        logits_soft_cap: Optional[float] = None,
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
    ) -> None:
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads

        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
        if head_size % 128 != 0:
            raise NotImplementedError("Head size must be a multiple of 128.")
        if alibi_slopes is not None:
            raise NotImplementedError("Alibi slopes is not supported.")
        if sliding_window is not None:
            raise NotImplementedError("Sliding window is not supported.")
        if kv_cache_dtype != "auto":
            raise NotImplementedError("FP8 KV cache dtype is not supported.")
        if blocksparse_params is not None:
            raise NotImplementedError("Blocksparse is not supported.")
118
119
120
        if logits_soft_cap is not None:
            raise NotImplementedError(
                "Attention logits soft-capping is not supported.")
121
122
123
124
125

        if torch_xla.tpu.version() < 4:
            raise NotImplementedError("TPU version must be 4 or higher.")

        self.megacore_mode = None
126
        tpu_type = torch_xla.tpu.get_tpu_env()["TYPE"].lower()
127
        if "lite" not in tpu_type:
128
129
130
131
132
133
134
135
136
137
138
139
140
141
            if self.num_kv_heads % 2 == 0:
                self.megacore_mode = "kv_head"
            else:
                # NOTE(woosuk): If the batch size is not a multiple of 2, the
                # megacore mode will be None.
                self.megacore_mode = "batch"

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]],
        attn_metadata: PallasMetadata,
142
143
        k_scale: float = 1.0,
        v_scale: float = 1.0,
144
        attn_type: AttentionType = AttentionType.DECODER,
145
146
147
148
149
150
151
152
153
154
155
156
157
    ) -> torch.Tensor:
        """Forward pass with Pallas attention.

        Args:
            query: shape = [batch_size, seq_len, num_heads * head_size]
            key: shape = [batch_size, seq_len, num_kv_heads * head_size]
            value: shape = [batch_size, seq_len, num_kv_heads * head_size]
            key_cache = [num_kv_heads, num_blocks, block_size, head_size]
            value_cache = [num_kv_heads, num_blocks, block_size, head_size]
            attn_metadata: Metadata for attention.
        Returns:
            shape = [batch_size, seq_len, num_heads * head_size]
        """
158
        assert k_scale == 1.0 and v_scale == 1.0
159
160
161
162
163
        if attn_type != AttentionType.DECODER:
            raise NotImplementedError("Encoder self-attention and "
                                      "encoder/decoder cross-attention "
                                      "are not implemented for "
                                      "PallasAttentionBackendImpl")
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
        batch_size, seq_len, hidden_size = query.shape
        query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
        key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
        value = value.view(batch_size, seq_len, self.num_kv_heads,
                           self.head_size)

        if kv_cache[0] is not None:
            slot_mapping = attn_metadata.slot_mapping
            key_cache, value_cache = kv_cache
            write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping)

        query = query * self.scale
        if attn_metadata.num_prefills > 0:
            assert seq_len % 16 == 0, (
                "Pallas FlashAttention kernel requires seq_len to be a "
                f"multiple of 16 but got {seq_len}")

            # Handle GQA/MQA.
            if self.num_kv_heads != self.num_heads:
                key = key.repeat_interleave(self.num_queries_per_kv, dim=-2)
                key = key.view(batch_size, seq_len, self.num_heads,
                               self.head_size)
                value = value.repeat_interleave(self.num_queries_per_kv,
                                                dim=-2)
                value = value.view(batch_size, seq_len, self.num_heads,
                                   self.head_size)
            # FlashAttention requires [batch_size, num_heads, seq_len, d_model]
            # while the input is [batch_size, seq_len, num_heads, d_model].
            # Permute the input to match the required format.
            output = torch.ops.xla.flash_attention(
                query.permute(0, 2, 1, 3),
                key.permute(0, 2, 1, 3),
                value.permute(0, 2, 1, 3),
                True,
            )
            output = output.permute(0, 2, 1, 3)
        else:
            # Decoding run.
            assert kv_cache is not None

            pages_per_compute_block = 16  # TODO(woosuk): Tune this value.
            if self.megacore_mode == "batch" and batch_size % 2 != 0:
                megacore_mode = None
            else:
                megacore_mode = self.megacore_mode

            # NOTE(woosuk): A temporary workaround to avoid the error:
            # "xla::paged_attention() Expected a value of type 'str' for
            # argument 'megacore_mode' but instead found type 'NoneType'."
            if megacore_mode is not None:
                output = torch.ops.xla.paged_attention(
                    query.squeeze(dim=1),
                    key_cache,
                    value_cache,
                    attn_metadata.context_lens,
                    attn_metadata.block_tables,
                    pages_per_compute_block,
                    megacore_mode=megacore_mode,
                )
            else:
                output = torch.ops.xla.paged_attention(
                    query.squeeze(dim=1),
                    key_cache,
                    value_cache,
                    attn_metadata.context_lens,
                    attn_metadata.block_tables,
                    pages_per_compute_block,
                )

        # Reshape the output tensor.
        return output.reshape(batch_size, seq_len, hidden_size)


def write_to_kv_cache(
    key: torch.Tensor,
    value: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    slot_mapping: torch.Tensor,
) -> None:
    torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True)
    torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True)

    key = key.flatten(0, 2)
    value = value.flatten(0, 2)
    key_cache = key_cache.flatten(0, 2)
    value_cache = value_cache.flatten(0, 2)
    key_cache.index_copy_(0, slot_mapping, key)
    value_cache.index_copy_(0, slot_mapping, value)