mxfp4.py 14.2 KB
Newer Older
1
2
3
4
5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import torch

6
from vllm.config import get_current_vllm_config
7
from vllm.logger import init_logger
8
from vllm.model_executor.layers.attention import Attention
9
10
11
12
13
from vllm.model_executor.layers.fused_moe import (
    FusedMoE,
    FusedMoEConfig,
    FusedMoEMethodBase,
)
14
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
15
from vllm.model_executor.layers.fused_moe.config import (
16
    FusedMoEParallelConfig,
17
    FusedMoEQuantConfig,
18
)
19
20
21
22
23
24
25
26
from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import (
    TRITON_BACKENDS,
    Mxfp4MoeBackend,
    convert_to_mxfp4_moe_kernel_format,
    make_mxfp4_moe_kernel,
    make_mxfp4_moe_quant_config,
    mxfp4_round_up_hidden_size_and_intermediate_size,
    select_mxfp4_moe_backend,
27
28
)
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
29
30
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
31
32
33
34
    QuantizationConfig,
    QuantizeMethodBase,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
35
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
36

37
38
39
logger = init_logger(__name__)


40
class Mxfp4Config(QuantizationConfig):
41
    def __init__(self, ignored_layers: list[str] | None = None):
42
43
44
45
46
47
48
49
50
        super().__init__()
        self.ignored_layers = ignored_layers

    @classmethod
    def from_config(cls, config):
        return cls()

    @classmethod
    def get_min_capability(cls) -> int:
51
        return 80
52
53
54
55
56
57
58
59
60
61
62
63
64

    @classmethod
    def get_name(cls) -> QuantizationMethods:
        return "mxfp4"

    @classmethod
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
        return [torch.bfloat16]

    @classmethod
    def get_config_filenames(cls) -> list[str]:
        return []

65
66
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
67
    ) -> "QuantizeMethodBase | None":
68
69
        if isinstance(layer, LinearBase):
            if self.ignored_layers and is_layer_skipped(
70
71
72
73
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
74
                return UnquantizedLinearMethod()
75
            logger.debug_once(
76
                "MXFP4 linear layer is not implemented - falling back to "
77
78
                "UnquantizedLinearMethod.",
                scope="local",
79
80
            )
            return UnquantizedLinearMethod()
81
        elif isinstance(layer, FusedMoE):
82
            return Mxfp4MoEMethod(layer.moe_config)
83
        elif isinstance(layer, Attention):
84
            logger.debug_once(
85
                "MXFP4 attention layer is not implemented. "
86
87
                "Skipping quantization for this layer.",
                scope="local",
88
            )
89
90
        return None

91
92
93
94
    def is_mxfp4_quant(self, prefix: str, layer: torch.nn.Module) -> bool:
        """MXFP4 config always uses MXFP4 quantization."""
        return True

95
96

class Mxfp4MoEMethod(FusedMoEMethodBase):
97
98
    """MXFP4 MoE quantization method."""

99
    def __init__(self, moe: FusedMoEConfig):
100
        super().__init__(moe)
101
        self.weight_dtype = "mxfp4"
102
        self.mxfp4_backend, self.experts_cls = select_mxfp4_moe_backend(moe)
103

104
        self.max_capture_size = (
105
            get_current_vllm_config().compilation_config.max_cudagraph_capture_size
106
        )
107

108
        self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
109
        self.moe_kernel: mk.FusedMoEKernel | None = None
110

111
112
113
114
        # Used for triton kernel precision configs
        self.w13_precision_config = None
        self.w2_precision_config = None

115
116
117
118
    @property
    def skip_forward_padding(self) -> bool:
        # SM100_FI_MXFP4_MXFP8_TRTLLM supports padding with mxfp8 quant
        # so can skip the padding in the forward before applying the moe method
119
        return self.mxfp4_backend == Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8
120

121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
    def maybe_roundup_sizes(
        self,
        hidden_size: int,
        intermediate_size_per_partition: int,
        act_dtype: torch.dtype,
        moe_parallel_config: FusedMoEParallelConfig,
    ) -> tuple[int, int]:
        hidden_size, intermediate_size_per_partition = super().maybe_roundup_sizes(
            hidden_size=hidden_size,
            intermediate_size_per_partition=intermediate_size_per_partition,
            act_dtype=act_dtype,
            moe_parallel_config=moe_parallel_config,
        )
        return mxfp4_round_up_hidden_size_and_intermediate_size(
            self.mxfp4_backend, hidden_size, intermediate_size_per_partition
        )

138
139
140
141
142
143
144
145
146
    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
147
148
149
150
151
        self.num_experts = num_experts
        weight_dtype = torch.uint8
        scale_dtype = torch.uint8
        mxfp4_block = 32

152
153
        layer.params_dtype = params_dtype
        layer.num_experts = num_experts
154
155
        self.intermediate_size = intermediate_size_per_partition
        self.hidden_size = hidden_size
156

157
        # Fused gate_up_proj (column parallel)
158
159
160
        w13_weight = torch.nn.Parameter(
            torch.zeros(
                num_experts,
161
                2 * intermediate_size_per_partition,
162
163
164
165
166
                hidden_size // 2,
                dtype=weight_dtype,
            ),
            requires_grad=False,
        )
167
168
169
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

170
171
172
        w13_weight_scale = torch.nn.Parameter(
            torch.zeros(
                num_experts,
173
                2 * intermediate_size_per_partition,
174
175
176
177
178
                hidden_size // mxfp4_block,
                dtype=scale_dtype,
            ),
            requires_grad=False,
        )
179
180
181
182
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)

        # down_proj (row parallel)
183
184
185
186
        w2_weight = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                hidden_size,
187
                intermediate_size_per_partition // 2,
188
189
190
191
                dtype=weight_dtype,
            ),
            requires_grad=False,
        )
192
193
194
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

195
196
197
198
        w2_weight_scale = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                hidden_size,
199
                intermediate_size_per_partition // mxfp4_block,
200
201
202
203
                dtype=scale_dtype,
            ),
            requires_grad=False,
        )
204
205
206
        layer.register_parameter("w2_weight_scale", w2_weight_scale)
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)

207
208
209
210
        if self.moe.has_bias:
            w13_bias = torch.nn.Parameter(
                torch.zeros(
                    num_experts,
211
                    2 * intermediate_size_per_partition,
212
                    dtype=torch.bfloat16,
213
                ),
214
215
                requires_grad=False,
            )
216
217
            layer.register_parameter("w13_bias", w13_bias)
            set_weight_attrs(w13_bias, extra_weight_attrs)
218

219
220
221
222
223
224
            w2_bias = torch.nn.Parameter(
                torch.zeros(
                    num_experts,
                    hidden_size,
                    dtype=torch.bfloat16,
                ),
225
226
                requires_grad=False,
            )
227
228
            layer.register_parameter("w2_bias", w2_bias)
            set_weight_attrs(w2_bias, extra_weight_attrs)
229

230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
    def _setup_kernel(
        self,
        layer: FusedMoE,
        w13: torch.Tensor,
        w2: torch.Tensor,
        w13_scale: torch.Tensor,
        w2_scale: torch.Tensor,
        w13_bias: torch.Tensor | None = None,
        w2_bias: torch.Tensor | None = None,
    ) -> None:
        num_experts = self.num_experts
        intermediate_size = self.intermediate_size
        hidden_size = self.hidden_size
        sf_block_size = 32

        # Shape assertions
        assert (
            w13.dim() == 3
            and w13.shape[0] == num_experts
            and w13.shape[1] == intermediate_size * 2
            and w13.shape[2] == hidden_size // 2
        )
        assert (
            w13_scale.dim() == 3
            and w13_scale.shape[0] == num_experts
            and w13_scale.shape[1] == intermediate_size * 2
            and w13_scale.shape[2] == hidden_size // sf_block_size
        )
        assert (
            w2.dim() == 3
            and w2.shape[0] == num_experts
            and w2.shape[1] == hidden_size
            and w2.shape[2] == intermediate_size // 2
        )
        assert (
            w2_scale.dim() == 3
            and w2_scale.shape[1] == hidden_size
            and w2_scale.shape[2] == intermediate_size // sf_block_size
        )
        if w13_bias is not None:
270
            assert (
271
272
273
                w13_bias.dim() == 2
                and w13_bias.shape[0] == num_experts
                and w13_bias.shape[1] == intermediate_size * 2
274
            )
275
        if w2_bias is not None:
276
            assert (
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
                w2_bias.dim() == 2
                and w2_bias.shape[0] == num_experts
                and w2_bias.shape[1] == hidden_size
            )

        # Convert weights to kernel format
        w13, w2, w13_scale, w2_scale, w13_bias, w2_bias = (
            convert_to_mxfp4_moe_kernel_format(
                mxfp4_backend=self.mxfp4_backend,
                layer=layer,
                w13_weight=w13,
                w2_weight=w2,
                w13_weight_scale=w13_scale,
                w2_weight_scale=w2_scale,
                w13_bias=w13_bias,
                w2_bias=w2_bias,
                _cache_permute_indices=self._cache_permute_indices,
294
            )
295
        )
296

297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
        # For TRITON backends, weights are wrapped tensors from triton_kernels
        # that don't support .detach(). Manually assign parameters.
        if self.mxfp4_backend not in TRITON_BACKENDS:
            replace_parameter(layer, "w13_weight", w13)
            replace_parameter(layer, "w2_weight", w2)
            replace_parameter(layer, "w13_weight_scale", w13_scale)
            replace_parameter(layer, "w2_weight_scale", w2_scale)
        else:
            layer.w13_weight = w13
            layer.w2_weight = w2
            self.w13_precision_config = w13_scale
            self.w2_precision_config = w2_scale

        if w13_bias is not None and w2_bias is not None:
            replace_parameter(layer, "w13_bias", w13_bias)
            replace_parameter(layer, "w2_bias", w2_bias)

        # Build quant config
        self.moe_quant_config = self.get_fused_moe_quant_config(layer)

        # Build kernel (modular or monolithic)
        if self.moe_quant_config is not None and self.experts_cls is not None:
            self.moe_kernel = make_mxfp4_moe_kernel(
                moe_quant_config=self.moe_quant_config,
                moe_config=self.moe,
                mxfp4_backend=self.mxfp4_backend,
                experts_cls=self.experts_cls,
324
                routing_tables=layer._maybe_init_expert_routing_tables(),
325
                shared_experts=layer.shared_experts,
326
327
            )

328
329
330
331
332
333
334
    def process_weights_after_loading(self, layer):
        w13 = layer.w13_weight
        w2 = layer.w2_weight
        w13_scale = layer.w13_weight_scale
        w2_scale = layer.w2_weight_scale
        w13_bias = getattr(layer, "w13_bias", None)
        w2_bias = getattr(layer, "w2_bias", None)
335

336
337
        if self.mxfp4_backend == Mxfp4MoeBackend.NONE:
            return
338

339
        self._setup_kernel(layer, w13, w2, w13_scale, w2_scale, w13_bias, w2_bias)
340

341
    def get_fused_moe_quant_config(
342
        self, layer: torch.nn.Module
343
    ) -> FusedMoEQuantConfig | None:
344
345
346
347
348
349
350
351
        w1_scale = layer.w13_weight_scale
        w2_scale = layer.w2_weight_scale
        w1_bias = getattr(layer, "w13_bias", None)
        w2_bias = getattr(layer, "w2_bias", None)

        if self.mxfp4_backend in TRITON_BACKENDS:
            assert self.w13_precision_config is not None
            assert self.w2_precision_config is not None
352
353
            w1_scale = self.w13_precision_config
            w2_scale = self.w2_precision_config
354
355
356
357
358
359
360
361

        return make_mxfp4_moe_quant_config(
            mxfp4_backend=self.mxfp4_backend,
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            w1_bias=w1_bias,
            w2_bias=w2_bias,
        )
362

363
364
    def select_gemm_impl(
        self,
365
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
366
        layer: torch.nn.Module,
367
    ) -> mk.FusedMoEExpertsModular:
368
369
370
        raise ValueError(
            f"{self.__class__.__name__} uses the new modular kernel "
            "initialization logic. This function should not be called."
371
372
        )

373
374
    def apply(
        self,
375
        layer: FusedMoE,
376
        x: torch.Tensor,
377
378
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
379
        shared_experts_input: torch.Tensor | None,
380
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
381
        assert not self.is_monolithic
382
383
        assert self.moe_kernel is not None
        return self.moe_kernel.apply(
384
385
386
387
388
389
390
391
392
393
            hidden_states=x,
            w1=layer.w13_weight,
            w2=layer.w2_weight,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            activation=layer.activation,
            global_num_experts=layer.global_num_experts,
            apply_router_weight_on_input=layer.apply_router_weight_on_input,
            expert_map=layer.expert_map,
            shared_experts_input=shared_experts_input,
394
395
396
397
398
399
400
401
402
        )

    def apply_monolithic(
        self,
        layer: FusedMoE,
        x: torch.Tensor,
        router_logits: torch.Tensor,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        assert self.is_monolithic
403
404
405
406
407
408
409
410
411
412
413
        assert self.moe_kernel is not None
        return self.moe_kernel.apply_monolithic(
            hidden_states=x,
            w1=layer.w13_weight,
            w2=layer.w2_weight,
            router_logits=router_logits,
            activation=layer.activation,
            global_num_experts=layer.global_num_experts,
            expert_map=layer.expert_map,
            apply_router_weight_on_input=layer.apply_router_weight_on_input,
        )