fp8.py 49.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import TYPE_CHECKING, Any
5
6
7

import torch
from torch.nn import Module
8
from torch.utils._python_dispatch import TorchDispatchMode
9

10
import vllm.envs as envs
11
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
12
from vllm import _custom_ops as ops
13
from vllm._aiter_ops import rocm_aiter_ops
14
from vllm.distributed import get_tensor_model_parallel_world_size
15
from vllm.logger import init_logger
16
from vllm.model_executor.layers.attention import Attention
17
from vllm.model_executor.layers.batch_invariant import (
18
    vllm_is_batch_invariant,
19
)
bnellnm's avatar
bnellnm committed
20
from vllm.model_executor.layers.fused_moe import (
21
22
23
24
25
26
    FusedMoE,
    FusedMoEMethodBase,
    FusedMoEPermuteExpertsUnpermute,
    FusedMoEPrepareAndFinalize,
    FusedMoeWeightScaleSupported,
)
27
from vllm.model_executor.layers.fused_moe.config import (
28
    FusedMoEQuantConfig,
29
    RoutingMethodType,
30
31
)
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
32
33
34
35
36
37
38
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
    Fp8MoeBackend,
    convert_to_fp8_moe_kernel_format,
    make_fp8_moe_kernel,
    make_fp8_moe_quant_config,
    select_fp8_moe_backend,
)
39
40
41
42
43
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
44
from vllm.model_executor.layers.quantization import QuantizationMethods
45
from vllm.model_executor.layers.quantization.base_config import (
46
47
48
    QuantizationConfig,
    QuantizeMethodBase,
)
49
50
51
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
    init_fp8_linear_kernel,
)
52
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
53
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
54
    apply_fi_trtllm_fp8_per_tensor_moe,
55
)
56
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
57
58
59
60
61
    W8A8BlockFp8LinearOp,
    create_fp8_input_scale,
    create_fp8_scale_parameter,
    create_fp8_weight_parameter,
    maybe_post_process_fp8_weight_block,
62
    process_fp8_input_tensor_strategy_moe,
63
64
    process_fp8_weight_block_strategy,
    process_fp8_weight_tensor_strategy,
65
    process_fp8_weight_tensor_strategy_moe,
66
67
    validate_fp8_block_shape,
)
68
69
70
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
    get_marlin_input_dtype,
)
71
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
72
73
74
    apply_fp8_marlin_linear,
    prepare_fp8_layer_for_marlin,
)
75
from vllm.model_executor.layers.quantization.utils.quant_utils import (
76
77
    GroupShape,
    is_layer_skipped,
78
    kFp8Dynamic128Sym,
79
80
    kFp8DynamicTensorSym,
    kFp8DynamicTokenSym,
81
    kFp8Static128BlockSym,
82
    kFp8StaticTensorSym,
83
)
84
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
85
86
87
88
    cutlass_block_fp8_supported,
    cutlass_fp8_supported,
    normalize_e4m3fn_to_e4m3fnuz,
)
89
from vllm.model_executor.model_loader.weight_utils import initialize_single_dummy_weight
90
91
92
93
94
from vllm.model_executor.parameter import (
    BlockQuantScaleParameter,
    ModelWeightParameter,
    PerTensorScaleParameter,
)
95
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
96
from vllm.platforms import current_platform
97
98
99
from vllm.utils.deep_gemm import (
    is_deep_gemm_supported,
)
100

101
102
103
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

104
105
106
107
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

108

109
class Fp8Config(QuantizationConfig):
110
111
    """Config class for FP8."""

112
113
    def __init__(
        self,
114
        is_checkpoint_fp8_serialized: bool = False,
115
        activation_scheme: str = "dynamic",
116
117
        ignored_layers: list[str] | None = None,
        weight_block_size: list[int] | None = None,
118
    ) -> None:
119
        super().__init__()
120

121
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
122

123
        if activation_scheme not in ACTIVATION_SCHEMES:
124
            raise ValueError(f"Unsupported activation scheme {activation_scheme}")
125
        self.activation_scheme = activation_scheme
126
        self.ignored_layers = ignored_layers or []
127
128
129
130
        if weight_block_size is not None:
            if not is_checkpoint_fp8_serialized:
                raise ValueError(
                    "The block-wise quantization only supports fp8-serialized "
131
132
                    "checkpoint for now."
                )
133
134
135
            if len(weight_block_size) != 2:
                raise ValueError(
                    "The quantization block size of weight must have 2 "
136
137
                    f"dimensions, but got {len(weight_block_size)} dimensions"
                )
138
            if activation_scheme != "dynamic":
139
140
141
142
143
                raise ValueError(
                    "The block-wise quantization only supports "
                    "dynamic activation scheme for now, but got "
                    f"{activation_scheme} activation scheme."
                )
144
        self.weight_block_size = weight_block_size
145

146
    @classmethod
147
    def get_name(cls) -> QuantizationMethods:
148
149
150
        return "fp8"

    @classmethod
151
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
152
153
154
155
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
156
        return 75
157
158

    @classmethod
159
    def get_config_filenames(cls) -> list[str]:
160
161
        return []

162
163
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.ignored_layers is not None:
164
            self.ignored_layers = hf_to_vllm_mapper.apply_list(self.ignored_layers)
165

166
    @classmethod
167
    def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
168
        quant_method = cls.get_from_keys(config, ["quant_method"])
169
        is_checkpoint_fp8_serialized = "fp8" in quant_method
170
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
171
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
172
        weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
173
        if not ignored_layers:
174
175
176
177
178
179
180
181
182
183
184
185
            ignored_layers = cls.get_from_keys_or(
                config, ["modules_to_not_convert"], None
            )
        return cls(
            is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
            activation_scheme=activation_scheme,
            ignored_layers=ignored_layers,
            weight_block_size=weight_block_size,
        )

    def get_xpu_quant_method(
        self, layer: torch.nn.Module, prefix: str
186
    ) -> "QuantizeMethodBase | None":
187
        from vllm.model_executor.layers.quantization.ipex_quant import (
188
189
190
191
            XPUFp8LinearMethod,
            XPUFp8MoEMethod,
        )

192
193
194
195
        fp8_config = Fp8Config(
            is_checkpoint_fp8_serialized=self.is_checkpoint_fp8_serialized,
            activation_scheme=self.activation_scheme,
            ignored_layers=self.ignored_layers,
196
197
            weight_block_size=self.weight_block_size,
        )
198
199

        if isinstance(layer, LinearBase):
200
201
202
203
204
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
205
206
207
                return UnquantizedLinearMethod()
            return XPUFp8LinearMethod(fp8_config)
        elif isinstance(layer, FusedMoE):
208
209
210
211
212
213
214
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
                return UnquantizedFusedMoEMethod(layer.moe_config)

215
216
217
218
219
            return XPUFp8MoEMethod(fp8_config, layer)
        elif isinstance(layer, Attention):
            return Fp8KVCacheMethod(self)
        return None

220
221
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
222
    ) -> "QuantizeMethodBase | None":
223
224
        if current_platform.is_xpu():
            return self.get_xpu_quant_method(layer, prefix)
225
        if isinstance(layer, LinearBase):
226
227
228
229
230
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
231
                return UnquantizedLinearMethod()
232
233
234
235
236
237
238
239
            if not self.is_checkpoint_fp8_serialized:
                online_method = Fp8OnlineLinearMethod(self)
                online_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
                return online_method
            else:
                offline_method = Fp8LinearMethod(self)
                offline_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
                return offline_method
240
        elif isinstance(layer, FusedMoE):
241
242
243
244
245
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
XuruiYang's avatar
XuruiYang committed
246
                return UnquantizedFusedMoEMethod(layer.moe_config)
247
248
249
250
            if self.is_checkpoint_fp8_serialized:
                moe_quant_method = Fp8MoEMethod(self, layer)
            else:
                moe_quant_method = Fp8OnlineMoEMethod(self, layer)
251
            return moe_quant_method
252
        elif isinstance(layer, Attention):
253
            return Fp8KVCacheMethod(self)
254
        return None
255

256
    def get_cache_scale(self, name: str) -> str | None:
257
258
259
260
261
262
263
264
265
266
267
268
        """
        Check whether the param name matches the format for k/v cache scales
        in compressed-tensors. If this is the case, return its equivalent
        param name expected by vLLM

        :param name: param name
        :return: matching param name for KV cache scale in vLLM
        """
        if name.endswith(".output_scale") and ".k_proj" in name:
            return name.replace(".k_proj.output_scale", ".attn.k_scale")
        if name.endswith(".output_scale") and ".v_proj" in name:
            return name.replace(".v_proj.output_scale", ".attn.v_scale")
269
270
271
272
273
        if name.endswith(".output_scale") and ".q_proj" in name:
            return name.replace(".q_proj.output_scale", ".attn.q_scale")
        if name.endswith("self_attn.prob_output_scale"):
            return name.replace(".prob_output_scale", ".attn.prob_scale")
        # If no matches, return None
274
275
        return None

276

277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
class CopyNumelCounter(TorchDispatchMode):
    """
    Tracks total number of elements modified with `copy_`. Useful for keeping
    track of weight loading where underlying weights can be arbitrarily
    transformed (such as with `narrow`) before calling copy.
    """

    def __init__(self):
        super().__init__()
        self.copied_numel = 0

    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        out = func(*args, **kwargs)
        if func == torch.ops.aten.copy_.default:
            self.copied_numel += args[0].numel()
        return out


297
298
299
300
301
302
303
304
305
306
def _copy_missing_attrs(old: torch.Tensor, new: torch.Tensor) -> None:
    """Copies any attrs present in `old` but not in `new` to `new`"""
    new_attrs = set(dir(new))
    attrs_to_set = {}
    for attr in dir(old):
        if attr not in new_attrs:
            attrs_to_set[attr] = getattr(old, attr)
    set_weight_attrs(new, attrs_to_set)


307
308
class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
309
310
311
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

312
    Limitations:
313
    1. Only support float8_e4m3fn data type due to the limitation of
314
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
315

316
317
318
319
    Args:
        quant_config: The quantization config.
    """

320
    def __init__(self, quant_config: Fp8Config):
321
        self.quant_config = quant_config
322
        self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
323
        self.out_dtype = torch.get_default_dtype()
324

325
326
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
327
        self.marlin_input_dtype = None
328
329
330
331
        self.use_marlin = (
            not current_platform.has_device_capability(89)
            or envs.VLLM_TEST_FORCE_FP8_MARLIN
        )
332
        # Disable marlin for rocm
333
        if current_platform.is_rocm():
334
            self.use_marlin = False
335
        if vllm_is_batch_invariant():
336
            self.use_marlin = False
337

338
        self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
339
        self.use_deep_gemm = is_deep_gemm_supported()
340

341
342
        self.weight_block_size = self.quant_config.weight_block_size
        self.block_quant = self.weight_block_size is not None
343
344
        self.act_q_static = self.quant_config.activation_scheme == "static"

345
346
347
348
349
        if self.block_quant:
            assert not self.act_q_static
            assert self.weight_block_size is not None
            self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
                weight_group_shape=GroupShape(*self.weight_block_size),
350
                act_quant_group_shape=GroupShape(1, self.weight_block_size[0]),
351
352
353
354
                cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
                use_aiter_and_is_supported=self.use_aiter_and_is_supported,
            )
        else:
355
356
357
358
359
360
361
362
363
364
365
366
367
            # Use per-token quantization for better perf if dynamic and cutlass
            if self.act_q_static:
                activation_quant_key = kFp8StaticTensorSym
            elif cutlass_fp8_supported():
                activation_quant_key = kFp8DynamicTokenSym
            else:
                activation_quant_key = kFp8DynamicTensorSym

            self.fp8_linear = init_fp8_linear_kernel(
                activation_quant_key=activation_quant_key,
                weight_quant_key=kFp8StaticTensorSym,
                out_dtype=torch.get_default_dtype(),
                module_name=self.__class__.__name__,
368
            )
369

370
371
372
373
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
374
        output_partition_sizes: list[int],
375
376
377
378
379
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
380
        output_size_per_partition = sum(output_partition_sizes)
381
        weight_loader = extra_weight_attrs.get("weight_loader")
382
383
384
385
386
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None
387

388
        if self.block_quant:
389
390
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
391
392
393
394
395
396
397
398
            validate_fp8_block_shape(
                layer,
                input_size,
                output_size,
                input_size_per_partition,
                output_partition_sizes,
                self.weight_block_size,
            )
399

400
401
402
403
404
405
406
407
408
409
410
411
412
        weight = create_fp8_weight_parameter(
            output_size_per_partition, input_size_per_partition, weight_loader
        )
        layer.register_parameter("weight", weight)

        # WEIGHT SCALE
        if not self.block_quant:
            scale = create_fp8_scale_parameter(
                PerTensorScaleParameter,
                output_partition_sizes,
                input_size_per_partition,
                None,
                weight_loader,
413
            )
414
            layer.register_parameter("weight_scale", scale)
415
        else:
416
417
418
419
420
421
422
423
            assert not self.act_q_static
            assert self.weight_block_size is not None
            scale = create_fp8_scale_parameter(
                BlockQuantScaleParameter,
                output_partition_sizes,
                input_size_per_partition,
                self.weight_block_size,
                weight_loader,
424
            )
425
426
            # The weight_scale_inv name is intentional for deepseekv3
            layer.register_parameter("weight_scale_inv", scale)
427

428
429
430
431
432
        # INPUT ACTIVATION SCALE
        if self.act_q_static:
            scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
            set_weight_attrs(scale, {"scale_type": "input_scale"})
            layer.register_parameter("input_scale", scale)
433

434
    def process_weights_after_loading(self, layer: Module) -> None:
435
        size_k_first = True
436
        input_scale = None
437
        # TODO(rob): refactor block quant into separate class.
438
        if self.block_quant:
439
            assert not self.act_q_static
440
            size_k_first = False
441

442
            weight, weight_scale_inv = process_fp8_weight_block_strategy(
443
444
                layer.weight, layer.weight_scale_inv
            )
445
446
447
448

            # Update layer with new values
            replace_parameter(layer, "weight", weight.data)
            replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data)
449

450
        # If checkpoint not serialized fp8, quantize the weights.
451
452
453
        else:
            # If checkpoint is fp8 per-tensor, handle that there are N scales for N
            # shards in a fused module
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
            weight = layer.weight
            weight_scale = layer.weight_scale

            # If using w8a8, torch._scaled_mm needs per tensor, so
            # requantize the logical shards as a single weight.
            if not self.use_marlin:
                weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy(
                    weight,
                    weight_scale,
                    layer.logical_widths,
                    getattr(layer, "input_scale", None),
                )
                if self.act_q_static:
                    assert input_scale is not None
                    input_scale = input_scale.max()
            weight = weight.t()
470

471
472
473
474
475
476
            # Update layer with new values.
            replace_parameter(layer, "weight", weight.data)
            replace_parameter(layer, "weight_scale", weight_scale.data)

        if input_scale is not None:
            replace_parameter(layer, "input_scale", input_scale)
477
        else:
478
            layer.input_scale = None
479

480
        if self.use_marlin:
481
482
483
            prepare_fp8_layer_for_marlin(
                layer, size_k_first, input_dtype=self.marlin_input_dtype
            )
484
485
            # Activations not quantized for marlin.
            del layer.input_scale
486
            return
487

488
        if self.block_quant:
489
            maybe_post_process_fp8_weight_block(layer)
490

491
492
493
494
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
495
        bias: torch.Tensor | None = None,
496
    ) -> torch.Tensor:
497
498
        # if batch invariant mode is enabled, prefer DeepGEMM FP8 path
        # we will use BF16 dequant when DeepGEMM is not supported.
499
        if vllm_is_batch_invariant():
500
501
            if self.block_quant:
                assert self.weight_block_size is not None
502
503
504
                return self.w8a8_block_fp8_linear.apply(
                    input=x,
                    weight=layer.weight,
505
                    weight_scale=layer.weight_scale_inv,
506
507
508
                    input_scale=layer.input_scale,
                    bias=bias,
                )
509
            else:
510
511
512
                # per-tensor/channel: dequant to BF16 and run GEMM
                weight_fp8 = layer.weight.to(torch.bfloat16)
                weight_scale = layer.weight_scale.to(torch.bfloat16)
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
                if weight_scale.numel() == 1:
                    # Per-tensor: simple scalar multiplication
                    weight_bf16 = weight_fp8 * weight_scale
                else:
                    # Multiple scales (fused modules like QKV)
                    # Try to infer correct broadcasting
                    # weight is [K, N], scale could be [num_logical_weights]
                    # Need to figure out how to broadcast - for now just try
                    # direct multiplication
                    if (
                        weight_scale.dim() == 1
                        and weight_scale.shape[0] == weight_fp8.shape[0]
                    ):
                        # Per-row scaling
                        weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1)
                    else:
                        # Fallback
                        weight_bf16 = weight_fp8 * weight_scale
531
                return torch.nn.functional.linear(x, weight_bf16.t(), bias)
532

533
        if self.use_marlin:
534
535
536
537
538
            if self.block_quant:
                weight_scale = layer.weight_scale_inv
            else:
                weight_scale = layer.weight_scale

539
540
541
            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
542
                weight_scale=weight_scale,
543
544
545
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
546
                input_dtype=self.marlin_input_dtype,
547
548
                bias=bias,
            )
549

550
        if self.block_quant:
551
552
553
            assert self.weight_block_size is not None

            return self.w8a8_block_fp8_linear.apply(
554
                input=x,
555
                weight=layer.weight,
556
                weight_scale=layer.weight_scale_inv,
557
                input_scale=layer.input_scale,
558
                bias=bias,
559
            )
560

561
        return self.fp8_linear.apply_weights(layer, x, bias)
562
563


564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
class Fp8OnlineLinearMethod(Fp8LinearMethod):
    """Online version of Fp8LinearMethod, loads the fp16/bf16 checkpoint
    and quantized the weights during loading."""

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

        # WEIGHT
        def patched_weight_loader(param, loaded_weight, *args, **kwargs):
            # track how many elements we have updated
            if not hasattr(layer, "_loaded_numel"):
                layer._loaded_numel = 0

592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
                # when the first `loaded_weight` is about to be
                # loaded to `param`, materialize `param` just-in-time
                weight = ModelWeightParameter(
                    data=torch.empty_like(layer.weight, device=layer._load_device),
                    input_dim=1,
                    output_dim=0,
                    weight_loader=patched_weight_loader,
                )
                _copy_missing_attrs(layer.weight, weight)
                layer.register_parameter("weight", weight)
                del layer._load_device

            # refresh the reference to `param` to reflect just-in-time
            # materialization
            param = layer.weight

608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
            # load the current weight chunk
            copy_numel_counter = CopyNumelCounter()
            with copy_numel_counter:
                res = weight_loader(param, loaded_weight, *args, **kwargs)  # type: ignore[misc]
            layer._loaded_numel += copy_numel_counter.copied_numel

            # if we have loaded all of the elements, call
            # process_weights_after_loading
            target_loaded_numel = layer.weight.numel()
            if layer._loaded_numel == target_loaded_numel:
                self.process_weights_after_loading(layer)

                # Prevent the usual `process_weights_after_loading` call from doing
                # anything
                layer._already_called_process_weights_after_loading = True

624
625
626
627
628
                # Note that we keep `layer._loaded_numel` around just in case
                # there is logic added to vllm in the future which calls a
                # weight loader twice - we do not want to re-initialize in
                # that case.

629
630
631
632
633
634
            return res

        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition,
635
636
                # materialized just-in-time in `patched_weight_loader`
                device="meta",
637
638
639
640
641
642
                dtype=params_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=patched_weight_loader,
        )
643
644
        # stash the correct device for `patched_weight_loader`
        layer._load_device = torch.get_default_device()
645
646
647
648
649
650
        layer.register_parameter("weight", weight)

    def process_weights_after_loading(self, layer: Module) -> None:
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

651
652
653
654
655
656
657
658
659
660
661
662
663
        # deferred initialization of randomly initialized weights for the
        # `--load_format dummy` feature
        if layer.weight.device == torch.device("meta"):
            weight = ModelWeightParameter(
                data=torch.empty_like(layer.weight, device=layer._load_device),
                input_dim=1,
                output_dim=0,
                weight_loader=layer.weight.weight_loader,
            )
            _copy_missing_attrs(layer.weight, weight)
            layer.register_parameter("weight", weight)
            initialize_single_dummy_weight(layer.weight)

664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
        # TODO(future): support block_quant in online quant path
        assert not self.block_quant

        layer.input_scale = None
        qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
        weight = qweight.t()

        # Update layer with new values.
        replace_parameter(layer, "weight", weight.data)
        replace_parameter(layer, "weight_scale", weight_scale.data)

        if self.use_marlin:
            size_k_first = True
            prepare_fp8_layer_for_marlin(
                layer, size_k_first, input_dtype=self.marlin_input_dtype
            )
            # Activations not quantized for marlin.


683
684
685
686
687
688
689
690
691
692
693
694
695
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

696
697
    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(layer.moe_config)
698
        self.quant_config = quant_config
699
        self.weight_block_size = self.quant_config.weight_block_size
700
        self.block_quant: bool = self.weight_block_size is not None
701
702
703
        self.weight_scale_name = (
            "weight_scale_inv" if self.block_quant else "weight_scale"
        )
704

705
706
707
708
709
710
711
712
713
714
        # Set weight key and activation key for kernel compatibility
        if self.block_quant:
            weight_key = kFp8Static128BlockSym
            activation_key = kFp8Dynamic128Sym
        else:
            weight_key = kFp8StaticTensorSym
            activation_key = (
                kFp8StaticTensorSym
                if self.quant_config.activation_scheme == "static"
                else kFp8DynamicTensorSym
715
            )
716

717
718
719
720
721
722
723
724
        # Select Fp8 MoE backend
        self.fp8_backend, self.experts_cls = select_fp8_moe_backend(
            config=self.moe,
            weight_key=weight_key,
            activation_key=activation_key,
            allow_vllm_cutlass=False,
        )

725
726
727
728
729
730
731
732
733
    def create_weights(
        self,
        layer: Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
734
735
736
737
738
739
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

740
741
742
        assert self.quant_config.is_checkpoint_fp8_serialized
        params_dtype = torch.float8_e4m3fn

743
        if self.block_quant:
744
745
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
746
747
            tp_size = get_tensor_model_parallel_world_size()
            block_n, block_k = (
748
749
                self.weight_block_size[0],
                self.weight_block_size[1],
750
751
752
753
754
            )
            # NOTE: To ensure proper alignment of the block-wise quantization
            # scales, the output_size of the weights for both the gate and up
            # layers must be divisible by block_n.
            # Required by column parallel or enabling merged weights
755
            if intermediate_size_per_partition % block_n != 0:
756
757
                raise ValueError(
                    f"The output_size of gate's and up's weight = "
758
                    f"{intermediate_size_per_partition} is not divisible by "
759
760
761
                    f"weight quantization block_n = {block_n}."
                )
            if tp_size > 1 and intermediate_size_per_partition % block_k != 0:
762
                # Required by row parallel
763
764
765
                raise ValueError(
                    f"The input_size of down's weight = "
                    f"{intermediate_size_per_partition} is not divisible by "
766
767
                    f"weight quantization block_k = {block_k}."
                )
768
769

        # WEIGHTS
770
771
772
773
774
775
776
777
778
        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
779
780
781
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

782
783
784
785
786
787
788
789
790
        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
791
792
793
794
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
795
        if not self.block_quant:
796
797
798
            # For per-tensor quant, the scales are per expert and weight.
            w13_scale_data = torch.ones(num_experts, 2, dtype=torch.float32)
            w2_scale_data = torch.ones(num_experts, dtype=torch.float32)
799
        else:
800
801
802
803
804
805
            # For block quant, the scales are per block (typically 128x128).
            w13_scale_data = torch.ones(
                num_experts,
                2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
                (hidden_size + block_k - 1) // block_k,
                dtype=torch.float32,
806
            )
807
808
809
810
811
            w2_scale_data = torch.ones(
                num_experts,
                (hidden_size + block_n - 1) // block_n,
                (intermediate_size_per_partition + block_k - 1) // block_k,
                dtype=torch.float32,
812
            )
813
814
815
816
817
        w13_weight_scale = torch.nn.Parameter(w13_scale_data, requires_grad=False)
        w2_weight_scale = torch.nn.Parameter(w2_scale_data, requires_grad=False)
        # Note: name is weight_scale for tensor, weight_scale_inv for block.
        layer.register_parameter(f"w13_{self.weight_scale_name}", w13_weight_scale)
        layer.register_parameter(f"w2_{self.weight_scale_name}", w2_weight_scale)
818

819
820
821
        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
822
823
824
825
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
            if self.block_quant
            else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
826
827
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)
828
829
830

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
831
            assert not self.block_quant
832
833
834
            w13_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
835
            layer.register_parameter("w13_input_scale", w13_input_scale)
836
            set_weight_attrs(w13_input_scale, extra_weight_attrs)
837

838
839
840
            w2_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
841
            layer.register_parameter("w2_input_scale", w2_input_scale)
842
843
            set_weight_attrs(w2_input_scale, extra_weight_attrs)

844
        else:
845
846
            layer.w13_input_scale = None
            layer.w2_input_scale = None
847

848
    def _setup_kernel(
849
        self,
850
        layer: FusedMoE,
851
852
853
854
        w13: torch.Tensor,
        w2: torch.Tensor,
        w13_scale: torch.Tensor,
        w2_scale: torch.Tensor,
855
856
        w13_input_scale: torch.Tensor | None,
        w2_input_scale: torch.Tensor | None,
857
    ) -> None:
858
859
860
861
862
863
864
865
866
867
868
        # Shuffle weights to runtime format.
        w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format(
            fp8_backend=self.fp8_backend,
            layer=layer,
            w13=w13,
            w2=w2,
            w13_scale=w13_scale,
            w2_scale=w2_scale,
            w13_input_scale=w13_input_scale,
            w2_input_scale=w2_input_scale,
        )
869

870
871
        # Replace parameters with updated versions. Note that this helper
        # function ensures the replacement is compatible with RL weight reloads.
872
873
874
875
        replace_parameter(layer, "w13_weight", w13)
        replace_parameter(layer, "w2_weight", w2)
        replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_scale)
        replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_scale)
876

877
878
879
880
        # Setup modular kernel for TP case and naive DP/EP case.
        # In non-naive DP/EP case, we will create a ModularKernelMethod.
        # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
        # in both cases.
881
        self.moe_quant_config = self.get_fused_moe_quant_config(layer)
882
        if self.moe_quant_config:
883
            assert self.experts_cls is not None
884
            self.moe_mk, self.use_inplace = make_fp8_moe_kernel(
885
886
887
                moe_quant_config=self.moe_quant_config,
                moe_config=self.moe,
                fp8_backend=self.fp8_backend,
888
                experts_cls=self.experts_cls,
889
890
                routing_tables=layer._maybe_init_expert_routing_tables(),
                shared_experts=layer.shared_experts,
891
            )
892

893
894
895
896
897
    def process_weights_after_loading(self, layer: Module) -> None:
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

        # Allow for accessing weights and scales in standard way.
898
899
900
901
        w13 = layer.w13_weight
        w2 = layer.w2_weight
        w13_scale = getattr(layer, f"w13_{self.weight_scale_name}")
        w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
902
903
        w13_input_scale = layer.w13_input_scale
        w2_input_scale = layer.w2_input_scale
904
905
906

        # MI300x and MI325x use FNUZ format for FP8. Convert if needed.
        if current_platform.is_fp8_fnuz():
907
908
909
910
            w13, w13_scale, w13_input_scale = normalize_e4m3fn_to_e4m3fnuz(
                w13,
                w13_scale,
                w13_input_scale,
911
            )
912
913
914
915
            w2, w2_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz(
                w2,
                w2_scale,
                w2_input_scale,
916
917
918
919
920
            )

        # Per tensor kernels require single activation scale. Use the max.
        if self.quant_config.activation_scheme == "static":
            assert not self.block_quant
921
            assert w13_input_scale is not None and w2_input_scale is not None
922
923
924
925
926
            w13_input_scale, w2_input_scale = process_fp8_input_tensor_strategy_moe(
                w13_input_scale, w2_input_scale
            )
            replace_parameter(layer, "w13_input_scale", w13_input_scale)
            replace_parameter(layer, "w2_input_scale", w2_input_scale)
927
928
929
930
931

        # Per tensor kernels require single weight scale for w13 per expert, but
        # on disk there is a scale for w1 and w3. Use the max to requantize.
        if not self.block_quant:
            shard_size = layer.intermediate_size_per_partition
932
933
934
            w13, w13_scale = process_fp8_weight_tensor_strategy_moe(
                w13, w13_scale, shard_size, layer.local_num_experts
            )
935

936
937
938
        # Shuffle weights to runtime format and setup kernel.
        self._setup_kernel(
            layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale
939
940
        )

941
942
943
944
    def maybe_make_prepare_finalize(
        self,
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
    ) -> mk.FusedMoEPrepareAndFinalize | None:
945
946
947
948
        raise ValueError(
            f"{self.__class__.__name__} uses the new modular kernel initialization "
            "logic. This function should not be called."
        )
949

bnellnm's avatar
bnellnm committed
950
951
952
    def select_gemm_impl(
        self,
        prepare_finalize: FusedMoEPrepareAndFinalize,
953
        layer: torch.nn.Module,
bnellnm's avatar
bnellnm committed
954
    ) -> FusedMoEPermuteExpertsUnpermute:
955
956
957
        raise ValueError(
            f"{self.__class__.__name__} uses the new modular kernel initialization "
            "logic. This function should not be called."
958
        )
959

960
    def get_fused_moe_quant_config(
961
        self, layer: torch.nn.Module
962
    ) -> FusedMoEQuantConfig | None:
963
964
965
966
967
968
969
970
971
        # TRTLLM does not use Modular Kernel.
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
            return None

        w1_scale = getattr(layer, f"w13_{self.weight_scale_name}")
        w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
        a1_scale = layer.w13_input_scale
        a2_scale = layer.w2_input_scale

972
973
        return make_fp8_moe_quant_config(
            fp8_backend=self.fp8_backend,
974
975
976
977
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
978
            block_shape=self.weight_block_size,
979
980
        )

981
982
983
984
985
986
987
988
    @property
    def supports_eplb(self) -> bool:
        return True

    @property
    def allow_inplace(self) -> bool:
        return True

989
990
991
992
993
    @property
    def is_monolithic(self) -> bool:
        return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM

    def apply_monolithic(
994
        self,
995
        layer: FusedMoE,
996
997
        x: torch.Tensor,
        router_logits: torch.Tensor,
998
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
999
1000
1001
1002
1003
1004
1005
1006
1007
        assert self.is_monolithic
        assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM

        # TODO(rob): convert this to MK.
        if layer.enable_eplb:
            raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.")
        assert layer.activation == "silu", (
            f"Expected 'silu' activation but got {layer.activation}"
        )
1008

1009
1010
        if self.block_quant:
            import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe  # noqa: E501, F401
1011

1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
            e_score_correction_bias = (
                layer.e_score_correction_bias.to(x.dtype)
                if layer.e_score_correction_bias is not None
                else None
            )
            routing_method_type = layer.routing_method_type
            return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
                routing_logits=router_logits.to(torch.float32)
                if routing_method_type == RoutingMethodType.DeepSeekV3
                else router_logits,
                routing_bias=e_score_correction_bias,
                x=x,
                w13_weight=layer.w13_weight,
                w13_weight_scale_inv=layer.w13_weight_scale_inv,
                w2_weight=layer.w2_weight,
                w2_weight_scale_inv=layer.w2_weight_scale_inv,
                global_num_experts=layer.global_num_experts,
                top_k=layer.top_k,
                num_expert_group=layer.num_expert_group,
                topk_group=layer.topk_group,
                intermediate_size=layer.intermediate_size_per_partition,
                expert_offset=layer.ep_rank * layer.local_num_experts,
                local_num_experts=layer.local_num_experts,
                block_shape=self.weight_block_size,
                routing_method_type=routing_method_type,
                routed_scaling=layer.routed_scaling_factor,
            )
        else:
            return apply_fi_trtllm_fp8_per_tensor_moe(
                layer=layer,
                hidden_states=x,
                router_logits=router_logits,
                routing_bias=layer.e_score_correction_bias,
                global_num_experts=layer.global_num_experts,
                top_k=layer.top_k,
                num_expert_group=layer.num_expert_group,
                topk_group=layer.topk_group,
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
            )
1051

1052
1053
1054
1055
1056
1057
1058
    def apply(
        self,
        layer: FusedMoE,
        x: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1059
        assert self.moe_mk is not None
1060
        assert not self.is_monolithic
1061
        return self.moe_mk(
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
            x,
            layer.w13_weight,
            layer.w2_weight,
            topk_weights,
            topk_ids,
            inplace=self.use_inplace,
            activation=layer.activation,
            global_num_experts=layer.global_num_experts,
            expert_map=layer.expert_map,
            apply_router_weight_on_input=layer.apply_router_weight_on_input,
        )
1073

1074

1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
class Fp8OnlineMoEMethod(Fp8MoEMethod):
    """MoE method for online FP8 quantization.
    Supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(quant_config, layer)
        assert not quant_config.is_checkpoint_fp8_serialized
        assert quant_config.activation_scheme == "dynamic"
        assert quant_config.weight_block_size is None

    def create_weights(
        self,
        layer: Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

        # We are doing online quantization, patch the weight loaded
        # to call `process_weights_after_loading` in a streaming fashion
        # as soon as the last weight chunk is loaded.
        weight_loader = extra_weight_attrs["weight_loader"]
        # create a new holder to prevent modifying behavior of any other
        # objects which might depend on the old one
        new_extra_weight_attrs = extra_weight_attrs

        def patched_weight_loader(param, loaded_weight, *args, **kwargs):
            # add a counter to track how many elements we have updated
            if not hasattr(layer, "_loaded_numel"):
                layer._loaded_numel = 0
1118

1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
                # save the ids of original w13 and w2 so that we can
                # distinguish which one `param` should map to further
                # down in this file
                layer._w13_weight_orig_id = id(layer.w13_weight)
                layer._w2_weight_orig_id = id(layer.w2_weight)

                # when the first `loaded_weight` is about to be
                # loaded to `param`, materialize `param` just-in-time

                w13_weight = torch.nn.Parameter(
                    torch.empty_like(layer.w13_weight, device=layer._load_device),
                    requires_grad=False,
                )
                set_weight_attrs(w13_weight, extra_weight_attrs)
                _copy_missing_attrs(layer.w13_weight, w13_weight)
                layer.register_parameter("w13_weight", w13_weight)

                w2_weight = torch.nn.Parameter(
                    torch.empty_like(layer.w2_weight, device=layer._load_device),
                    requires_grad=False,
                )
                set_weight_attrs(w2_weight, extra_weight_attrs)
                _copy_missing_attrs(layer.w2_weight, w2_weight)
                layer.register_parameter("w2_weight", w2_weight)
                del layer._load_device

            # refresh the reference to `param` to reflect just-in-time
            # materialization
            if id(param) == layer._w13_weight_orig_id:
                param = layer.w13_weight
            elif id(param) == layer._w2_weight_orig_id:
                param = layer.w2_weight

1152
1153
1154
1155
1156
            # load the current weight chunk
            copy_numel_counter = CopyNumelCounter()
            with copy_numel_counter:
                res = weight_loader(param, loaded_weight, *args, **kwargs)  # type: ignore[misc]
            layer._loaded_numel += copy_numel_counter.copied_numel
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167

            # if we have loaded all of the elements, call
            # process_weights_after_loading
            target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel()
            if layer._loaded_numel == target_loaded_numel:
                self.process_weights_after_loading(layer)

                # Prevent the usual `process_weights_after_loading` call
                # from doing anything
                layer._already_called_process_weights_after_loading = True

1168
1169
1170
1171
1172
1173
                # Note that we keep `layer._loaded_numel`,
                # `layer._w13_weight_orig_id` and `layer._w2_weight_orig_id`
                # around because if EP is on, weight loaders for non-local
                # experts will run but not actually copy any elements, and we
                # need to not re-initialize in that case.

1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
            return res

        new_extra_weight_attrs["weight_loader"] = patched_weight_loader
        extra_weight_attrs = new_extra_weight_attrs

        # WEIGHTS
        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
1185
1186
                # materialized just-in-time in `patched_weight_loader`
                device="meta",
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
1199
1200
                # materialized just-in-time in `patched_weight_loader`
                device="meta",
1201
1202
1203
1204
1205
1206
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)
1207
1208
        # stash the correct device for `patched_weight_loader`
        layer._load_device = torch.get_default_device()
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220

        # WEIGHT_SCALES
        # Allocate 2 scales for w1 and w3 respectively.
        # They will be combined to a single scale after weight loading.
        w13_weight_scale = torch.nn.Parameter(
            torch.ones(num_experts, dtype=torch.float32), requires_grad=False
        )
        w2_weight_scale = torch.nn.Parameter(
            torch.ones(num_experts, dtype=torch.float32), requires_grad=False
        )
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        layer.register_parameter("w2_weight_scale", w2_weight_scale)
1221
1222
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)
1223
1224
1225
1226
1227
1228
1229
1230

        layer.w13_input_scale = None
        layer.w2_input_scale = None

    def process_weights_after_loading(self, layer: Module) -> None:
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
        # deferred initialization of randomly initialized weights for the
        # `--load_format dummy` feature
        if layer.w13_weight.device == torch.device("meta"):
            w13_weight = torch.nn.Parameter(
                torch.empty_like(layer.w13_weight, device=layer._load_device),
                requires_grad=False,
            )
            set_weight_attrs(
                w13_weight, {"weight_loader": layer.w13_weight.weight_loader}
            )
            _copy_missing_attrs(layer.w13_weight, w13_weight)
            layer.register_parameter("w13_weight", w13_weight)
            initialize_single_dummy_weight(layer.w13_weight)
        if layer.w2_weight.device == torch.device("meta"):
            w2_weight = torch.nn.Parameter(
                torch.empty_like(layer.w2_weight, device=layer._load_device),
                requires_grad=False,
            )
            set_weight_attrs(
                w2_weight, {"weight_loader": layer.w2_weight.weight_loader}
            )
            _copy_missing_attrs(layer.w2_weight, w2_weight)
            layer.register_parameter("w2_weight", w2_weight)
            initialize_single_dummy_weight(layer.w2_weight)

1256
1257
        # If checkpoint is fp16, quantize in place.
        fp8_dtype = current_platform.fp8_dtype()
1258
1259
1260
1261
        w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
        w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
        w13_scale = layer.w13_weight_scale
        w2_scale = layer.w2_weight_scale
1262
1263

        for expert in range(layer.local_num_experts):
1264
1265
            w13[expert, :, :], w13_scale[expert] = ops.scaled_fp8_quant(
                layer.w13_weight[expert, :, :]
1266
            )
1267
1268
            w2[expert, :, :], w2_scale[expert] = ops.scaled_fp8_quant(
                layer.w2_weight[expert, :, :]
1269
1270
            )

1271
1272
1273
1274
1275
1276
1277
1278
1279
        # Shuffle weights to runtime format and setup kernel.
        self._setup_kernel(
            layer,
            w13,
            w2,
            w13_scale,
            w2_scale,
            layer.w13_input_scale,
            layer.w2_input_scale,
1280
        )
1281
1282


1283
1284
1285
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
1286
1287
1288
    """

    def __init__(self, quant_config: Fp8Config):
1289
        super().__init__(quant_config)