_ipex_ops.py 12.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import Optional
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21

import torch

from vllm.logger import init_logger

logger = init_logger(__name__)

try:
    import intel_extension_for_pytorch as ipex
except ImportError as e:
    logger.warning("Import error msg: %s", e.msg)


class ipex_ops:

    @staticmethod
    def _reshape_activation_tensor(
22
            x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
23
24
25
26
27
28
29
30
        num = x.size(0)
        d = x.size(1) // 2
        x = x.reshape(num, 2, d)
        x1, x2 = torch.chunk(x, chunks=2, dim=1)
        x1 = x1.reshape(num, d)
        x2 = x2.reshape(num, d)
        return x1, x2

31
    @staticmethod
32
    def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
33
        ipex.llm.functional.silu_and_mul(x, out)
34

35
    @staticmethod
36
    def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
37
        ipex.llm.functional.gelu_and_mul(x, out)
38

39
    @staticmethod
40
    def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
41
        ipex.llm.functional.gelu_and_mul(x, out)
42

43
    @staticmethod
44
45
    def gelu_fast(x: torch.Tensor) -> torch.Tensor:
        return torch.nn.functional.gelu(x)
46

47
    @staticmethod
48
49
    def gelu_new(x: torch.Tensor) -> torch.Tensor:
        return torch.nn.functional.gelu(x)
50

51
52
53
    @staticmethod
    def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
        ipex.llm.functional.gelu_quick(x, out)
54

55
    @staticmethod
56
57
58
59
60
61
62
63
64
65
66
67
68
    def paged_attention_v1(
        out: torch.Tensor,
        query: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        num_kv_heads: int,
        scale: float,
        block_tables: torch.Tensor,
        context_lens: torch.Tensor,
        block_size: int,
        max_context_len: int,
        alibi_slopes: Optional[torch.Tensor],
        kv_cache_dtype: str,
69
70
        k_scale: float,
        v_scale: float,
71
72
73
74
75
76
77
78
79
        tp_rank: int = 0,
        blocksparse_local_blocks: int = 0,
        blocksparse_vert_stride: int = 0,
        blocksparse_block_size: int = 64,
        blocksparse_head_sliding_step: int = 0,
    ) -> None:
        assert kv_cache_dtype == "auto"
        num_heads = out.size(1)
        num_queries_per_tokens = num_heads // num_kv_heads
80
        ipex.llm.modules.PagedAttention.single_query_kv_attention(
81
82
83
84
            out,
            query.contiguous(),
            key_cache.view_as(value_cache),
            value_cache,
85
            num_queries_per_tokens,
86
87
88
89
90
91
92
            scale,
            block_tables,
            context_lens,
            block_size,
            max_context_len,
            alibi_slopes,
        )
93

94
    @staticmethod
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
    def paged_attention_v2(
        out: torch.Tensor,
        exp_sum: torch.Tensor,
        max_logits: torch.Tensor,
        tmp_out: torch.Tensor,
        query: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        num_kv_heads: int,
        scale: float,
        block_tables: torch.Tensor,
        context_lens: torch.Tensor,
        block_size: int,
        max_context_len: int,
        alibi_slopes: Optional[torch.Tensor],
        kv_cache_dtype: str,
111
112
        k_scale: float,
        v_scale: float,
113
114
115
116
117
118
119
120
121
        tp_rank: int = 0,
        blocksparse_local_blocks: int = 0,
        blocksparse_vert_stride: int = 0,
        blocksparse_block_size: int = 64,
        blocksparse_head_sliding_step: int = 0,
    ) -> None:
        assert kv_cache_dtype == "auto"
        num_heads = out.size(1)
        num_queries_per_tokens = num_heads // num_kv_heads
122
        ipex.llm.modules.PagedAttention.single_query_kv_attention(
123
124
125
126
            out,
            query.contiguous(),
            key_cache.view_as(value_cache),
            value_cache,
127
128
            num_queries_per_tokens,
            scale,
129
130
131
132
133
134
            block_tables,
            context_lens,
            block_size,
            max_context_len,
            alibi_slopes,
        )
135

136
    @staticmethod
137
138
139
140
141
142
143
144
    def rotary_embedding(
        positions: torch.Tensor,  # [batch_size, seq_len]
        query: torch.Tensor,  # [batch_size, seq_len, num_heads*head_size]
        key: torch.Tensor,  # [batch_size, seq_len, num_kv_heads*head_size]
        head_size: int,
        cos_sin_cache: torch.Tensor,  # [cos_sin_dim, rot_dim]
        is_neox: bool,
    ) -> None:
145
146
147
148
        rot_dim = cos_sin_cache.size(1)
        ipex.llm.functional.rotary_embedding_batched(positions, query, key,
                                                     head_size, cos_sin_cache,
                                                     is_neox, rot_dim)
149

150
    @staticmethod
151
152
153
154
155
    def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
                                 key: torch.Tensor, head_size: int,
                                 cos_sin_cache: torch.Tensor, is_neox: bool,
                                 rot_dim: int,
                                 cos_sin_cache_offsets: torch.Tensor) -> None:
156
157
158
159
        ipex.llm.functional.rotary_embedding_batched(positions, query, key,
                                                     head_size, cos_sin_cache,
                                                     is_neox, rot_dim,
                                                     cos_sin_cache_offsets)
160

161
    @staticmethod
162
163
164
    def rms_norm(input: torch.Tensor, weight: torch.Tensor,
                 epsilon: float) -> torch.Tensor:
        return ipex.llm.functional.rms_norm(input, weight, epsilon)
165

166
    @staticmethod
167
168
169
170
171
172
    def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
                           weight: torch.Tensor, epsilon: float) -> None:
        tmp = ipex.llm.functional.add_rms_norm(residual, input, weight, None,
                                               epsilon, True)
        input.copy_(tmp)

173
    @staticmethod
174
175
176
177
178
179
180
    def varlen_attention(
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        out: torch.Tensor,
        seqlen_q: torch.Tensor,
        seqlen_k: torch.Tensor,
181
        alibi_slopes: Optional[torch.Tensor],
182
183
184
185
186
187
188
189
        max_seqlen_q: int,
        max_seqlen_k: int,
        pdropout: float,
        softmax_scale: float,
        zero_tensors: bool,
        is_causal: bool,
        return_softmax: bool,
        gen_: torch.Generator,
190
191
        window_size_left: float,
        window_size_right: float,
192
        logits_soft_cap: float,
193
    ) -> None:
Thien Tran's avatar
Thien Tran committed
194
195
196
        if ipex.__version__.endswith("cpu"):
            if logits_soft_cap != 0.0:
                raise ValueError("IPEX CPU does not support logits_soft_cap")
197
198
            assert alibi_slopes is None
            assert window_size_left < 0 and window_size_right < 0
Thien Tran's avatar
Thien Tran committed
199
200
201
202
203
204
205
206
207
208
            ipex.llm.functional.varlen_attention(query.contiguous(),
                                                 key.contiguous(),
                                                 value.contiguous(), out,
                                                 seqlen_q.int(),
                                                 seqlen_k.int(), max_seqlen_q,
                                                 max_seqlen_k, pdropout,
                                                 softmax_scale, zero_tensors,
                                                 is_causal, return_softmax,
                                                 gen_)
        else:  # XPU build
209
210
211
212
213
214
            ipex.llm.functional.varlen_attention(
                query.contiguous(), key.contiguous(), value.contiguous(), out,
                seqlen_q.int(), seqlen_k.int(), alibi_slopes, max_seqlen_q,
                max_seqlen_k, pdropout, softmax_scale, zero_tensors, is_causal,
                return_softmax, gen_, window_size_left, window_size_right,
                logits_soft_cap)
215

216
    @staticmethod
217
218
219
220
221
222
223
    def reshape_and_cache(
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        slot_mapping: torch.Tensor,
        kv_cache_dtype: str,
224
225
        k_scale: float,
        v_scale: float,
226
227
228
229
230
    ) -> None:
        assert kv_cache_dtype == "auto"
        ipex.llm.modules.PagedAttention.reshape_and_cache(
            key, value, key_cache, value_cache, slot_mapping)

231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
    @staticmethod
    def reshape_and_cache_flash(
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        slot_mapping: torch.Tensor,
        kv_cache_dtype: str,
        k_scale: Optional[torch.Tensor] = None,
        v_scale: Optional[torch.Tensor] = None,
        k_scale_float: float = 1.0,
        v_scale_float: float = 1.0,
    ) -> None:
        assert kv_cache_dtype == "auto"
        # TODO: support FP8 kv cache.
        ipex.llm.modules.PagedAttention.reshape_and_cache_flash(
            key, value, key_cache, value_cache, slot_mapping)

    @staticmethod
    def flash_attn_varlen_func(
        out: torch.Tensor,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        cu_seqlens_q: torch.Tensor,
        seqused_k: torch.Tensor,  # we don't support this in ipex kernel
        max_seqlen_q: int,
        max_seqlen_k: int,
        softmax_scale: float,
        causal: bool,
        block_table: torch.Tensor,
        alibi_slopes: Optional[torch.Tensor],
        window_size: Optional[list[int]] = None,
        softcap: Optional[float] = 0.0,
        cu_seqlens_k: Optional[torch.Tensor] = None,
        # The following parameters are not used in ipex kernel currently,
        # we keep API compatible to CUDA's.
        scheduler_metadata=None,
        fa_version: int = 2,
        q_descale=None,
        k_descale=None,
        v_descale=None,
273
        num_splits=0,
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
    ):
        if cu_seqlens_k is None:
            # cu_seqlens_k is not used in ipex kernel.
            cu_seqlens_k = torch.cumsum(seqused_k, dim=0)
            cu_seqlens_k = torch.cat([
                torch.tensor([0], device=seqused_k.device, dtype=torch.int32),
                cu_seqlens_k
            ]).to(torch.int32)

        real_window_size: tuple[int, int]
        if window_size is None:
            real_window_size = (-1, -1)
        else:
            assert len(window_size) == 2
            real_window_size = (window_size[0], window_size[1])
        return ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
            out,
            q.contiguous(),
            k,
            v,
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            softmax_scale,
            causal,
            block_table,
            alibi_slopes,
            softcap=softcap,
            window_size_left=real_window_size[0],
            window_size_right=real_window_size[1],
            k_scale=1.0,
            v_scale=1.0,
        )

    @staticmethod
    def get_scheduler_metadata(
            batch_size,
            max_seqlen_q,
            max_seqlen_k,
            num_heads_q,
            num_heads_kv,
            headdim,
            cache_seqlens: torch.Tensor,
            qkv_dtype=torch.bfloat16,
            headdim_v=None,
            cu_seqlens_q: Optional[torch.Tensor] = None,
            cu_seqlens_k_new: Optional[torch.Tensor] = None,
            cache_leftpad: Optional[torch.Tensor] = None,
            page_size: Optional[int] = None,
            max_seqlen_k_new=0,
            causal=False,
            window_size=(-1, -1),  # -1 means infinite context window
            has_softcap=False,
            num_splits=0,  # Can be tuned for speed
            pack_gqa=None,  # Can be tuned for speed
            sm_margin=0,  # Can be tuned if some SMs are used for communication
    ) -> None:
        logger.warning_once(
            "get_scheduler_metadata is not implemented for ipex_ops, "
            "returning None.")
        return None

337
    @staticmethod
338
339
    def copy_blocks(key_caches: list[torch.Tensor],
                    value_caches: list[torch.Tensor],
340
                    block_mapping: torch.Tensor) -> None:
341
342
343
344
345
        torch.xpu.copy_blocks(  # type: ignore
            key_caches,
            value_caches,
            block_mapping,
        )
346

347
    @staticmethod
348
349
    def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
                    block_mapping: torch.Tensor) -> None:
350
        torch.xpu.swap_blocks(src, dst, block_mapping)  # type: ignore