"vllm/vscode:/vscode.git/clone" did not exist on "0a56bcc03de0857be464c3f8783258d590cbc762"
fp8.py 47.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import TYPE_CHECKING, Any, Optional
5
6
7

import torch
from torch.nn import Module
8
from torch.utils._python_dispatch import TorchDispatchMode
9

10
import vllm.envs as envs
11
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
12
from vllm import _custom_ops as ops
13
from vllm._aiter_ops import rocm_aiter_ops
14
from vllm.attention.layer import Attention
15
from vllm.distributed import get_tensor_model_parallel_world_size
16
from vllm.logger import init_logger
17
from vllm.model_executor.layers.batch_invariant import (
18
    vllm_is_batch_invariant,
19
)
bnellnm's avatar
bnellnm committed
20
from vllm.model_executor.layers.fused_moe import (
21
22
23
24
25
    FusedMoE,
    FusedMoEActivationFormat,
    FusedMoEMethodBase,
    FusedMoEPermuteExpertsUnpermute,
    FusedMoEPrepareAndFinalize,
26
    FusedMoERouter,
27
28
    FusedMoeWeightScaleSupported,
)
29
from vllm.model_executor.layers.fused_moe.config import (
30
    FusedMoEQuantConfig,
31
    RoutingMethodType,
32
33
)
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
34
35
36
37
38
39
40
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
    Fp8MoeBackend,
    convert_to_fp8_moe_kernel_format,
    make_fp8_moe_kernel,
    make_fp8_moe_quant_config,
    select_fp8_moe_backend,
)
41
42
43
44
45
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
46
from vllm.model_executor.layers.quantization import QuantizationMethods
47
from vllm.model_executor.layers.quantization.base_config import (
48
49
50
    QuantizationConfig,
    QuantizeMethodBase,
)
51
52
53
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
    init_fp8_linear_kernel,
)
54
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
55
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
56
    apply_fi_trtllm_fp8_per_tensor_moe,
57
    build_flashinfer_fp8_cutlass_moe_prepare_finalize,
58
59
    select_cutlass_fp8_gemm_impl,
)
60
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
61
62
63
64
65
    W8A8BlockFp8LinearOp,
    create_fp8_input_scale,
    create_fp8_scale_parameter,
    create_fp8_weight_parameter,
    maybe_post_process_fp8_weight_block,
66
    process_fp8_input_tensor_strategy_moe,
67
68
    process_fp8_weight_block_strategy,
    process_fp8_weight_tensor_strategy,
69
    process_fp8_weight_tensor_strategy_moe,
70
71
    validate_fp8_block_shape,
)
72
73
74
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
    get_marlin_input_dtype,
)
75
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
76
77
78
    apply_fp8_marlin_linear,
    prepare_fp8_layer_for_marlin,
)
79
from vllm.model_executor.layers.quantization.utils.quant_utils import (
80
81
    GroupShape,
    is_layer_skipped,
82
83
84
    kFp8DynamicTensorSym,
    kFp8DynamicTokenSym,
    kFp8StaticTensorSym,
85
)
86
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
87
88
89
90
91
92
93
94
95
    cutlass_block_fp8_supported,
    cutlass_fp8_supported,
    normalize_e4m3fn_to_e4m3fnuz,
)
from vllm.model_executor.parameter import (
    BlockQuantScaleParameter,
    ModelWeightParameter,
    PerTensorScaleParameter,
)
96
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
97
from vllm.platforms import current_platform
98
99
100
from vllm.utils.deep_gemm import (
    is_deep_gemm_supported,
)
101

102
103
104
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

105
106
107
108
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

109

110
class Fp8Config(QuantizationConfig):
111
112
    """Config class for FP8."""

113
114
    def __init__(
        self,
115
        is_checkpoint_fp8_serialized: bool = False,
116
        activation_scheme: str = "dynamic",
117
118
        ignored_layers: list[str] | None = None,
        weight_block_size: list[int] | None = None,
119
    ) -> None:
120
        super().__init__()
121

122
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
123

124
        if activation_scheme not in ACTIVATION_SCHEMES:
125
            raise ValueError(f"Unsupported activation scheme {activation_scheme}")
126
        self.activation_scheme = activation_scheme
127
        self.ignored_layers = ignored_layers or []
128
129
130
131
        if weight_block_size is not None:
            if not is_checkpoint_fp8_serialized:
                raise ValueError(
                    "The block-wise quantization only supports fp8-serialized "
132
133
                    "checkpoint for now."
                )
134
135
136
            if len(weight_block_size) != 2:
                raise ValueError(
                    "The quantization block size of weight must have 2 "
137
138
                    f"dimensions, but got {len(weight_block_size)} dimensions"
                )
139
            if activation_scheme != "dynamic":
140
141
142
143
144
                raise ValueError(
                    "The block-wise quantization only supports "
                    "dynamic activation scheme for now, but got "
                    f"{activation_scheme} activation scheme."
                )
145
        self.weight_block_size = weight_block_size
146

147
    @classmethod
148
    def get_name(cls) -> QuantizationMethods:
149
150
151
        return "fp8"

    @classmethod
152
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
153
154
155
156
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
157
        return 75
158
159

    @classmethod
160
    def get_config_filenames(cls) -> list[str]:
161
162
        return []

163
164
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.ignored_layers is not None:
165
            self.ignored_layers = hf_to_vllm_mapper.apply_list(self.ignored_layers)
166

167
    @classmethod
168
    def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
169
        quant_method = cls.get_from_keys(config, ["quant_method"])
170
        is_checkpoint_fp8_serialized = "fp8" in quant_method
171
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
172
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
173
        weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
174
        if not ignored_layers:
175
176
177
178
179
180
181
182
183
184
185
186
187
            ignored_layers = cls.get_from_keys_or(
                config, ["modules_to_not_convert"], None
            )
        return cls(
            is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
            activation_scheme=activation_scheme,
            ignored_layers=ignored_layers,
            weight_block_size=weight_block_size,
        )

    def get_xpu_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
188
        from vllm.model_executor.layers.quantization.ipex_quant import (
189
190
191
192
            XPUFp8LinearMethod,
            XPUFp8MoEMethod,
        )

193
194
195
196
        fp8_config = Fp8Config(
            is_checkpoint_fp8_serialized=self.is_checkpoint_fp8_serialized,
            activation_scheme=self.activation_scheme,
            ignored_layers=self.ignored_layers,
197
198
            weight_block_size=self.weight_block_size,
        )
199
200

        if isinstance(layer, LinearBase):
201
202
203
204
205
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
206
207
208
                return UnquantizedLinearMethod()
            return XPUFp8LinearMethod(fp8_config)
        elif isinstance(layer, FusedMoE):
209
210
211
212
213
214
215
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
                return UnquantizedFusedMoEMethod(layer.moe_config)

216
217
218
219
220
            return XPUFp8MoEMethod(fp8_config, layer)
        elif isinstance(layer, Attention):
            return Fp8KVCacheMethod(self)
        return None

221
222
223
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
224
225
        if current_platform.is_xpu():
            return self.get_xpu_quant_method(layer, prefix)
226
        if isinstance(layer, LinearBase):
227
228
229
230
231
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
232
                return UnquantizedLinearMethod()
233
234
235
            quant_method = Fp8LinearMethod(self)
            quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return quant_method
236
        elif isinstance(layer, FusedMoE):
237
238
239
240
241
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
XuruiYang's avatar
XuruiYang committed
242
                return UnquantizedFusedMoEMethod(layer.moe_config)
243
244
245
246
            if self.is_checkpoint_fp8_serialized:
                moe_quant_method = Fp8MoEMethod(self, layer)
            else:
                moe_quant_method = Fp8OnlineMoEMethod(self, layer)
247
            return moe_quant_method
248
        elif isinstance(layer, Attention):
249
            return Fp8KVCacheMethod(self)
250
        return None
251

252
    def get_cache_scale(self, name: str) -> str | None:
253
254
255
256
257
258
259
260
261
262
263
264
        """
        Check whether the param name matches the format for k/v cache scales
        in compressed-tensors. If this is the case, return its equivalent
        param name expected by vLLM

        :param name: param name
        :return: matching param name for KV cache scale in vLLM
        """
        if name.endswith(".output_scale") and ".k_proj" in name:
            return name.replace(".k_proj.output_scale", ".attn.k_scale")
        if name.endswith(".output_scale") and ".v_proj" in name:
            return name.replace(".v_proj.output_scale", ".attn.v_scale")
265
266
267
268
269
        if name.endswith(".output_scale") and ".q_proj" in name:
            return name.replace(".q_proj.output_scale", ".attn.q_scale")
        if name.endswith("self_attn.prob_output_scale"):
            return name.replace(".prob_output_scale", ".attn.prob_scale")
        # If no matches, return None
270
271
        return None

272

273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
class CopyNumelCounter(TorchDispatchMode):
    """
    Tracks total number of elements modified with `copy_`. Useful for keeping
    track of weight loading where underlying weights can be arbitrarily
    transformed (such as with `narrow`) before calling copy.
    """

    def __init__(self):
        super().__init__()
        self.copied_numel = 0

    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        out = func(*args, **kwargs)
        if func == torch.ops.aten.copy_.default:
            self.copied_numel += args[0].numel()
        return out


293
294
class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
295
296
297
298
299
300
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.
301
302
303
304
305

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
306

307
308
309
310
    Args:
        quant_config: The quantization config.
    """

311
    def __init__(self, quant_config: Fp8Config):
312
        self.quant_config = quant_config
313
        self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
314
        self.out_dtype = torch.get_default_dtype()
315

316
317
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
318
        self.marlin_input_dtype = None
319
320
321
322
        self.use_marlin = (
            not current_platform.has_device_capability(89)
            or envs.VLLM_TEST_FORCE_FP8_MARLIN
        )
323
        # Disable marlin for rocm
324
        if current_platform.is_rocm():
325
            self.use_marlin = False
326
        if vllm_is_batch_invariant():
327
            self.use_marlin = False
328

329
        self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
330
        self.use_deep_gemm = is_deep_gemm_supported()
331

332
333
        self.weight_block_size = self.quant_config.weight_block_size
        self.block_quant = self.weight_block_size is not None
334
335
        self.act_q_static = self.quant_config.activation_scheme == "static"

336
337
338
339
340
        if self.block_quant:
            assert not self.act_q_static
            assert self.weight_block_size is not None
            self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
                weight_group_shape=GroupShape(*self.weight_block_size),
341
                act_quant_group_shape=GroupShape(1, self.weight_block_size[0]),
342
343
344
345
                cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
                use_aiter_and_is_supported=self.use_aiter_and_is_supported,
            )
        else:
346
347
348
349
350
351
352
353
354
355
356
357
358
            # Use per-token quantization for better perf if dynamic and cutlass
            if self.act_q_static:
                activation_quant_key = kFp8StaticTensorSym
            elif cutlass_fp8_supported():
                activation_quant_key = kFp8DynamicTokenSym
            else:
                activation_quant_key = kFp8DynamicTensorSym

            self.fp8_linear = init_fp8_linear_kernel(
                activation_quant_key=activation_quant_key,
                weight_quant_key=kFp8StaticTensorSym,
                out_dtype=torch.get_default_dtype(),
                module_name=self.__class__.__name__,
359
            )
360

361
362
363
364
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
365
        output_partition_sizes: list[int],
366
367
368
369
370
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
371
        output_size_per_partition = sum(output_partition_sizes)
372
        weight_loader = extra_weight_attrs.get("weight_loader")
373
374
375
376
377
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None
378

379
        if self.block_quant:
380
381
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
382
383
384
385
386
387
388
389
            validate_fp8_block_shape(
                layer,
                input_size,
                output_size,
                input_size_per_partition,
                output_partition_sizes,
                self.weight_block_size,
            )
390

391
        # WEIGHT
392
        if self.quant_config.is_checkpoint_fp8_serialized:
393
394
395
            weight = create_fp8_weight_parameter(
                output_size_per_partition, input_size_per_partition, weight_loader
            )
396
        else:
397
398
399
400
401

            def patched_weight_loader(param, loaded_weight, *args, **kwargs):
                # track how many elements we have updated
                if not hasattr(layer, "_loaded_numel"):
                    layer._loaded_numel = 0
402
403
404
405
406
407

                # load the current weight chunk
                copy_numel_counter = CopyNumelCounter()
                with copy_numel_counter:
                    res = weight_loader(param, loaded_weight, *args, **kwargs)  # type: ignore[misc]
                layer._loaded_numel += copy_numel_counter.copied_numel
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422

                # if we have loaded all of the elements, call
                # process_weights_after_loading
                target_loaded_numel = layer.weight.numel()
                if layer._loaded_numel == target_loaded_numel:
                    self.process_weights_after_loading(layer)

                    # Delete the bookkeeping
                    del layer._loaded_numel
                    # Prevent the usual `process_weights_after_loading` call from doing
                    # anything
                    layer._already_called_process_weights_after_loading = True

                return res

423
            # For non-serialized checkpoints, use original dtype
424
425
426
427
428
429
430
431
            weight = ModelWeightParameter(
                data=torch.empty(
                    output_size_per_partition,
                    input_size_per_partition,
                    dtype=params_dtype,
                ),
                input_dim=1,
                output_dim=0,
432
                weight_loader=patched_weight_loader,
433
            )
434
435
        layer.register_parameter("weight", weight)

436
437
438
439
        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
440
            if not self.block_quant:
441
442
443
444
445
446
447
                scale = create_fp8_scale_parameter(
                    PerTensorScaleParameter,
                    output_partition_sizes,
                    input_size_per_partition,
                    None,
                    weight_loader,
                )
448
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
449
450
                layer.register_parameter("weight_scale", scale)
            else:
451
452
                assert not self.act_q_static
                assert self.weight_block_size is not None
453
454
455
456
457
458
459
                scale = create_fp8_scale_parameter(
                    BlockQuantScaleParameter,
                    output_partition_sizes,
                    input_size_per_partition,
                    self.weight_block_size,
                    weight_loader,
                )
460
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
461
462
                # The weight_scale_inv name is intentional for deepseekv3
                layer.register_parameter("weight_scale_inv", scale)
463

464
            # INPUT ACTIVATION SCALE
465
            if self.act_q_static:
466
                scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
467
                set_weight_attrs(scale, {"scale_type": "input_scale"})
468
                layer.register_parameter("input_scale", scale)
469

470
    def process_weights_after_loading(self, layer: Module) -> None:
471
472
473
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

474
        size_k_first = True
475
        input_scale = None
476
        # TODO(rob): refactor block quant into separate class.
477
        if self.block_quant:
478
            assert not self.act_q_static
479
            size_k_first = False
480

481
            weight, weight_scale_inv = process_fp8_weight_block_strategy(
482
483
                layer.weight, layer.weight_scale_inv
            )
484
485
486
487

            # Update layer with new values
            replace_parameter(layer, "weight", weight.data)
            replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data)
488

489
        # If checkpoint not serialized fp8, quantize the weights.
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
        else:
            if not self.quant_config.is_checkpoint_fp8_serialized:
                qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
                weight = qweight.t()

            # If checkpoint is fp8 per-tensor, handle that there are N scales for N
            # shards in a fused module
            else:
                weight = layer.weight
                weight_scale = layer.weight_scale

                # If using w8a8, torch._scaled_mm needs per tensor, so
                # requantize the logical shards as a single weight.
                if not self.use_marlin:
                    weight, weight_scale, input_scale = (
                        process_fp8_weight_tensor_strategy(
                            weight,
                            weight_scale,
                            layer.logical_widths,
                            getattr(layer, "input_scale", None),
                        )
                    )
                    if self.act_q_static:
                        assert input_scale is not None
                        input_scale = input_scale.max()
                weight = weight.t()
516

517
518
519
520
521
522
            # Update layer with new values.
            replace_parameter(layer, "weight", weight.data)
            replace_parameter(layer, "weight_scale", weight_scale.data)

        if input_scale is not None:
            replace_parameter(layer, "input_scale", input_scale)
523
        else:
524
            layer.input_scale = None
525

526
        if self.use_marlin:
527
528
529
            prepare_fp8_layer_for_marlin(
                layer, size_k_first, input_dtype=self.marlin_input_dtype
            )
530
531
            # Activations not quantized for marlin.
            del layer.input_scale
532
            return
533

534
        if self.block_quant:
535
            maybe_post_process_fp8_weight_block(layer)
536

537
538
539
540
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
541
        bias: torch.Tensor | None = None,
542
    ) -> torch.Tensor:
543
544
        # if batch invariant mode is enabled, prefer DeepGEMM FP8 path
        # we will use BF16 dequant when DeepGEMM is not supported.
545
        if vllm_is_batch_invariant():
546
547
            if self.block_quant:
                assert self.weight_block_size is not None
548
549
550
                return self.w8a8_block_fp8_linear.apply(
                    input=x,
                    weight=layer.weight,
551
                    weight_scale=layer.weight_scale_inv,
552
553
554
                    input_scale=layer.input_scale,
                    bias=bias,
                )
555
            else:
556
557
558
                # per-tensor/channel: dequant to BF16 and run GEMM
                weight_fp8 = layer.weight.to(torch.bfloat16)
                weight_scale = layer.weight_scale.to(torch.bfloat16)
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
                if weight_scale.numel() == 1:
                    # Per-tensor: simple scalar multiplication
                    weight_bf16 = weight_fp8 * weight_scale
                else:
                    # Multiple scales (fused modules like QKV)
                    # Try to infer correct broadcasting
                    # weight is [K, N], scale could be [num_logical_weights]
                    # Need to figure out how to broadcast - for now just try
                    # direct multiplication
                    if (
                        weight_scale.dim() == 1
                        and weight_scale.shape[0] == weight_fp8.shape[0]
                    ):
                        # Per-row scaling
                        weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1)
                    else:
                        # Fallback
                        weight_bf16 = weight_fp8 * weight_scale
577
                return torch.nn.functional.linear(x, weight_bf16.t(), bias)
578

579
        if self.use_marlin:
580
581
582
583
584
            if self.block_quant:
                weight_scale = layer.weight_scale_inv
            else:
                weight_scale = layer.weight_scale

585
586
587
            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
588
                weight_scale=weight_scale,
589
590
591
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
592
                input_dtype=self.marlin_input_dtype,
593
594
                bias=bias,
            )
595

596
        if self.block_quant:
597
598
599
            assert self.weight_block_size is not None

            return self.w8a8_block_fp8_linear.apply(
600
                input=x,
601
                weight=layer.weight,
602
                weight_scale=layer.weight_scale_inv,
603
                input_scale=layer.input_scale,
604
                bias=bias,
605
            )
606

607
        return self.fp8_linear.apply_weights(layer, x, bias)
608
609


610
611
612
613
614
615
616
617
618
619
620
621
622
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

623
624
    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(layer.moe_config)
625
        self.quant_config = quant_config
626
        self.weight_block_size = self.quant_config.weight_block_size
627
        self.block_quant: bool = self.weight_block_size is not None
628
629
630
        self.weight_scale_name = (
            "weight_scale_inv" if self.block_quant else "weight_scale"
        )
631
632
633
634
        self.fp8_backend = select_fp8_moe_backend(
            block_quant=self.block_quant,
            tp_size=layer.moe_parallel_config.tp_size,
            with_lora_support=self.moe.is_lora_enabled,
635
            is_act_and_mul=self.moe.is_act_and_mul,
636
        )
637

638
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
639
640
641
642
643
644
645
646
647
            if self.block_quant and self.weight_block_size != [128, 128]:
                raise NotImplementedError(
                    "FlashInfer CUTLASS FP8 MoE backend only supports block "
                    "size [128, 128]."
                )
            if layer.activation != "silu":
                raise NotImplementedError(
                    "FlashInfer CUTLASS FP8 MoE backend only supports SiLU "
                    "activation function, but got {layer.activation}."
648
                )
649
650
651
        dynamic_per_token = (
            not self.block_quant and self.quant_config.activation_scheme != "static"
        )
652
653
654
655
        if dynamic_per_token and self.fp8_backend in [
            Fp8MoeBackend.FLASHINFER_TRTLLM,
            Fp8MoeBackend.FLASHINFER_CUTLASS,
        ]:
656
657
658
659
            raise NotImplementedError(
                "FlashInfer FP8 MoE backend does not support dynamic per token "
                "activation quantization."
            )
660

661
662
        self.kernel: mk.FusedMoEModularKernel | None = None

663
664
665
666
667
668
669
670
671
    def create_weights(
        self,
        layer: Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
672
673
674
675
676
677
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

678
679
680
        assert self.quant_config.is_checkpoint_fp8_serialized
        params_dtype = torch.float8_e4m3fn

681
        if self.block_quant:
682
683
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
684
685
            tp_size = get_tensor_model_parallel_world_size()
            block_n, block_k = (
686
687
                self.weight_block_size[0],
                self.weight_block_size[1],
688
689
690
691
692
            )
            # NOTE: To ensure proper alignment of the block-wise quantization
            # scales, the output_size of the weights for both the gate and up
            # layers must be divisible by block_n.
            # Required by column parallel or enabling merged weights
693
            if intermediate_size_per_partition % block_n != 0:
694
695
                raise ValueError(
                    f"The output_size of gate's and up's weight = "
696
                    f"{intermediate_size_per_partition} is not divisible by "
697
698
699
                    f"weight quantization block_n = {block_n}."
                )
            if tp_size > 1 and intermediate_size_per_partition % block_k != 0:
700
                # Required by row parallel
701
702
703
                raise ValueError(
                    f"The input_size of down's weight = "
                    f"{intermediate_size_per_partition} is not divisible by "
704
705
                    f"weight quantization block_k = {block_k}."
                )
706
707

        # WEIGHTS
708
709
710
711
712
713
714
715
716
        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
717
718
719
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

720
721
722
723
724
725
726
727
728
        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
729
730
731
732
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
733
        if not self.block_quant:
734
735
736
            # For per-tensor quant, the scales are per expert and weight.
            w13_scale_data = torch.ones(num_experts, 2, dtype=torch.float32)
            w2_scale_data = torch.ones(num_experts, dtype=torch.float32)
737
        else:
738
739
740
741
742
743
            # For block quant, the scales are per block (typically 128x128).
            w13_scale_data = torch.ones(
                num_experts,
                2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
                (hidden_size + block_k - 1) // block_k,
                dtype=torch.float32,
744
            )
745
746
747
748
749
            w2_scale_data = torch.ones(
                num_experts,
                (hidden_size + block_n - 1) // block_n,
                (intermediate_size_per_partition + block_k - 1) // block_k,
                dtype=torch.float32,
750
            )
751
752
753
754
755
        w13_weight_scale = torch.nn.Parameter(w13_scale_data, requires_grad=False)
        w2_weight_scale = torch.nn.Parameter(w2_scale_data, requires_grad=False)
        # Note: name is weight_scale for tensor, weight_scale_inv for block.
        layer.register_parameter(f"w13_{self.weight_scale_name}", w13_weight_scale)
        layer.register_parameter(f"w2_{self.weight_scale_name}", w2_weight_scale)
756

757
758
759
        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
760
761
762
763
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
            if self.block_quant
            else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
764
765
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)
766
767
768

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
769
            assert not self.block_quant
770
771
772
            w13_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
773
            layer.register_parameter("w13_input_scale", w13_input_scale)
774
            set_weight_attrs(w13_input_scale, extra_weight_attrs)
775

776
777
778
            w2_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
779
            layer.register_parameter("w2_input_scale", w2_input_scale)
780
781
            set_weight_attrs(w2_input_scale, extra_weight_attrs)

782
        else:
783
784
            layer.w13_input_scale = None
            layer.w2_input_scale = None
785

786
    def _setup_kernel(
787
788
        self,
        layer: Module,
789
790
791
792
        w13: torch.Tensor,
        w2: torch.Tensor,
        w13_scale: torch.Tensor,
        w2_scale: torch.Tensor,
793
794
        w13_input_scale: torch.Tensor | None,
        w2_input_scale: torch.Tensor | None,
795
    ) -> None:
796
797
798
799
800
801
802
803
804
805
806
        # Shuffle weights to runtime format.
        w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format(
            fp8_backend=self.fp8_backend,
            layer=layer,
            w13=w13,
            w2=w2,
            w13_scale=w13_scale,
            w2_scale=w2_scale,
            w13_input_scale=w13_input_scale,
            w2_input_scale=w2_input_scale,
        )
807

808
809
        # Replace parameters with updated versions. Note that this helper
        # function ensures the replacement is compatible with RL weight reloads.
810
811
812
813
        replace_parameter(layer, "w13_weight", w13)
        replace_parameter(layer, "w2_weight", w2)
        replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_scale)
        replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_scale)
814

815
        # Setup modular kernel for TP case.
816
        self.moe_quant_config = self.get_fused_moe_quant_config(layer)
817
818
819
820
821
822
        if self.moe_quant_config:
            self.kernel, self.use_inplace = make_fp8_moe_kernel(
                layer=layer,
                moe_quant_config=self.moe_quant_config,
                moe_config=self.moe,
                fp8_backend=self.fp8_backend,
823
            )
824

825
826
827
828
829
    def process_weights_after_loading(self, layer: Module) -> None:
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

        # Allow for accessing weights and scales in standard way.
830
831
832
833
        w13 = layer.w13_weight
        w2 = layer.w2_weight
        w13_scale = getattr(layer, f"w13_{self.weight_scale_name}")
        w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
834
835
        w13_input_scale = layer.w13_input_scale
        w2_input_scale = layer.w2_input_scale
836
837
838

        # MI300x and MI325x use FNUZ format for FP8. Convert if needed.
        if current_platform.is_fp8_fnuz():
839
840
841
842
            w13, w13_scale, w13_input_scale = normalize_e4m3fn_to_e4m3fnuz(
                w13,
                w13_scale,
                w13_input_scale,
843
            )
844
845
846
847
            w2, w2_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz(
                w2,
                w2_scale,
                w2_input_scale,
848
849
850
851
852
            )

        # Per tensor kernels require single activation scale. Use the max.
        if self.quant_config.activation_scheme == "static":
            assert not self.block_quant
853
            assert w13_input_scale is not None and w2_input_scale is not None
854
855
856
857
858
            w13_input_scale, w2_input_scale = process_fp8_input_tensor_strategy_moe(
                w13_input_scale, w2_input_scale
            )
            replace_parameter(layer, "w13_input_scale", w13_input_scale)
            replace_parameter(layer, "w2_input_scale", w2_input_scale)
859
860
861
862
863

        # Per tensor kernels require single weight scale for w13 per expert, but
        # on disk there is a scale for w1 and w3. Use the max to requantize.
        if not self.block_quant:
            shard_size = layer.intermediate_size_per_partition
864
865
866
            w13, w13_scale = process_fp8_weight_tensor_strategy_moe(
                w13, w13_scale, shard_size, layer.local_num_experts
            )
867

868
869
870
        # Shuffle weights to runtime format and setup kernel.
        self._setup_kernel(
            layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale
871
872
        )

873
874
875
876
    def maybe_make_prepare_finalize(
        self,
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
    ) -> mk.FusedMoEPrepareAndFinalize | None:
877
878
879
880
881
        if self.fp8_backend in [
            Fp8MoeBackend.AITER,
            Fp8MoeBackend.MARLIN,
            Fp8MoeBackend.FLASHINFER_TRTLLM,
        ]:
882
            return None
883
884
885
886
887
888
889
        elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
            prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
                self.moe,
                use_deepseek_fp8_block_scale=self.block_quant,
            )
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
890
        return super().maybe_make_prepare_finalize(routing_tables)
891

bnellnm's avatar
bnellnm committed
892
893
894
    def select_gemm_impl(
        self,
        prepare_finalize: FusedMoEPrepareAndFinalize,
895
        layer: torch.nn.Module,
bnellnm's avatar
bnellnm committed
896
    ) -> FusedMoEPermuteExpertsUnpermute:
897
        from vllm.model_executor.layers.fused_moe import (
898
899
            BatchedDeepGemmExperts,
            BatchedTritonExperts,
900
            TritonExperts,
901
902
            TritonOrDeepGemmExperts,
        )
903

904
905
906
907
        if self.fp8_backend in [Fp8MoeBackend.MARLIN, Fp8MoeBackend.AITER]:
            raise NotImplementedError(
                "Marlin and ROCm AITER are not supported with all2all yet."
            )
908

909
910
        assert self.moe_quant_config is not None

911
912
913
914
915
        if (
            prepare_finalize.activation_format
            == FusedMoEActivationFormat.BatchedExperts
        ):
            max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
bnellnm's avatar
bnellnm committed
916
            assert max_num_tokens_per_rank is not None
917
918

            experts_impl = (
919
920
921
                BatchedDeepGemmExperts
                if self.fp8_backend == Fp8MoeBackend.DEEPGEMM
                else BatchedTritonExperts
922
            )
bnellnm's avatar
bnellnm committed
923
            logger.debug(
924
925
                "%s(%s): max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
                experts_impl.__name__,
926
927
928
929
930
                self.__class__.__name__,
                max_num_tokens_per_rank,
                self.weight_block_size,
                False,
            )
931
            return experts_impl(
932
                max_num_tokens=max_num_tokens_per_rank,
933
                num_dispatchers=prepare_finalize.num_dispatchers(),
934
                quant_config=self.moe_quant_config,
935
            )
936
937
        elif self.moe.is_lora_enabled:
            return TritonExperts(quant_config=self.moe_quant_config)
938
        elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
939
            # Select GEMM experts with block-scale when weights are block-quantized
940
            experts = select_cutlass_fp8_gemm_impl(
941
942
                self.moe,
                self.moe_quant_config,
943
                use_deepseek_fp8_block_scale=self.block_quant,
944
945
946
            )
            logger.debug_once("Using %s", experts.__class__.__name__)
            return experts
947
        elif self.fp8_backend == Fp8MoeBackend.DEEPGEMM:
bnellnm's avatar
bnellnm committed
948
949
            logger.debug(
                "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
950
951
952
953
                self.__class__.__name__,
                self.weight_block_size,
                False,
            )
954
955
956
957
958
959
960
961
            return TritonOrDeepGemmExperts(self.moe_quant_config)
        else:
            assert self.fp8_backend == Fp8MoeBackend.TRITON
            logger.debug(
                "TritonExperts(%s): block_size=%s, per_act_token=%s",
                self.__class__.__name__,
                self.weight_block_size,
                False,
962
            )
963
            return TritonExperts(self.moe_quant_config)
964

965
    def get_fused_moe_quant_config(
966
        self, layer: torch.nn.Module
967
    ) -> FusedMoEQuantConfig | None:
968
969
970
971
972
973
974
975
976
        # TRTLLM does not use Modular Kernel.
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
            return None

        w1_scale = getattr(layer, f"w13_{self.weight_scale_name}")
        w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
        a1_scale = layer.w13_input_scale
        a2_scale = layer.w2_input_scale

977
978
        return make_fp8_moe_quant_config(
            fp8_backend=self.fp8_backend,
979
980
981
982
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
983
            block_shape=self.weight_block_size,
984
985
        )

986
987
988
989
990
991
992
993
    @property
    def supports_eplb(self) -> bool:
        return True

    @property
    def allow_inplace(self) -> bool:
        return True

994
995
    def apply(
        self,
996
        layer: FusedMoE,
997
        router: FusedMoERouter,
998
999
        x: torch.Tensor,
        router_logits: torch.Tensor,
1000
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1001
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
1002
            # TODO(rob): convert this to MK.
1003
1004
1005
1006
            if layer.enable_eplb:
                raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.")
            assert layer.activation == "silu", (
                f"Expected 'silu' activation but got {layer.activation}"
1007
            )
1008

1009
            if self.block_quant:
1010
                import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe  # noqa: E501, F401
1011
1012

                e_score_correction_bias = (
1013
1014
                    layer.e_score_correction_bias.to(x.dtype)
                    if layer.e_score_correction_bias is not None
1015
1016
                    else None
                )
1017
                routing_method_type = layer.routing_method_type
1018
                return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
1019
1020
1021
                    routing_logits=router_logits.to(torch.float32)
                    if routing_method_type == RoutingMethodType.DeepSeekV3
                    else router_logits,
1022
1023
1024
1025
1026
1027
                    routing_bias=e_score_correction_bias,
                    x=x,
                    w13_weight=layer.w13_weight,
                    w13_weight_scale_inv=layer.w13_weight_scale_inv,
                    w2_weight=layer.w2_weight,
                    w2_weight_scale_inv=layer.w2_weight_scale_inv,
1028
1029
1030
1031
                    global_num_experts=layer.global_num_experts,
                    top_k=layer.top_k,
                    num_expert_group=layer.num_expert_group,
                    topk_group=layer.topk_group,
1032
1033
1034
                    intermediate_size=layer.intermediate_size_per_partition,
                    expert_offset=layer.ep_rank * layer.local_num_experts,
                    local_num_experts=layer.local_num_experts,
1035
                    block_shape=self.weight_block_size,
1036
                    routing_method_type=routing_method_type,
1037
                    routed_scaling=layer.routed_scaling_factor,
1038
1039
                )
            else:
1040
                result = apply_fi_trtllm_fp8_per_tensor_moe(
1041
1042
1043
                    layer=layer,
                    hidden_states=x,
                    router_logits=router_logits,
1044
1045
1046
1047
1048
1049
                    routing_bias=layer.e_score_correction_bias,
                    global_num_experts=layer.global_num_experts,
                    top_k=layer.top_k,
                    num_expert_group=layer.num_expert_group,
                    topk_group=layer.topk_group,
                    apply_router_weight_on_input=layer.apply_router_weight_on_input,
1050
                )
1051

1052
        topk_weights, topk_ids = router.select_experts(
1053
1054
1055
            hidden_states=x,
            router_logits=router_logits,
        )
1056
1057

        assert self.kernel is not None
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
        result = self.kernel(
            x,
            layer.w13_weight,
            layer.w2_weight,
            topk_weights,
            topk_ids,
            inplace=self.use_inplace,
            activation=layer.activation,
            global_num_experts=layer.global_num_experts,
            expert_map=layer.expert_map,
            apply_router_weight_on_input=layer.apply_router_weight_on_input,
        )
1070

1071
        return result
1072
1073


1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
class Fp8OnlineMoEMethod(Fp8MoEMethod):
    """MoE method for online FP8 quantization.
    Supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(quant_config, layer)
        assert not quant_config.is_checkpoint_fp8_serialized
        assert quant_config.activation_scheme == "dynamic"
        assert quant_config.weight_block_size is None

    def create_weights(
        self,
        layer: Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

        # We are doing online quantization, patch the weight loaded
        # to call `process_weights_after_loading` in a streaming fashion
        # as soon as the last weight chunk is loaded.
        weight_loader = extra_weight_attrs["weight_loader"]
        # create a new holder to prevent modifying behavior of any other
        # objects which might depend on the old one
        new_extra_weight_attrs = extra_weight_attrs

        def patched_weight_loader(param, loaded_weight, *args, **kwargs):
            # add a counter to track how many elements we have updated
            if not hasattr(layer, "_loaded_numel"):
                layer._loaded_numel = 0
1117
1118
1119
1120
1121
1122

            # load the current weight chunk
            copy_numel_counter = CopyNumelCounter()
            with copy_numel_counter:
                res = weight_loader(param, loaded_weight, *args, **kwargs)  # type: ignore[misc]
            layer._loaded_numel += copy_numel_counter.copied_numel
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176

            # if we have loaded all of the elements, call
            # process_weights_after_loading
            target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel()
            if layer._loaded_numel == target_loaded_numel:
                self.process_weights_after_loading(layer)

                # Delete the bookkeeping
                del layer._loaded_numel
                # Prevent the usual `process_weights_after_loading` call
                # from doing anything
                layer._already_called_process_weights_after_loading = True

            return res

        new_extra_weight_attrs["weight_loader"] = patched_weight_loader
        extra_weight_attrs = new_extra_weight_attrs

        # WEIGHTS
        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
        # Allocate 2 scales for w1 and w3 respectively.
        # They will be combined to a single scale after weight loading.
        w13_weight_scale = torch.nn.Parameter(
            torch.ones(num_experts, dtype=torch.float32), requires_grad=False
        )
        w2_weight_scale = torch.nn.Parameter(
            torch.ones(num_experts, dtype=torch.float32), requires_grad=False
        )
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        layer.register_parameter("w2_weight_scale", w2_weight_scale)
1177
1178
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188

        layer.w13_input_scale = None
        layer.w2_input_scale = None

    def process_weights_after_loading(self, layer: Module) -> None:
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

        # If checkpoint is fp16, quantize in place.
        fp8_dtype = current_platform.fp8_dtype()
1189
1190
1191
1192
        w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
        w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
        w13_scale = layer.w13_weight_scale
        w2_scale = layer.w2_weight_scale
1193
1194

        for expert in range(layer.local_num_experts):
1195
1196
            w13[expert, :, :], w13_scale[expert] = ops.scaled_fp8_quant(
                layer.w13_weight[expert, :, :]
1197
            )
1198
1199
            w2[expert, :, :], w2_scale[expert] = ops.scaled_fp8_quant(
                layer.w2_weight[expert, :, :]
1200
1201
            )

1202
1203
1204
1205
1206
1207
1208
1209
1210
        # Shuffle weights to runtime format and setup kernel.
        self._setup_kernel(
            layer,
            w13,
            w2,
            w13_scale,
            w2_scale,
            layer.w13_input_scale,
            layer.w2_input_scale,
1211
        )
1212
1213


1214
1215
1216
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
1217
1218
1219
    """

    def __init__(self, quant_config: Fp8Config):
1220
        super().__init__(quant_config)