fp8.py 48.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import TYPE_CHECKING, Any
5
6
7

import torch
from torch.nn import Module
8
from torch.utils._python_dispatch import TorchDispatchMode
9

10
import vllm.envs as envs
11
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
12
from vllm import _custom_ops as ops
13
from vllm._aiter_ops import rocm_aiter_ops
14
from vllm.distributed import get_tensor_model_parallel_world_size
15
from vllm.logger import init_logger
16
17
18
from vllm.model_executor.kernels.linear import (
    init_fp8_linear_kernel,
)
19
from vllm.model_executor.layers.attention import Attention
20
from vllm.model_executor.layers.batch_invariant import (
21
    vllm_is_batch_invariant,
22
)
bnellnm's avatar
bnellnm committed
23
from vllm.model_executor.layers.fused_moe import (
24
25
26
27
    FusedMoE,
    FusedMoEMethodBase,
    FusedMoeWeightScaleSupported,
)
28
from vllm.model_executor.layers.fused_moe.config import (
29
30
31
    FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
32
33
34
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
    convert_to_fp8_moe_kernel_format,
    make_fp8_moe_kernel,
zhuwenwen's avatar
zhuwenwen committed
35
    make_fp8_moe_kernel_for_mkm,
36
37
38
    make_fp8_moe_quant_config,
    select_fp8_moe_backend,
)
39
40
41
42
43
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
44
from vllm.model_executor.layers.quantization import QuantizationMethods
45
from vllm.model_executor.layers.quantization.base_config import (
46
47
48
    QuantizationConfig,
    QuantizeMethodBase,
)
49
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
50
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
51
52
53
54
55
    W8A8BlockFp8LinearOp,
    create_fp8_input_scale,
    create_fp8_scale_parameter,
    create_fp8_weight_parameter,
    maybe_post_process_fp8_weight_block,
56
    process_fp8_input_tensor_strategy_moe,
57
58
    process_fp8_weight_block_strategy,
    process_fp8_weight_tensor_strategy,
59
    process_fp8_weight_tensor_strategy_moe,
60
61
    validate_fp8_block_shape,
)
62
63
64
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
    get_marlin_input_dtype,
)
65
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
66
67
68
    apply_fp8_marlin_linear,
    prepare_fp8_layer_for_marlin,
)
69
from vllm.model_executor.layers.quantization.utils.quant_utils import (
70
71
    GroupShape,
    is_layer_skipped,
72
    kFp8Dynamic128Sym,
73
74
    kFp8DynamicTensorSym,
    kFp8DynamicTokenSym,
75
    kFp8Static128BlockSym,
76
    kFp8StaticTensorSym,
77
)
78
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
79
80
81
82
    cutlass_block_fp8_supported,
    cutlass_fp8_supported,
    normalize_e4m3fn_to_e4m3fnuz,
)
83
from vllm.model_executor.model_loader.weight_utils import initialize_single_dummy_weight
84
85
86
87
88
from vllm.model_executor.parameter import (
    BlockQuantScaleParameter,
    ModelWeightParameter,
    PerTensorScaleParameter,
)
89
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
90
from vllm.platforms import current_platform
91
92
93
from vllm.utils.deep_gemm import (
    is_deep_gemm_supported,
)
94

95
96
97
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

98
99
100
101
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

102

103
class Fp8Config(QuantizationConfig):
104
105
    """Config class for FP8."""

106
107
    def __init__(
        self,
108
        is_checkpoint_fp8_serialized: bool = False,
109
        activation_scheme: str = "dynamic",
110
111
        ignored_layers: list[str] | None = None,
        weight_block_size: list[int] | None = None,
112
    ) -> None:
113
        super().__init__()
114

115
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
116

117
        if activation_scheme not in ACTIVATION_SCHEMES:
118
            raise ValueError(f"Unsupported activation scheme {activation_scheme}")
119
        self.activation_scheme = activation_scheme
120
        self.ignored_layers = ignored_layers or []
121
122
123
124
        if weight_block_size is not None:
            if not is_checkpoint_fp8_serialized:
                raise ValueError(
                    "The block-wise quantization only supports fp8-serialized "
125
126
                    "checkpoint for now."
                )
127
128
129
            if len(weight_block_size) != 2:
                raise ValueError(
                    "The quantization block size of weight must have 2 "
130
131
                    f"dimensions, but got {len(weight_block_size)} dimensions"
                )
132
            if activation_scheme != "dynamic":
133
134
135
136
137
                raise ValueError(
                    "The block-wise quantization only supports "
                    "dynamic activation scheme for now, but got "
                    f"{activation_scheme} activation scheme."
                )
138
        self.weight_block_size = weight_block_size
139

140
    @classmethod
141
    def get_name(cls) -> QuantizationMethods:
142
143
144
        return "fp8"

    @classmethod
145
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
146
147
148
149
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
150
        return 75
151
152

    @classmethod
153
    def get_config_filenames(cls) -> list[str]:
154
155
        return []

156
157
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.ignored_layers is not None:
158
            self.ignored_layers = hf_to_vllm_mapper.apply_list(self.ignored_layers)
159

160
    @classmethod
161
    def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
162
        quant_method = cls.get_from_keys(config, ["quant_method"])
163
        is_checkpoint_fp8_serialized = "fp8" in quant_method
164
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
165
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
166
        weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
167
        if not ignored_layers:
168
169
170
171
172
173
174
175
176
177
178
179
            ignored_layers = cls.get_from_keys_or(
                config, ["modules_to_not_convert"], None
            )
        return cls(
            is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
            activation_scheme=activation_scheme,
            ignored_layers=ignored_layers,
            weight_block_size=weight_block_size,
        )

    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
180
    ) -> "QuantizeMethodBase | None":
181
        if isinstance(layer, LinearBase):
182
183
184
185
186
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
187
                return UnquantizedLinearMethod()
188
189
190
191
192
193
194
195
            if not self.is_checkpoint_fp8_serialized:
                online_method = Fp8OnlineLinearMethod(self)
                online_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
                return online_method
            else:
                offline_method = Fp8LinearMethod(self)
                offline_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
                return offline_method
196
        elif isinstance(layer, FusedMoE):
197
198
199
200
201
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
XuruiYang's avatar
XuruiYang committed
202
                return UnquantizedFusedMoEMethod(layer.moe_config)
203
204
205
206
            if self.is_checkpoint_fp8_serialized:
                moe_quant_method = Fp8MoEMethod(self, layer)
            else:
                moe_quant_method = Fp8OnlineMoEMethod(self, layer)
207
            return moe_quant_method
208
        elif isinstance(layer, Attention):
209
            return Fp8KVCacheMethod(self)
210
        return None
211

212
    def get_cache_scale(self, name: str) -> str | None:
213
214
215
216
217
218
219
220
221
222
223
224
        """
        Check whether the param name matches the format for k/v cache scales
        in compressed-tensors. If this is the case, return its equivalent
        param name expected by vLLM

        :param name: param name
        :return: matching param name for KV cache scale in vLLM
        """
        if name.endswith(".output_scale") and ".k_proj" in name:
            return name.replace(".k_proj.output_scale", ".attn.k_scale")
        if name.endswith(".output_scale") and ".v_proj" in name:
            return name.replace(".v_proj.output_scale", ".attn.v_scale")
225
226
227
228
229
        if name.endswith(".output_scale") and ".q_proj" in name:
            return name.replace(".q_proj.output_scale", ".attn.q_scale")
        if name.endswith("self_attn.prob_output_scale"):
            return name.replace(".prob_output_scale", ".attn.prob_scale")
        # If no matches, return None
230
231
        return None

232

233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
class CopyNumelCounter(TorchDispatchMode):
    """
    Tracks total number of elements modified with `copy_`. Useful for keeping
    track of weight loading where underlying weights can be arbitrarily
    transformed (such as with `narrow`) before calling copy.
    """

    def __init__(self):
        super().__init__()
        self.copied_numel = 0

    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        out = func(*args, **kwargs)
        if func == torch.ops.aten.copy_.default:
            self.copied_numel += args[0].numel()
        return out


253
254
255
256
257
258
259
260
261
262
def _copy_missing_attrs(old: torch.Tensor, new: torch.Tensor) -> None:
    """Copies any attrs present in `old` but not in `new` to `new`"""
    new_attrs = set(dir(new))
    attrs_to_set = {}
    for attr in dir(old):
        if attr not in new_attrs:
            attrs_to_set[attr] = getattr(old, attr)
    set_weight_attrs(new, attrs_to_set)


263
264
class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
265
266
267
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

268
    Limitations:
269
    1. Only support float8_e4m3fn data type due to the limitation of
270
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
271

272
273
274
275
    Args:
        quant_config: The quantization config.
    """

276
    def __init__(self, quant_config: Fp8Config):
277
        self.quant_config = quant_config
278
        self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
279
        self.out_dtype = torch.get_default_dtype()
280

281
282
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
283
        self.marlin_input_dtype = None
284
285
286
287
        self.use_marlin = (
            not current_platform.has_device_capability(89)
            or envs.VLLM_TEST_FORCE_FP8_MARLIN
        )
288
        # Disable marlin for rocm
289
        if current_platform.is_rocm() or current_platform.is_xpu():
290
            self.use_marlin = False
291
        if vllm_is_batch_invariant():
292
            self.use_marlin = False
293

294
        self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
295
        self.use_deep_gemm = is_deep_gemm_supported()
296

297
298
        self.weight_block_size = self.quant_config.weight_block_size
        self.block_quant = self.weight_block_size is not None
299
300
        self.act_q_static = self.quant_config.activation_scheme == "static"

301
302
303
304
305
        if self.block_quant:
            assert not self.act_q_static
            assert self.weight_block_size is not None
            self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
                weight_group_shape=GroupShape(*self.weight_block_size),
306
                act_quant_group_shape=GroupShape(1, self.weight_block_size[0]),
307
308
309
310
                cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
                use_aiter_and_is_supported=self.use_aiter_and_is_supported,
            )
        else:
311
312
313
314
315
316
317
318
319
320
321
322
323
            # Use per-token quantization for better perf if dynamic and cutlass
            if self.act_q_static:
                activation_quant_key = kFp8StaticTensorSym
            elif cutlass_fp8_supported():
                activation_quant_key = kFp8DynamicTokenSym
            else:
                activation_quant_key = kFp8DynamicTensorSym

            self.fp8_linear = init_fp8_linear_kernel(
                activation_quant_key=activation_quant_key,
                weight_quant_key=kFp8StaticTensorSym,
                out_dtype=torch.get_default_dtype(),
                module_name=self.__class__.__name__,
324
            )
325

326
327
328
329
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
330
        output_partition_sizes: list[int],
331
332
333
334
335
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
336
        output_size_per_partition = sum(output_partition_sizes)
337
        weight_loader = extra_weight_attrs.get("weight_loader")
338
339
340
341
342
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None
343

344
        if self.block_quant:
345
346
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
347
348
349
350
351
352
353
354
            validate_fp8_block_shape(
                layer,
                input_size,
                output_size,
                input_size_per_partition,
                output_partition_sizes,
                self.weight_block_size,
            )
355

356
357
358
359
360
361
362
363
364
365
366
367
368
        weight = create_fp8_weight_parameter(
            output_size_per_partition, input_size_per_partition, weight_loader
        )
        layer.register_parameter("weight", weight)

        # WEIGHT SCALE
        if not self.block_quant:
            scale = create_fp8_scale_parameter(
                PerTensorScaleParameter,
                output_partition_sizes,
                input_size_per_partition,
                None,
                weight_loader,
369
            )
370
            layer.register_parameter("weight_scale", scale)
371
        else:
372
373
374
375
376
377
378
379
            assert not self.act_q_static
            assert self.weight_block_size is not None
            scale = create_fp8_scale_parameter(
                BlockQuantScaleParameter,
                output_partition_sizes,
                input_size_per_partition,
                self.weight_block_size,
                weight_loader,
380
            )
381
382
            # The weight_scale_inv name is intentional for deepseekv3
            layer.register_parameter("weight_scale_inv", scale)
383

384
385
386
387
388
        # INPUT ACTIVATION SCALE
        if self.act_q_static:
            scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
            set_weight_attrs(scale, {"scale_type": "input_scale"})
            layer.register_parameter("input_scale", scale)
389

390
    def process_weights_after_loading(self, layer: Module) -> None:
391
        size_k_first = True
392
        input_scale = None
393
        # TODO(rob): refactor block quant into separate class.
394
        if self.block_quant:
395
            assert not self.act_q_static
396
            size_k_first = False
397

398
            weight, weight_scale_inv = process_fp8_weight_block_strategy(
399
400
                layer.weight, layer.weight_scale_inv
            )
401
402
403
404

            # Update layer with new values
            replace_parameter(layer, "weight", weight.data)
            replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data)
405

406
        # If checkpoint not serialized fp8, quantize the weights.
407
408
409
        else:
            # If checkpoint is fp8 per-tensor, handle that there are N scales for N
            # shards in a fused module
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
            weight = layer.weight
            weight_scale = layer.weight_scale

            # If using w8a8, torch._scaled_mm needs per tensor, so
            # requantize the logical shards as a single weight.
            if not self.use_marlin:
                weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy(
                    weight,
                    weight_scale,
                    layer.logical_widths,
                    getattr(layer, "input_scale", None),
                )
                if self.act_q_static:
                    assert input_scale is not None
                    input_scale = input_scale.max()
            weight = weight.t()
426

427
428
429
430
431
432
            # Update layer with new values.
            replace_parameter(layer, "weight", weight.data)
            replace_parameter(layer, "weight_scale", weight_scale.data)

        if input_scale is not None:
            replace_parameter(layer, "input_scale", input_scale)
433
        else:
434
            layer.input_scale = None
435

436
        if self.use_marlin:
437
438
439
            prepare_fp8_layer_for_marlin(
                layer, size_k_first, input_dtype=self.marlin_input_dtype
            )
440
441
            # Activations not quantized for marlin.
            del layer.input_scale
442
            return
443

444
        if self.block_quant:
445
            maybe_post_process_fp8_weight_block(layer)
446

447
448
449
450
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
451
        bias: torch.Tensor | None = None,
452
    ) -> torch.Tensor:
453
454
        # if batch invariant mode is enabled, prefer DeepGEMM FP8 path
        # we will use BF16 dequant when DeepGEMM is not supported.
455
        if vllm_is_batch_invariant():
456
457
            if self.block_quant:
                assert self.weight_block_size is not None
458
459
460
                return self.w8a8_block_fp8_linear.apply(
                    input=x,
                    weight=layer.weight,
461
                    weight_scale=layer.weight_scale_inv,
462
463
464
                    input_scale=layer.input_scale,
                    bias=bias,
                )
465
            else:
466
467
468
                # per-tensor/channel: dequant to BF16 and run GEMM
                weight_fp8 = layer.weight.to(torch.bfloat16)
                weight_scale = layer.weight_scale.to(torch.bfloat16)
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
                if weight_scale.numel() == 1:
                    # Per-tensor: simple scalar multiplication
                    weight_bf16 = weight_fp8 * weight_scale
                else:
                    # Multiple scales (fused modules like QKV)
                    # Try to infer correct broadcasting
                    # weight is [K, N], scale could be [num_logical_weights]
                    # Need to figure out how to broadcast - for now just try
                    # direct multiplication
                    if (
                        weight_scale.dim() == 1
                        and weight_scale.shape[0] == weight_fp8.shape[0]
                    ):
                        # Per-row scaling
                        weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1)
                    else:
                        # Fallback
                        weight_bf16 = weight_fp8 * weight_scale
487
                return torch.nn.functional.linear(x, weight_bf16.t(), bias)
488

489
        if self.use_marlin:
490
491
492
493
494
            if self.block_quant:
                weight_scale = layer.weight_scale_inv
            else:
                weight_scale = layer.weight_scale

495
496
497
            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
498
                weight_scale=weight_scale,
499
500
501
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
502
                input_dtype=self.marlin_input_dtype,
503
504
                bias=bias,
            )
505

506
        if self.block_quant:
507
508
509
            assert self.weight_block_size is not None

            return self.w8a8_block_fp8_linear.apply(
510
                input=x,
511
                weight=layer.weight,
512
                weight_scale=layer.weight_scale_inv,
513
                input_scale=layer.input_scale,
514
                bias=bias,
515
            )
516

517
        return self.fp8_linear.apply_weights(layer, x, bias)
518
519


520
521
522
523
class Fp8OnlineLinearMethod(Fp8LinearMethod):
    """Online version of Fp8LinearMethod, loads the fp16/bf16 checkpoint
    and quantized the weights during loading."""

524
525
    uses_meta_device: bool = True

526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

        # WEIGHT
        def patched_weight_loader(param, loaded_weight, *args, **kwargs):
            # track how many elements we have updated
            if not hasattr(layer, "_loaded_numel"):
                layer._loaded_numel = 0

550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
                # when the first `loaded_weight` is about to be
                # loaded to `param`, materialize `param` just-in-time
                weight = ModelWeightParameter(
                    data=torch.empty_like(layer.weight, device=layer._load_device),
                    input_dim=1,
                    output_dim=0,
                    weight_loader=patched_weight_loader,
                )
                _copy_missing_attrs(layer.weight, weight)
                layer.register_parameter("weight", weight)
                del layer._load_device

            # refresh the reference to `param` to reflect just-in-time
            # materialization
            param = layer.weight

566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
            # load the current weight chunk
            copy_numel_counter = CopyNumelCounter()
            with copy_numel_counter:
                res = weight_loader(param, loaded_weight, *args, **kwargs)  # type: ignore[misc]
            layer._loaded_numel += copy_numel_counter.copied_numel

            # if we have loaded all of the elements, call
            # process_weights_after_loading
            target_loaded_numel = layer.weight.numel()
            if layer._loaded_numel == target_loaded_numel:
                self.process_weights_after_loading(layer)

                # Prevent the usual `process_weights_after_loading` call from doing
                # anything
                layer._already_called_process_weights_after_loading = True

582
583
584
585
586
                # Note that we keep `layer._loaded_numel` around just in case
                # there is logic added to vllm in the future which calls a
                # weight loader twice - we do not want to re-initialize in
                # that case.

587
588
589
590
591
592
            return res

        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition,
593
594
                # materialized just-in-time in `patched_weight_loader`
                device="meta",
595
596
597
598
599
600
                dtype=params_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=patched_weight_loader,
        )
601
602
        # stash the correct device for `patched_weight_loader`
        layer._load_device = torch.get_default_device()
603
604
605
606
607
608
        layer.register_parameter("weight", weight)

    def process_weights_after_loading(self, layer: Module) -> None:
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

609
610
611
612
613
614
615
616
617
618
619
620
621
        # deferred initialization of randomly initialized weights for the
        # `--load_format dummy` feature
        if layer.weight.device == torch.device("meta"):
            weight = ModelWeightParameter(
                data=torch.empty_like(layer.weight, device=layer._load_device),
                input_dim=1,
                output_dim=0,
                weight_loader=layer.weight.weight_loader,
            )
            _copy_missing_attrs(layer.weight, weight)
            layer.register_parameter("weight", weight)
            initialize_single_dummy_weight(layer.weight)

622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
        # TODO(future): support block_quant in online quant path
        assert not self.block_quant

        layer.input_scale = None
        qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
        weight = qweight.t()

        # Update layer with new values.
        replace_parameter(layer, "weight", weight.data)
        replace_parameter(layer, "weight_scale", weight_scale.data)

        if self.use_marlin:
            size_k_first = True
            prepare_fp8_layer_for_marlin(
                layer, size_k_first, input_dtype=self.marlin_input_dtype
            )
            # Activations not quantized for marlin.

640
641
642
        # Prevent duplicate processing (e.g., during weight reload)
        layer._already_called_process_weights_after_loading = True

643

644
645
646
647
648
649
650
651
652
653
654
655
656
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

657
658
    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(layer.moe_config)
659
        self.quant_config = quant_config
660
        self.weight_block_size = self.quant_config.weight_block_size
661
        self.block_quant: bool = self.weight_block_size is not None
662
663
664
        self.weight_scale_name = (
            "weight_scale_inv" if self.block_quant else "weight_scale"
        )
665

666
667
668
669
670
671
672
673
674
675
        # Set weight key and activation key for kernel compatibility
        if self.block_quant:
            weight_key = kFp8Static128BlockSym
            activation_key = kFp8Dynamic128Sym
        else:
            weight_key = kFp8StaticTensorSym
            activation_key = (
                kFp8StaticTensorSym
                if self.quant_config.activation_scheme == "static"
                else kFp8DynamicTensorSym
676
            )
677

678
679
680
681
682
683
684
685
        # Select Fp8 MoE backend
        self.fp8_backend, self.experts_cls = select_fp8_moe_backend(
            config=self.moe,
            weight_key=weight_key,
            activation_key=activation_key,
            allow_vllm_cutlass=False,
        )

zhuwenwen's avatar
zhuwenwen committed
686
687
688
689
690
691
692
693
694
        # Delay creation of the kernel until after process-weights.
        self.kernel: mk.FusedMoEModularKernel | None = None

    @property
    def topk_indices_dtype(self) -> torch.dtype | None:
        if self.kernel is not None:
            return self.kernel.prepare_finalize.topk_indices_dtype()
        return None

695
696
697
698
699
700
701
702
703
    def create_weights(
        self,
        layer: Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
704
705
706
707
708
709
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

710
711
712
        assert self.quant_config.is_checkpoint_fp8_serialized
        params_dtype = torch.float8_e4m3fn

713
        if self.block_quant:
714
715
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
716
717
            tp_size = get_tensor_model_parallel_world_size()
            block_n, block_k = (
718
719
                self.weight_block_size[0],
                self.weight_block_size[1],
720
721
722
723
724
            )
            # NOTE: To ensure proper alignment of the block-wise quantization
            # scales, the output_size of the weights for both the gate and up
            # layers must be divisible by block_n.
            # Required by column parallel or enabling merged weights
725
            if intermediate_size_per_partition % block_n != 0:
726
727
                raise ValueError(
                    f"The output_size of gate's and up's weight = "
728
                    f"{intermediate_size_per_partition} is not divisible by "
729
730
731
                    f"weight quantization block_n = {block_n}."
                )
            if tp_size > 1 and intermediate_size_per_partition % block_k != 0:
732
                # Required by row parallel
733
734
735
                raise ValueError(
                    f"The input_size of down's weight = "
                    f"{intermediate_size_per_partition} is not divisible by "
736
737
                    f"weight quantization block_k = {block_k}."
                )
738
739

        # WEIGHTS
740
741
742
743
744
745
746
747
748
        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
749
750
751
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

752
753
754
755
756
757
758
759
760
        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
761
762
763
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
        # BIASES (for models like GPT-OSS that have biased MoE)
        if self.moe.has_bias:
            w13_bias = torch.nn.Parameter(
                torch.zeros(
                    num_experts,
                    2 * intermediate_size_per_partition,
                    dtype=layer.orig_dtype,
                ),
                requires_grad=False,
            )
            layer.register_parameter("w13_bias", w13_bias)
            set_weight_attrs(w13_bias, extra_weight_attrs)
            w2_bias = torch.nn.Parameter(
                torch.zeros(num_experts, hidden_size, dtype=layer.orig_dtype),
                requires_grad=False,
            )
            layer.register_parameter("w2_bias", w2_bias)
            set_weight_attrs(w2_bias, extra_weight_attrs)

783
        # WEIGHT_SCALES
784
        if not self.block_quant:
785
786
787
            # For per-tensor quant, the scales are per expert and weight.
            w13_scale_data = torch.ones(num_experts, 2, dtype=torch.float32)
            w2_scale_data = torch.ones(num_experts, dtype=torch.float32)
788
        else:
789
790
791
792
793
794
            # For block quant, the scales are per block (typically 128x128).
            w13_scale_data = torch.ones(
                num_experts,
                2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
                (hidden_size + block_k - 1) // block_k,
                dtype=torch.float32,
795
            )
796
797
798
799
800
            w2_scale_data = torch.ones(
                num_experts,
                (hidden_size + block_n - 1) // block_n,
                (intermediate_size_per_partition + block_k - 1) // block_k,
                dtype=torch.float32,
801
            )
802
803
804
805
806
        w13_weight_scale = torch.nn.Parameter(w13_scale_data, requires_grad=False)
        w2_weight_scale = torch.nn.Parameter(w2_scale_data, requires_grad=False)
        # Note: name is weight_scale for tensor, weight_scale_inv for block.
        layer.register_parameter(f"w13_{self.weight_scale_name}", w13_weight_scale)
        layer.register_parameter(f"w2_{self.weight_scale_name}", w2_weight_scale)
807

808
809
810
        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
811
812
813
814
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
            if self.block_quant
            else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
815
816
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)
817
818
819

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
820
            assert not self.block_quant
821
822
823
            w13_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
824
            layer.register_parameter("w13_input_scale", w13_input_scale)
825
            set_weight_attrs(w13_input_scale, extra_weight_attrs)
826

827
828
829
            w2_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
830
            layer.register_parameter("w2_input_scale", w2_input_scale)
831
832
            set_weight_attrs(w2_input_scale, extra_weight_attrs)

833
        else:
834
835
            layer.w13_input_scale = None
            layer.w2_input_scale = None
836

837
    def _setup_kernel(
838
        self,
zhuwenwen's avatar
zhuwenwen committed
839
        layer: Module,
840
841
842
843
        w13: torch.Tensor,
        w2: torch.Tensor,
        w13_scale: torch.Tensor,
        w2_scale: torch.Tensor,
844
845
        w13_input_scale: torch.Tensor | None,
        w2_input_scale: torch.Tensor | None,
846
    ) -> None:
847
848
849
850
851
852
853
854
855
856
857
        # Shuffle weights to runtime format.
        w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format(
            fp8_backend=self.fp8_backend,
            layer=layer,
            w13=w13,
            w2=w2,
            w13_scale=w13_scale,
            w2_scale=w2_scale,
            w13_input_scale=w13_input_scale,
            w2_input_scale=w2_input_scale,
        )
858

859
860
        # Replace parameters with updated versions. Note that this helper
        # function ensures the replacement is compatible with RL weight reloads.
861
862
863
864
        replace_parameter(layer, "w13_weight", w13)
        replace_parameter(layer, "w2_weight", w2)
        replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_scale)
        replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_scale)
865

866
        self.moe_quant_config = self.get_fused_moe_quant_config(layer)
zhuwenwen's avatar
zhuwenwen committed
867
868
869
870
        if self.moe_quant_config and (
            (not self.moe.moe_parallel_config.use_all2all_kernels)
            or self.moe.moe_parallel_config.use_naive_all2all_kernels
        ):
871
            assert self.experts_cls is not None
872
            self.moe_kernel = make_fp8_moe_kernel(
873
874
875
                moe_quant_config=self.moe_quant_config,
                moe_config=self.moe,
                fp8_backend=self.fp8_backend,
876
                experts_cls=self.experts_cls,
877
            )
878

879
880
881
882
883
    def process_weights_after_loading(self, layer: Module) -> None:
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

        # Allow for accessing weights and scales in standard way.
884
885
886
887
        w13 = layer.w13_weight
        w2 = layer.w2_weight
        w13_scale = getattr(layer, f"w13_{self.weight_scale_name}")
        w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
888
889
        w13_input_scale = layer.w13_input_scale
        w2_input_scale = layer.w2_input_scale
890
891
892

        # MI300x and MI325x use FNUZ format for FP8. Convert if needed.
        if current_platform.is_fp8_fnuz():
893
894
895
896
            w13, w13_scale, w13_input_scale = normalize_e4m3fn_to_e4m3fnuz(
                w13,
                w13_scale,
                w13_input_scale,
897
            )
898
899
900
901
            w2, w2_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz(
                w2,
                w2_scale,
                w2_input_scale,
902
903
904
905
906
            )

        # Per tensor kernels require single activation scale. Use the max.
        if self.quant_config.activation_scheme == "static":
            assert not self.block_quant
907
            assert w13_input_scale is not None and w2_input_scale is not None
908
909
910
911
912
            w13_input_scale, w2_input_scale = process_fp8_input_tensor_strategy_moe(
                w13_input_scale, w2_input_scale
            )
            replace_parameter(layer, "w13_input_scale", w13_input_scale)
            replace_parameter(layer, "w2_input_scale", w2_input_scale)
913
914
915
916
917

        # Per tensor kernels require single weight scale for w13 per expert, but
        # on disk there is a scale for w1 and w3. Use the max to requantize.
        if not self.block_quant:
            shard_size = layer.intermediate_size_per_partition
918
919
920
            w13, w13_scale = process_fp8_weight_tensor_strategy_moe(
                w13, w13_scale, shard_size, layer.local_num_experts
            )
921

922
923
924
        # Shuffle weights to runtime format and setup kernel.
        self._setup_kernel(
            layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale
925
926
        )

927
928
929
        # Prevent duplicate processing (e.g., during weight reload)
        layer._already_called_process_weights_after_loading = True

930
931
932
    def maybe_make_prepare_finalize(
        self,
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
933
    ) -> mk.FusedMoEPrepareAndFinalizeModular | None:
934
935
936
        raise ValueError(
            f"{self.__class__.__name__} uses the new modular kernel initialization "
            "logic. This function should not be called."
937
        )
938

939
    def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
940
941
942
943
944
        w1_scale = getattr(layer, f"w13_{self.weight_scale_name}")
        w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
        a1_scale = layer.w13_input_scale
        a2_scale = layer.w2_input_scale

945
        quant_config = make_fp8_moe_quant_config(
946
            fp8_backend=self.fp8_backend,
947
948
949
950
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
951
            block_shape=self.weight_block_size,
952
953
        )

954
955
956
957
958
959
960
961
962
963
964
965
        # Inject biases into the quant config if the model has them
        # (e.g. GPT-OSS biased MoE)
        if quant_config is not None and self.moe.has_bias:
            w13_bias = getattr(layer, "w13_bias", None)
            w2_bias = getattr(layer, "w2_bias", None)
            if w13_bias is not None:
                quant_config._w1.bias = w13_bias
            if w2_bias is not None:
                quant_config._w2.bias = w2_bias

        return quant_config

966
967
968
969
    @property
    def supports_eplb(self) -> bool:
        return True

970
    def apply_monolithic(
971
        self,
972
        layer: FusedMoE,
973
974
        x: torch.Tensor,
        router_logits: torch.Tensor,
975
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
976
        assert self.is_monolithic
977
978
979
980
981
982
983
984
985
986
987
988
989
990
        assert self.moe_kernel is not None
        return self.moe_kernel.apply_monolithic(
            x,
            layer.w13_weight,
            layer.w2_weight,
            router_logits,
            activation=layer.activation,
            global_num_experts=layer.global_num_experts,
            expert_map=layer.expert_map,
            apply_router_weight_on_input=layer.apply_router_weight_on_input,
            num_expert_group=layer.num_expert_group,
            topk_group=layer.topk_group,
            e_score_correction_bias=layer.e_score_correction_bias,
            routed_scaling_factor=layer.routed_scaling_factor,
991
        )
992

993
994
995
996
997
998
    def apply(
        self,
        layer: FusedMoE,
        x: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
999
        shared_experts_input: torch.Tensor | None,
1000
1001
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        assert not self.is_monolithic
1002
1003
        assert self.moe_kernel is not None
        return self.moe_kernel.apply(
1004
1005
1006
1007
1008
1009
1010
1011
1012
            x,
            layer.w13_weight,
            layer.w2_weight,
            topk_weights,
            topk_ids,
            activation=layer.activation,
            global_num_experts=layer.global_num_experts,
            expert_map=layer.expert_map,
            apply_router_weight_on_input=layer.apply_router_weight_on_input,
1013
            shared_experts_input=shared_experts_input,
1014
        )
1015

1016

1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
class Fp8OnlineMoEMethod(Fp8MoEMethod):
    """MoE method for online FP8 quantization.
    Supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

1027
1028
    uses_meta_device: bool = True

1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(quant_config, layer)
        assert not quant_config.is_checkpoint_fp8_serialized
        assert quant_config.activation_scheme == "dynamic"
        assert quant_config.weight_block_size is None

    def create_weights(
        self,
        layer: Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

        # We are doing online quantization, patch the weight loaded
        # to call `process_weights_after_loading` in a streaming fashion
        # as soon as the last weight chunk is loaded.
        weight_loader = extra_weight_attrs["weight_loader"]
        # create a new holder to prevent modifying behavior of any other
        # objects which might depend on the old one
        new_extra_weight_attrs = extra_weight_attrs

        def patched_weight_loader(param, loaded_weight, *args, **kwargs):
            # add a counter to track how many elements we have updated
            if not hasattr(layer, "_loaded_numel"):
                layer._loaded_numel = 0
1062

1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
                # save the ids of original w13 and w2 so that we can
                # distinguish which one `param` should map to further
                # down in this file
                layer._w13_weight_orig_id = id(layer.w13_weight)
                layer._w2_weight_orig_id = id(layer.w2_weight)

                # when the first `loaded_weight` is about to be
                # loaded to `param`, materialize `param` just-in-time

                w13_weight = torch.nn.Parameter(
                    torch.empty_like(layer.w13_weight, device=layer._load_device),
                    requires_grad=False,
                )
                set_weight_attrs(w13_weight, extra_weight_attrs)
                _copy_missing_attrs(layer.w13_weight, w13_weight)
                layer.register_parameter("w13_weight", w13_weight)

                w2_weight = torch.nn.Parameter(
                    torch.empty_like(layer.w2_weight, device=layer._load_device),
                    requires_grad=False,
                )
                set_weight_attrs(w2_weight, extra_weight_attrs)
                _copy_missing_attrs(layer.w2_weight, w2_weight)
                layer.register_parameter("w2_weight", w2_weight)
                del layer._load_device

            # refresh the reference to `param` to reflect just-in-time
            # materialization
            if id(param) == layer._w13_weight_orig_id:
                param = layer.w13_weight
            elif id(param) == layer._w2_weight_orig_id:
                param = layer.w2_weight

1096
1097
1098
1099
1100
            # load the current weight chunk
            copy_numel_counter = CopyNumelCounter()
            with copy_numel_counter:
                res = weight_loader(param, loaded_weight, *args, **kwargs)  # type: ignore[misc]
            layer._loaded_numel += copy_numel_counter.copied_numel
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111

            # if we have loaded all of the elements, call
            # process_weights_after_loading
            target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel()
            if layer._loaded_numel == target_loaded_numel:
                self.process_weights_after_loading(layer)

                # Prevent the usual `process_weights_after_loading` call
                # from doing anything
                layer._already_called_process_weights_after_loading = True

1112
1113
1114
1115
1116
1117
                # Note that we keep `layer._loaded_numel`,
                # `layer._w13_weight_orig_id` and `layer._w2_weight_orig_id`
                # around because if EP is on, weight loaders for non-local
                # experts will run but not actually copy any elements, and we
                # need to not re-initialize in that case.

1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
            return res

        new_extra_weight_attrs["weight_loader"] = patched_weight_loader
        extra_weight_attrs = new_extra_weight_attrs

        # WEIGHTS
        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
1129
1130
                # materialized just-in-time in `patched_weight_loader`
                device="meta",
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
1143
1144
                # materialized just-in-time in `patched_weight_loader`
                device="meta",
1145
1146
1147
1148
1149
1150
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)
1151
1152
        # stash the correct device for `patched_weight_loader`
        layer._load_device = torch.get_default_device()
1153

1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
        # BIASES (for models like GPT-OSS that have biased MoE)
        if self.moe.has_bias:
            # Use the original weight_loader (not patched) for biases
            orig_extra_weight_attrs = dict(extra_weight_attrs)
            orig_extra_weight_attrs["weight_loader"] = weight_loader
            w13_bias = torch.nn.Parameter(
                torch.zeros(
                    num_experts,
                    2 * intermediate_size_per_partition,
                    dtype=layer.orig_dtype,
                ),
                requires_grad=False,
            )
            layer.register_parameter("w13_bias", w13_bias)
            set_weight_attrs(w13_bias, orig_extra_weight_attrs)
            w2_bias = torch.nn.Parameter(
                torch.zeros(num_experts, hidden_size, dtype=layer.orig_dtype),
                requires_grad=False,
            )
            layer.register_parameter("w2_bias", w2_bias)
            set_weight_attrs(w2_bias, orig_extra_weight_attrs)

1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
        # WEIGHT_SCALES
        # Allocate 2 scales for w1 and w3 respectively.
        # They will be combined to a single scale after weight loading.
        w13_weight_scale = torch.nn.Parameter(
            torch.ones(num_experts, dtype=torch.float32), requires_grad=False
        )
        w2_weight_scale = torch.nn.Parameter(
            torch.ones(num_experts, dtype=torch.float32), requires_grad=False
        )
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        layer.register_parameter("w2_weight_scale", w2_weight_scale)
1187
1188
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)
1189
1190
1191
1192
1193
1194
1195
1196

        layer.w13_input_scale = None
        layer.w2_input_scale = None

    def process_weights_after_loading(self, layer: Module) -> None:
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
        # deferred initialization of randomly initialized weights for the
        # `--load_format dummy` feature
        if layer.w13_weight.device == torch.device("meta"):
            w13_weight = torch.nn.Parameter(
                torch.empty_like(layer.w13_weight, device=layer._load_device),
                requires_grad=False,
            )
            set_weight_attrs(
                w13_weight, {"weight_loader": layer.w13_weight.weight_loader}
            )
            _copy_missing_attrs(layer.w13_weight, w13_weight)
            layer.register_parameter("w13_weight", w13_weight)
            initialize_single_dummy_weight(layer.w13_weight)
        if layer.w2_weight.device == torch.device("meta"):
            w2_weight = torch.nn.Parameter(
                torch.empty_like(layer.w2_weight, device=layer._load_device),
                requires_grad=False,
            )
            set_weight_attrs(
                w2_weight, {"weight_loader": layer.w2_weight.weight_loader}
            )
            _copy_missing_attrs(layer.w2_weight, w2_weight)
            layer.register_parameter("w2_weight", w2_weight)
            initialize_single_dummy_weight(layer.w2_weight)

1222
1223
        # If checkpoint is fp16, quantize in place.
        fp8_dtype = current_platform.fp8_dtype()
1224
1225
1226
1227
        w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
        w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
        w13_scale = layer.w13_weight_scale
        w2_scale = layer.w2_weight_scale
1228
1229

        for expert in range(layer.local_num_experts):
1230
1231
            w13[expert, :, :], w13_scale[expert] = ops.scaled_fp8_quant(
                layer.w13_weight[expert, :, :]
1232
            )
1233
1234
            w2[expert, :, :], w2_scale[expert] = ops.scaled_fp8_quant(
                layer.w2_weight[expert, :, :]
1235
1236
            )

1237
1238
1239
1240
1241
1242
1243
1244
1245
        # Shuffle weights to runtime format and setup kernel.
        self._setup_kernel(
            layer,
            w13,
            w2,
            w13_scale,
            w2_scale,
            layer.w13_input_scale,
            layer.w2_input_scale,
1246
        )
1247

1248
1249
1250
        # Prevent duplicate processing (e.g., during weight reload)
        layer._already_called_process_weights_after_loading = True

1251

1252
1253
1254
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
1255
1256
1257
    """

    def __init__(self, quant_config: Fp8Config):
zhuwenwen's avatar
zhuwenwen committed
1258
        super().__init__(quant_config)