"vllm/vscode:/vscode.git/clone" did not exist on "69bff9bc893475fbd64c47633cb8ece46cd462c7"
mxfp4.py 14.5 KB
Newer Older
1
2
3
4
5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import torch

6
from vllm.config import get_current_vllm_config
7
from vllm.logger import init_logger
8
from vllm.model_executor.layers.attention import Attention
9
10
11
12
13
from vllm.model_executor.layers.fused_moe import (
    FusedMoE,
    FusedMoEConfig,
    FusedMoEMethodBase,
)
14
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
15
from vllm.model_executor.layers.fused_moe.config import (
16
    FusedMoEQuantConfig,
17
)
18
19
20
21
22
23
24
25
from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import (
    TRITON_BACKENDS,
    Mxfp4MoeBackend,
    convert_to_mxfp4_moe_kernel_format,
    make_mxfp4_moe_kernel,
    make_mxfp4_moe_quant_config,
    mxfp4_round_up_hidden_size_and_intermediate_size,
    select_mxfp4_moe_backend,
26
27
)
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
28
29
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
30
31
32
33
    QuantizationConfig,
    QuantizeMethodBase,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
34
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
35

36
37
38
logger = init_logger(__name__)


39
class Mxfp4Config(QuantizationConfig):
40
    def __init__(self, ignored_layers: list[str] | None = None):
41
42
43
44
45
46
47
48
49
        super().__init__()
        self.ignored_layers = ignored_layers

    @classmethod
    def from_config(cls, config):
        return cls()

    @classmethod
    def get_min_capability(cls) -> int:
50
        return 80
51
52
53
54
55
56
57
58
59
60
61
62
63

    @classmethod
    def get_name(cls) -> QuantizationMethods:
        return "mxfp4"

    @classmethod
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
        return [torch.bfloat16]

    @classmethod
    def get_config_filenames(cls) -> list[str]:
        return []

64
65
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
66
    ) -> "QuantizeMethodBase | None":
67
68
        if isinstance(layer, LinearBase):
            if self.ignored_layers and is_layer_skipped(
69
70
71
72
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
73
                return UnquantizedLinearMethod()
74
            logger.debug_once(
75
                "MXFP4 linear layer is not implemented - falling back to "
76
77
                "UnquantizedLinearMethod.",
                scope="local",
78
79
            )
            return UnquantizedLinearMethod()
80
        elif isinstance(layer, FusedMoE):
81
            return Mxfp4MoEMethod(layer.moe_config)
82
        elif isinstance(layer, Attention):
83
            logger.debug_once(
84
                "MXFP4 attention layer is not implemented. "
85
86
                "Skipping quantization for this layer.",
                scope="local",
87
            )
88
89
        return None

90
91
92
93
    def is_mxfp4_quant(self, prefix: str, layer: torch.nn.Module) -> bool:
        """MXFP4 config always uses MXFP4 quantization."""
        return True

94
95

class Mxfp4MoEMethod(FusedMoEMethodBase):
96
97
    """MXFP4 MoE quantization method."""

98
    def __init__(self, moe: FusedMoEConfig):
99
        super().__init__(moe)
100
        self.weight_dtype = "mxfp4"
101
        self.mxfp4_backend, self.experts_cls = select_mxfp4_moe_backend(moe)
102

103
        self.max_capture_size = (
104
            get_current_vllm_config().compilation_config.max_cudagraph_capture_size
105
        )
106

107
        self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
108
        self.moe_kernel: mk.FusedMoEKernel | None = None
109

110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
        # Round up dims once based on backend. This mutates the shared
        # FusedMoEConfig in-place so that create_weights() and all
        # downstream code see the padded dimensions. This must happen
        # before create_weights() is called.
        self.moe.hidden_dim, self.moe.intermediate_size_per_partition = (
            mxfp4_round_up_hidden_size_and_intermediate_size(
                self.mxfp4_backend,
                self.moe.hidden_dim,
                self.moe.intermediate_size_per_partition,
            )
        )

        # Used for triton kernel precision configs
        self.w13_precision_config = None
        self.w2_precision_config = None

126
127
128
129
    @property
    def skip_forward_padding(self) -> bool:
        # SM100_FI_MXFP4_MXFP8_TRTLLM supports padding with mxfp8 quant
        # so can skip the padding in the forward before applying the moe method
130
        return self.mxfp4_backend == Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8
131

132
133
134
135
136
137
138
139
140
    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
141
142
143
144
145
        self.num_experts = num_experts
        weight_dtype = torch.uint8
        scale_dtype = torch.uint8
        mxfp4_block = 32

146
147
148
        # Use pre-rounded sizes from config
        self.intermediate_size = intermediate_size_per_partition_after_pad = (
            self.moe.intermediate_size_per_partition
149
        )
150
151
        self.hidden_size = hidden_size = self.moe.hidden_dim

152
153
154
155
156
157
158
159
160
        # Expose padded dimensions on the layer for LoRA and Marlin code
        # that reads layer.hidden_size / layer.intermediate_size_per_partition.
        layer.params_dtype = params_dtype
        layer.num_experts = num_experts
        layer.hidden_size = hidden_size
        layer.intermediate_size_per_partition = (
            intermediate_size_per_partition_after_pad
        )

161
        # Fused gate_up_proj (column parallel)
162
163
164
165
166
167
168
169
170
        w13_weight = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                2 * intermediate_size_per_partition_after_pad,
                hidden_size // 2,
                dtype=weight_dtype,
            ),
            requires_grad=False,
        )
171
172
173
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

174
175
176
177
178
179
180
181
182
        w13_weight_scale = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                2 * intermediate_size_per_partition_after_pad,
                hidden_size // mxfp4_block,
                dtype=scale_dtype,
            ),
            requires_grad=False,
        )
183
184
185
186
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)

        # down_proj (row parallel)
187
188
189
190
191
192
193
194
195
        w2_weight = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                hidden_size,
                intermediate_size_per_partition_after_pad // 2,
                dtype=weight_dtype,
            ),
            requires_grad=False,
        )
196
197
198
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

199
200
201
202
203
204
205
206
207
        w2_weight_scale = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                hidden_size,
                intermediate_size_per_partition_after_pad // mxfp4_block,
                dtype=scale_dtype,
            ),
            requires_grad=False,
        )
208
209
210
        layer.register_parameter("w2_weight_scale", w2_weight_scale)
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)

211
212
213
214
215
216
        if self.moe.has_bias:
            w13_bias = torch.nn.Parameter(
                torch.zeros(
                    num_experts,
                    2 * intermediate_size_per_partition_after_pad,
                    dtype=torch.bfloat16,
217
                ),
218
219
                requires_grad=False,
            )
220
221
            layer.register_parameter("w13_bias", w13_bias)
            set_weight_attrs(w13_bias, extra_weight_attrs)
222

223
224
225
226
227
228
            w2_bias = torch.nn.Parameter(
                torch.zeros(
                    num_experts,
                    hidden_size,
                    dtype=torch.bfloat16,
                ),
229
230
                requires_grad=False,
            )
231
232
            layer.register_parameter("w2_bias", w2_bias)
            set_weight_attrs(w2_bias, extra_weight_attrs)
233

234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
    def _setup_kernel(
        self,
        layer: FusedMoE,
        w13: torch.Tensor,
        w2: torch.Tensor,
        w13_scale: torch.Tensor,
        w2_scale: torch.Tensor,
        w13_bias: torch.Tensor | None = None,
        w2_bias: torch.Tensor | None = None,
    ) -> None:
        num_experts = self.num_experts
        intermediate_size = self.intermediate_size
        hidden_size = self.hidden_size
        sf_block_size = 32

        # Shape assertions
        assert (
            w13.dim() == 3
            and w13.shape[0] == num_experts
            and w13.shape[1] == intermediate_size * 2
            and w13.shape[2] == hidden_size // 2
        )
        assert (
            w13_scale.dim() == 3
            and w13_scale.shape[0] == num_experts
            and w13_scale.shape[1] == intermediate_size * 2
            and w13_scale.shape[2] == hidden_size // sf_block_size
        )
        assert (
            w2.dim() == 3
            and w2.shape[0] == num_experts
            and w2.shape[1] == hidden_size
            and w2.shape[2] == intermediate_size // 2
        )
        assert (
            w2_scale.dim() == 3
            and w2_scale.shape[1] == hidden_size
            and w2_scale.shape[2] == intermediate_size // sf_block_size
        )
        if w13_bias is not None:
274
            assert (
275
276
277
                w13_bias.dim() == 2
                and w13_bias.shape[0] == num_experts
                and w13_bias.shape[1] == intermediate_size * 2
278
            )
279
        if w2_bias is not None:
280
            assert (
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
                w2_bias.dim() == 2
                and w2_bias.shape[0] == num_experts
                and w2_bias.shape[1] == hidden_size
            )

        # Convert weights to kernel format
        w13, w2, w13_scale, w2_scale, w13_bias, w2_bias = (
            convert_to_mxfp4_moe_kernel_format(
                mxfp4_backend=self.mxfp4_backend,
                layer=layer,
                w13_weight=w13,
                w2_weight=w2,
                w13_weight_scale=w13_scale,
                w2_weight_scale=w2_scale,
                w13_bias=w13_bias,
                w2_bias=w2_bias,
                _cache_permute_indices=self._cache_permute_indices,
298
            )
299
        )
300

301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
        # For TRITON backends, weights are wrapped tensors from triton_kernels
        # that don't support .detach(). Manually assign parameters.
        if self.mxfp4_backend not in TRITON_BACKENDS:
            replace_parameter(layer, "w13_weight", w13)
            replace_parameter(layer, "w2_weight", w2)
            replace_parameter(layer, "w13_weight_scale", w13_scale)
            replace_parameter(layer, "w2_weight_scale", w2_scale)
        else:
            layer.w13_weight = w13
            layer.w2_weight = w2
            self.w13_precision_config = w13_scale
            self.w2_precision_config = w2_scale

        if w13_bias is not None and w2_bias is not None:
            replace_parameter(layer, "w13_bias", w13_bias)
            replace_parameter(layer, "w2_bias", w2_bias)

        # Build quant config
        self.moe_quant_config = self.get_fused_moe_quant_config(layer)

        # Build kernel (modular or monolithic)
        if self.moe_quant_config is not None and self.experts_cls is not None:
            self.moe_kernel = make_mxfp4_moe_kernel(
                moe_quant_config=self.moe_quant_config,
                moe_config=self.moe,
                mxfp4_backend=self.mxfp4_backend,
                experts_cls=self.experts_cls,
328
                routing_tables=layer._maybe_init_expert_routing_tables(),
329
                shared_experts=layer.shared_experts,
330
331
            )

332
333
334
335
336
337
338
    def process_weights_after_loading(self, layer):
        w13 = layer.w13_weight
        w2 = layer.w2_weight
        w13_scale = layer.w13_weight_scale
        w2_scale = layer.w2_weight_scale
        w13_bias = getattr(layer, "w13_bias", None)
        w2_bias = getattr(layer, "w2_bias", None)
339

340
341
        if self.mxfp4_backend == Mxfp4MoeBackend.NONE:
            return
342

343
        self._setup_kernel(layer, w13, w2, w13_scale, w2_scale, w13_bias, w2_bias)
344

345
    def get_fused_moe_quant_config(
346
        self, layer: torch.nn.Module
347
    ) -> FusedMoEQuantConfig | None:
348
349
350
351
352
353
354
355
        w1_scale = layer.w13_weight_scale
        w2_scale = layer.w2_weight_scale
        w1_bias = getattr(layer, "w13_bias", None)
        w2_bias = getattr(layer, "w2_bias", None)

        if self.mxfp4_backend in TRITON_BACKENDS:
            assert self.w13_precision_config is not None
            assert self.w2_precision_config is not None
356
357
            w1_scale = self.w13_precision_config
            w2_scale = self.w2_precision_config
358
359
360
361
362
363
364
365

        return make_mxfp4_moe_quant_config(
            mxfp4_backend=self.mxfp4_backend,
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            w1_bias=w1_bias,
            w2_bias=w2_bias,
        )
366

367
368
    def select_gemm_impl(
        self,
369
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
370
        layer: torch.nn.Module,
371
    ) -> mk.FusedMoEExpertsModular:
372
373
374
        raise ValueError(
            f"{self.__class__.__name__} uses the new modular kernel "
            "initialization logic. This function should not be called."
375
376
        )

377
378
    def apply(
        self,
379
        layer: FusedMoE,
380
        x: torch.Tensor,
381
382
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
383
        shared_experts_input: torch.Tensor | None,
384
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
385
        assert not self.is_monolithic
386
387
        assert self.moe_kernel is not None
        return self.moe_kernel.apply(
388
389
390
391
392
393
394
395
396
397
            hidden_states=x,
            w1=layer.w13_weight,
            w2=layer.w2_weight,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            activation=layer.activation,
            global_num_experts=layer.global_num_experts,
            apply_router_weight_on_input=layer.apply_router_weight_on_input,
            expert_map=layer.expert_map,
            shared_experts_input=shared_experts_input,
398
399
400
401
402
403
404
405
406
        )

    def apply_monolithic(
        self,
        layer: FusedMoE,
        x: torch.Tensor,
        router_logits: torch.Tensor,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        assert self.is_monolithic
407
408
409
410
411
412
413
414
415
416
417
        assert self.moe_kernel is not None
        return self.moe_kernel.apply_monolithic(
            hidden_states=x,
            w1=layer.w13_weight,
            w2=layer.w2_weight,
            router_logits=router_logits,
            activation=layer.activation,
            global_num_experts=layer.global_num_experts,
            expert_map=layer.expert_map,
            apply_router_weight_on_input=layer.apply_router_weight_on_input,
        )