_ipex_ops.py 13.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import Optional, Union
5
6
7
8

import torch

from vllm.logger import init_logger
9
from vllm.platforms import current_platform
10
11
12
13
14
15

logger = init_logger(__name__)

try:
    import intel_extension_for_pytorch as ipex
except ImportError as e:
16
    logger.debug("Import error msg: %s", e.msg)
17
18
19
20
21


class ipex_ops:
    @staticmethod
    def _reshape_activation_tensor(
22
23
        x: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
24
25
26
27
28
29
30
31
        num = x.size(0)
        d = x.size(1) // 2
        x = x.reshape(num, 2, d)
        x1, x2 = torch.chunk(x, chunks=2, dim=1)
        x1 = x1.reshape(num, d)
        x2 = x2.reshape(num, d)
        return x1, x2

32
    @staticmethod
33
    def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
34
        ipex.llm.functional.silu_and_mul(x, out)
35

36
    @staticmethod
37
    def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
38
        ipex.llm.functional.gelu_and_mul(x, out)
39

40
    @staticmethod
41
    def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
42
        ipex.llm.functional.gelu_and_mul(x, out)
43

44
    @staticmethod
45
46
    def gelu_fast(x: torch.Tensor) -> torch.Tensor:
        return torch.nn.functional.gelu(x)
47

48
    @staticmethod
49
50
    def gelu_new(x: torch.Tensor) -> torch.Tensor:
        return torch.nn.functional.gelu(x)
51

52
53
54
    @staticmethod
    def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
        ipex.llm.functional.gelu_quick(x, out)
55

56
    @staticmethod
57
58
59
60
61
62
63
64
65
66
67
68
69
    def paged_attention_v1(
        out: torch.Tensor,
        query: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        num_kv_heads: int,
        scale: float,
        block_tables: torch.Tensor,
        context_lens: torch.Tensor,
        block_size: int,
        max_context_len: int,
        alibi_slopes: Optional[torch.Tensor],
        kv_cache_dtype: str,
70
71
        k_scale: float,
        v_scale: float,
72
73
74
75
76
77
78
79
80
        tp_rank: int = 0,
        blocksparse_local_blocks: int = 0,
        blocksparse_vert_stride: int = 0,
        blocksparse_block_size: int = 64,
        blocksparse_head_sliding_step: int = 0,
    ) -> None:
        assert kv_cache_dtype == "auto"
        num_heads = out.size(1)
        num_queries_per_tokens = num_heads // num_kv_heads
81
        ipex.llm.modules.PagedAttention.single_query_kv_attention(
82
83
84
85
            out,
            query.contiguous(),
            key_cache.view_as(value_cache),
            value_cache,
86
            num_queries_per_tokens,
87
88
89
90
91
92
93
            scale,
            block_tables,
            context_lens,
            block_size,
            max_context_len,
            alibi_slopes,
        )
94

95
    @staticmethod
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
    def paged_attention_v2(
        out: torch.Tensor,
        exp_sum: torch.Tensor,
        max_logits: torch.Tensor,
        tmp_out: torch.Tensor,
        query: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        num_kv_heads: int,
        scale: float,
        block_tables: torch.Tensor,
        context_lens: torch.Tensor,
        block_size: int,
        max_context_len: int,
        alibi_slopes: Optional[torch.Tensor],
        kv_cache_dtype: str,
112
113
        k_scale: float,
        v_scale: float,
114
115
116
117
118
119
120
121
122
        tp_rank: int = 0,
        blocksparse_local_blocks: int = 0,
        blocksparse_vert_stride: int = 0,
        blocksparse_block_size: int = 64,
        blocksparse_head_sliding_step: int = 0,
    ) -> None:
        assert kv_cache_dtype == "auto"
        num_heads = out.size(1)
        num_queries_per_tokens = num_heads // num_kv_heads
123
        ipex.llm.modules.PagedAttention.single_query_kv_attention(
124
125
126
127
            out,
            query.contiguous(),
            key_cache.view_as(value_cache),
            value_cache,
128
129
            num_queries_per_tokens,
            scale,
130
131
132
133
134
135
            block_tables,
            context_lens,
            block_size,
            max_context_len,
            alibi_slopes,
        )
136

137
    @staticmethod
138
139
140
141
142
143
144
145
    def rotary_embedding(
        positions: torch.Tensor,  # [batch_size, seq_len]
        query: torch.Tensor,  # [batch_size, seq_len, num_heads*head_size]
        key: torch.Tensor,  # [batch_size, seq_len, num_kv_heads*head_size]
        head_size: int,
        cos_sin_cache: torch.Tensor,  # [cos_sin_dim, rot_dim]
        is_neox: bool,
    ) -> None:
146
        rot_dim = cos_sin_cache.size(1)
147
148
149
        ipex.llm.functional.rotary_embedding_batched(
            positions, query, key, head_size, cos_sin_cache, is_neox, rot_dim
        )
150

151
    @staticmethod
152
153
154
    def rms_norm(
        input: torch.Tensor, weight: torch.Tensor, epsilon: float
    ) -> torch.Tensor:
155
        return ipex.llm.functional.rms_norm(input, weight, epsilon)
156

157
    @staticmethod
158
159
160
161
162
163
164
165
166
    def fused_add_rms_norm(
        input: torch.Tensor,
        residual: torch.Tensor,
        weight: torch.Tensor,
        epsilon: float,
    ) -> None:
        tmp = ipex.llm.functional.add_rms_norm(
            residual, input, weight, None, epsilon, True
        )
167
168
        input.copy_(tmp)

169
    @staticmethod
170
171
172
173
174
175
176
    def varlen_attention(
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        out: torch.Tensor,
        seqlen_q: torch.Tensor,
        seqlen_k: torch.Tensor,
177
        alibi_slopes: Optional[torch.Tensor],
178
179
180
181
182
183
184
185
        max_seqlen_q: int,
        max_seqlen_k: int,
        pdropout: float,
        softmax_scale: float,
        zero_tensors: bool,
        is_causal: bool,
        return_softmax: bool,
        gen_: torch.Generator,
186
187
        window_size_left: float,
        window_size_right: float,
188
        logits_soft_cap: float,
189
    ) -> None:
Thien Tran's avatar
Thien Tran committed
190
191
192
        if ipex.__version__.endswith("cpu"):
            if logits_soft_cap != 0.0:
                raise ValueError("IPEX CPU does not support logits_soft_cap")
193
194
            assert alibi_slopes is None
            assert window_size_left < 0 and window_size_right < 0
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
            ipex.llm.functional.varlen_attention(
                query.contiguous(),
                key.contiguous(),
                value.contiguous(),
                out,
                seqlen_q.int(),
                seqlen_k.int(),
                max_seqlen_q,
                max_seqlen_k,
                pdropout,
                softmax_scale,
                zero_tensors,
                is_causal,
                return_softmax,
                gen_,
            )
Thien Tran's avatar
Thien Tran committed
211
        else:  # XPU build
212
            ipex.llm.functional.varlen_attention(
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
                query.contiguous(),
                key.contiguous(),
                value.contiguous(),
                out,
                seqlen_q.int(),
                seqlen_k.int(),
                alibi_slopes,
                max_seqlen_q,
                max_seqlen_k,
                pdropout,
                softmax_scale,
                zero_tensors,
                is_causal,
                return_softmax,
                gen_,
                window_size_left,
                window_size_right,
                logits_soft_cap,
            )
232

233
    @staticmethod
234
235
236
237
238
239
240
    def reshape_and_cache(
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        slot_mapping: torch.Tensor,
        kv_cache_dtype: str,
241
242
        k_scale: float,
        v_scale: float,
243
244
245
    ) -> None:
        assert kv_cache_dtype == "auto"
        ipex.llm.modules.PagedAttention.reshape_and_cache(
246
247
            key, value, key_cache, value_cache, slot_mapping
        )
248

249
250
251
252
253
254
255
256
257
258
259
260
261
262
    @staticmethod
    def reshape_and_cache_flash(
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        slot_mapping: torch.Tensor,
        kv_cache_dtype: str,
        k_scale: Optional[torch.Tensor] = None,
        v_scale: Optional[torch.Tensor] = None,
        k_scale_float: float = 1.0,
        v_scale_float: float = 1.0,
    ) -> None:
        ipex.llm.modules.PagedAttention.reshape_and_cache_flash(
263
264
265
266
267
268
269
270
271
            key,
            value,
            key_cache,
            value_cache,
            slot_mapping,
            kv_cache_dtype,
            k_scale_float,
            v_scale_float,
        )
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296

    @staticmethod
    def flash_attn_varlen_func(
        out: torch.Tensor,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        cu_seqlens_q: torch.Tensor,
        seqused_k: torch.Tensor,  # we don't support this in ipex kernel
        max_seqlen_q: int,
        max_seqlen_k: int,
        softmax_scale: float,
        causal: bool,
        block_table: torch.Tensor,
        alibi_slopes: Optional[torch.Tensor],
        window_size: Optional[list[int]] = None,
        softcap: Optional[float] = 0.0,
        cu_seqlens_k: Optional[torch.Tensor] = None,
        # The following parameters are not used in ipex kernel currently,
        # we keep API compatible to CUDA's.
        scheduler_metadata=None,
        fa_version: int = 2,
        q_descale=None,
        k_descale=None,
        v_descale=None,
297
        num_splits=0,
298
        s_aux: Optional[torch.Tensor] = None,
299
300
301
302
    ):
        if cu_seqlens_k is None:
            # cu_seqlens_k is not used in ipex kernel.
            cu_seqlens_k = torch.cumsum(seqused_k, dim=0)
303
304
305
306
307
308
            cu_seqlens_k = torch.cat(
                [
                    torch.tensor([0], device=seqused_k.device, dtype=torch.int32),
                    cu_seqlens_k,
                ]
            ).to(torch.int32)
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337

        real_window_size: tuple[int, int]
        if window_size is None:
            real_window_size = (-1, -1)
        else:
            assert len(window_size) == 2
            real_window_size = (window_size[0], window_size[1])
        return ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
            out,
            q.contiguous(),
            k,
            v,
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            softmax_scale,
            causal,
            block_table,
            alibi_slopes,
            softcap=softcap,
            window_size_left=real_window_size[0],
            window_size_right=real_window_size[1],
            k_scale=1.0,
            v_scale=1.0,
        )

    @staticmethod
    def get_scheduler_metadata(
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
        batch_size,
        max_seqlen_q,
        max_seqlen_k,
        num_heads_q,
        num_heads_kv,
        headdim,
        cache_seqlens: torch.Tensor,
        qkv_dtype=torch.bfloat16,
        headdim_v=None,
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_k_new: Optional[torch.Tensor] = None,
        cache_leftpad: Optional[torch.Tensor] = None,
        page_size: Optional[int] = None,
        max_seqlen_k_new=0,
        causal=False,
        window_size=(-1, -1),  # -1 means infinite context window
        has_softcap=False,
        num_splits=0,  # Can be tuned for speed
        pack_gqa=None,  # Can be tuned for speed
        sm_margin=0,  # Can be tuned if some SMs are used for communication
358
359
    ) -> None:
        logger.warning_once(
360
361
            "get_scheduler_metadata is not implemented for ipex_ops, returning None."
        )
362
363
        return None

364
    @staticmethod
365
366
367
368
369
    def copy_blocks(
        key_caches: list[torch.Tensor],
        value_caches: list[torch.Tensor],
        block_mapping: torch.Tensor,
    ) -> None:
370
371
372
373
374
        torch.xpu.copy_blocks(  # type: ignore
            key_caches,
            value_caches,
            block_mapping,
        )
375

376
    @staticmethod
377
378
379
    def swap_blocks(
        src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor
    ) -> None:
380
        torch.xpu.swap_blocks(src, dst, block_mapping)  # type: ignore
381
382
383
384
385
386
387
388
389
390
391
392

    @staticmethod
    def scaled_fp8_quant(
        input: torch.Tensor,
        scale: Optional[torch.Tensor] = None,
        num_token_padding: Optional[int] = None,
        scale_ub: Optional[torch.Tensor] = None,
        use_per_token_if_dynamic: bool = False,
        output: Optional[torch.Tensor] = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Quantize input tensor to FP8 and return quantized tensor and scale.
393

394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
        This function is designed for both static and dynamic quantization:
        If you provide the scale, it will use static scaling and if you omit
        it, the scale will be determined dynamically. Currently, XPU platform
        only supports dynamic quantization. The function also allows optional
        padding of the output tensors for downstream kernels that will benefit
        from padding.

        Args:
            input: The input tensor to be quantized to FP8
            scale: Optional scaling factor for the FP8 quantization
            scale_ub: Optional upper bound for scaling factor in dynamic
                per token case
            num_token_padding: If specified, pad the first dimension
                of the output to at least this value.
            use_per_token_if_dynamic: Whether to do per_tensor or per_token
                in the dynamic quantization case.
410

411
412
413
414
415
        Returns:
            tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
                scaling factor.
        """
        # This code assumes batch_dim and num_tokens are flattened
416
        assert input.ndim == 2
417
418
419
420
421
422
423
        shape: Union[tuple[int, int], torch.Size] = input.shape
        out_dtype: torch.dtype = current_platform.fp8_dtype()
        if num_token_padding:
            shape = (max(num_token_padding, input.shape[0]), shape[1])
        if output is None:
            output = torch.empty(shape, device=input.device, dtype=out_dtype)
        else:
424
            assert num_token_padding is None, (
425
                "padding not supported if output passed in"
426
            )
427
428
429
            assert output.dtype == out_dtype
        assert scale is None, "only dynamic fp8 quantization supported on XPU"
        assert not use_per_token_if_dynamic, (
430
431
            "per token dynamic fp8 quantization not supported on XPU"
        )
432
433
434
435
        scale = torch.zeros(1, device=input.device, dtype=torch.float32)
        torch.ops.torch_ipex.dynamic_scaled_fp8_quant(output, input, scale)

        return output, scale