fp8.py 47.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import TYPE_CHECKING, Any, Optional
5
6
7

import torch
from torch.nn import Module
8
from torch.utils._python_dispatch import TorchDispatchMode
9

10
import vllm.envs as envs
11
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
12
from vllm import _custom_ops as ops
13
from vllm._aiter_ops import rocm_aiter_ops
14
from vllm.attention.layer import Attention
15
from vllm.distributed import get_tensor_model_parallel_world_size
16
from vllm.logger import init_logger
17
from vllm.model_executor.layers.batch_invariant import (
18
    vllm_is_batch_invariant,
19
)
bnellnm's avatar
bnellnm committed
20
from vllm.model_executor.layers.fused_moe import (
21
22
23
24
25
26
27
    FusedMoE,
    FusedMoEActivationFormat,
    FusedMoEMethodBase,
    FusedMoEPermuteExpertsUnpermute,
    FusedMoEPrepareAndFinalize,
    FusedMoeWeightScaleSupported,
)
28
from vllm.model_executor.layers.fused_moe.config import (
29
    FusedMoEQuantConfig,
30
    RoutingMethodType,
31
)
32
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
33
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
34
35
36
37
38
39
40
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
    Fp8MoeBackend,
    convert_to_fp8_moe_kernel_format,
    make_fp8_moe_kernel,
    make_fp8_moe_quant_config,
    select_fp8_moe_backend,
)
41
42
43
44
45
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
46
from vllm.model_executor.layers.quantization import QuantizationMethods
47
from vllm.model_executor.layers.quantization.base_config import (
48
49
50
    QuantizationConfig,
    QuantizeMethodBase,
)
51
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
52
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
53
    apply_fi_trtllm_fp8_per_tensor_moe,
54
    build_flashinfer_fp8_cutlass_moe_prepare_finalize,
55
56
    select_cutlass_fp8_gemm_impl,
)
57
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
58
59
60
61
62
    W8A8BlockFp8LinearOp,
    create_fp8_input_scale,
    create_fp8_scale_parameter,
    create_fp8_weight_parameter,
    maybe_post_process_fp8_weight_block,
63
    process_fp8_input_tensor_strategy_moe,
64
65
    process_fp8_weight_block_strategy,
    process_fp8_weight_tensor_strategy,
66
    process_fp8_weight_tensor_strategy_moe,
67
68
    validate_fp8_block_shape,
)
69
70
71
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
    get_marlin_input_dtype,
)
72
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
73
74
75
    apply_fp8_marlin_linear,
    prepare_fp8_layer_for_marlin,
)
76
from vllm.model_executor.layers.quantization.utils.quant_utils import (
77
78
79
    GroupShape,
    is_layer_skipped,
)
80
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
81
82
83
84
85
86
87
88
89
90
91
    Fp8LinearOp,
    cutlass_block_fp8_supported,
    cutlass_fp8_supported,
    maybe_create_device_identity,
    normalize_e4m3fn_to_e4m3fnuz,
)
from vllm.model_executor.parameter import (
    BlockQuantScaleParameter,
    ModelWeightParameter,
    PerTensorScaleParameter,
)
92
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
93
from vllm.platforms import current_platform
94
95
96
from vllm.utils.deep_gemm import (
    is_deep_gemm_supported,
)
97

98
99
100
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

101
102
103
104
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

105

106
class Fp8Config(QuantizationConfig):
107
108
    """Config class for FP8."""

109
110
    def __init__(
        self,
111
        is_checkpoint_fp8_serialized: bool = False,
112
        activation_scheme: str = "dynamic",
113
114
        ignored_layers: list[str] | None = None,
        weight_block_size: list[int] | None = None,
115
    ) -> None:
116
        super().__init__()
117

118
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
119

120
        if activation_scheme not in ACTIVATION_SCHEMES:
121
            raise ValueError(f"Unsupported activation scheme {activation_scheme}")
122
        self.activation_scheme = activation_scheme
123
        self.ignored_layers = ignored_layers or []
124
125
126
127
        if weight_block_size is not None:
            if not is_checkpoint_fp8_serialized:
                raise ValueError(
                    "The block-wise quantization only supports fp8-serialized "
128
129
                    "checkpoint for now."
                )
130
131
132
            if len(weight_block_size) != 2:
                raise ValueError(
                    "The quantization block size of weight must have 2 "
133
134
                    f"dimensions, but got {len(weight_block_size)} dimensions"
                )
135
            if activation_scheme != "dynamic":
136
137
138
139
140
                raise ValueError(
                    "The block-wise quantization only supports "
                    "dynamic activation scheme for now, but got "
                    f"{activation_scheme} activation scheme."
                )
141
        self.weight_block_size = weight_block_size
142

143
    @classmethod
144
    def get_name(cls) -> QuantizationMethods:
145
146
147
        return "fp8"

    @classmethod
148
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
149
150
151
152
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
153
        return 75
154
155

    @classmethod
156
    def get_config_filenames(cls) -> list[str]:
157
158
        return []

159
160
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.ignored_layers is not None:
161
            self.ignored_layers = hf_to_vllm_mapper.apply_list(self.ignored_layers)
162

163
    @classmethod
164
    def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
165
        quant_method = cls.get_from_keys(config, ["quant_method"])
166
        is_checkpoint_fp8_serialized = "fp8" in quant_method
167
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
168
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
169
        weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
170
        if not ignored_layers:
171
172
173
174
175
176
177
178
179
180
181
182
183
            ignored_layers = cls.get_from_keys_or(
                config, ["modules_to_not_convert"], None
            )
        return cls(
            is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
            activation_scheme=activation_scheme,
            ignored_layers=ignored_layers,
            weight_block_size=weight_block_size,
        )

    def get_xpu_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
184
        from vllm.model_executor.layers.quantization.ipex_quant import (
185
186
187
188
            XPUFp8LinearMethod,
            XPUFp8MoEMethod,
        )

189
190
191
192
        fp8_config = Fp8Config(
            is_checkpoint_fp8_serialized=self.is_checkpoint_fp8_serialized,
            activation_scheme=self.activation_scheme,
            ignored_layers=self.ignored_layers,
193
194
            weight_block_size=self.weight_block_size,
        )
195
196

        if isinstance(layer, LinearBase):
197
198
199
200
201
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
202
203
204
                return UnquantizedLinearMethod()
            return XPUFp8LinearMethod(fp8_config)
        elif isinstance(layer, FusedMoE):
205
206
207
208
209
210
211
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
                return UnquantizedFusedMoEMethod(layer.moe_config)

212
213
214
215
216
            return XPUFp8MoEMethod(fp8_config, layer)
        elif isinstance(layer, Attention):
            return Fp8KVCacheMethod(self)
        return None

217
218
219
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
220
221
        if current_platform.is_xpu():
            return self.get_xpu_quant_method(layer, prefix)
222
        if isinstance(layer, LinearBase):
223
224
225
226
227
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
228
                return UnquantizedLinearMethod()
229
230
231
            quant_method = Fp8LinearMethod(self)
            quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return quant_method
232
        elif isinstance(layer, FusedMoE):
233
234
235
236
237
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
XuruiYang's avatar
XuruiYang committed
238
                return UnquantizedFusedMoEMethod(layer.moe_config)
239
240
241
242
            if self.is_checkpoint_fp8_serialized:
                moe_quant_method = Fp8MoEMethod(self, layer)
            else:
                moe_quant_method = Fp8OnlineMoEMethod(self, layer)
243
            return moe_quant_method
244
        elif isinstance(layer, Attention):
245
            return Fp8KVCacheMethod(self)
246
        return None
247

248
    def get_cache_scale(self, name: str) -> str | None:
249
250
251
252
253
254
255
256
257
258
259
260
        """
        Check whether the param name matches the format for k/v cache scales
        in compressed-tensors. If this is the case, return its equivalent
        param name expected by vLLM

        :param name: param name
        :return: matching param name for KV cache scale in vLLM
        """
        if name.endswith(".output_scale") and ".k_proj" in name:
            return name.replace(".k_proj.output_scale", ".attn.k_scale")
        if name.endswith(".output_scale") and ".v_proj" in name:
            return name.replace(".v_proj.output_scale", ".attn.v_scale")
261
262
263
264
265
        if name.endswith(".output_scale") and ".q_proj" in name:
            return name.replace(".q_proj.output_scale", ".attn.q_scale")
        if name.endswith("self_attn.prob_output_scale"):
            return name.replace(".prob_output_scale", ".attn.prob_scale")
        # If no matches, return None
266
267
        return None

268

269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
class CopyNumelCounter(TorchDispatchMode):
    """
    Tracks total number of elements modified with `copy_`. Useful for keeping
    track of weight loading where underlying weights can be arbitrarily
    transformed (such as with `narrow`) before calling copy.
    """

    def __init__(self):
        super().__init__()
        self.copied_numel = 0

    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        out = func(*args, **kwargs)
        if func == torch.ops.aten.copy_.default:
            self.copied_numel += args[0].numel()
        return out


289
290
class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
291
292
293
294
295
296
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.
297
298
299
300
301

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
302

303
304
305
306
    Args:
        quant_config: The quantization config.
    """

307
    def __init__(self, quant_config: Fp8Config):
308
        self.quant_config = quant_config
309
        self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
310
        self.out_dtype = torch.get_default_dtype()
311

312
313
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
314
        self.marlin_input_dtype = None
315
316
317
318
        self.use_marlin = (
            not current_platform.has_device_capability(89)
            or envs.VLLM_TEST_FORCE_FP8_MARLIN
        )
319
        # Disable marlin for rocm
320
        if current_platform.is_rocm():
321
            self.use_marlin = False
322
        if vllm_is_batch_invariant():
323
            self.use_marlin = False
324

325
        self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
326
        self.use_deep_gemm = is_deep_gemm_supported()
327

328
329
        self.weight_block_size = self.quant_config.weight_block_size
        self.block_quant = self.weight_block_size is not None
330
        self.act_q_static = self.quant_config.activation_scheme == "static"
331
332
        if self.weight_block_size:
            self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
333
        else:
334
335
336
337
338
            # Use per-token quantization for better perf if dynamic and cutlass
            if not self.act_q_static and cutlass_fp8_supported():
                self.act_q_group_shape = GroupShape.PER_TOKEN
            else:
                self.act_q_group_shape = GroupShape.PER_TENSOR
339

340
341
342
343
344
345
346
347
348
349
350
351
        if self.block_quant:
            assert not self.act_q_static
            assert self.weight_block_size is not None
            self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
                weight_group_shape=GroupShape(*self.weight_block_size),
                act_quant_group_shape=self.act_q_group_shape,
                cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
                use_aiter_and_is_supported=self.use_aiter_and_is_supported,
            )
        else:
            self.fp8_linear = Fp8LinearOp(
                act_quant_static=self.act_q_static,
352
353
                act_quant_group_shape=self.act_q_group_shape,
            )
354

355
356
357
358
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
359
        output_partition_sizes: list[int],
360
361
362
363
364
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
365
366
        maybe_create_device_identity()

367
        output_size_per_partition = sum(output_partition_sizes)
368
        weight_loader = extra_weight_attrs.get("weight_loader")
369
370
371
372
373
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None
374

375
        if self.block_quant:
376
377
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
378
379
380
381
382
383
384
385
            validate_fp8_block_shape(
                layer,
                input_size,
                output_size,
                input_size_per_partition,
                output_partition_sizes,
                self.weight_block_size,
            )
386

387
        # WEIGHT
388
        if self.quant_config.is_checkpoint_fp8_serialized:
389
390
391
            weight = create_fp8_weight_parameter(
                output_size_per_partition, input_size_per_partition, weight_loader
            )
392
        else:
393
394
395
396
397

            def patched_weight_loader(param, loaded_weight, *args, **kwargs):
                # track how many elements we have updated
                if not hasattr(layer, "_loaded_numel"):
                    layer._loaded_numel = 0
398
399
400
401
402
403

                # load the current weight chunk
                copy_numel_counter = CopyNumelCounter()
                with copy_numel_counter:
                    res = weight_loader(param, loaded_weight, *args, **kwargs)  # type: ignore[misc]
                layer._loaded_numel += copy_numel_counter.copied_numel
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418

                # if we have loaded all of the elements, call
                # process_weights_after_loading
                target_loaded_numel = layer.weight.numel()
                if layer._loaded_numel == target_loaded_numel:
                    self.process_weights_after_loading(layer)

                    # Delete the bookkeeping
                    del layer._loaded_numel
                    # Prevent the usual `process_weights_after_loading` call from doing
                    # anything
                    layer._already_called_process_weights_after_loading = True

                return res

419
            # For non-serialized checkpoints, use original dtype
420
421
422
423
424
425
426
427
            weight = ModelWeightParameter(
                data=torch.empty(
                    output_size_per_partition,
                    input_size_per_partition,
                    dtype=params_dtype,
                ),
                input_dim=1,
                output_dim=0,
428
                weight_loader=patched_weight_loader,
429
            )
430
431
        layer.register_parameter("weight", weight)

432
433
434
435
        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
436
            if not self.block_quant:
437
438
439
440
441
442
443
                scale = create_fp8_scale_parameter(
                    PerTensorScaleParameter,
                    output_partition_sizes,
                    input_size_per_partition,
                    None,
                    weight_loader,
                )
444
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
445
446
                layer.register_parameter("weight_scale", scale)
            else:
447
448
                assert not self.act_q_static
                assert self.weight_block_size is not None
449
450
451
452
453
454
455
                scale = create_fp8_scale_parameter(
                    BlockQuantScaleParameter,
                    output_partition_sizes,
                    input_size_per_partition,
                    self.weight_block_size,
                    weight_loader,
                )
456
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
457
458
                # The weight_scale_inv name is intentional for deepseekv3
                layer.register_parameter("weight_scale_inv", scale)
459

460
            # INPUT ACTIVATION SCALE
461
            if self.act_q_static:
462
                scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
463
                set_weight_attrs(scale, {"scale_type": "input_scale"})
464
                layer.register_parameter("input_scale", scale)
465
466
            else:
                layer.register_parameter("input_scale", None)
467

468
    def process_weights_after_loading(self, layer: Module) -> None:
469
470
471
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

472
        size_k_first = True
473
        input_scale = None
474
        # TODO(rob): refactor block quant into separate class.
475
        if self.block_quant:
476
            assert not self.act_q_static
477
            size_k_first = False
478

479
            weight, weight_scale_inv = process_fp8_weight_block_strategy(
480
481
                layer.weight, layer.weight_scale_inv
            )
482
483
484
485

            # Update layer with new values
            replace_parameter(layer, "weight", weight.data)
            replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data)
486

487
        # If checkpoint not serialized fp8, quantize the weights.
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
        else:
            if not self.quant_config.is_checkpoint_fp8_serialized:
                qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
                weight = qweight.t()

            # If checkpoint is fp8 per-tensor, handle that there are N scales for N
            # shards in a fused module
            else:
                weight = layer.weight
                weight_scale = layer.weight_scale

                # If using w8a8, torch._scaled_mm needs per tensor, so
                # requantize the logical shards as a single weight.
                if not self.use_marlin:
                    weight, weight_scale, input_scale = (
                        process_fp8_weight_tensor_strategy(
                            weight,
                            weight_scale,
                            layer.logical_widths,
                            getattr(layer, "input_scale", None),
                        )
                    )
                    if self.act_q_static:
                        assert input_scale is not None
                        input_scale = input_scale.max()
                weight = weight.t()
514

515
516
517
518
519
520
            # Update layer with new values.
            replace_parameter(layer, "weight", weight.data)
            replace_parameter(layer, "weight_scale", weight_scale.data)

        if input_scale is not None:
            replace_parameter(layer, "input_scale", input_scale)
521
        else:
522
            layer.input_scale = None
523

524
        if self.use_marlin:
525
526
527
            prepare_fp8_layer_for_marlin(
                layer, size_k_first, input_dtype=self.marlin_input_dtype
            )
528
529
            # Activations not quantized for marlin.
            del layer.input_scale
530
            return
531

532
        if self.block_quant:
533
            maybe_post_process_fp8_weight_block(layer)
534

535
536
537
538
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
539
        bias: torch.Tensor | None = None,
540
    ) -> torch.Tensor:
541
542
        # if batch invariant mode is enabled, prefer DeepGEMM FP8 path
        # we will use BF16 dequant when DeepGEMM is not supported.
543
        if vllm_is_batch_invariant():
544
545
            if self.block_quant:
                assert self.weight_block_size is not None
546
547
548
                return self.w8a8_block_fp8_linear.apply(
                    input=x,
                    weight=layer.weight,
549
                    weight_scale=layer.weight_scale_inv,
550
551
552
                    input_scale=layer.input_scale,
                    bias=bias,
                )
553
            else:
554
555
556
                # per-tensor/channel: dequant to BF16 and run GEMM
                weight_fp8 = layer.weight.to(torch.bfloat16)
                weight_scale = layer.weight_scale.to(torch.bfloat16)
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
                if weight_scale.numel() == 1:
                    # Per-tensor: simple scalar multiplication
                    weight_bf16 = weight_fp8 * weight_scale
                else:
                    # Multiple scales (fused modules like QKV)
                    # Try to infer correct broadcasting
                    # weight is [K, N], scale could be [num_logical_weights]
                    # Need to figure out how to broadcast - for now just try
                    # direct multiplication
                    if (
                        weight_scale.dim() == 1
                        and weight_scale.shape[0] == weight_fp8.shape[0]
                    ):
                        # Per-row scaling
                        weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1)
                    else:
                        # Fallback
                        weight_bf16 = weight_fp8 * weight_scale
575
                return torch.nn.functional.linear(x, weight_bf16.t(), bias)
576

577
        if self.use_marlin:
578
579
580
581
582
            if self.block_quant:
                weight_scale = layer.weight_scale_inv
            else:
                weight_scale = layer.weight_scale

583
584
585
            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
586
                weight_scale=weight_scale,
587
588
589
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
590
                input_dtype=self.marlin_input_dtype,
591
592
                bias=bias,
            )
593

594
        if self.block_quant:
595
596
597
            assert self.weight_block_size is not None

            return self.w8a8_block_fp8_linear.apply(
598
                input=x,
599
                weight=layer.weight,
600
                weight_scale=layer.weight_scale_inv,
601
                input_scale=layer.input_scale,
602
                bias=bias,
603
            )
604

605
606
607
608
609
610
611
612
        return self.fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            out_dtype=self.out_dtype,
            input_scale=layer.input_scale,
            bias=bias,
        )
613
614


615
616
617
618
619
620
621
622
623
624
625
626
627
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

628
629
    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(layer.moe_config)
630
        self.quant_config = quant_config
631
        self.weight_block_size = self.quant_config.weight_block_size
632
        self.block_quant: bool = self.weight_block_size is not None
633
634
635
        self.weight_scale_name = (
            "weight_scale_inv" if self.block_quant else "weight_scale"
        )
636
637
638
639
        self.fp8_backend = select_fp8_moe_backend(
            block_quant=self.block_quant,
            tp_size=layer.moe_parallel_config.tp_size,
            with_lora_support=self.moe.is_lora_enabled,
640
        )
641

642
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
643
644
645
646
647
648
649
650
651
            if self.block_quant and self.weight_block_size != [128, 128]:
                raise NotImplementedError(
                    "FlashInfer CUTLASS FP8 MoE backend only supports block "
                    "size [128, 128]."
                )
            if layer.activation != "silu":
                raise NotImplementedError(
                    "FlashInfer CUTLASS FP8 MoE backend only supports SiLU "
                    "activation function, but got {layer.activation}."
652
                )
653
654
655
        dynamic_per_token = (
            not self.block_quant and self.quant_config.activation_scheme != "static"
        )
656
657
658
659
        if dynamic_per_token and self.fp8_backend in [
            Fp8MoeBackend.FLASHINFER_TRTLLM,
            Fp8MoeBackend.FLASHINFER_CUTLASS,
        ]:
660
661
662
663
            raise NotImplementedError(
                "FlashInfer FP8 MoE backend does not support dynamic per token "
                "activation quantization."
            )
664

665
666
        self.kernel: mk.FusedMoEModularKernel | None = None

667
668
669
670
671
672
673
674
675
    def create_weights(
        self,
        layer: Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
676
677
678
679
680
681
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

682
683
684
        assert self.quant_config.is_checkpoint_fp8_serialized
        params_dtype = torch.float8_e4m3fn

685
        if self.block_quant:
686
687
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
688
689
            tp_size = get_tensor_model_parallel_world_size()
            block_n, block_k = (
690
691
                self.weight_block_size[0],
                self.weight_block_size[1],
692
693
694
695
696
            )
            # NOTE: To ensure proper alignment of the block-wise quantization
            # scales, the output_size of the weights for both the gate and up
            # layers must be divisible by block_n.
            # Required by column parallel or enabling merged weights
697
            if intermediate_size_per_partition % block_n != 0:
698
699
                raise ValueError(
                    f"The output_size of gate's and up's weight = "
700
                    f"{intermediate_size_per_partition} is not divisible by "
701
702
703
                    f"weight quantization block_n = {block_n}."
                )
            if tp_size > 1 and intermediate_size_per_partition % block_k != 0:
704
                # Required by row parallel
705
706
707
                raise ValueError(
                    f"The input_size of down's weight = "
                    f"{intermediate_size_per_partition} is not divisible by "
708
709
                    f"weight quantization block_k = {block_k}."
                )
710
711

        # WEIGHTS
712
713
714
715
716
717
718
719
720
        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
721
722
723
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

724
725
726
727
728
729
730
731
732
        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
733
734
735
736
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
737
        if not self.block_quant:
738
739
740
            # For per-tensor quant, the scales are per expert and weight.
            w13_scale_data = torch.ones(num_experts, 2, dtype=torch.float32)
            w2_scale_data = torch.ones(num_experts, dtype=torch.float32)
741
        else:
742
743
744
745
746
747
            # For block quant, the scales are per block (typically 128x128).
            w13_scale_data = torch.ones(
                num_experts,
                2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
                (hidden_size + block_k - 1) // block_k,
                dtype=torch.float32,
748
            )
749
750
751
752
753
            w2_scale_data = torch.ones(
                num_experts,
                (hidden_size + block_n - 1) // block_n,
                (intermediate_size_per_partition + block_k - 1) // block_k,
                dtype=torch.float32,
754
            )
755
756
757
758
759
        w13_weight_scale = torch.nn.Parameter(w13_scale_data, requires_grad=False)
        w2_weight_scale = torch.nn.Parameter(w2_scale_data, requires_grad=False)
        # Note: name is weight_scale for tensor, weight_scale_inv for block.
        layer.register_parameter(f"w13_{self.weight_scale_name}", w13_weight_scale)
        layer.register_parameter(f"w2_{self.weight_scale_name}", w2_weight_scale)
760

761
762
763
        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
764
765
766
767
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
            if self.block_quant
            else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
768
769
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)
770
771
772

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
773
            assert not self.block_quant
774
775
776
            w13_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
777
            layer.register_parameter("w13_input_scale", w13_input_scale)
778
            set_weight_attrs(w13_input_scale, extra_weight_attrs)
779

780
781
782
            w2_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
783
            layer.register_parameter("w2_input_scale", w2_input_scale)
784
785
            set_weight_attrs(w2_input_scale, extra_weight_attrs)

786
        else:
787
788
            layer.w13_input_scale = None
            layer.w2_input_scale = None
789

790
    def _setup_kernel(
791
792
        self,
        layer: Module,
793
794
795
796
        w13: torch.Tensor,
        w2: torch.Tensor,
        w13_scale: torch.Tensor,
        w2_scale: torch.Tensor,
797
798
        w13_input_scale: torch.Tensor | None,
        w2_input_scale: torch.Tensor | None,
799
    ) -> None:
800
801
802
803
804
805
806
807
808
809
810
        # Shuffle weights to runtime format.
        w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format(
            fp8_backend=self.fp8_backend,
            layer=layer,
            w13=w13,
            w2=w2,
            w13_scale=w13_scale,
            w2_scale=w2_scale,
            w13_input_scale=w13_input_scale,
            w2_input_scale=w2_input_scale,
        )
811

812
813
        # Replace parameters with updated versions. Note that this helper
        # function ensures the replacement is compatible with RL weight reloads.
814
815
816
817
        replace_parameter(layer, "w13_weight", w13)
        replace_parameter(layer, "w2_weight", w2)
        replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_scale)
        replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_scale)
818

819
        # Setup modular kernel for TP case.
820
        self.moe_quant_config = self.get_fused_moe_quant_config(layer)
821
822
823
824
825
826
        if self.moe_quant_config:
            self.kernel, self.use_inplace = make_fp8_moe_kernel(
                layer=layer,
                moe_quant_config=self.moe_quant_config,
                moe_config=self.moe,
                fp8_backend=self.fp8_backend,
827
            )
828

829
830
831
832
833
    def process_weights_after_loading(self, layer: Module) -> None:
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

        # Allow for accessing weights and scales in standard way.
834
835
836
837
        w13 = layer.w13_weight
        w2 = layer.w2_weight
        w13_scale = getattr(layer, f"w13_{self.weight_scale_name}")
        w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
838
839
        w13_input_scale = layer.w13_input_scale
        w2_input_scale = layer.w2_input_scale
840
841
842

        # MI300x and MI325x use FNUZ format for FP8. Convert if needed.
        if current_platform.is_fp8_fnuz():
843
844
845
846
            w13, w13_scale, w13_input_scale = normalize_e4m3fn_to_e4m3fnuz(
                w13,
                w13_scale,
                w13_input_scale,
847
            )
848
849
850
851
            w2, w2_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz(
                w2,
                w2_scale,
                w2_input_scale,
852
853
854
855
856
            )

        # Per tensor kernels require single activation scale. Use the max.
        if self.quant_config.activation_scheme == "static":
            assert not self.block_quant
857
            assert w13_input_scale is not None and w2_input_scale is not None
858
859
860
861
862
            w13_input_scale, w2_input_scale = process_fp8_input_tensor_strategy_moe(
                w13_input_scale, w2_input_scale
            )
            replace_parameter(layer, "w13_input_scale", w13_input_scale)
            replace_parameter(layer, "w2_input_scale", w2_input_scale)
863
864
865
866
867

        # Per tensor kernels require single weight scale for w13 per expert, but
        # on disk there is a scale for w1 and w3. Use the max to requantize.
        if not self.block_quant:
            shard_size = layer.intermediate_size_per_partition
868
869
870
            w13, w13_scale = process_fp8_weight_tensor_strategy_moe(
                w13, w13_scale, shard_size, layer.local_num_experts
            )
871

872
873
874
        # Shuffle weights to runtime format and setup kernel.
        self._setup_kernel(
            layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale
875
876
        )

877
878
879
880
    def maybe_make_prepare_finalize(
        self,
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
    ) -> mk.FusedMoEPrepareAndFinalize | None:
881
882
883
884
885
        if self.fp8_backend in [
            Fp8MoeBackend.AITER,
            Fp8MoeBackend.MARLIN,
            Fp8MoeBackend.FLASHINFER_TRTLLM,
        ]:
886
            return None
887
888
889
890
891
892
893
        elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
            prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
                self.moe,
                use_deepseek_fp8_block_scale=self.block_quant,
            )
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
894
        return super().maybe_make_prepare_finalize(routing_tables)
895

bnellnm's avatar
bnellnm committed
896
897
898
    def select_gemm_impl(
        self,
        prepare_finalize: FusedMoEPrepareAndFinalize,
899
        layer: torch.nn.Module,
bnellnm's avatar
bnellnm committed
900
    ) -> FusedMoEPermuteExpertsUnpermute:
901
        from vllm.model_executor.layers.fused_moe import (
902
903
            BatchedDeepGemmExperts,
            BatchedTritonExperts,
904
            TritonExperts,
905
906
            TritonOrDeepGemmExperts,
        )
907

908
909
910
911
        if self.fp8_backend in [Fp8MoeBackend.MARLIN, Fp8MoeBackend.AITER]:
            raise NotImplementedError(
                "Marlin and ROCm AITER are not supported with all2all yet."
            )
912

913
914
        assert self.moe_quant_config is not None

915
916
917
918
919
        if (
            prepare_finalize.activation_format
            == FusedMoEActivationFormat.BatchedExperts
        ):
            max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
bnellnm's avatar
bnellnm committed
920
            assert max_num_tokens_per_rank is not None
921
922

            experts_impl = (
923
924
925
                BatchedDeepGemmExperts
                if self.fp8_backend == Fp8MoeBackend.DEEPGEMM
                else BatchedTritonExperts
926
            )
bnellnm's avatar
bnellnm committed
927
            logger.debug(
928
929
                "%s(%s): max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
                experts_impl.__name__,
930
931
932
933
934
                self.__class__.__name__,
                max_num_tokens_per_rank,
                self.weight_block_size,
                False,
            )
935
            return experts_impl(
936
                max_num_tokens=max_num_tokens_per_rank,
937
                num_dispatchers=prepare_finalize.num_dispatchers(),
938
                quant_config=self.moe_quant_config,
939
            )
940
941
        elif self.moe.is_lora_enabled:
            return TritonExperts(quant_config=self.moe_quant_config)
942
        elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
943
            # Select GEMM experts with block-scale when weights are block-quantized
944
            experts = select_cutlass_fp8_gemm_impl(
945
946
                self.moe,
                self.moe_quant_config,
947
                use_deepseek_fp8_block_scale=self.block_quant,
948
949
950
            )
            logger.debug_once("Using %s", experts.__class__.__name__)
            return experts
951
        elif self.fp8_backend == Fp8MoeBackend.DEEPGEMM:
bnellnm's avatar
bnellnm committed
952
953
            logger.debug(
                "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
954
955
956
957
                self.__class__.__name__,
                self.weight_block_size,
                False,
            )
958
959
960
961
962
963
964
965
            return TritonOrDeepGemmExperts(self.moe_quant_config)
        else:
            assert self.fp8_backend == Fp8MoeBackend.TRITON
            logger.debug(
                "TritonExperts(%s): block_size=%s, per_act_token=%s",
                self.__class__.__name__,
                self.weight_block_size,
                False,
966
            )
967
            return TritonExperts(self.moe_quant_config)
968

969
    def get_fused_moe_quant_config(
970
        self, layer: torch.nn.Module
971
    ) -> FusedMoEQuantConfig | None:
972
973
974
975
976
977
978
979
980
        # TRTLLM does not use Modular Kernel.
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
            return None

        w1_scale = getattr(layer, f"w13_{self.weight_scale_name}")
        w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
        a1_scale = layer.w13_input_scale
        a2_scale = layer.w2_input_scale

981
982
        return make_fp8_moe_quant_config(
            fp8_backend=self.fp8_backend,
983
984
985
986
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
987
            block_shape=self.weight_block_size,
988
989
        )

990
991
992
993
994
995
996
997
    @property
    def supports_eplb(self) -> bool:
        return True

    @property
    def allow_inplace(self) -> bool:
        return True

998
999
    def apply(
        self,
1000
        layer: FusedMoE,
1001
        router: FusedMoERouter,
1002
1003
        x: torch.Tensor,
        router_logits: torch.Tensor,
1004
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1005
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
1006
            # TODO(rob): convert this to MK.
1007
1008
1009
1010
            if layer.enable_eplb:
                raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.")
            assert layer.activation == "silu", (
                f"Expected 'silu' activation but got {layer.activation}"
1011
            )
1012

1013
            if self.block_quant:
1014
                import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe  # noqa: E501, F401
1015
1016

                e_score_correction_bias = (
1017
1018
                    layer.e_score_correction_bias.to(x.dtype)
                    if layer.e_score_correction_bias is not None
1019
1020
                    else None
                )
1021
                routing_method_type = layer.routing_method_type
1022
                return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
1023
1024
1025
                    routing_logits=router_logits.to(torch.float32)
                    if routing_method_type == RoutingMethodType.DeepSeekV3
                    else router_logits,
1026
1027
1028
1029
1030
1031
                    routing_bias=e_score_correction_bias,
                    x=x,
                    w13_weight=layer.w13_weight,
                    w13_weight_scale_inv=layer.w13_weight_scale_inv,
                    w2_weight=layer.w2_weight,
                    w2_weight_scale_inv=layer.w2_weight_scale_inv,
1032
1033
1034
1035
                    global_num_experts=layer.global_num_experts,
                    top_k=layer.top_k,
                    num_expert_group=layer.num_expert_group,
                    topk_group=layer.topk_group,
1036
1037
1038
                    intermediate_size=layer.intermediate_size_per_partition,
                    expert_offset=layer.ep_rank * layer.local_num_experts,
                    local_num_experts=layer.local_num_experts,
1039
                    block_shape=self.weight_block_size,
1040
                    routing_method_type=routing_method_type,
1041
                    routed_scaling=layer.routed_scaling_factor,
1042
1043
                )
            else:
1044
                result = apply_fi_trtllm_fp8_per_tensor_moe(
1045
1046
1047
                    layer=layer,
                    hidden_states=x,
                    router_logits=router_logits,
1048
1049
1050
1051
1052
1053
                    routing_bias=layer.e_score_correction_bias,
                    global_num_experts=layer.global_num_experts,
                    top_k=layer.top_k,
                    num_expert_group=layer.num_expert_group,
                    topk_group=layer.topk_group,
                    apply_router_weight_on_input=layer.apply_router_weight_on_input,
1054
                )
1055

1056
        topk_weights, topk_ids = router.select_experts(
1057
1058
1059
            hidden_states=x,
            router_logits=router_logits,
        )
1060
1061

        assert self.kernel is not None
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
        result = self.kernel(
            x,
            layer.w13_weight,
            layer.w2_weight,
            topk_weights,
            topk_ids,
            inplace=self.use_inplace,
            activation=layer.activation,
            global_num_experts=layer.global_num_experts,
            expert_map=layer.expert_map,
            apply_router_weight_on_input=layer.apply_router_weight_on_input,
        )
1074

1075
        return result
1076
1077


1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
class Fp8OnlineMoEMethod(Fp8MoEMethod):
    """MoE method for online FP8 quantization.
    Supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(quant_config, layer)
        assert not quant_config.is_checkpoint_fp8_serialized
        assert quant_config.activation_scheme == "dynamic"
        assert quant_config.weight_block_size is None

    def create_weights(
        self,
        layer: Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

        # We are doing online quantization, patch the weight loaded
        # to call `process_weights_after_loading` in a streaming fashion
        # as soon as the last weight chunk is loaded.
        weight_loader = extra_weight_attrs["weight_loader"]
        # create a new holder to prevent modifying behavior of any other
        # objects which might depend on the old one
        new_extra_weight_attrs = extra_weight_attrs

        def patched_weight_loader(param, loaded_weight, *args, **kwargs):
            # add a counter to track how many elements we have updated
            if not hasattr(layer, "_loaded_numel"):
                layer._loaded_numel = 0
1121
1122
1123
1124
1125
1126

            # load the current weight chunk
            copy_numel_counter = CopyNumelCounter()
            with copy_numel_counter:
                res = weight_loader(param, loaded_weight, *args, **kwargs)  # type: ignore[misc]
            layer._loaded_numel += copy_numel_counter.copied_numel
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180

            # if we have loaded all of the elements, call
            # process_weights_after_loading
            target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel()
            if layer._loaded_numel == target_loaded_numel:
                self.process_weights_after_loading(layer)

                # Delete the bookkeeping
                del layer._loaded_numel
                # Prevent the usual `process_weights_after_loading` call
                # from doing anything
                layer._already_called_process_weights_after_loading = True

            return res

        new_extra_weight_attrs["weight_loader"] = patched_weight_loader
        extra_weight_attrs = new_extra_weight_attrs

        # WEIGHTS
        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
        # Allocate 2 scales for w1 and w3 respectively.
        # They will be combined to a single scale after weight loading.
        w13_weight_scale = torch.nn.Parameter(
            torch.ones(num_experts, dtype=torch.float32), requires_grad=False
        )
        w2_weight_scale = torch.nn.Parameter(
            torch.ones(num_experts, dtype=torch.float32), requires_grad=False
        )
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        layer.register_parameter("w2_weight_scale", w2_weight_scale)
1181
1182
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192

        layer.w13_input_scale = None
        layer.w2_input_scale = None

    def process_weights_after_loading(self, layer: Module) -> None:
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

        # If checkpoint is fp16, quantize in place.
        fp8_dtype = current_platform.fp8_dtype()
1193
1194
1195
1196
        w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
        w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
        w13_scale = layer.w13_weight_scale
        w2_scale = layer.w2_weight_scale
1197
1198

        for expert in range(layer.local_num_experts):
1199
1200
            w13[expert, :, :], w13_scale[expert] = ops.scaled_fp8_quant(
                layer.w13_weight[expert, :, :]
1201
            )
1202
1203
            w2[expert, :, :], w2_scale[expert] = ops.scaled_fp8_quant(
                layer.w2_weight[expert, :, :]
1204
1205
            )

1206
1207
1208
1209
1210
1211
1212
1213
1214
        # Shuffle weights to runtime format and setup kernel.
        self._setup_kernel(
            layer,
            w13,
            w2,
            w13_scale,
            w2_scale,
            layer.w13_input_scale,
            layer.w2_input_scale,
1215
        )
1216
1217


1218
1219
1220
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
1221
1222
1223
    """

    def __init__(self, quant_config: Fp8Config):
1224
        super().__init__(quant_config)