unquant.py 11.9 KB
Newer Older
1
2
from __future__ import annotations

3
import importlib
4
from typing import TYPE_CHECKING, Callable, List, Optional
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25

import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter

from sglang.srt.custom_op import CustomOp
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
from sglang.srt.layers.quantization.base_config import (
    FusedMoEMethodBase,
    LinearMethodBase,
    QuantizeMethodBase,
)
from sglang.srt.utils import (
    cpu_has_amx_support,
    get_bool_env_var,
    is_cpu,
    is_hip,
    set_weight_attrs,
    use_intel_amx_backend,
)

26
if TYPE_CHECKING:
27
    from sglang.srt.layers.moe.ep_moe.layer import EPMoE
28
29
    from sglang.srt.layers.moe.topk import TopKOutput

30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None


_is_cpu_amx_available = cpu_has_amx_support()
_is_hip = is_hip()
_is_cpu = is_cpu()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip

if _use_aiter:
    from aiter import ActivationType
    from aiter.fused_moe import fused_moe
    from aiter.ops.shuffle import shuffle_weight


class UnquantizedEmbeddingMethod(QuantizeMethodBase):
    """Unquantized method for embeddings."""

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: List[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        """Create weights for embedding layer."""
        weight = Parameter(
            torch.empty(
                sum(output_partition_sizes),
                input_size_per_partition,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        return F.linear(x, layer.weight, bias)

    def embedding(self, layer: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor:
        return F.embedding(input_, layer.weight)


class UnquantizedLinearMethod(LinearMethodBase):
    """Linear method without quantization."""

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: List[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        weight = Parameter(
            torch.empty(
                sum(output_partition_sizes),
                input_size_per_partition,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        if _is_cpu and _is_cpu_amx_available:
            _amx_process_weight_after_loading(layer, ["weight"])

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:

        if use_intel_amx_backend(layer):
            return torch.ops.sgl_kernel.weight_packed_linear(
                x, layer.weight, bias, True  # is_vnni
            )

        return F.linear(x, layer.weight, bias)


class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
    """MoE method without quantization."""

    def __init__(self, use_triton_kernels: bool = False):
        super().__init__()
        self.use_triton_kernels = use_triton_kernels

    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        # Fused gate_up_proj (column parallel)
        w13_weight_n, w13_weight_k = 2 * intermediate_size, hidden_size
        if self.use_triton_kernels:
            w13_weight_n, w13_weight_k = w13_weight_k, w13_weight_n
        w13_weight = torch.nn.Parameter(
            torch.empty(num_experts, w13_weight_n, w13_weight_k, dtype=params_dtype),
            requires_grad=False,
        )
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

        # down_proj (row parallel)
        w2_weight_n, w2_weight_k = (
            hidden_size,
            intermediate_size,
        )
        if self.use_triton_kernels:
            w2_weight_n, w2_weight_k = w2_weight_k, w2_weight_n
        w2_weight = torch.nn.Parameter(
            torch.empty(num_experts, w2_weight_n, w2_weight_k, dtype=params_dtype),
            requires_grad=False,
        )
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        if _use_aiter:
            layer.w13_weight = torch.nn.Parameter(
                shuffle_weight(layer.w13_weight.data, (16, 16)),
                requires_grad=False,
            )
            torch.cuda.empty_cache()
            layer.w2_weight = torch.nn.Parameter(
                shuffle_weight(layer.w2_weight.data, (16, 16)),
                requires_grad=False,
            )
            torch.cuda.empty_cache()

        # Pack weight for get better performance on CPU
        if _is_cpu and _is_cpu_amx_available:
            _amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])

        return

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
190
191
        topk_output: TopKOutput,
        *,
192
193
194
195
196
197
        activation: str = "silu",
        apply_router_weight_on_input: bool = False,
        inplace: bool = True,
        no_combine: bool = False,
        routed_scaling_factor: Optional[float] = None,
    ) -> torch.Tensor:
198
199
200
201
202
203
204
205
206

        from sglang.srt.layers.moe.ep_moe.layer import EPMoE

        if isinstance(layer, EPMoE):
            return layer.run_moe(
                hidden_states=x,
                topk_output=topk_output,
            )

207
208
209
        return self.forward(
            x=x,
            layer=layer,
210
            topk_output=topk_output,
211
212
213
214
215
216
217
218
219
220
221
            activation=activation,
            apply_router_weight_on_input=apply_router_weight_on_input,
            inplace=inplace,
            no_combine=no_combine,
            routed_scaling_factor=routed_scaling_factor,
        )

    def forward_cuda(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
222
223
        topk_output: TopKOutput,
        *,
224
225
226
227
228
229
230
231
        activation: str = "silu",
        apply_router_weight_on_input: bool = False,
        inplace: bool = True,
        no_combine: bool = False,
        routed_scaling_factor: Optional[float] = None,
    ) -> torch.Tensor:

        if self.use_triton_kernels:
232
233
234
235
236
237
238
239
240
241
            # TODO(ch-wan): re-enable the Triton kernel
            raise NotImplementedError("The Triton kernel is temporarily disabled.")
            # return triton_kernel_moe_forward(
            #     hidden_states=x,
            #     w1=layer.w13_weight,
            #     w2=layer.w2_weight,
            #     gating_output=router_logits,
            #     topk=top_k,
            #     renormalize=renormalize,
            # )
242
243
244
        else:
            if _use_aiter:
                assert not no_combine, "unsupported"
245
                topk_weights, topk_ids, _ = topk_output
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
                if apply_router_weight_on_input:
                    assert (
                        topk_weights.dim() == 2
                    ), "`topk_weights` should be in shape (num_tokens, topk)"
                    _, topk = topk_weights.shape
                    assert (
                        topk == 1
                    ), "Only support topk=1 when `apply_router_weight_on_input` is True"
                    x = x * topk_weights.to(x.dtype)
                    topk_weights = torch.ones_like(
                        topk_weights, dtype=torch.float32
                    )  # topk_weights must be FP32 (float32)
                return fused_moe(
                    x,
                    layer.w13_weight,
                    layer.w2_weight,
                    topk_weights,
                    topk_ids,
                    activation=(
                        ActivationType.Silu
                        if activation == "silu"
                        else ActivationType.Gelu
                    ),
                )
            else:
271
272
273
274
275
                from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
                    fused_experts,
                )

                return fused_experts(
276
277
278
                    hidden_states=x,
                    w1=layer.w13_weight,
                    w2=layer.w2_weight,
279
                    topk_output=topk_output,
280
281
282
283
284
285
286
287
288
289
290
                    inplace=inplace and not no_combine,
                    activation=activation,
                    apply_router_weight_on_input=apply_router_weight_on_input,
                    no_combine=no_combine,
                    routed_scaling_factor=routed_scaling_factor,
                )

    def forward_cpu(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
291
292
        topk_output: TopKOutput,
        *,
293
294
295
296
297
298
299
300
        activation: str = "silu",
        apply_router_weight_on_input: bool = False,
        inplace: bool = True,
        no_combine: bool = False,
        routed_scaling_factor: Optional[float] = None,
    ) -> torch.Tensor:
        assert activation == "silu", f"activation = {activation} is not supported."

301
302
        if use_intel_amx_backend(layer) and not apply_router_weight_on_input:
            from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
303

304
            topk_weights, topk_ids, _ = topk_output
305
306
307
            x, topk_weights = apply_topk_weights_cpu(
                apply_router_weight_on_input, topk_weights, x
            )
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
            return torch.ops.sgl_kernel.fused_experts_cpu(
                x,
                layer.w13_weight,
                layer.w2_weight,
                topk_weights,
                topk_ids,
                False,  # inplace # See [Note] inplace should be False in fused_experts.
                False,  # use_int8_w8a8
                False,  # use_fp8_w8a16
                None,  # w1_scale
                None,  # w2_scale
                None,  # block_size
                None,  # a1_scale
                None,  # a2_scale
                True,  # is_vnni
            )
        else:
325
326
327
            from sglang.srt.layers.moe.fused_moe_native import moe_forward_native

            return moe_forward_native(
328
329
                layer,
                x,
330
331
332
333
334
335
                topk_output,
                activation=activation,
                apply_router_weight_on_input=apply_router_weight_on_input,
                inplace=inplace,
                no_combine=no_combine,
                routed_scaling_factor=routed_scaling_factor,
336
337
338
339
340
341
            )

    def forward_npu(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
342
343
        topk_output: TopKOutput,
        *,
344
345
346
347
348
349
        activation: str = "silu",
        apply_router_weight_on_input: bool = False,
        inplace: bool = True,
        no_combine: bool = False,
        routed_scaling_factor: Optional[float] = None,
    ) -> torch.Tensor:
350
351
352
        from sglang.srt.layers.moe.fused_moe_native import moe_forward_native

        return moe_forward_native(
353
354
            layer,
            x,
355
356
357
358
359
360
            topk_output,
            activation=activation,
            apply_router_weight_on_input=apply_router_weight_on_input,
            inplace=inplace,
            no_combine=no_combine,
            routed_scaling_factor=routed_scaling_factor,
361
362
363
364
365
366
        )

    def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
        raise NotImplementedError("The TPU backend currently does not support MoE.")

    forward_native = forward_cpu