fp8.py 47.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import TYPE_CHECKING, Any
5
6
7

import torch
from torch.nn import Module
8
from torch.utils._python_dispatch import TorchDispatchMode
9

10
import vllm.envs as envs
11
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
12
from vllm import _custom_ops as ops
13
from vllm._aiter_ops import rocm_aiter_ops
14
from vllm.distributed import get_tensor_model_parallel_world_size
15
from vllm.logger import init_logger
16
17
18
from vllm.model_executor.kernels.linear import (
    init_fp8_linear_kernel,
)
19
from vllm.model_executor.kernels.linear.scaled_mm import MarlinFP8ScaledMMLinearKernel
20
from vllm.model_executor.layers.attention import Attention
bnellnm's avatar
bnellnm committed
21
from vllm.model_executor.layers.fused_moe import (
22
23
24
25
    FusedMoE,
    FusedMoEMethodBase,
    FusedMoeWeightScaleSupported,
)
26
from vllm.model_executor.layers.fused_moe.config import (
27
28
29
    FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
30
31
32
33
34
35
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
    convert_to_fp8_moe_kernel_format,
    make_fp8_moe_kernel,
    make_fp8_moe_quant_config,
    select_fp8_moe_backend,
)
36
37
38
39
40
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
41
from vllm.model_executor.layers.quantization import QuantizationMethods
42
from vllm.model_executor.layers.quantization.base_config import (
43
44
45
    QuantizationConfig,
    QuantizeMethodBase,
)
46
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
47
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
48
49
50
51
52
    W8A8BlockFp8LinearOp,
    create_fp8_input_scale,
    create_fp8_scale_parameter,
    create_fp8_weight_parameter,
    maybe_post_process_fp8_weight_block,
53
    process_fp8_input_tensor_strategy_moe,
54
55
    process_fp8_weight_block_strategy,
    process_fp8_weight_tensor_strategy,
56
    process_fp8_weight_tensor_strategy_moe,
57
58
    validate_fp8_block_shape,
)
59
60
61
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
    get_marlin_input_dtype,
)
62
from vllm.model_executor.layers.quantization.utils.quant_utils import (
63
64
    GroupShape,
    is_layer_skipped,
65
    kFp8Dynamic128Sym,
66
67
    kFp8DynamicTensorSym,
    kFp8DynamicTokenSym,
68
    kFp8Static128BlockSym,
69
    kFp8StaticTensorSym,
70
)
71
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
72
73
74
75
    cutlass_block_fp8_supported,
    cutlass_fp8_supported,
    normalize_e4m3fn_to_e4m3fnuz,
)
76
from vllm.model_executor.model_loader.weight_utils import initialize_single_dummy_weight
77
78
79
80
81
from vllm.model_executor.parameter import (
    BlockQuantScaleParameter,
    ModelWeightParameter,
    PerTensorScaleParameter,
)
82
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
83
from vllm.platforms import current_platform
84
85
86
from vllm.utils.deep_gemm import (
    is_deep_gemm_supported,
)
87

88
89
90
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

91
92
93
94
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

95

96
class Fp8Config(QuantizationConfig):
97
98
    """Config class for FP8."""

99
100
    def __init__(
        self,
101
        is_checkpoint_fp8_serialized: bool = False,
102
        activation_scheme: str = "dynamic",
103
104
        ignored_layers: list[str] | None = None,
        weight_block_size: list[int] | None = None,
105
    ) -> None:
106
        super().__init__()
107

108
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
109

110
        if activation_scheme not in ACTIVATION_SCHEMES:
111
            raise ValueError(f"Unsupported activation scheme {activation_scheme}")
112
        self.activation_scheme = activation_scheme
113
        self.ignored_layers = ignored_layers or []
114
115
116
117
        if weight_block_size is not None:
            if not is_checkpoint_fp8_serialized:
                raise ValueError(
                    "The block-wise quantization only supports fp8-serialized "
118
119
                    "checkpoint for now."
                )
120
121
122
            if len(weight_block_size) != 2:
                raise ValueError(
                    "The quantization block size of weight must have 2 "
123
124
                    f"dimensions, but got {len(weight_block_size)} dimensions"
                )
125
            if activation_scheme != "dynamic":
126
127
128
129
130
                raise ValueError(
                    "The block-wise quantization only supports "
                    "dynamic activation scheme for now, but got "
                    f"{activation_scheme} activation scheme."
                )
131
        self.weight_block_size = weight_block_size
132
        self.use_deep_gemm: bool | None = None
133

134
    @classmethod
135
    def get_name(cls) -> QuantizationMethods:
136
137
138
        return "fp8"

    @classmethod
139
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
140
141
142
143
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
144
        return 75
145
146

    @classmethod
147
    def get_config_filenames(cls) -> list[str]:
148
149
        return []

150
151
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.ignored_layers is not None:
152
            self.ignored_layers = hf_to_vllm_mapper.apply_list(self.ignored_layers)
153

154
    @classmethod
155
    def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
156
        quant_method = cls.get_from_keys(config, ["quant_method"])
157
        is_checkpoint_fp8_serialized = "fp8" in quant_method
158
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
159
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
160
        weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
161
        if not ignored_layers:
162
163
164
165
166
167
168
169
170
171
172
173
            ignored_layers = cls.get_from_keys_or(
                config, ["modules_to_not_convert"], None
            )
        return cls(
            is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
            activation_scheme=activation_scheme,
            ignored_layers=ignored_layers,
            weight_block_size=weight_block_size,
        )

    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
174
    ) -> "QuantizeMethodBase | None":
175
        if isinstance(layer, LinearBase):
176
177
178
179
180
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
181
                return UnquantizedLinearMethod()
182
183
184
185
186
187
188
189
            if not self.is_checkpoint_fp8_serialized:
                online_method = Fp8OnlineLinearMethod(self)
                online_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
                return online_method
            else:
                offline_method = Fp8LinearMethod(self)
                offline_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
                return offline_method
190
        elif isinstance(layer, FusedMoE):
191
192
193
194
195
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
XuruiYang's avatar
XuruiYang committed
196
                return UnquantizedFusedMoEMethod(layer.moe_config)
197
198
199
200
            if self.is_checkpoint_fp8_serialized:
                moe_quant_method = Fp8MoEMethod(self, layer)
            else:
                moe_quant_method = Fp8OnlineMoEMethod(self, layer)
201
            return moe_quant_method
202
        elif isinstance(layer, Attention):
203
            return Fp8KVCacheMethod(self)
204
        return None
205

206
    def get_cache_scale(self, name: str) -> str | None:
207
208
209
210
211
212
213
214
215
216
217
218
        """
        Check whether the param name matches the format for k/v cache scales
        in compressed-tensors. If this is the case, return its equivalent
        param name expected by vLLM

        :param name: param name
        :return: matching param name for KV cache scale in vLLM
        """
        if name.endswith(".output_scale") and ".k_proj" in name:
            return name.replace(".k_proj.output_scale", ".attn.k_scale")
        if name.endswith(".output_scale") and ".v_proj" in name:
            return name.replace(".v_proj.output_scale", ".attn.v_scale")
219
220
221
222
223
        if name.endswith(".output_scale") and ".q_proj" in name:
            return name.replace(".q_proj.output_scale", ".attn.q_scale")
        if name.endswith("self_attn.prob_output_scale"):
            return name.replace(".prob_output_scale", ".attn.prob_scale")
        # If no matches, return None
224
225
        return None

226

227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
class CopyNumelCounter(TorchDispatchMode):
    """
    Tracks total number of elements modified with `copy_`. Useful for keeping
    track of weight loading where underlying weights can be arbitrarily
    transformed (such as with `narrow`) before calling copy.
    """

    def __init__(self):
        super().__init__()
        self.copied_numel = 0

    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        out = func(*args, **kwargs)
        if func == torch.ops.aten.copy_.default:
            self.copied_numel += args[0].numel()
        return out


247
248
249
250
251
252
253
254
255
256
def _copy_missing_attrs(old: torch.Tensor, new: torch.Tensor) -> None:
    """Copies any attrs present in `old` but not in `new` to `new`"""
    new_attrs = set(dir(new))
    attrs_to_set = {}
    for attr in dir(old):
        if attr not in new_attrs:
            attrs_to_set[attr] = getattr(old, attr)
    set_weight_attrs(new, attrs_to_set)


257
258
class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
259
260
261
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

262
    Limitations:
263
    1. Only support float8_e4m3fn data type due to the limitation of
264
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
265

266
267
268
269
    Args:
        quant_config: The quantization config.
    """

270
    def __init__(self, quant_config: Fp8Config):
271
        self.quant_config = quant_config
272
        self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
273
        self.out_dtype = torch.get_default_dtype()
274

275
276
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
277
        self.marlin_input_dtype = None
278

279
        self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
280
281
282
283
        if self.quant_config.use_deep_gemm is not None:
            self.use_deep_gemm = self.quant_config.use_deep_gemm
        else:
            self.use_deep_gemm = is_deep_gemm_supported()
284

285
286
        self.weight_block_size = self.quant_config.weight_block_size
        self.block_quant = self.weight_block_size is not None
287
288
        self.act_q_static = self.quant_config.activation_scheme == "static"

289
290
291
292
293
294
295
296
        # Use per-token quantization for better perf if dynamic and cutlass
        if self.act_q_static:
            activation_quant_key = kFp8StaticTensorSym
        elif cutlass_fp8_supported():
            activation_quant_key = kFp8DynamicTokenSym
        else:
            activation_quant_key = kFp8DynamicTensorSym

297
        if self.block_quant:
298
299
300
301
302
303
304
305
306
307
308
309
310
            weight_quant_key = kFp8Static128BlockSym
        else:
            weight_quant_key = kFp8StaticTensorSym

        self.fp8_linear = init_fp8_linear_kernel(
            activation_quant_key=activation_quant_key,
            weight_quant_key=weight_quant_key,
            out_dtype=torch.get_default_dtype(),
            module_name=self.__class__.__name__,
        )
        self.use_marlin = isinstance(self.fp8_linear, MarlinFP8ScaledMMLinearKernel)

        if self.block_quant and not self.use_marlin:
311
312
313
314
            assert not self.act_q_static
            assert self.weight_block_size is not None
            self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
                weight_group_shape=GroupShape(*self.weight_block_size),
315
                act_quant_group_shape=GroupShape(1, self.weight_block_size[0]),
316
317
                cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
                use_aiter_and_is_supported=self.use_aiter_and_is_supported,
318
                use_deep_gemm=self.use_deep_gemm,
319
            )
320

321
322
323
324
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
325
        output_partition_sizes: list[int],
326
327
328
329
330
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
331
        output_size_per_partition = sum(output_partition_sizes)
332
        weight_loader = extra_weight_attrs.get("weight_loader")
333
334
335
336
337
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None
338

339
        if self.block_quant:
340
341
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
342
343
344
345
346
347
348
349
            validate_fp8_block_shape(
                layer,
                input_size,
                output_size,
                input_size_per_partition,
                output_partition_sizes,
                self.weight_block_size,
            )
350

351
352
353
354
355
356
357
358
359
360
361
362
363
        weight = create_fp8_weight_parameter(
            output_size_per_partition, input_size_per_partition, weight_loader
        )
        layer.register_parameter("weight", weight)

        # WEIGHT SCALE
        if not self.block_quant:
            scale = create_fp8_scale_parameter(
                PerTensorScaleParameter,
                output_partition_sizes,
                input_size_per_partition,
                None,
                weight_loader,
364
            )
365
            layer.register_parameter("weight_scale", scale)
366
        else:
367
368
369
370
371
372
373
374
            assert not self.act_q_static
            assert self.weight_block_size is not None
            scale = create_fp8_scale_parameter(
                BlockQuantScaleParameter,
                output_partition_sizes,
                input_size_per_partition,
                self.weight_block_size,
                weight_loader,
375
            )
376
377
            # The weight_scale_inv name is intentional for deepseekv3
            layer.register_parameter("weight_scale_inv", scale)
378

379
380
381
382
383
        # INPUT ACTIVATION SCALE
        if self.act_q_static:
            scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
            set_weight_attrs(scale, {"scale_type": "input_scale"})
            layer.register_parameter("input_scale", scale)
384

385
    def process_weights_after_loading(self, layer: Module) -> None:
386
387
388
389
390
391
392
393
        if self.use_marlin:
            # Only Marlin kernels support `marlin_input_dtype`; guard to avoid
            # AttributeError if backend selection changes.
            if hasattr(self.fp8_linear, "marlin_input_dtype"):
                self.fp8_linear.marlin_input_dtype = self.marlin_input_dtype
            self.fp8_linear.process_weights_after_loading(layer)
            return

394
        input_scale = None
395
        # TODO(rob): refactor block quant into separate class.
396
        if self.block_quant:
397
            assert not self.act_q_static
398

399
            weight, weight_scale_inv = process_fp8_weight_block_strategy(
400
401
                layer.weight, layer.weight_scale_inv
            )
402
403
404
405

            # Update layer with new values
            replace_parameter(layer, "weight", weight.data)
            replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data)
406

407
        # If checkpoint not serialized fp8, quantize the weights.
408
409
410
        else:
            # If checkpoint is fp8 per-tensor, handle that there are N scales for N
            # shards in a fused module
411
412
413
414
415
            weight = layer.weight
            weight_scale = layer.weight_scale

            # If using w8a8, torch._scaled_mm needs per tensor, so
            # requantize the logical shards as a single weight.
416
417
418
419
420
421
422
423
424
            weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy(
                weight,
                weight_scale,
                layer.logical_widths,
                getattr(layer, "input_scale", None),
            )
            if self.act_q_static:
                assert input_scale is not None
                input_scale = input_scale.max()
425
            weight = weight.t()
426

427
428
429
430
431
432
            # Update layer with new values.
            replace_parameter(layer, "weight", weight.data)
            replace_parameter(layer, "weight_scale", weight_scale.data)

        if input_scale is not None:
            replace_parameter(layer, "input_scale", input_scale)
433
        else:
434
            layer.input_scale = None
435

436
        if self.block_quant and self.use_deep_gemm:
437
            maybe_post_process_fp8_weight_block(layer)
438

439
440
441
442
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
443
        bias: torch.Tensor | None = None,
444
    ) -> torch.Tensor:
445
446
        # if batch invariant mode is enabled, prefer DeepGEMM FP8 path
        # we will use BF16 dequant when DeepGEMM is not supported.
447
        if envs.VLLM_BATCH_INVARIANT:
448
449
            if self.block_quant:
                assert self.weight_block_size is not None
450
451
452
                return self.w8a8_block_fp8_linear.apply(
                    input=x,
                    weight=layer.weight,
453
                    weight_scale=layer.weight_scale_inv,
454
455
456
                    input_scale=layer.input_scale,
                    bias=bias,
                )
457
            else:
458
459
460
                # per-tensor/channel: dequant to BF16 and run GEMM
                weight_fp8 = layer.weight.to(torch.bfloat16)
                weight_scale = layer.weight_scale.to(torch.bfloat16)
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
                if weight_scale.numel() == 1:
                    # Per-tensor: simple scalar multiplication
                    weight_bf16 = weight_fp8 * weight_scale
                else:
                    # Multiple scales (fused modules like QKV)
                    # Try to infer correct broadcasting
                    # weight is [K, N], scale could be [num_logical_weights]
                    # Need to figure out how to broadcast - for now just try
                    # direct multiplication
                    if (
                        weight_scale.dim() == 1
                        and weight_scale.shape[0] == weight_fp8.shape[0]
                    ):
                        # Per-row scaling
                        weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1)
                    else:
                        # Fallback
                        weight_bf16 = weight_fp8 * weight_scale
479
                return torch.nn.functional.linear(x, weight_bf16.t(), bias)
480

481
        if self.use_marlin:
482
            return self.fp8_linear.apply_weights(layer, x, bias)
483

484
        if self.block_quant:
485
486
487
            assert self.weight_block_size is not None

            return self.w8a8_block_fp8_linear.apply(
488
                input=x,
489
                weight=layer.weight,
490
                weight_scale=layer.weight_scale_inv,
491
                input_scale=layer.input_scale,
492
                bias=bias,
493
            )
494

495
        return self.fp8_linear.apply_weights(layer, x, bias)
496
497


498
499
500
501
class Fp8OnlineLinearMethod(Fp8LinearMethod):
    """Online version of Fp8LinearMethod, loads the fp16/bf16 checkpoint
    and quantized the weights during loading."""

502
503
    uses_meta_device: bool = True

504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

        # WEIGHT
        def patched_weight_loader(param, loaded_weight, *args, **kwargs):
            # track how many elements we have updated
            if not hasattr(layer, "_loaded_numel"):
                layer._loaded_numel = 0

528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
                # when the first `loaded_weight` is about to be
                # loaded to `param`, materialize `param` just-in-time
                weight = ModelWeightParameter(
                    data=torch.empty_like(layer.weight, device=layer._load_device),
                    input_dim=1,
                    output_dim=0,
                    weight_loader=patched_weight_loader,
                )
                _copy_missing_attrs(layer.weight, weight)
                layer.register_parameter("weight", weight)
                del layer._load_device

            # refresh the reference to `param` to reflect just-in-time
            # materialization
            param = layer.weight

544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
            # load the current weight chunk
            copy_numel_counter = CopyNumelCounter()
            with copy_numel_counter:
                res = weight_loader(param, loaded_weight, *args, **kwargs)  # type: ignore[misc]
            layer._loaded_numel += copy_numel_counter.copied_numel

            # if we have loaded all of the elements, call
            # process_weights_after_loading
            target_loaded_numel = layer.weight.numel()
            if layer._loaded_numel == target_loaded_numel:
                self.process_weights_after_loading(layer)

                # Prevent the usual `process_weights_after_loading` call from doing
                # anything
                layer._already_called_process_weights_after_loading = True

560
561
562
563
564
                # Note that we keep `layer._loaded_numel` around just in case
                # there is logic added to vllm in the future which calls a
                # weight loader twice - we do not want to re-initialize in
                # that case.

565
566
567
568
569
570
            return res

        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition,
571
572
                # materialized just-in-time in `patched_weight_loader`
                device="meta",
573
574
575
576
577
578
                dtype=params_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=patched_weight_loader,
        )
579
580
        # stash the correct device for `patched_weight_loader`
        layer._load_device = torch.get_default_device()
581
582
583
584
585
586
        layer.register_parameter("weight", weight)

    def process_weights_after_loading(self, layer: Module) -> None:
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

587
588
589
590
591
592
593
594
595
596
597
598
599
        # deferred initialization of randomly initialized weights for the
        # `--load_format dummy` feature
        if layer.weight.device == torch.device("meta"):
            weight = ModelWeightParameter(
                data=torch.empty_like(layer.weight, device=layer._load_device),
                input_dim=1,
                output_dim=0,
                weight_loader=layer.weight.weight_loader,
            )
            _copy_missing_attrs(layer.weight, weight)
            layer.register_parameter("weight", weight)
            initialize_single_dummy_weight(layer.weight)

600
601
602
603
604
605
606
        # TODO(future): support block_quant in online quant path
        assert not self.block_quant

        layer.input_scale = None
        qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)

        # Update layer with new values.
607
        replace_parameter(layer, "weight", qweight.data)
608
609
610
        replace_parameter(layer, "weight_scale", weight_scale.data)

        if self.use_marlin:
611
612
613
614
615
616
617
618
            # Only Marlin kernels support `marlin_input_dtype`; guard to avoid
            # AttributeError if backend selection changes.
            if hasattr(self.fp8_linear, "marlin_input_dtype"):
                self.fp8_linear.marlin_input_dtype = self.marlin_input_dtype
            self.fp8_linear.process_weights_after_loading(layer)
        else:
            weight = qweight.t()
            replace_parameter(layer, "weight", weight.data)
619

620
621
622
        # Prevent duplicate processing (e.g., during weight reload)
        layer._already_called_process_weights_after_loading = True

623

624
625
626
627
628
629
630
631
632
633
634
635
636
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

637
638
    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(layer.moe_config)
639
        self.quant_config = quant_config
640
        self.weight_block_size = self.quant_config.weight_block_size
641
        self.block_quant: bool = self.weight_block_size is not None
642
643
644
        self.weight_scale_name = (
            "weight_scale_inv" if self.block_quant else "weight_scale"
        )
645

646
647
648
649
650
651
652
653
654
655
        # Set weight key and activation key for kernel compatibility
        if self.block_quant:
            weight_key = kFp8Static128BlockSym
            activation_key = kFp8Dynamic128Sym
        else:
            weight_key = kFp8StaticTensorSym
            activation_key = (
                kFp8StaticTensorSym
                if self.quant_config.activation_scheme == "static"
                else kFp8DynamicTensorSym
656
            )
657

658
659
660
661
662
663
664
665
        # Select Fp8 MoE backend
        self.fp8_backend, self.experts_cls = select_fp8_moe_backend(
            config=self.moe,
            weight_key=weight_key,
            activation_key=activation_key,
            allow_vllm_cutlass=False,
        )

666
667
668
669
670
671
672
673
674
    def create_weights(
        self,
        layer: Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
675
676
677
678
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

679
680
681
        assert self.quant_config.is_checkpoint_fp8_serialized
        params_dtype = torch.float8_e4m3fn

682
        if self.block_quant:
683
684
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
685
686
            tp_size = get_tensor_model_parallel_world_size()
            block_n, block_k = (
687
688
                self.weight_block_size[0],
                self.weight_block_size[1],
689
690
691
692
693
            )
            # NOTE: To ensure proper alignment of the block-wise quantization
            # scales, the output_size of the weights for both the gate and up
            # layers must be divisible by block_n.
            # Required by column parallel or enabling merged weights
694
            if intermediate_size_per_partition % block_n != 0:
695
696
                raise ValueError(
                    f"The output_size of gate's and up's weight = "
697
                    f"{intermediate_size_per_partition} is not divisible by "
698
699
700
                    f"weight quantization block_n = {block_n}."
                )
            if tp_size > 1 and intermediate_size_per_partition % block_k != 0:
701
                # Required by row parallel
702
703
704
                raise ValueError(
                    f"The input_size of down's weight = "
                    f"{intermediate_size_per_partition} is not divisible by "
705
706
                    f"weight quantization block_k = {block_k}."
                )
707
708

        # WEIGHTS
709
710
711
712
713
714
715
716
717
        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
718
719
720
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

721
722
723
724
725
726
727
728
729
        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
730
731
732
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
        # BIASES (for models like GPT-OSS that have biased MoE)
        if self.moe.has_bias:
            w13_bias = torch.nn.Parameter(
                torch.zeros(
                    num_experts,
                    2 * intermediate_size_per_partition,
                    dtype=layer.orig_dtype,
                ),
                requires_grad=False,
            )
            layer.register_parameter("w13_bias", w13_bias)
            set_weight_attrs(w13_bias, extra_weight_attrs)
            w2_bias = torch.nn.Parameter(
                torch.zeros(num_experts, hidden_size, dtype=layer.orig_dtype),
                requires_grad=False,
            )
            layer.register_parameter("w2_bias", w2_bias)
            set_weight_attrs(w2_bias, extra_weight_attrs)

752
        # WEIGHT_SCALES
753
        if not self.block_quant:
754
755
756
            # For per-tensor quant, the scales are per expert and weight.
            w13_scale_data = torch.ones(num_experts, 2, dtype=torch.float32)
            w2_scale_data = torch.ones(num_experts, dtype=torch.float32)
757
        else:
758
759
760
761
762
763
            # For block quant, the scales are per block (typically 128x128).
            w13_scale_data = torch.ones(
                num_experts,
                2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
                (hidden_size + block_k - 1) // block_k,
                dtype=torch.float32,
764
            )
765
766
767
768
769
            w2_scale_data = torch.ones(
                num_experts,
                (hidden_size + block_n - 1) // block_n,
                (intermediate_size_per_partition + block_k - 1) // block_k,
                dtype=torch.float32,
770
            )
771
772
773
774
775
        w13_weight_scale = torch.nn.Parameter(w13_scale_data, requires_grad=False)
        w2_weight_scale = torch.nn.Parameter(w2_scale_data, requires_grad=False)
        # Note: name is weight_scale for tensor, weight_scale_inv for block.
        layer.register_parameter(f"w13_{self.weight_scale_name}", w13_weight_scale)
        layer.register_parameter(f"w2_{self.weight_scale_name}", w2_weight_scale)
776

777
778
779
        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
780
781
782
783
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
            if self.block_quant
            else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
784
785
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)
786
787
788

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
789
            assert not self.block_quant
790
791
792
            w13_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
793
            layer.register_parameter("w13_input_scale", w13_input_scale)
794
            set_weight_attrs(w13_input_scale, extra_weight_attrs)
795

796
797
798
            w2_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
799
            layer.register_parameter("w2_input_scale", w2_input_scale)
800
801
            set_weight_attrs(w2_input_scale, extra_weight_attrs)

802
        else:
803
804
            layer.w13_input_scale = None
            layer.w2_input_scale = None
805

806
    def _setup_kernel(
807
        self,
808
        layer: FusedMoE,
809
810
811
812
        w13: torch.Tensor,
        w2: torch.Tensor,
        w13_scale: torch.Tensor,
        w2_scale: torch.Tensor,
813
814
        w13_input_scale: torch.Tensor | None,
        w2_input_scale: torch.Tensor | None,
815
    ) -> None:
816
817
818
819
820
821
822
823
824
825
826
        # Shuffle weights to runtime format.
        w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format(
            fp8_backend=self.fp8_backend,
            layer=layer,
            w13=w13,
            w2=w2,
            w13_scale=w13_scale,
            w2_scale=w2_scale,
            w13_input_scale=w13_input_scale,
            w2_input_scale=w2_input_scale,
        )
827

828
829
        # Replace parameters with updated versions. Note that this helper
        # function ensures the replacement is compatible with RL weight reloads.
830
831
832
833
        replace_parameter(layer, "w13_weight", w13)
        replace_parameter(layer, "w2_weight", w2)
        replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_scale)
        replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_scale)
834

835
        self.moe_quant_config = self.get_fused_moe_quant_config(layer)
836
        if self.moe_quant_config:
837
            assert self.experts_cls is not None
838
            self.moe_kernel = make_fp8_moe_kernel(
839
840
841
                moe_quant_config=self.moe_quant_config,
                moe_config=self.moe,
                fp8_backend=self.fp8_backend,
842
                experts_cls=self.experts_cls,
843
844
                routing_tables=layer._maybe_init_expert_routing_tables(),
                shared_experts=layer.shared_experts,
845
            )
846

847
848
849
850
851
    def process_weights_after_loading(self, layer: Module) -> None:
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

        # Allow for accessing weights and scales in standard way.
852
853
854
855
        w13 = layer.w13_weight
        w2 = layer.w2_weight
        w13_scale = getattr(layer, f"w13_{self.weight_scale_name}")
        w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
856
857
        w13_input_scale = layer.w13_input_scale
        w2_input_scale = layer.w2_input_scale
858
859
860

        # MI300x and MI325x use FNUZ format for FP8. Convert if needed.
        if current_platform.is_fp8_fnuz():
861
862
863
864
            w13, w13_scale, w13_input_scale = normalize_e4m3fn_to_e4m3fnuz(
                w13,
                w13_scale,
                w13_input_scale,
865
            )
866
867
868
869
            w2, w2_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz(
                w2,
                w2_scale,
                w2_input_scale,
870
871
872
873
874
            )

        # Per tensor kernels require single activation scale. Use the max.
        if self.quant_config.activation_scheme == "static":
            assert not self.block_quant
875
            assert w13_input_scale is not None and w2_input_scale is not None
876
877
878
879
880
            w13_input_scale, w2_input_scale = process_fp8_input_tensor_strategy_moe(
                w13_input_scale, w2_input_scale
            )
            replace_parameter(layer, "w13_input_scale", w13_input_scale)
            replace_parameter(layer, "w2_input_scale", w2_input_scale)
881
882
883
884
885

        # Per tensor kernels require single weight scale for w13 per expert, but
        # on disk there is a scale for w1 and w3. Use the max to requantize.
        if not self.block_quant:
            shard_size = layer.intermediate_size_per_partition
886
887
888
            w13, w13_scale = process_fp8_weight_tensor_strategy_moe(
                w13, w13_scale, shard_size, layer.local_num_experts
            )
889

890
891
892
        # Shuffle weights to runtime format and setup kernel.
        self._setup_kernel(
            layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale
893
894
        )

895
896
897
        # Prevent duplicate processing (e.g., during weight reload)
        layer._already_called_process_weights_after_loading = True

898
899
900
    def maybe_make_prepare_finalize(
        self,
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
901
    ) -> mk.FusedMoEPrepareAndFinalizeModular | None:
902
903
904
        raise ValueError(
            f"{self.__class__.__name__} uses the new modular kernel initialization "
            "logic. This function should not be called."
905
        )
906

907
    def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
908
909
910
911
912
        w1_scale = getattr(layer, f"w13_{self.weight_scale_name}")
        w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
        a1_scale = layer.w13_input_scale
        a2_scale = layer.w2_input_scale

913
        quant_config = make_fp8_moe_quant_config(
914
            fp8_backend=self.fp8_backend,
915
916
917
918
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
919
            block_shape=self.weight_block_size,
920
921
        )

922
923
924
925
926
927
928
929
930
931
932
933
        # Inject biases into the quant config if the model has them
        # (e.g. GPT-OSS biased MoE)
        if quant_config is not None and self.moe.has_bias:
            w13_bias = getattr(layer, "w13_bias", None)
            w2_bias = getattr(layer, "w2_bias", None)
            if w13_bias is not None:
                quant_config._w1.bias = w13_bias
            if w2_bias is not None:
                quant_config._w2.bias = w2_bias

        return quant_config

934
935
936
937
    @property
    def supports_eplb(self) -> bool:
        return True

938
    def apply_monolithic(
939
        self,
940
        layer: FusedMoE,
941
942
        x: torch.Tensor,
        router_logits: torch.Tensor,
943
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
944
        assert self.is_monolithic
945
946
947
948
949
950
951
952
953
954
955
956
957
958
        assert self.moe_kernel is not None
        return self.moe_kernel.apply_monolithic(
            x,
            layer.w13_weight,
            layer.w2_weight,
            router_logits,
            activation=layer.activation,
            global_num_experts=layer.global_num_experts,
            expert_map=layer.expert_map,
            apply_router_weight_on_input=layer.apply_router_weight_on_input,
            num_expert_group=layer.num_expert_group,
            topk_group=layer.topk_group,
            e_score_correction_bias=layer.e_score_correction_bias,
            routed_scaling_factor=layer.routed_scaling_factor,
959
        )
960

961
962
963
964
965
966
    def apply(
        self,
        layer: FusedMoE,
        x: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
967
        shared_experts_input: torch.Tensor | None,
968
969
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        assert not self.is_monolithic
970
971
        assert self.moe_kernel is not None
        return self.moe_kernel.apply(
972
973
974
975
976
977
978
979
980
            x,
            layer.w13_weight,
            layer.w2_weight,
            topk_weights,
            topk_ids,
            activation=layer.activation,
            global_num_experts=layer.global_num_experts,
            expert_map=layer.expert_map,
            apply_router_weight_on_input=layer.apply_router_weight_on_input,
981
            shared_experts_input=shared_experts_input,
982
        )
983

984

985
986
987
988
989
990
991
992
993
994
class Fp8OnlineMoEMethod(Fp8MoEMethod):
    """MoE method for online FP8 quantization.
    Supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

995
996
    uses_meta_device: bool = True

997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(quant_config, layer)
        assert not quant_config.is_checkpoint_fp8_serialized
        assert quant_config.activation_scheme == "dynamic"
        assert quant_config.weight_block_size is None

    def create_weights(
        self,
        layer: Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

        # We are doing online quantization, patch the weight loaded
        # to call `process_weights_after_loading` in a streaming fashion
        # as soon as the last weight chunk is loaded.
        weight_loader = extra_weight_attrs["weight_loader"]
        # create a new holder to prevent modifying behavior of any other
        # objects which might depend on the old one
        new_extra_weight_attrs = extra_weight_attrs

        def patched_weight_loader(param, loaded_weight, *args, **kwargs):
            # add a counter to track how many elements we have updated
            if not hasattr(layer, "_loaded_numel"):
                layer._loaded_numel = 0
1028

1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
                # save the ids of original w13 and w2 so that we can
                # distinguish which one `param` should map to further
                # down in this file
                layer._w13_weight_orig_id = id(layer.w13_weight)
                layer._w2_weight_orig_id = id(layer.w2_weight)

                # when the first `loaded_weight` is about to be
                # loaded to `param`, materialize `param` just-in-time

                w13_weight = torch.nn.Parameter(
                    torch.empty_like(layer.w13_weight, device=layer._load_device),
                    requires_grad=False,
                )
                set_weight_attrs(w13_weight, extra_weight_attrs)
                _copy_missing_attrs(layer.w13_weight, w13_weight)
                layer.register_parameter("w13_weight", w13_weight)

                w2_weight = torch.nn.Parameter(
                    torch.empty_like(layer.w2_weight, device=layer._load_device),
                    requires_grad=False,
                )
                set_weight_attrs(w2_weight, extra_weight_attrs)
                _copy_missing_attrs(layer.w2_weight, w2_weight)
                layer.register_parameter("w2_weight", w2_weight)
                del layer._load_device

            # refresh the reference to `param` to reflect just-in-time
            # materialization
            if id(param) == layer._w13_weight_orig_id:
                param = layer.w13_weight
            elif id(param) == layer._w2_weight_orig_id:
                param = layer.w2_weight

1062
1063
1064
1065
1066
            # load the current weight chunk
            copy_numel_counter = CopyNumelCounter()
            with copy_numel_counter:
                res = weight_loader(param, loaded_weight, *args, **kwargs)  # type: ignore[misc]
            layer._loaded_numel += copy_numel_counter.copied_numel
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077

            # if we have loaded all of the elements, call
            # process_weights_after_loading
            target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel()
            if layer._loaded_numel == target_loaded_numel:
                self.process_weights_after_loading(layer)

                # Prevent the usual `process_weights_after_loading` call
                # from doing anything
                layer._already_called_process_weights_after_loading = True

1078
1079
1080
1081
1082
1083
                # Note that we keep `layer._loaded_numel`,
                # `layer._w13_weight_orig_id` and `layer._w2_weight_orig_id`
                # around because if EP is on, weight loaders for non-local
                # experts will run but not actually copy any elements, and we
                # need to not re-initialize in that case.

1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
            return res

        new_extra_weight_attrs["weight_loader"] = patched_weight_loader
        extra_weight_attrs = new_extra_weight_attrs

        # WEIGHTS
        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
1095
1096
                # materialized just-in-time in `patched_weight_loader`
                device="meta",
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
1109
1110
                # materialized just-in-time in `patched_weight_loader`
                device="meta",
1111
1112
1113
1114
1115
1116
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)
1117
1118
        # stash the correct device for `patched_weight_loader`
        layer._load_device = torch.get_default_device()
1119

1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
        # BIASES (for models like GPT-OSS that have biased MoE)
        if self.moe.has_bias:
            # Use the original weight_loader (not patched) for biases
            orig_extra_weight_attrs = dict(extra_weight_attrs)
            orig_extra_weight_attrs["weight_loader"] = weight_loader
            w13_bias = torch.nn.Parameter(
                torch.zeros(
                    num_experts,
                    2 * intermediate_size_per_partition,
                    dtype=layer.orig_dtype,
                ),
                requires_grad=False,
            )
            layer.register_parameter("w13_bias", w13_bias)
            set_weight_attrs(w13_bias, orig_extra_weight_attrs)
            w2_bias = torch.nn.Parameter(
                torch.zeros(num_experts, hidden_size, dtype=layer.orig_dtype),
                requires_grad=False,
            )
            layer.register_parameter("w2_bias", w2_bias)
            set_weight_attrs(w2_bias, orig_extra_weight_attrs)

1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
        # WEIGHT_SCALES
        # Allocate 2 scales for w1 and w3 respectively.
        # They will be combined to a single scale after weight loading.
        w13_weight_scale = torch.nn.Parameter(
            torch.ones(num_experts, dtype=torch.float32), requires_grad=False
        )
        w2_weight_scale = torch.nn.Parameter(
            torch.ones(num_experts, dtype=torch.float32), requires_grad=False
        )
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        layer.register_parameter("w2_weight_scale", w2_weight_scale)
1153
1154
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)
1155
1156
1157
1158
1159
1160
1161
1162

        layer.w13_input_scale = None
        layer.w2_input_scale = None

    def process_weights_after_loading(self, layer: Module) -> None:
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
        # deferred initialization of randomly initialized weights for the
        # `--load_format dummy` feature
        if layer.w13_weight.device == torch.device("meta"):
            w13_weight = torch.nn.Parameter(
                torch.empty_like(layer.w13_weight, device=layer._load_device),
                requires_grad=False,
            )
            set_weight_attrs(
                w13_weight, {"weight_loader": layer.w13_weight.weight_loader}
            )
            _copy_missing_attrs(layer.w13_weight, w13_weight)
            layer.register_parameter("w13_weight", w13_weight)
            initialize_single_dummy_weight(layer.w13_weight)
        if layer.w2_weight.device == torch.device("meta"):
            w2_weight = torch.nn.Parameter(
                torch.empty_like(layer.w2_weight, device=layer._load_device),
                requires_grad=False,
            )
            set_weight_attrs(
                w2_weight, {"weight_loader": layer.w2_weight.weight_loader}
            )
            _copy_missing_attrs(layer.w2_weight, w2_weight)
            layer.register_parameter("w2_weight", w2_weight)
            initialize_single_dummy_weight(layer.w2_weight)

1188
1189
        # If checkpoint is fp16, quantize in place.
        fp8_dtype = current_platform.fp8_dtype()
1190
1191
1192
1193
        w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
        w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
        w13_scale = layer.w13_weight_scale
        w2_scale = layer.w2_weight_scale
1194
1195

        for expert in range(layer.local_num_experts):
1196
1197
            w13[expert, :, :], w13_scale[expert] = ops.scaled_fp8_quant(
                layer.w13_weight[expert, :, :]
1198
            )
1199
1200
            w2[expert, :, :], w2_scale[expert] = ops.scaled_fp8_quant(
                layer.w2_weight[expert, :, :]
1201
1202
            )

1203
1204
1205
1206
1207
1208
1209
1210
1211
        # Shuffle weights to runtime format and setup kernel.
        self._setup_kernel(
            layer,
            w13,
            w2,
            w13_scale,
            w2_scale,
            layer.w13_input_scale,
            layer.w2_input_scale,
1212
        )
1213

1214
1215
1216
        # Prevent duplicate processing (e.g., during weight reload)
        layer._already_called_process_weights_after_loading = True

1217

1218
1219
1220
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
1221
1222
1223
    """

    def __init__(self, quant_config: Fp8Config):
1224
        super().__init__(quant_config)