_ipex_ops.py 8.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from typing import List, Optional, Tuple

import torch

from vllm.logger import init_logger

logger = init_logger(__name__)

try:
    import intel_extension_for_pytorch as ipex
except ImportError as e:
    logger.warning("Import error msg: %s", e.msg)


class ipex_ops:

    @staticmethod
    def _reshape_activation_tensor(
            x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        num = x.size(0)
        d = x.size(1) // 2
        x = x.reshape(num, 2, d)
        x1, x2 = torch.chunk(x, chunks=2, dim=1)
        x1 = x1.reshape(num, d)
        x2 = x2.reshape(num, d)
        return x1, x2

28
    @staticmethod
29
    def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
30
        ipex.llm.functional.silu_and_mul(x, out)
31

32
    @staticmethod
33
    def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
34
        ipex.llm.functional.gelu_and_mul(x, out)
35

36
    @staticmethod
37
    def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
38
        ipex.llm.functional.gelu_and_mul(x, out)
39

40
    @staticmethod
41
42
    def gelu_fast(x: torch.Tensor) -> torch.Tensor:
        return torch.nn.functional.gelu(x)
43

44
    @staticmethod
45
46
    def gelu_new(x: torch.Tensor) -> torch.Tensor:
        return torch.nn.functional.gelu(x)
47

48
49
50
    @staticmethod
    def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
        ipex.llm.functional.gelu_quick(x, out)
51

52
    @staticmethod
53
54
55
56
57
58
59
60
61
62
63
64
65
    def paged_attention_v1(
        out: torch.Tensor,
        query: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        num_kv_heads: int,
        scale: float,
        block_tables: torch.Tensor,
        context_lens: torch.Tensor,
        block_size: int,
        max_context_len: int,
        alibi_slopes: Optional[torch.Tensor],
        kv_cache_dtype: str,
66
67
        k_scale: float,
        v_scale: float,
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
        tp_rank: int = 0,
        blocksparse_local_blocks: int = 0,
        blocksparse_vert_stride: int = 0,
        blocksparse_block_size: int = 64,
        blocksparse_head_sliding_step: int = 0,
    ) -> None:
        assert kv_cache_dtype == "auto"
        num_heads = out.size(1)
        num_queries_per_tokens = num_heads // num_kv_heads
        head_mapping = torch.arange(
            0,
            num_kv_heads,
            device=query.device,
            dtype=torch.int32,
        ).view(num_kv_heads,
               1).repeat_interleave(num_queries_per_tokens).flatten()
        # todo: ipex will refactor namespace
85
86
87
88
89
90
91
92
93
94
95
96
97
        torch.xpu.paged_attention_v1(  # type: ignore
            out,
            query.contiguous(),
            key_cache.view_as(value_cache),
            value_cache,
            head_mapping,
            scale,
            block_tables,
            context_lens,
            block_size,
            max_context_len,
            alibi_slopes,
        )
98

99
    @staticmethod
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
    def paged_attention_v2(
        out: torch.Tensor,
        exp_sum: torch.Tensor,
        max_logits: torch.Tensor,
        tmp_out: torch.Tensor,
        query: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        num_kv_heads: int,
        scale: float,
        block_tables: torch.Tensor,
        context_lens: torch.Tensor,
        block_size: int,
        max_context_len: int,
        alibi_slopes: Optional[torch.Tensor],
        kv_cache_dtype: str,
116
117
        k_scale: float,
        v_scale: float,
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
        tp_rank: int = 0,
        blocksparse_local_blocks: int = 0,
        blocksparse_vert_stride: int = 0,
        blocksparse_block_size: int = 64,
        blocksparse_head_sliding_step: int = 0,
    ) -> None:
        assert kv_cache_dtype == "auto"
        num_heads = out.size(1)
        num_queries_per_tokens = num_heads // num_kv_heads
        head_mapping = torch.arange(
            0,
            num_kv_heads,
            dtype=torch.int32,
            device=query.device,
        ).view(num_kv_heads,
               1).repeat_interleave(num_queries_per_tokens).flatten()
        # todo: ipex will refactor namespace
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
        torch.xpu.paged_attention_v2(  # type: ignore
            out,
            exp_sum,
            max_logits,
            tmp_out,
            query.contiguous(),
            key_cache.view_as(value_cache),
            value_cache,
            head_mapping,
            block_tables,
            context_lens,
            scale,
            block_size,
            max_context_len,
            alibi_slopes,
        )
151

152
    @staticmethod
153
154
155
156
157
158
159
160
    def rotary_embedding(
        positions: torch.Tensor,  # [batch_size, seq_len]
        query: torch.Tensor,  # [batch_size, seq_len, num_heads*head_size]
        key: torch.Tensor,  # [batch_size, seq_len, num_kv_heads*head_size]
        head_size: int,
        cos_sin_cache: torch.Tensor,  # [cos_sin_dim, rot_dim]
        is_neox: bool,
    ) -> None:
161
162
163
164
        rot_dim = cos_sin_cache.size(1)
        ipex.llm.functional.rotary_embedding_batched(positions, query, key,
                                                     head_size, cos_sin_cache,
                                                     is_neox, rot_dim)
165

166
    @staticmethod
167
168
169
170
171
    def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
                                 key: torch.Tensor, head_size: int,
                                 cos_sin_cache: torch.Tensor, is_neox: bool,
                                 rot_dim: int,
                                 cos_sin_cache_offsets: torch.Tensor) -> None:
172
173
174
175
        ipex.llm.functional.rotary_embedding_batched(positions, query, key,
                                                     head_size, cos_sin_cache,
                                                     is_neox, rot_dim,
                                                     cos_sin_cache_offsets)
176

177
    @staticmethod
178
179
180
    def rms_norm(input: torch.Tensor, weight: torch.Tensor,
                 epsilon: float) -> torch.Tensor:
        return ipex.llm.functional.rms_norm(input, weight, epsilon)
181

182
    @staticmethod
183
184
185
186
187
188
    def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
                           weight: torch.Tensor, epsilon: float) -> None:
        tmp = ipex.llm.functional.add_rms_norm(residual, input, weight, None,
                                               epsilon, True)
        input.copy_(tmp)

189
    @staticmethod
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
    def varlen_attention(
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        out: torch.Tensor,
        seqlen_q: torch.Tensor,
        seqlen_k: torch.Tensor,
        max_seqlen_q: int,
        max_seqlen_k: int,
        pdropout: float,
        softmax_scale: float,
        zero_tensors: bool,
        is_causal: bool,
        return_softmax: bool,
        gen_: torch.Generator,
    ) -> None:
206
207
208
209
210
211
212
213
        ipex.llm.functional.varlen_attention(query.contiguous(),
                                             key.contiguous(),
                                             value.contiguous(), out,
                                             seqlen_q.int(), seqlen_k.int(),
                                             max_seqlen_q, max_seqlen_k,
                                             pdropout, softmax_scale,
                                             zero_tensors, is_causal,
                                             return_softmax, gen_)
214

215
    @staticmethod
216
217
218
219
220
221
222
    def reshape_and_cache(
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        slot_mapping: torch.Tensor,
        kv_cache_dtype: str,
223
224
        k_scale: float,
        v_scale: float,
225
226
227
228
229
230
231
232
233
    ) -> None:
        assert kv_cache_dtype == "auto"
        ipex.llm.modules.PagedAttention.reshape_and_cache(
            key, value, key_cache, value_cache, slot_mapping)

    @staticmethod
    def copy_blocks(key_caches: List[torch.Tensor],
                    value_caches: List[torch.Tensor],
                    block_mapping: torch.Tensor) -> None:
234
235
236
237
238
        torch.xpu.copy_blocks(  # type: ignore
            key_caches,
            value_caches,
            block_mapping,
        )
239

240
    @staticmethod
241
242
    def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
                    block_mapping: torch.Tensor) -> None:
243
        torch.xpu.swap_blocks(src, dst, block_mapping)  # type: ignore