mxfp4.py 17.6 KB
Newer Older
1
2
3
4
5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import torch

6
from vllm.config import get_current_vllm_config
7
from vllm.logger import init_logger
8
from vllm.model_executor.layers.attention import Attention
9
10
11
12
from vllm.model_executor.layers.fused_moe import (
    FusedMoE,
    FusedMoEConfig,
    FusedMoEMethodBase,
13
    MoEActivation,
14
)
15
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
16
from vllm.model_executor.layers.fused_moe.config import (
17
    FusedMoEQuantConfig,
18
)
19
20
21
22
23
24
25
26
from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import (
    TRITON_BACKENDS,
    Mxfp4MoeBackend,
    convert_to_mxfp4_moe_kernel_format,
    make_mxfp4_moe_kernel,
    make_mxfp4_moe_quant_config,
    mxfp4_round_up_hidden_size_and_intermediate_size,
    select_mxfp4_moe_backend,
27
28
)
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
29
30
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
31
32
33
34
    QuantizationConfig,
    QuantizeMethodBase,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
35
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
36
from vllm.platforms import current_platform
37

38
39
40
logger = init_logger(__name__)


41
class Mxfp4Config(QuantizationConfig):
42
    def __init__(self, ignored_layers: list[str] | None = None):
43
44
45
46
47
48
49
50
51
        super().__init__()
        self.ignored_layers = ignored_layers

    @classmethod
    def from_config(cls, config):
        return cls()

    @classmethod
    def get_min_capability(cls) -> int:
52
        return 80
53
54
55
56
57
58
59
60
61
62
63
64
65

    @classmethod
    def get_name(cls) -> QuantizationMethods:
        return "mxfp4"

    @classmethod
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
        return [torch.bfloat16]

    @classmethod
    def get_config_filenames(cls) -> list[str]:
        return []

66
67
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
68
    ) -> "QuantizeMethodBase | None":
69
70
        if isinstance(layer, LinearBase):
            if self.ignored_layers and is_layer_skipped(
71
72
73
74
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
75
                return UnquantizedLinearMethod()
76
            logger.debug_once(
77
                "MXFP4 linear layer is not implemented - falling back to "
78
79
                "UnquantizedLinearMethod.",
                scope="local",
80
81
            )
            return UnquantizedLinearMethod()
82
        elif isinstance(layer, FusedMoE):
83
            if current_platform.is_xpu():
84
                return XpuMxfp4MoEMethod(layer.moe_config)
85
            else:
86
                return Mxfp4MoEMethod(layer.moe_config)
87
        elif isinstance(layer, Attention):
88
            logger.debug_once(
89
                "MXFP4 attention layer is not implemented. "
90
91
                "Skipping quantization for this layer.",
                scope="local",
92
            )
93
94
        return None

95
96
97
98
    def is_mxfp4_quant(self, prefix: str, layer: torch.nn.Module) -> bool:
        """MXFP4 config always uses MXFP4 quantization."""
        return True

99
100

class Mxfp4MoEMethod(FusedMoEMethodBase):
101
102
    """MXFP4 MoE quantization method."""

103
    def __init__(self, moe: FusedMoEConfig):
104
        super().__init__(moe)
105
        self.weight_dtype = "mxfp4"
106
        self.mxfp4_backend, self.experts_cls = select_mxfp4_moe_backend(moe)
107

108
        self.max_capture_size = (
109
            get_current_vllm_config().compilation_config.max_cudagraph_capture_size
110
        )
111

112
        self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
113
        self.moe_kernel: mk.FusedMoEKernel | None = None
114

115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
        # Round up dims once based on backend. This mutates the shared
        # FusedMoEConfig in-place so that create_weights() and all
        # downstream code see the padded dimensions. This must happen
        # before create_weights() is called.
        self.moe.hidden_dim, self.moe.intermediate_size_per_partition = (
            mxfp4_round_up_hidden_size_and_intermediate_size(
                self.mxfp4_backend,
                self.moe.hidden_dim,
                self.moe.intermediate_size_per_partition,
            )
        )

        # Used for triton kernel precision configs
        self.w13_precision_config = None
        self.w2_precision_config = None

131
132
133
134
    @property
    def skip_forward_padding(self) -> bool:
        # SM100_FI_MXFP4_MXFP8_TRTLLM supports padding with mxfp8 quant
        # so can skip the padding in the forward before applying the moe method
135
        return self.mxfp4_backend == Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8
136

137
138
139
140
141
142
143
144
145
    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
146
147
148
149
150
        self.num_experts = num_experts
        weight_dtype = torch.uint8
        scale_dtype = torch.uint8
        mxfp4_block = 32

151
152
153
        # Use pre-rounded sizes from config
        self.intermediate_size = intermediate_size_per_partition_after_pad = (
            self.moe.intermediate_size_per_partition
154
        )
155
156
        self.hidden_size = hidden_size = self.moe.hidden_dim

157
158
159
160
161
162
163
164
165
        # Expose padded dimensions on the layer for LoRA and Marlin code
        # that reads layer.hidden_size / layer.intermediate_size_per_partition.
        layer.params_dtype = params_dtype
        layer.num_experts = num_experts
        layer.hidden_size = hidden_size
        layer.intermediate_size_per_partition = (
            intermediate_size_per_partition_after_pad
        )

166
        # Fused gate_up_proj (column parallel)
167
168
169
170
171
172
173
174
175
        w13_weight = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                2 * intermediate_size_per_partition_after_pad,
                hidden_size // 2,
                dtype=weight_dtype,
            ),
            requires_grad=False,
        )
176
177
178
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

179
180
181
182
183
184
185
186
187
        w13_weight_scale = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                2 * intermediate_size_per_partition_after_pad,
                hidden_size // mxfp4_block,
                dtype=scale_dtype,
            ),
            requires_grad=False,
        )
188
189
190
191
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)

        # down_proj (row parallel)
192
193
194
195
196
197
198
199
200
        w2_weight = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                hidden_size,
                intermediate_size_per_partition_after_pad // 2,
                dtype=weight_dtype,
            ),
            requires_grad=False,
        )
201
202
203
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

204
205
206
207
208
209
210
211
212
        w2_weight_scale = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                hidden_size,
                intermediate_size_per_partition_after_pad // mxfp4_block,
                dtype=scale_dtype,
            ),
            requires_grad=False,
        )
213
214
215
        layer.register_parameter("w2_weight_scale", w2_weight_scale)
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)

216
217
218
219
220
221
        if self.moe.has_bias:
            w13_bias = torch.nn.Parameter(
                torch.zeros(
                    num_experts,
                    2 * intermediate_size_per_partition_after_pad,
                    dtype=torch.bfloat16,
222
                ),
223
224
                requires_grad=False,
            )
225
226
            layer.register_parameter("w13_bias", w13_bias)
            set_weight_attrs(w13_bias, extra_weight_attrs)
227

228
229
230
231
232
233
            w2_bias = torch.nn.Parameter(
                torch.zeros(
                    num_experts,
                    hidden_size,
                    dtype=torch.bfloat16,
                ),
234
235
                requires_grad=False,
            )
236
237
            layer.register_parameter("w2_bias", w2_bias)
            set_weight_attrs(w2_bias, extra_weight_attrs)
238

239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
    def _setup_kernel(
        self,
        layer: FusedMoE,
        w13: torch.Tensor,
        w2: torch.Tensor,
        w13_scale: torch.Tensor,
        w2_scale: torch.Tensor,
        w13_bias: torch.Tensor | None = None,
        w2_bias: torch.Tensor | None = None,
    ) -> None:
        num_experts = self.num_experts
        intermediate_size = self.intermediate_size
        hidden_size = self.hidden_size
        sf_block_size = 32

        # Shape assertions
        assert (
            w13.dim() == 3
            and w13.shape[0] == num_experts
            and w13.shape[1] == intermediate_size * 2
            and w13.shape[2] == hidden_size // 2
        )
        assert (
            w13_scale.dim() == 3
            and w13_scale.shape[0] == num_experts
            and w13_scale.shape[1] == intermediate_size * 2
            and w13_scale.shape[2] == hidden_size // sf_block_size
        )
        assert (
            w2.dim() == 3
            and w2.shape[0] == num_experts
            and w2.shape[1] == hidden_size
            and w2.shape[2] == intermediate_size // 2
        )
        assert (
            w2_scale.dim() == 3
            and w2_scale.shape[1] == hidden_size
            and w2_scale.shape[2] == intermediate_size // sf_block_size
        )
        if w13_bias is not None:
279
            assert (
280
281
282
                w13_bias.dim() == 2
                and w13_bias.shape[0] == num_experts
                and w13_bias.shape[1] == intermediate_size * 2
283
            )
284
        if w2_bias is not None:
285
            assert (
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
                w2_bias.dim() == 2
                and w2_bias.shape[0] == num_experts
                and w2_bias.shape[1] == hidden_size
            )

        # Convert weights to kernel format
        w13, w2, w13_scale, w2_scale, w13_bias, w2_bias = (
            convert_to_mxfp4_moe_kernel_format(
                mxfp4_backend=self.mxfp4_backend,
                layer=layer,
                w13_weight=w13,
                w2_weight=w2,
                w13_weight_scale=w13_scale,
                w2_weight_scale=w2_scale,
                w13_bias=w13_bias,
                w2_bias=w2_bias,
                _cache_permute_indices=self._cache_permute_indices,
303
            )
304
        )
305

306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
        # For TRITON backends, weights are wrapped tensors from triton_kernels
        # that don't support .detach(). Manually assign parameters.
        if self.mxfp4_backend not in TRITON_BACKENDS:
            replace_parameter(layer, "w13_weight", w13)
            replace_parameter(layer, "w2_weight", w2)
            replace_parameter(layer, "w13_weight_scale", w13_scale)
            replace_parameter(layer, "w2_weight_scale", w2_scale)
        else:
            layer.w13_weight = w13
            layer.w2_weight = w2
            self.w13_precision_config = w13_scale
            self.w2_precision_config = w2_scale

        if w13_bias is not None and w2_bias is not None:
            replace_parameter(layer, "w13_bias", w13_bias)
            replace_parameter(layer, "w2_bias", w2_bias)

        # Build quant config
        self.moe_quant_config = self.get_fused_moe_quant_config(layer)

        # Build kernel (modular or monolithic)
        if self.moe_quant_config is not None and self.experts_cls is not None:
            self.moe_kernel = make_mxfp4_moe_kernel(
                moe_quant_config=self.moe_quant_config,
                moe_config=self.moe,
                mxfp4_backend=self.mxfp4_backend,
                experts_cls=self.experts_cls,
333
                routing_tables=layer._maybe_init_expert_routing_tables(),
334
                shared_experts=layer.shared_experts,
335
336
            )

337
338
339
340
341
342
343
    def process_weights_after_loading(self, layer):
        w13 = layer.w13_weight
        w2 = layer.w2_weight
        w13_scale = layer.w13_weight_scale
        w2_scale = layer.w2_weight_scale
        w13_bias = getattr(layer, "w13_bias", None)
        w2_bias = getattr(layer, "w2_bias", None)
344

345
346
        if self.mxfp4_backend == Mxfp4MoeBackend.NONE:
            return
347

348
        self._setup_kernel(layer, w13, w2, w13_scale, w2_scale, w13_bias, w2_bias)
349

350
    def get_fused_moe_quant_config(
351
        self, layer: torch.nn.Module
352
    ) -> FusedMoEQuantConfig | None:
353
354
355
356
357
358
359
360
        w1_scale = layer.w13_weight_scale
        w2_scale = layer.w2_weight_scale
        w1_bias = getattr(layer, "w13_bias", None)
        w2_bias = getattr(layer, "w2_bias", None)

        if self.mxfp4_backend in TRITON_BACKENDS:
            assert self.w13_precision_config is not None
            assert self.w2_precision_config is not None
361
362
            w1_scale = self.w13_precision_config
            w2_scale = self.w2_precision_config
363
364
365
366
367
368
369
370

        return make_mxfp4_moe_quant_config(
            mxfp4_backend=self.mxfp4_backend,
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            w1_bias=w1_bias,
            w2_bias=w2_bias,
        )
371

372
373
    def select_gemm_impl(
        self,
374
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
375
        layer: torch.nn.Module,
376
    ) -> mk.FusedMoEExpertsModular:
377
378
379
        raise ValueError(
            f"{self.__class__.__name__} uses the new modular kernel "
            "initialization logic. This function should not be called."
380
381
        )

382
383
    def apply(
        self,
384
        layer: FusedMoE,
385
        x: torch.Tensor,
386
387
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
388
        shared_experts_input: torch.Tensor | None,
389
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
390
        assert not self.is_monolithic
391
392
        assert self.moe_kernel is not None
        return self.moe_kernel.apply(
393
394
395
396
397
398
399
400
401
402
            hidden_states=x,
            w1=layer.w13_weight,
            w2=layer.w2_weight,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            activation=layer.activation,
            global_num_experts=layer.global_num_experts,
            apply_router_weight_on_input=layer.apply_router_weight_on_input,
            expert_map=layer.expert_map,
            shared_experts_input=shared_experts_input,
403
404
405
406
407
408
409
410
411
        )

    def apply_monolithic(
        self,
        layer: FusedMoE,
        x: torch.Tensor,
        router_logits: torch.Tensor,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        assert self.is_monolithic
412
413
414
415
416
417
418
419
420
421
422
        assert self.moe_kernel is not None
        return self.moe_kernel.apply_monolithic(
            hidden_states=x,
            w1=layer.w13_weight,
            w2=layer.w2_weight,
            router_logits=router_logits,
            activation=layer.activation,
            global_num_experts=layer.global_num_experts,
            expert_map=layer.expert_map,
            apply_router_weight_on_input=layer.apply_router_weight_on_input,
        )
423
424


425
class XpuMxfp4MoEMethod(Mxfp4MoEMethod):
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
    def __init__(self, moe_config: FusedMoEConfig):
        super().__init__(moe_config)
        self.moe_config = moe_config

    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        super().create_weights(
            layer,
            num_experts,
            hidden_size,
            intermediate_size_per_partition,
            params_dtype,
            **extra_weight_attrs,
        )
        self.original_hidden_size = hidden_size

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
450
        pass
451

452
453
454
455
456
    @property
    def is_monolithic(self) -> bool:
        return True

    def apply_monolithic(
457
        self,
458
        layer: FusedMoE,
459
460
461
        x: torch.Tensor,
        router_logits: torch.Tensor,
    ) -> torch.Tensor:
462
463
464
        assert layer.activation == MoEActivation.SWIGLUOAI, (
            "Only swiglu_oai activation is supported for "
            f"XPU MXFP4 MoE, not {layer.activation}."
465
        )
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
        from vllm_xpu_kernels.fused_moe_interface import xpu_fused_moe

        M, _ = x.size()
        routing_weights = torch.empty(
            M, layer.top_k, dtype=torch.float32, device=x.device
        )
        selected_experts = torch.empty(
            M, layer.top_k, dtype=torch.int32, device=x.device
        )
        token_expert_indices = torch.empty(
            M, layer.top_k, dtype=torch.int32, device=x.device
        )

        if layer.use_grouped_topk:
            routing_weights, selected_experts = torch.ops._moe_C.fused_grouped_topk(
                x,
                router_logits,
                layer.top_k,
                layer.renormalize,
                n_expert_group=layer.num_expert_group,
                n_topk_group=layer.topk_group,
                scoring_func=layer.scoring_func,
                routed_scaling_factor=layer.routed_scaling_factor,
                bias=layer.e_score_correction_bias,
            )
        else:
            torch.ops._moe_C.topk_softmax(
                routing_weights,
                selected_experts,
                token_expert_indices,
                router_logits,
                layer.renormalize,
                layer.e_score_correction_bias,
            )

        return xpu_fused_moe(
            hidden_states=x,
            w13=layer.w13_weight,
            w13_bias=layer.w13_bias if self.moe.has_bias else None,
            w13_scales=layer.w13_weight_scale,
            w2=layer.w2_weight,
            w2_bias=layer.w2_bias if self.moe.has_bias else None,
            w2_scales=layer.w2_weight_scale,
            topk_weights=routing_weights,
            topk_ids=selected_experts,
            n_experts_per_token=layer.top_k,
512
            activation=layer.activation.value,
513
514
            num_experts=layer.local_num_experts,
            is_mxfp4=True,
515
        )