fp8.py 47.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import TYPE_CHECKING, Any, Optional
5
6
7

import torch
from torch.nn import Module
8
from torch.utils._python_dispatch import TorchDispatchMode
9

10
import vllm.envs as envs
11
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
12
from vllm import _custom_ops as ops
13
from vllm._aiter_ops import rocm_aiter_ops
14
from vllm.attention.layer import Attention
15
from vllm.distributed import get_tensor_model_parallel_world_size
16
from vllm.logger import init_logger
17
from vllm.model_executor.layers.batch_invariant import (
18
    vllm_is_batch_invariant,
19
)
bnellnm's avatar
bnellnm committed
20
from vllm.model_executor.layers.fused_moe import (
21
22
23
24
25
26
27
    FusedMoE,
    FusedMoEActivationFormat,
    FusedMoEMethodBase,
    FusedMoEPermuteExpertsUnpermute,
    FusedMoEPrepareAndFinalize,
    FusedMoeWeightScaleSupported,
)
28
from vllm.model_executor.layers.fused_moe.config import (
29
    FusedMoEQuantConfig,
30
    RoutingMethodType,
31
32
)
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
33
34
35
36
37
38
39
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
    Fp8MoeBackend,
    convert_to_fp8_moe_kernel_format,
    make_fp8_moe_kernel,
    make_fp8_moe_quant_config,
    select_fp8_moe_backend,
)
40
41
42
43
44
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
45
from vllm.model_executor.layers.quantization import QuantizationMethods
46
from vllm.model_executor.layers.quantization.base_config import (
47
48
49
    QuantizationConfig,
    QuantizeMethodBase,
)
50
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
51
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
52
    apply_fi_trtllm_fp8_per_tensor_moe,
53
    build_flashinfer_fp8_cutlass_moe_prepare_finalize,
54
55
    select_cutlass_fp8_gemm_impl,
)
56
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
57
58
59
60
61
    W8A8BlockFp8LinearOp,
    create_fp8_input_scale,
    create_fp8_scale_parameter,
    create_fp8_weight_parameter,
    maybe_post_process_fp8_weight_block,
62
    process_fp8_input_tensor_strategy_moe,
63
64
    process_fp8_weight_block_strategy,
    process_fp8_weight_tensor_strategy,
65
    process_fp8_weight_tensor_strategy_moe,
66
67
    validate_fp8_block_shape,
)
68
69
70
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
    get_marlin_input_dtype,
)
71
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
72
73
74
    apply_fp8_marlin_linear,
    prepare_fp8_layer_for_marlin,
)
75
from vllm.model_executor.layers.quantization.utils.quant_utils import (
76
77
78
    GroupShape,
    is_layer_skipped,
)
79
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
80
81
82
83
84
85
86
87
88
89
90
    Fp8LinearOp,
    cutlass_block_fp8_supported,
    cutlass_fp8_supported,
    maybe_create_device_identity,
    normalize_e4m3fn_to_e4m3fnuz,
)
from vllm.model_executor.parameter import (
    BlockQuantScaleParameter,
    ModelWeightParameter,
    PerTensorScaleParameter,
)
91
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
92
from vllm.platforms import current_platform
93
94
95
from vllm.utils.deep_gemm import (
    is_deep_gemm_supported,
)
96

97
98
99
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

100
101
102
103
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

104

105
class Fp8Config(QuantizationConfig):
106
107
    """Config class for FP8."""

108
109
    def __init__(
        self,
110
        is_checkpoint_fp8_serialized: bool = False,
111
        activation_scheme: str = "dynamic",
112
113
        ignored_layers: list[str] | None = None,
        weight_block_size: list[int] | None = None,
114
    ) -> None:
115
        super().__init__()
116

117
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
118

119
        if activation_scheme not in ACTIVATION_SCHEMES:
120
            raise ValueError(f"Unsupported activation scheme {activation_scheme}")
121
        self.activation_scheme = activation_scheme
122
        self.ignored_layers = ignored_layers or []
123
124
125
126
        if weight_block_size is not None:
            if not is_checkpoint_fp8_serialized:
                raise ValueError(
                    "The block-wise quantization only supports fp8-serialized "
127
128
                    "checkpoint for now."
                )
129
130
131
            if len(weight_block_size) != 2:
                raise ValueError(
                    "The quantization block size of weight must have 2 "
132
133
                    f"dimensions, but got {len(weight_block_size)} dimensions"
                )
134
            if activation_scheme != "dynamic":
135
136
137
138
139
                raise ValueError(
                    "The block-wise quantization only supports "
                    "dynamic activation scheme for now, but got "
                    f"{activation_scheme} activation scheme."
                )
140
        self.weight_block_size = weight_block_size
141

142
    @classmethod
143
    def get_name(cls) -> QuantizationMethods:
144
145
146
        return "fp8"

    @classmethod
147
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
148
149
150
151
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
152
        return 75
153
154

    @classmethod
155
    def get_config_filenames(cls) -> list[str]:
156
157
        return []

158
159
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.ignored_layers is not None:
160
            self.ignored_layers = hf_to_vllm_mapper.apply_list(self.ignored_layers)
161

162
    @classmethod
163
    def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
164
        quant_method = cls.get_from_keys(config, ["quant_method"])
165
        is_checkpoint_fp8_serialized = "fp8" in quant_method
166
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
167
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
168
        weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
169
        if not ignored_layers:
170
171
172
173
174
175
176
177
178
179
180
181
182
            ignored_layers = cls.get_from_keys_or(
                config, ["modules_to_not_convert"], None
            )
        return cls(
            is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
            activation_scheme=activation_scheme,
            ignored_layers=ignored_layers,
            weight_block_size=weight_block_size,
        )

    def get_xpu_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
183
        from vllm.model_executor.layers.quantization.ipex_quant import (
184
185
186
187
            XPUFp8LinearMethod,
            XPUFp8MoEMethod,
        )

188
189
190
191
        fp8_config = Fp8Config(
            is_checkpoint_fp8_serialized=self.is_checkpoint_fp8_serialized,
            activation_scheme=self.activation_scheme,
            ignored_layers=self.ignored_layers,
192
193
            weight_block_size=self.weight_block_size,
        )
194
195

        if isinstance(layer, LinearBase):
196
197
198
199
200
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
201
202
203
                return UnquantizedLinearMethod()
            return XPUFp8LinearMethod(fp8_config)
        elif isinstance(layer, FusedMoE):
204
205
206
207
208
209
210
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
                return UnquantizedFusedMoEMethod(layer.moe_config)

211
212
213
214
215
            return XPUFp8MoEMethod(fp8_config, layer)
        elif isinstance(layer, Attention):
            return Fp8KVCacheMethod(self)
        return None

216
217
218
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
219
220
        if current_platform.is_xpu():
            return self.get_xpu_quant_method(layer, prefix)
221
        if isinstance(layer, LinearBase):
222
223
224
225
226
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
227
                return UnquantizedLinearMethod()
228
229
230
            quant_method = Fp8LinearMethod(self)
            quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
            return quant_method
231
        elif isinstance(layer, FusedMoE):
232
233
234
235
236
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
XuruiYang's avatar
XuruiYang committed
237
                return UnquantizedFusedMoEMethod(layer.moe_config)
238
239
240
241
            if self.is_checkpoint_fp8_serialized:
                moe_quant_method = Fp8MoEMethod(self, layer)
            else:
                moe_quant_method = Fp8OnlineMoEMethod(self, layer)
242
            return moe_quant_method
243
        elif isinstance(layer, Attention):
244
            return Fp8KVCacheMethod(self)
245
        return None
246

247
    def get_cache_scale(self, name: str) -> str | None:
248
249
250
251
252
253
254
255
256
257
258
259
        """
        Check whether the param name matches the format for k/v cache scales
        in compressed-tensors. If this is the case, return its equivalent
        param name expected by vLLM

        :param name: param name
        :return: matching param name for KV cache scale in vLLM
        """
        if name.endswith(".output_scale") and ".k_proj" in name:
            return name.replace(".k_proj.output_scale", ".attn.k_scale")
        if name.endswith(".output_scale") and ".v_proj" in name:
            return name.replace(".v_proj.output_scale", ".attn.v_scale")
260
261
262
263
264
        if name.endswith(".output_scale") and ".q_proj" in name:
            return name.replace(".q_proj.output_scale", ".attn.q_scale")
        if name.endswith("self_attn.prob_output_scale"):
            return name.replace(".prob_output_scale", ".attn.prob_scale")
        # If no matches, return None
265
266
        return None

267

268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
class CopyNumelCounter(TorchDispatchMode):
    """
    Tracks total number of elements modified with `copy_`. Useful for keeping
    track of weight loading where underlying weights can be arbitrarily
    transformed (such as with `narrow`) before calling copy.
    """

    def __init__(self):
        super().__init__()
        self.copied_numel = 0

    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        out = func(*args, **kwargs)
        if func == torch.ops.aten.copy_.default:
            self.copied_numel += args[0].numel()
        return out


288
289
class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
290
291
292
293
294
295
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.
296
297
298
299
300

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
301

302
303
304
305
    Args:
        quant_config: The quantization config.
    """

306
    def __init__(self, quant_config: Fp8Config):
307
        self.quant_config = quant_config
308
        self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
309
        self.out_dtype = torch.get_default_dtype()
310

311
312
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
313
        self.marlin_input_dtype = None
314
315
316
317
        self.use_marlin = (
            not current_platform.has_device_capability(89)
            or envs.VLLM_TEST_FORCE_FP8_MARLIN
        )
318
        # Disable marlin for rocm
319
        if current_platform.is_rocm():
320
            self.use_marlin = False
321
        if vllm_is_batch_invariant():
322
            self.use_marlin = False
323

324
        self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
325
        self.use_deep_gemm = is_deep_gemm_supported()
326

327
328
        self.weight_block_size = self.quant_config.weight_block_size
        self.block_quant = self.weight_block_size is not None
329
        self.act_q_static = self.quant_config.activation_scheme == "static"
330
331
        if self.weight_block_size:
            self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
332
        else:
333
334
335
336
337
            # Use per-token quantization for better perf if dynamic and cutlass
            if not self.act_q_static and cutlass_fp8_supported():
                self.act_q_group_shape = GroupShape.PER_TOKEN
            else:
                self.act_q_group_shape = GroupShape.PER_TENSOR
338

339
340
341
342
343
344
345
346
347
348
349
350
        if self.block_quant:
            assert not self.act_q_static
            assert self.weight_block_size is not None
            self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
                weight_group_shape=GroupShape(*self.weight_block_size),
                act_quant_group_shape=self.act_q_group_shape,
                cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
                use_aiter_and_is_supported=self.use_aiter_and_is_supported,
            )
        else:
            self.fp8_linear = Fp8LinearOp(
                act_quant_static=self.act_q_static,
351
352
                act_quant_group_shape=self.act_q_group_shape,
            )
353

354
355
356
357
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
358
        output_partition_sizes: list[int],
359
360
361
362
363
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
364
365
        maybe_create_device_identity()

366
        output_size_per_partition = sum(output_partition_sizes)
367
        weight_loader = extra_weight_attrs.get("weight_loader")
368
369
370
371
372
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None
373

374
        if self.block_quant:
375
376
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
377
378
379
380
381
382
383
384
            validate_fp8_block_shape(
                layer,
                input_size,
                output_size,
                input_size_per_partition,
                output_partition_sizes,
                self.weight_block_size,
            )
385

386
        # WEIGHT
387
        if self.quant_config.is_checkpoint_fp8_serialized:
388
389
390
            weight = create_fp8_weight_parameter(
                output_size_per_partition, input_size_per_partition, weight_loader
            )
391
        else:
392
393
394
395
396

            def patched_weight_loader(param, loaded_weight, *args, **kwargs):
                # track how many elements we have updated
                if not hasattr(layer, "_loaded_numel"):
                    layer._loaded_numel = 0
397
398
399
400
401
402

                # load the current weight chunk
                copy_numel_counter = CopyNumelCounter()
                with copy_numel_counter:
                    res = weight_loader(param, loaded_weight, *args, **kwargs)  # type: ignore[misc]
                layer._loaded_numel += copy_numel_counter.copied_numel
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417

                # if we have loaded all of the elements, call
                # process_weights_after_loading
                target_loaded_numel = layer.weight.numel()
                if layer._loaded_numel == target_loaded_numel:
                    self.process_weights_after_loading(layer)

                    # Delete the bookkeeping
                    del layer._loaded_numel
                    # Prevent the usual `process_weights_after_loading` call from doing
                    # anything
                    layer._already_called_process_weights_after_loading = True

                return res

418
            # For non-serialized checkpoints, use original dtype
419
420
421
422
423
424
425
426
            weight = ModelWeightParameter(
                data=torch.empty(
                    output_size_per_partition,
                    input_size_per_partition,
                    dtype=params_dtype,
                ),
                input_dim=1,
                output_dim=0,
427
                weight_loader=patched_weight_loader,
428
            )
429
430
        layer.register_parameter("weight", weight)

431
432
433
434
        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
435
            if not self.block_quant:
436
437
438
439
440
441
442
                scale = create_fp8_scale_parameter(
                    PerTensorScaleParameter,
                    output_partition_sizes,
                    input_size_per_partition,
                    None,
                    weight_loader,
                )
443
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
444
445
                layer.register_parameter("weight_scale", scale)
            else:
446
447
                assert not self.act_q_static
                assert self.weight_block_size is not None
448
449
450
451
452
453
454
                scale = create_fp8_scale_parameter(
                    BlockQuantScaleParameter,
                    output_partition_sizes,
                    input_size_per_partition,
                    self.weight_block_size,
                    weight_loader,
                )
455
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
456
457
                # The weight_scale_inv name is intentional for deepseekv3
                layer.register_parameter("weight_scale_inv", scale)
458

459
            # INPUT ACTIVATION SCALE
460
            if self.act_q_static:
461
                scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
462
                set_weight_attrs(scale, {"scale_type": "input_scale"})
463
                layer.register_parameter("input_scale", scale)
464
465
            else:
                layer.register_parameter("input_scale", None)
466

467
    def process_weights_after_loading(self, layer: Module) -> None:
468
469
470
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

471
        size_k_first = True
472
        input_scale = None
473
        # TODO(rob): refactor block quant into separate class.
474
        if self.block_quant:
475
            assert not self.act_q_static
476
            size_k_first = False
477

478
            weight, weight_scale_inv = process_fp8_weight_block_strategy(
479
480
                layer.weight, layer.weight_scale_inv
            )
481
482
483
484

            # Update layer with new values
            replace_parameter(layer, "weight", weight.data)
            replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data)
485

486
        # If checkpoint not serialized fp8, quantize the weights.
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
        else:
            if not self.quant_config.is_checkpoint_fp8_serialized:
                qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
                weight = qweight.t()

            # If checkpoint is fp8 per-tensor, handle that there are N scales for N
            # shards in a fused module
            else:
                weight = layer.weight
                weight_scale = layer.weight_scale

                # If using w8a8, torch._scaled_mm needs per tensor, so
                # requantize the logical shards as a single weight.
                if not self.use_marlin:
                    weight, weight_scale, input_scale = (
                        process_fp8_weight_tensor_strategy(
                            weight,
                            weight_scale,
                            layer.logical_widths,
                            getattr(layer, "input_scale", None),
                        )
                    )
                    if self.act_q_static:
                        assert input_scale is not None
                        input_scale = input_scale.max()
                weight = weight.t()
513

514
515
516
517
518
519
            # Update layer with new values.
            replace_parameter(layer, "weight", weight.data)
            replace_parameter(layer, "weight_scale", weight_scale.data)

        if input_scale is not None:
            replace_parameter(layer, "input_scale", input_scale)
520
        else:
521
            layer.input_scale = None
522

523
        if self.use_marlin:
524
525
526
            prepare_fp8_layer_for_marlin(
                layer, size_k_first, input_dtype=self.marlin_input_dtype
            )
527
528
            # Activations not quantized for marlin.
            del layer.input_scale
529
            return
530

531
        if self.block_quant:
532
            maybe_post_process_fp8_weight_block(layer)
533

534
535
536
537
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
538
        bias: torch.Tensor | None = None,
539
    ) -> torch.Tensor:
540
541
        # if batch invariant mode is enabled, prefer DeepGEMM FP8 path
        # we will use BF16 dequant when DeepGEMM is not supported.
542
        if vllm_is_batch_invariant():
543
544
            if self.block_quant:
                assert self.weight_block_size is not None
545
546
547
                return self.w8a8_block_fp8_linear.apply(
                    input=x,
                    weight=layer.weight,
548
                    weight_scale=layer.weight_scale_inv,
549
550
551
                    input_scale=layer.input_scale,
                    bias=bias,
                )
552
            else:
553
554
555
                # per-tensor/channel: dequant to BF16 and run GEMM
                weight_fp8 = layer.weight.to(torch.bfloat16)
                weight_scale = layer.weight_scale.to(torch.bfloat16)
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
                if weight_scale.numel() == 1:
                    # Per-tensor: simple scalar multiplication
                    weight_bf16 = weight_fp8 * weight_scale
                else:
                    # Multiple scales (fused modules like QKV)
                    # Try to infer correct broadcasting
                    # weight is [K, N], scale could be [num_logical_weights]
                    # Need to figure out how to broadcast - for now just try
                    # direct multiplication
                    if (
                        weight_scale.dim() == 1
                        and weight_scale.shape[0] == weight_fp8.shape[0]
                    ):
                        # Per-row scaling
                        weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1)
                    else:
                        # Fallback
                        weight_bf16 = weight_fp8 * weight_scale
574
                return torch.nn.functional.linear(x, weight_bf16.t(), bias)
575

576
        if self.use_marlin:
577
578
579
580
581
            if self.block_quant:
                weight_scale = layer.weight_scale_inv
            else:
                weight_scale = layer.weight_scale

582
583
584
            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
585
                weight_scale=weight_scale,
586
587
588
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
589
                input_dtype=self.marlin_input_dtype,
590
591
                bias=bias,
            )
592

593
        if self.block_quant:
594
595
596
            assert self.weight_block_size is not None

            return self.w8a8_block_fp8_linear.apply(
597
                input=x,
598
                weight=layer.weight,
599
                weight_scale=layer.weight_scale_inv,
600
                input_scale=layer.input_scale,
601
                bias=bias,
602
            )
603

604
605
606
607
608
609
610
611
        return self.fp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            out_dtype=self.out_dtype,
            input_scale=layer.input_scale,
            bias=bias,
        )
612
613


614
615
616
617
618
619
620
621
622
623
624
625
626
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

627
628
    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(layer.moe_config)
629
        self.quant_config = quant_config
630
        self.weight_block_size = self.quant_config.weight_block_size
631
        self.block_quant: bool = self.weight_block_size is not None
632
633
634
        self.weight_scale_name = (
            "weight_scale_inv" if self.block_quant else "weight_scale"
        )
635
636
637
638
        self.fp8_backend = select_fp8_moe_backend(
            block_quant=self.block_quant,
            tp_size=layer.moe_parallel_config.tp_size,
            with_lora_support=self.moe.is_lora_enabled,
639
        )
640

641
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
642
643
644
645
646
647
648
649
650
            if self.block_quant and self.weight_block_size != [128, 128]:
                raise NotImplementedError(
                    "FlashInfer CUTLASS FP8 MoE backend only supports block "
                    "size [128, 128]."
                )
            if layer.activation != "silu":
                raise NotImplementedError(
                    "FlashInfer CUTLASS FP8 MoE backend only supports SiLU "
                    "activation function, but got {layer.activation}."
651
                )
652
653
654
        dynamic_per_token = (
            not self.block_quant and self.quant_config.activation_scheme != "static"
        )
655
656
657
658
        if dynamic_per_token and self.fp8_backend in [
            Fp8MoeBackend.FLASHINFER_TRTLLM,
            Fp8MoeBackend.FLASHINFER_CUTLASS,
        ]:
659
660
661
662
            raise NotImplementedError(
                "FlashInfer FP8 MoE backend does not support dynamic per token "
                "activation quantization."
            )
663

664
665
        self.kernel: mk.FusedMoEModularKernel | None = None

666
667
668
669
670
671
672
673
674
    def create_weights(
        self,
        layer: Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
675
676
677
678
679
680
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

681
682
683
        assert self.quant_config.is_checkpoint_fp8_serialized
        params_dtype = torch.float8_e4m3fn

684
        if self.block_quant:
685
686
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
687
688
            tp_size = get_tensor_model_parallel_world_size()
            block_n, block_k = (
689
690
                self.weight_block_size[0],
                self.weight_block_size[1],
691
692
693
694
695
            )
            # NOTE: To ensure proper alignment of the block-wise quantization
            # scales, the output_size of the weights for both the gate and up
            # layers must be divisible by block_n.
            # Required by column parallel or enabling merged weights
696
            if intermediate_size_per_partition % block_n != 0:
697
698
                raise ValueError(
                    f"The output_size of gate's and up's weight = "
699
                    f"{intermediate_size_per_partition} is not divisible by "
700
701
702
                    f"weight quantization block_n = {block_n}."
                )
            if tp_size > 1 and intermediate_size_per_partition % block_k != 0:
703
                # Required by row parallel
704
705
706
                raise ValueError(
                    f"The input_size of down's weight = "
                    f"{intermediate_size_per_partition} is not divisible by "
707
708
                    f"weight quantization block_k = {block_k}."
                )
709
710

        # WEIGHTS
711
712
713
714
715
716
717
718
719
        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
720
721
722
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

723
724
725
726
727
728
729
730
731
        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
732
733
734
735
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
736
        if not self.block_quant:
737
738
739
            # For per-tensor quant, the scales are per expert and weight.
            w13_scale_data = torch.ones(num_experts, 2, dtype=torch.float32)
            w2_scale_data = torch.ones(num_experts, dtype=torch.float32)
740
        else:
741
742
743
744
745
746
            # For block quant, the scales are per block (typically 128x128).
            w13_scale_data = torch.ones(
                num_experts,
                2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
                (hidden_size + block_k - 1) // block_k,
                dtype=torch.float32,
747
            )
748
749
750
751
752
            w2_scale_data = torch.ones(
                num_experts,
                (hidden_size + block_n - 1) // block_n,
                (intermediate_size_per_partition + block_k - 1) // block_k,
                dtype=torch.float32,
753
            )
754
755
756
757
758
        w13_weight_scale = torch.nn.Parameter(w13_scale_data, requires_grad=False)
        w2_weight_scale = torch.nn.Parameter(w2_scale_data, requires_grad=False)
        # Note: name is weight_scale for tensor, weight_scale_inv for block.
        layer.register_parameter(f"w13_{self.weight_scale_name}", w13_weight_scale)
        layer.register_parameter(f"w2_{self.weight_scale_name}", w2_weight_scale)
759

760
761
762
        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
763
764
765
766
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
            if self.block_quant
            else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
767
768
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)
769
770
771

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
772
            assert not self.block_quant
773
774
775
            w13_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
776
            layer.register_parameter("w13_input_scale", w13_input_scale)
777
            set_weight_attrs(w13_input_scale, extra_weight_attrs)
778

779
780
781
            w2_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
782
            layer.register_parameter("w2_input_scale", w2_input_scale)
783
784
            set_weight_attrs(w2_input_scale, extra_weight_attrs)

785
        else:
786
787
            layer.w13_input_scale = None
            layer.w2_input_scale = None
788

789
    def _setup_kernel(
790
791
        self,
        layer: Module,
792
793
794
795
        w13: torch.Tensor,
        w2: torch.Tensor,
        w13_scale: torch.Tensor,
        w2_scale: torch.Tensor,
796
797
        w13_input_scale: torch.Tensor | None,
        w2_input_scale: torch.Tensor | None,
798
    ) -> None:
799
800
801
802
803
804
805
806
807
808
809
        # Shuffle weights to runtime format.
        w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format(
            fp8_backend=self.fp8_backend,
            layer=layer,
            w13=w13,
            w2=w2,
            w13_scale=w13_scale,
            w2_scale=w2_scale,
            w13_input_scale=w13_input_scale,
            w2_input_scale=w2_input_scale,
        )
810

811
812
        # Replace parameters with updated versions. Note that this helper
        # function ensures the replacement is compatible with RL weight reloads.
813
814
815
816
        replace_parameter(layer, "w13_weight", w13)
        replace_parameter(layer, "w2_weight", w2)
        replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_scale)
        replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_scale)
817

818
        # Setup modular kernel for TP case.
819
        self.moe_quant_config = self.get_fused_moe_quant_config(layer)
820
821
822
823
824
825
        if self.moe_quant_config:
            self.kernel, self.use_inplace = make_fp8_moe_kernel(
                layer=layer,
                moe_quant_config=self.moe_quant_config,
                moe_config=self.moe,
                fp8_backend=self.fp8_backend,
826
            )
827

828
829
830
831
832
    def process_weights_after_loading(self, layer: Module) -> None:
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

        # Allow for accessing weights and scales in standard way.
833
834
835
836
        w13 = layer.w13_weight
        w2 = layer.w2_weight
        w13_scale = getattr(layer, f"w13_{self.weight_scale_name}")
        w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
837
838
        w13_input_scale = layer.w13_input_scale
        w2_input_scale = layer.w2_input_scale
839
840
841

        # MI300x and MI325x use FNUZ format for FP8. Convert if needed.
        if current_platform.is_fp8_fnuz():
842
843
844
845
            w13, w13_scale, w13_input_scale = normalize_e4m3fn_to_e4m3fnuz(
                w13,
                w13_scale,
                w13_input_scale,
846
            )
847
848
849
850
            w2, w2_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz(
                w2,
                w2_scale,
                w2_input_scale,
851
852
853
854
855
            )

        # Per tensor kernels require single activation scale. Use the max.
        if self.quant_config.activation_scheme == "static":
            assert not self.block_quant
856
            assert w13_input_scale is not None and w2_input_scale is not None
857
858
859
860
861
            w13_input_scale, w2_input_scale = process_fp8_input_tensor_strategy_moe(
                w13_input_scale, w2_input_scale
            )
            replace_parameter(layer, "w13_input_scale", w13_input_scale)
            replace_parameter(layer, "w2_input_scale", w2_input_scale)
862
863
864
865
866

        # Per tensor kernels require single weight scale for w13 per expert, but
        # on disk there is a scale for w1 and w3. Use the max to requantize.
        if not self.block_quant:
            shard_size = layer.intermediate_size_per_partition
867
868
869
            w13, w13_scale = process_fp8_weight_tensor_strategy_moe(
                w13, w13_scale, shard_size, layer.local_num_experts
            )
870

871
872
873
        # Shuffle weights to runtime format and setup kernel.
        self._setup_kernel(
            layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale
874
875
        )

876
877
878
879
    def maybe_make_prepare_finalize(
        self,
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
    ) -> mk.FusedMoEPrepareAndFinalize | None:
880
881
882
883
884
        if self.fp8_backend in [
            Fp8MoeBackend.AITER,
            Fp8MoeBackend.MARLIN,
            Fp8MoeBackend.FLASHINFER_TRTLLM,
        ]:
885
            return None
886
887
888
889
890
891
892
        elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
            prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
                self.moe,
                use_deepseek_fp8_block_scale=self.block_quant,
            )
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
893
        return super().maybe_make_prepare_finalize(routing_tables)
894

bnellnm's avatar
bnellnm committed
895
896
897
    def select_gemm_impl(
        self,
        prepare_finalize: FusedMoEPrepareAndFinalize,
898
        layer: torch.nn.Module,
bnellnm's avatar
bnellnm committed
899
    ) -> FusedMoEPermuteExpertsUnpermute:
900
        from vllm.model_executor.layers.fused_moe import (
901
902
            BatchedDeepGemmExperts,
            BatchedTritonExperts,
903
            TritonExperts,
904
905
            TritonOrDeepGemmExperts,
        )
906

907
908
909
910
        if self.fp8_backend in [Fp8MoeBackend.MARLIN, Fp8MoeBackend.AITER]:
            raise NotImplementedError(
                "Marlin and ROCm AITER are not supported with all2all yet."
            )
911

912
913
        assert self.moe_quant_config is not None

914
915
916
917
918
        if (
            prepare_finalize.activation_format
            == FusedMoEActivationFormat.BatchedExperts
        ):
            max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
bnellnm's avatar
bnellnm committed
919
            assert max_num_tokens_per_rank is not None
920
921

            experts_impl = (
922
923
924
                BatchedDeepGemmExperts
                if self.fp8_backend == Fp8MoeBackend.DEEPGEMM
                else BatchedTritonExperts
925
            )
bnellnm's avatar
bnellnm committed
926
            logger.debug(
927
928
                "%s(%s): max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
                experts_impl.__name__,
929
930
931
932
933
                self.__class__.__name__,
                max_num_tokens_per_rank,
                self.weight_block_size,
                False,
            )
934
            return experts_impl(
935
                max_num_tokens=max_num_tokens_per_rank,
936
                num_dispatchers=prepare_finalize.num_dispatchers(),
937
                quant_config=self.moe_quant_config,
938
            )
939
940
        elif self.moe.is_lora_enabled:
            return TritonExperts(quant_config=self.moe_quant_config)
941
        elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
942
            # Select GEMM experts with block-scale when weights are block-quantized
943
            experts = select_cutlass_fp8_gemm_impl(
944
945
                self.moe,
                self.moe_quant_config,
946
                use_deepseek_fp8_block_scale=self.block_quant,
947
948
949
            )
            logger.debug_once("Using %s", experts.__class__.__name__)
            return experts
950
        elif self.fp8_backend == Fp8MoeBackend.DEEPGEMM:
bnellnm's avatar
bnellnm committed
951
952
            logger.debug(
                "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
953
954
955
956
                self.__class__.__name__,
                self.weight_block_size,
                False,
            )
957
958
959
960
961
962
963
964
            return TritonOrDeepGemmExperts(self.moe_quant_config)
        else:
            assert self.fp8_backend == Fp8MoeBackend.TRITON
            logger.debug(
                "TritonExperts(%s): block_size=%s, per_act_token=%s",
                self.__class__.__name__,
                self.weight_block_size,
                False,
965
            )
966
            return TritonExperts(self.moe_quant_config)
967

968
    def get_fused_moe_quant_config(
969
        self, layer: torch.nn.Module
970
    ) -> FusedMoEQuantConfig | None:
971
972
973
974
975
976
977
978
979
        # TRTLLM does not use Modular Kernel.
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
            return None

        w1_scale = getattr(layer, f"w13_{self.weight_scale_name}")
        w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
        a1_scale = layer.w13_input_scale
        a2_scale = layer.w2_input_scale

980
981
        return make_fp8_moe_quant_config(
            fp8_backend=self.fp8_backend,
982
983
984
985
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
986
            block_shape=self.weight_block_size,
987
988
        )

989
990
991
992
993
994
995
996
    @property
    def supports_eplb(self) -> bool:
        return True

    @property
    def allow_inplace(self) -> bool:
        return True

997
998
    def apply(
        self,
999
        layer: FusedMoE,
1000
1001
        x: torch.Tensor,
        router_logits: torch.Tensor,
1002
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1003
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
1004
            # TODO(rob): convert this to MK.
1005
1006
1007
1008
            if layer.enable_eplb:
                raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.")
            assert layer.activation == "silu", (
                f"Expected 'silu' activation but got {layer.activation}"
1009
            )
1010

1011
            if self.block_quant:
1012
                import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe  # noqa: E501, F401
1013
1014

                e_score_correction_bias = (
1015
1016
                    layer.e_score_correction_bias.to(x.dtype)
                    if layer.e_score_correction_bias is not None
1017
1018
                    else None
                )
1019
                routing_method_type = layer.routing_method_type
1020
                return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
1021
1022
1023
                    routing_logits=router_logits.to(torch.float32)
                    if routing_method_type == RoutingMethodType.DeepSeekV3
                    else router_logits,
1024
1025
1026
1027
1028
1029
                    routing_bias=e_score_correction_bias,
                    x=x,
                    w13_weight=layer.w13_weight,
                    w13_weight_scale_inv=layer.w13_weight_scale_inv,
                    w2_weight=layer.w2_weight,
                    w2_weight_scale_inv=layer.w2_weight_scale_inv,
1030
1031
1032
1033
                    global_num_experts=layer.global_num_experts,
                    top_k=layer.top_k,
                    num_expert_group=layer.num_expert_group,
                    topk_group=layer.topk_group,
1034
1035
1036
                    intermediate_size=layer.intermediate_size_per_partition,
                    expert_offset=layer.ep_rank * layer.local_num_experts,
                    local_num_experts=layer.local_num_experts,
1037
                    block_shape=self.weight_block_size,
1038
                    routing_method_type=routing_method_type,
1039
                    routed_scaling=layer.routed_scaling_factor,
1040
1041
                )
            else:
1042
                result = apply_fi_trtllm_fp8_per_tensor_moe(
1043
1044
1045
                    layer=layer,
                    hidden_states=x,
                    router_logits=router_logits,
1046
1047
1048
1049
1050
1051
                    routing_bias=layer.e_score_correction_bias,
                    global_num_experts=layer.global_num_experts,
                    top_k=layer.top_k,
                    num_expert_group=layer.num_expert_group,
                    topk_group=layer.topk_group,
                    apply_router_weight_on_input=layer.apply_router_weight_on_input,
1052
                )
1053

1054
        topk_weights, topk_ids = layer.select_experts(
1055
1056
1057
            hidden_states=x,
            router_logits=router_logits,
        )
1058
1059

        assert self.kernel is not None
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
        result = self.kernel(
            x,
            layer.w13_weight,
            layer.w2_weight,
            topk_weights,
            topk_ids,
            inplace=self.use_inplace,
            activation=layer.activation,
            global_num_experts=layer.global_num_experts,
            expert_map=layer.expert_map,
            apply_router_weight_on_input=layer.apply_router_weight_on_input,
        )
1072

1073
        return result
1074
1075


1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
class Fp8OnlineMoEMethod(Fp8MoEMethod):
    """MoE method for online FP8 quantization.
    Supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(quant_config, layer)
        assert not quant_config.is_checkpoint_fp8_serialized
        assert quant_config.activation_scheme == "dynamic"
        assert quant_config.weight_block_size is None

    def create_weights(
        self,
        layer: Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

        # We are doing online quantization, patch the weight loaded
        # to call `process_weights_after_loading` in a streaming fashion
        # as soon as the last weight chunk is loaded.
        weight_loader = extra_weight_attrs["weight_loader"]
        # create a new holder to prevent modifying behavior of any other
        # objects which might depend on the old one
        new_extra_weight_attrs = extra_weight_attrs

        def patched_weight_loader(param, loaded_weight, *args, **kwargs):
            # add a counter to track how many elements we have updated
            if not hasattr(layer, "_loaded_numel"):
                layer._loaded_numel = 0
1119
1120
1121
1122
1123
1124

            # load the current weight chunk
            copy_numel_counter = CopyNumelCounter()
            with copy_numel_counter:
                res = weight_loader(param, loaded_weight, *args, **kwargs)  # type: ignore[misc]
            layer._loaded_numel += copy_numel_counter.copied_numel
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178

            # if we have loaded all of the elements, call
            # process_weights_after_loading
            target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel()
            if layer._loaded_numel == target_loaded_numel:
                self.process_weights_after_loading(layer)

                # Delete the bookkeeping
                del layer._loaded_numel
                # Prevent the usual `process_weights_after_loading` call
                # from doing anything
                layer._already_called_process_weights_after_loading = True

            return res

        new_extra_weight_attrs["weight_loader"] = patched_weight_loader
        extra_weight_attrs = new_extra_weight_attrs

        # WEIGHTS
        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
        # Allocate 2 scales for w1 and w3 respectively.
        # They will be combined to a single scale after weight loading.
        w13_weight_scale = torch.nn.Parameter(
            torch.ones(num_experts, dtype=torch.float32), requires_grad=False
        )
        w2_weight_scale = torch.nn.Parameter(
            torch.ones(num_experts, dtype=torch.float32), requires_grad=False
        )
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        layer.register_parameter("w2_weight_scale", w2_weight_scale)
1179
1180
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190

        layer.w13_input_scale = None
        layer.w2_input_scale = None

    def process_weights_after_loading(self, layer: Module) -> None:
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

        # If checkpoint is fp16, quantize in place.
        fp8_dtype = current_platform.fp8_dtype()
1191
1192
1193
1194
        w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
        w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
        w13_scale = layer.w13_weight_scale
        w2_scale = layer.w2_weight_scale
1195
1196

        for expert in range(layer.local_num_experts):
1197
1198
            w13[expert, :, :], w13_scale[expert] = ops.scaled_fp8_quant(
                layer.w13_weight[expert, :, :]
1199
            )
1200
1201
            w2[expert, :, :], w2_scale[expert] = ops.scaled_fp8_quant(
                layer.w2_weight[expert, :, :]
1202
1203
            )

1204
1205
1206
1207
1208
1209
1210
1211
1212
        # Shuffle weights to runtime format and setup kernel.
        self._setup_kernel(
            layer,
            w13,
            w2,
            w13_scale,
            w2_scale,
            layer.w13_input_scale,
            layer.w2_input_scale,
1213
        )
1214
1215


1216
1217
1218
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
1219
1220
1221
    """

    def __init__(self, quant_config: Fp8Config):
1222
        super().__init__(quant_config)