"vscode:/vscode.git/clone" did not exist on "69ff99fdcddf4d6dbcebf5f750b121dd171b86a3"
fp8.py 39.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import TYPE_CHECKING, Any
5
6
7

import torch
from torch.nn import Module
8
from torch.utils._python_dispatch import TorchDispatchMode
9

10
import vllm.envs as envs
11
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
12
from vllm import _custom_ops as ops
13
from vllm.config import get_current_vllm_config
14
from vllm.distributed import get_tensor_model_parallel_world_size
15
from vllm.logger import init_logger
16
17
18
from vllm.model_executor.kernels.linear import (
    init_fp8_linear_kernel,
)
19
from vllm.model_executor.kernels.linear.scaled_mm import MarlinFP8ScaledMMLinearKernel
20
from vllm.model_executor.layers.attention import Attention
bnellnm's avatar
bnellnm committed
21
from vllm.model_executor.layers.fused_moe import (
22
23
24
25
    FusedMoE,
    FusedMoEMethodBase,
    FusedMoeWeightScaleSupported,
)
26
from vllm.model_executor.layers.fused_moe.config import (
27
28
29
    FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
30
31
32
33
34
35
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
    convert_to_fp8_moe_kernel_format,
    make_fp8_moe_kernel,
    make_fp8_moe_quant_config,
    select_fp8_moe_backend,
)
36
37
38
39
40
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
41
from vllm.model_executor.layers.quantization import QuantizationMethods
42
from vllm.model_executor.layers.quantization.base_config import (
43
44
45
    QuantizationConfig,
    QuantizeMethodBase,
)
46
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
47
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
48
49
50
    create_fp8_input_scale,
    create_fp8_scale_parameter,
    create_fp8_weight_parameter,
51
    process_fp8_input_tensor_strategy_moe,
52
    process_fp8_weight_tensor_strategy,
53
    process_fp8_weight_tensor_strategy_moe,
54
55
    validate_fp8_block_shape,
)
56
57
58
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
    get_marlin_input_dtype,
)
59
from vllm.model_executor.layers.quantization.utils.quant_utils import (
60
    GroupShape,
61
    create_fp8_quant_key,
62
    is_layer_skipped,
63
    kFp8Dynamic128Sym,
64
65
    kFp8DynamicTensorSym,
    kFp8DynamicTokenSym,
66
    kFp8Static128BlockSym,
67
    kFp8StaticTensorSym,
68
)
69
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
70
71
72
73
    cutlass_block_fp8_supported,
    cutlass_fp8_supported,
    normalize_e4m3fn_to_e4m3fnuz,
)
74
75
76
from vllm.model_executor.model_loader.reload.layerwise import (
    initialize_online_processing,
)
77
78
79
80
81
from vllm.model_executor.parameter import (
    BlockQuantScaleParameter,
    ModelWeightParameter,
    PerTensorScaleParameter,
)
82
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
83
from vllm.platforms import current_platform
84
85
86
from vllm.utils.deep_gemm import (
    is_deep_gemm_supported,
)
87

88
89
90
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

91
92
93
94
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

95

96
class Fp8Config(QuantizationConfig):
97
98
    """Config class for FP8."""

99
100
    def __init__(
        self,
101
        is_checkpoint_fp8_serialized: bool = False,
102
        activation_scheme: str = "dynamic",
103
104
        ignored_layers: list[str] | None = None,
        weight_block_size: list[int] | None = None,
105
    ) -> None:
106
        super().__init__()
107

108
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
109

110
        if activation_scheme not in ACTIVATION_SCHEMES:
111
            raise ValueError(f"Unsupported activation scheme {activation_scheme}")
112
        self.activation_scheme = activation_scheme
113
        self.ignored_layers = ignored_layers or []
114
115
116
117
        if weight_block_size is not None:
            if not is_checkpoint_fp8_serialized:
                raise ValueError(
                    "The block-wise quantization only supports fp8-serialized "
118
119
                    "checkpoint for now."
                )
120
121
122
            if len(weight_block_size) != 2:
                raise ValueError(
                    "The quantization block size of weight must have 2 "
123
124
                    f"dimensions, but got {len(weight_block_size)} dimensions"
                )
125
            if activation_scheme != "dynamic":
126
127
128
129
130
                raise ValueError(
                    "The block-wise quantization only supports "
                    "dynamic activation scheme for now, but got "
                    f"{activation_scheme} activation scheme."
                )
131
        self.weight_block_size = weight_block_size
132
        self.use_deep_gemm: bool | None = None
133

134
    @classmethod
135
    def get_name(cls) -> QuantizationMethods:
136
137
138
        return "fp8"

    @classmethod
139
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
140
141
142
143
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
144
        return 75
145
146

    @classmethod
147
    def get_config_filenames(cls) -> list[str]:
148
149
        return []

150
151
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.ignored_layers is not None:
152
            self.ignored_layers = hf_to_vllm_mapper.apply_list(self.ignored_layers)
153

154
    @classmethod
155
    def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
156
        quant_method = cls.get_from_keys(config, ["quant_method"])
157
        is_checkpoint_fp8_serialized = "fp8" in quant_method
158
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
159
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
160
        weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
161
        if not ignored_layers:
162
163
164
165
166
167
168
169
170
171
172
173
            ignored_layers = cls.get_from_keys_or(
                config, ["modules_to_not_convert"], None
            )
        return cls(
            is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
            activation_scheme=activation_scheme,
            ignored_layers=ignored_layers,
            weight_block_size=weight_block_size,
        )

    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
174
    ) -> "QuantizeMethodBase | None":
175
        if isinstance(layer, LinearBase):
176
177
178
179
180
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
181
                return UnquantizedLinearMethod()
182
183
184
185
186
187
188
189
            if not self.is_checkpoint_fp8_serialized:
                online_method = Fp8OnlineLinearMethod(self)
                online_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
                return online_method
            else:
                offline_method = Fp8LinearMethod(self)
                offline_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
                return offline_method
190
        elif isinstance(layer, FusedMoE):
191
192
193
194
195
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
XuruiYang's avatar
XuruiYang committed
196
                return UnquantizedFusedMoEMethod(layer.moe_config)
197
198
199
200
            if self.is_checkpoint_fp8_serialized:
                moe_quant_method = Fp8MoEMethod(self, layer)
            else:
                moe_quant_method = Fp8OnlineMoEMethod(self, layer)
201
            return moe_quant_method
202
        elif isinstance(layer, Attention):
203
            return Fp8KVCacheMethod(self)
204
        return None
205

206
    def get_cache_scale(self, name: str) -> str | None:
207
208
209
210
211
212
213
214
215
216
217
218
        """
        Check whether the param name matches the format for k/v cache scales
        in compressed-tensors. If this is the case, return its equivalent
        param name expected by vLLM

        :param name: param name
        :return: matching param name for KV cache scale in vLLM
        """
        if name.endswith(".output_scale") and ".k_proj" in name:
            return name.replace(".k_proj.output_scale", ".attn.k_scale")
        if name.endswith(".output_scale") and ".v_proj" in name:
            return name.replace(".v_proj.output_scale", ".attn.v_scale")
219
220
221
222
223
        if name.endswith(".output_scale") and ".q_proj" in name:
            return name.replace(".q_proj.output_scale", ".attn.q_scale")
        if name.endswith("self_attn.prob_output_scale"):
            return name.replace(".prob_output_scale", ".attn.prob_scale")
        # If no matches, return None
224
225
        return None

226

227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
class CopyNumelCounter(TorchDispatchMode):
    """
    Tracks total number of elements modified with `copy_`. Useful for keeping
    track of weight loading where underlying weights can be arbitrarily
    transformed (such as with `narrow`) before calling copy.
    """

    def __init__(self):
        super().__init__()
        self.copied_numel = 0

    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        out = func(*args, **kwargs)
        if func == torch.ops.aten.copy_.default:
            self.copied_numel += args[0].numel()
        return out


247
248
249
250
251
252
253
254
255
256
def _copy_missing_attrs(old: torch.Tensor, new: torch.Tensor) -> None:
    """Copies any attrs present in `old` but not in `new` to `new`"""
    new_attrs = set(dir(new))
    attrs_to_set = {}
    for attr in dir(old):
        if attr not in new_attrs:
            attrs_to_set[attr] = getattr(old, attr)
    set_weight_attrs(new, attrs_to_set)


257
258
class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
259
260
261
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

262
    Limitations:
263
    1. Only support float8_e4m3fn data type due to the limitation of
264
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
265

266
267
268
269
    Args:
        quant_config: The quantization config.
    """

270
    def __init__(self, quant_config: Fp8Config):
271
        self.quant_config = quant_config
272
        self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
273
        self.out_dtype = torch.get_default_dtype()
274
        self.input_dtype = get_current_vllm_config().model_config.dtype
275

276
277
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
278
        self.marlin_input_dtype = None
279
        self.use_marlin = False
280

281
282
283
284
        if self.quant_config.use_deep_gemm is not None:
            self.use_deep_gemm = self.quant_config.use_deep_gemm
        else:
            self.use_deep_gemm = is_deep_gemm_supported()
285

286
287
        self.weight_block_size = self.quant_config.weight_block_size
        self.block_quant = self.weight_block_size is not None
288
289
        self.act_q_static = self.quant_config.activation_scheme == "static"

290
291
292
        if self.block_quant:
            assert not self.act_q_static
            assert self.weight_block_size is not None
293
294
295
296
297
298
299

            self.activation_quant_key = create_fp8_quant_key(
                static=self.act_q_static,
                group_shape=GroupShape(1, self.weight_block_size[0]),
            )
            self.weight_quant_key = create_fp8_quant_key(
                static=True, group_shape=GroupShape(*self.weight_block_size)
300
            )
301
302
303
304
305
306
307
308
309
        else:
            self.weight_quant_key = kFp8StaticTensorSym
            # Use per-token quantization for better perf if dynamic and cutlass
            if self.act_q_static:
                self.activation_quant_key = kFp8StaticTensorSym
            elif cutlass_fp8_supported():
                self.activation_quant_key = kFp8DynamicTokenSym
            else:
                self.activation_quant_key = kFp8DynamicTensorSym
310

311
312
313
314
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
315
        output_partition_sizes: list[int],
316
317
318
319
320
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
321
        output_size_per_partition = sum(output_partition_sizes)
322
        weight_loader = extra_weight_attrs.get("weight_loader")
323
324
325
326
327
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None
328

329
        if self.block_quant:
330
331
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
332
333
334
335
336
337
338
339
            validate_fp8_block_shape(
                layer,
                input_size,
                output_size,
                input_size_per_partition,
                output_partition_sizes,
                self.weight_block_size,
            )
340

341
342
343
344
345
346
347
348
349
350
351
352
353
        weight = create_fp8_weight_parameter(
            output_size_per_partition, input_size_per_partition, weight_loader
        )
        layer.register_parameter("weight", weight)

        # WEIGHT SCALE
        if not self.block_quant:
            scale = create_fp8_scale_parameter(
                PerTensorScaleParameter,
                output_partition_sizes,
                input_size_per_partition,
                None,
                weight_loader,
354
            )
355
            layer.register_parameter("weight_scale", scale)
356
        else:
357
358
359
360
361
362
363
364
            assert not self.act_q_static
            assert self.weight_block_size is not None
            scale = create_fp8_scale_parameter(
                BlockQuantScaleParameter,
                output_partition_sizes,
                input_size_per_partition,
                self.weight_block_size,
                weight_loader,
365
            )
366
367
            # The weight_scale_inv name is intentional for deepseekv3
            layer.register_parameter("weight_scale_inv", scale)
368

369
370
371
372
373
        # INPUT ACTIVATION SCALE
        if self.act_q_static:
            scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
            set_weight_attrs(scale, {"scale_type": "input_scale"})
            layer.register_parameter("input_scale", scale)
374

375
376
377
378
379
380
381
382
383
384
385
        self.fp8_linear = init_fp8_linear_kernel(
            activation_quant_key=self.activation_quant_key,
            weight_quant_key=self.weight_quant_key,
            weight_shape=layer.weight.shape,
            input_dtype=self.input_dtype,
            out_dtype=self.out_dtype,
            module_name=self.__class__.__name__,
        )

        self.use_marlin = isinstance(self.fp8_linear, MarlinFP8ScaledMMLinearKernel)

386
    def process_weights_after_loading(self, layer: Module) -> None:
387
388
389
390
391
392
393
394
        if self.use_marlin:
            # Only Marlin kernels support `marlin_input_dtype`; guard to avoid
            # AttributeError if backend selection changes.
            if hasattr(self.fp8_linear, "marlin_input_dtype"):
                self.fp8_linear.marlin_input_dtype = self.marlin_input_dtype
            self.fp8_linear.process_weights_after_loading(layer)
            return

395
        input_scale = None
396
        # TODO(rob): refactor block quant into separate class.
397
        if self.block_quant:
398
            assert not self.act_q_static
399

400
            self.fp8_linear.process_weights_after_loading(layer)
401

402
        # If checkpoint not serialized fp8, quantize the weights.
403
404
405
        else:
            # If checkpoint is fp8 per-tensor, handle that there are N scales for N
            # shards in a fused module
406
407
408
409
410
            weight = layer.weight
            weight_scale = layer.weight_scale

            # If using w8a8, torch._scaled_mm needs per tensor, so
            # requantize the logical shards as a single weight.
411
412
413
414
415
416
417
418
419
            weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy(
                weight,
                weight_scale,
                layer.logical_widths,
                getattr(layer, "input_scale", None),
            )
            if self.act_q_static:
                assert input_scale is not None
                input_scale = input_scale.max()
420
            weight = weight.t()
421

422
423
424
425
426
427
            # Update layer with new values.
            replace_parameter(layer, "weight", weight.data)
            replace_parameter(layer, "weight_scale", weight_scale.data)

        if input_scale is not None:
            replace_parameter(layer, "input_scale", input_scale)
428
        else:
429
            layer.input_scale = None
430

431
432
433
434
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
435
        bias: torch.Tensor | None = None,
436
    ) -> torch.Tensor:
437
438
        # if batch invariant mode is enabled, prefer DeepGEMM FP8 path
        # we will use BF16 dequant when DeepGEMM is not supported.
439
        if envs.VLLM_BATCH_INVARIANT:
440
441
            if self.block_quant:
                assert self.weight_block_size is not None
442
443
444
445
                return self.fp8_linear.apply_weights(
                    layer,
                    x,
                    bias,
446
                )
447
            else:
448
449
450
                # per-tensor/channel: dequant to BF16 and run GEMM
                weight_fp8 = layer.weight.to(torch.bfloat16)
                weight_scale = layer.weight_scale.to(torch.bfloat16)
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
                if weight_scale.numel() == 1:
                    # Per-tensor: simple scalar multiplication
                    weight_bf16 = weight_fp8 * weight_scale
                else:
                    # Multiple scales (fused modules like QKV)
                    # Try to infer correct broadcasting
                    # weight is [K, N], scale could be [num_logical_weights]
                    # Need to figure out how to broadcast - for now just try
                    # direct multiplication
                    if (
                        weight_scale.dim() == 1
                        and weight_scale.shape[0] == weight_fp8.shape[0]
                    ):
                        # Per-row scaling
                        weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1)
                    else:
                        # Fallback
                        weight_bf16 = weight_fp8 * weight_scale
469
                return torch.nn.functional.linear(x, weight_bf16.t(), bias)
470

471
        if self.use_marlin:
472
            return self.fp8_linear.apply_weights(layer, x, bias)
473

474
        return self.fp8_linear.apply_weights(layer, x, bias)
475
476


477
478
# TODO(future PR): remove this class in favor of
# online/fp8.py::Fp8PerTensorOnlineLinearMethod
479
class Fp8OnlineLinearMethod(Fp8LinearMethod):
480
481
    """Online version of Fp8LinearMethod which loads a full precision checkpoint
    and quantizes weights during loading."""
482

483
484
    uses_meta_device: bool = True

485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition,
507
                device="meta",  # materialized and processed during loading
508
509
510
511
                dtype=params_dtype,
            ),
            input_dim=1,
            output_dim=0,
512
            weight_loader=weight_loader,
513
514
515
        )
        layer.register_parameter("weight", weight)

516
517
        initialize_online_processing(layer)

518
519
        # TODO: remove this check once the following RFC is resolved.
        # https://github.com/vllm-project/vllm/issues/33314
520
521
522
523
        # Subclasses (e.g. Mxfp8OnlineLinearMethod) only need the weight
        # registration above and manage their own kernel, so skip fp8_linear
        # kernel creation for them.
        if type(self) is not Fp8OnlineLinearMethod:
524
525
526
527
528
529
530
531
532
533
534
535
            return

        self.fp8_linear = init_fp8_linear_kernel(
            activation_quant_key=self.activation_quant_key,
            weight_quant_key=self.weight_quant_key,
            weight_shape=layer.weight.shape,
            input_dtype=self.input_dtype,
            out_dtype=self.out_dtype,
            module_name=self.__class__.__name__,
        )
        self.use_marlin = isinstance(self.fp8_linear, MarlinFP8ScaledMMLinearKernel)

536
537
538
539
540
541
542
543
544
545
546
    def process_weights_after_loading(self, layer: Module) -> None:
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

        # TODO(future): support block_quant in online quant path
        assert not self.block_quant

        layer.input_scale = None
        qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)

        # Update layer with new values.
547
        replace_parameter(layer, "weight", qweight.data)
548
549
550
        replace_parameter(layer, "weight_scale", weight_scale.data)

        if self.use_marlin:
551
552
553
554
555
556
557
558
            # Only Marlin kernels support `marlin_input_dtype`; guard to avoid
            # AttributeError if backend selection changes.
            if hasattr(self.fp8_linear, "marlin_input_dtype"):
                self.fp8_linear.marlin_input_dtype = self.marlin_input_dtype
            self.fp8_linear.process_weights_after_loading(layer)
        else:
            weight = qweight.t()
            replace_parameter(layer, "weight", weight.data)
559

560
561
562
        # Prevent duplicate processing (e.g., during weight reload)
        layer._already_called_process_weights_after_loading = True

563

564
565
566
567
568
569
570
571
572
573
574
575
576
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

577
578
    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(layer.moe_config)
579
        self.quant_config = quant_config
580
        self.weight_block_size = self.quant_config.weight_block_size
581
        self.block_quant: bool = self.weight_block_size is not None
582
583
584
        self.weight_scale_name = (
            "weight_scale_inv" if self.block_quant else "weight_scale"
        )
585

586
587
588
589
590
591
592
593
594
595
        # Set weight key and activation key for kernel compatibility
        if self.block_quant:
            weight_key = kFp8Static128BlockSym
            activation_key = kFp8Dynamic128Sym
        else:
            weight_key = kFp8StaticTensorSym
            activation_key = (
                kFp8StaticTensorSym
                if self.quant_config.activation_scheme == "static"
                else kFp8DynamicTensorSym
596
            )
597

598
599
600
601
602
603
604
605
        # Select Fp8 MoE backend
        self.fp8_backend, self.experts_cls = select_fp8_moe_backend(
            config=self.moe,
            weight_key=weight_key,
            activation_key=activation_key,
            allow_vllm_cutlass=False,
        )

606
607
608
609
610
611
612
613
614
    def create_weights(
        self,
        layer: Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
615
616
617
618
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

619
620
621
        assert self.quant_config.is_checkpoint_fp8_serialized
        params_dtype = torch.float8_e4m3fn

622
        if self.block_quant:
623
624
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
625
626
            tp_size = get_tensor_model_parallel_world_size()
            block_n, block_k = (
627
628
                self.weight_block_size[0],
                self.weight_block_size[1],
629
630
631
632
633
            )
            # NOTE: To ensure proper alignment of the block-wise quantization
            # scales, the output_size of the weights for both the gate and up
            # layers must be divisible by block_n.
            # Required by column parallel or enabling merged weights
634
            if intermediate_size_per_partition % block_n != 0:
635
636
                raise ValueError(
                    f"The output_size of gate's and up's weight = "
637
                    f"{intermediate_size_per_partition} is not divisible by "
638
639
640
                    f"weight quantization block_n = {block_n}."
                )
            if tp_size > 1 and intermediate_size_per_partition % block_k != 0:
641
                # Required by row parallel
642
643
644
                raise ValueError(
                    f"The input_size of down's weight = "
                    f"{intermediate_size_per_partition} is not divisible by "
645
646
                    f"weight quantization block_k = {block_k}."
                )
647
648

        # WEIGHTS
649
650
651
652
653
654
655
656
657
        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
658
659
660
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

661
662
663
664
665
666
667
668
669
        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
670
671
672
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
        # BIASES (for models like GPT-OSS that have biased MoE)
        if self.moe.has_bias:
            w13_bias = torch.nn.Parameter(
                torch.zeros(
                    num_experts,
                    2 * intermediate_size_per_partition,
                    dtype=layer.orig_dtype,
                ),
                requires_grad=False,
            )
            layer.register_parameter("w13_bias", w13_bias)
            set_weight_attrs(w13_bias, extra_weight_attrs)
            w2_bias = torch.nn.Parameter(
                torch.zeros(num_experts, hidden_size, dtype=layer.orig_dtype),
                requires_grad=False,
            )
            layer.register_parameter("w2_bias", w2_bias)
            set_weight_attrs(w2_bias, extra_weight_attrs)

692
        # WEIGHT_SCALES
693
        if not self.block_quant:
694
695
696
            # For per-tensor quant, the scales are per expert and weight.
            w13_scale_data = torch.ones(num_experts, 2, dtype=torch.float32)
            w2_scale_data = torch.ones(num_experts, dtype=torch.float32)
697
        else:
698
699
700
701
702
703
            # For block quant, the scales are per block (typically 128x128).
            w13_scale_data = torch.ones(
                num_experts,
                2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
                (hidden_size + block_k - 1) // block_k,
                dtype=torch.float32,
704
            )
705
706
707
708
709
            w2_scale_data = torch.ones(
                num_experts,
                (hidden_size + block_n - 1) // block_n,
                (intermediate_size_per_partition + block_k - 1) // block_k,
                dtype=torch.float32,
710
            )
711
712
713
714
715
        w13_weight_scale = torch.nn.Parameter(w13_scale_data, requires_grad=False)
        w2_weight_scale = torch.nn.Parameter(w2_scale_data, requires_grad=False)
        # Note: name is weight_scale for tensor, weight_scale_inv for block.
        layer.register_parameter(f"w13_{self.weight_scale_name}", w13_weight_scale)
        layer.register_parameter(f"w2_{self.weight_scale_name}", w2_weight_scale)
716

717
718
719
        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
720
721
722
723
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
            if self.block_quant
            else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
724
725
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)
726
727
728

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
729
            assert not self.block_quant
730
731
732
            w13_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
733
            layer.register_parameter("w13_input_scale", w13_input_scale)
734
            set_weight_attrs(w13_input_scale, extra_weight_attrs)
735

736
737
738
            w2_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
739
            layer.register_parameter("w2_input_scale", w2_input_scale)
740
741
            set_weight_attrs(w2_input_scale, extra_weight_attrs)

742
        else:
743
744
            layer.w13_input_scale = None
            layer.w2_input_scale = None
745

746
    def _setup_kernel(
747
        self,
748
        layer: FusedMoE,
749
750
751
752
        w13: torch.Tensor,
        w2: torch.Tensor,
        w13_scale: torch.Tensor,
        w2_scale: torch.Tensor,
753
754
        w13_input_scale: torch.Tensor | None,
        w2_input_scale: torch.Tensor | None,
755
    ) -> None:
756
757
758
759
760
761
762
763
764
765
766
        # Shuffle weights to runtime format.
        w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format(
            fp8_backend=self.fp8_backend,
            layer=layer,
            w13=w13,
            w2=w2,
            w13_scale=w13_scale,
            w2_scale=w2_scale,
            w13_input_scale=w13_input_scale,
            w2_input_scale=w2_input_scale,
        )
767

768
769
        # Replace parameters with updated versions. Note that this helper
        # function ensures the replacement is compatible with RL weight reloads.
770
771
772
773
        replace_parameter(layer, "w13_weight", w13)
        replace_parameter(layer, "w2_weight", w2)
        replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_scale)
        replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_scale)
774

775
        self.moe_quant_config = self.get_fused_moe_quant_config(layer)
776
        if self.moe_quant_config:
777
            assert self.experts_cls is not None
778
            self.moe_kernel = make_fp8_moe_kernel(
779
780
781
                moe_quant_config=self.moe_quant_config,
                moe_config=self.moe,
                fp8_backend=self.fp8_backend,
782
                experts_cls=self.experts_cls,
783
784
                routing_tables=layer._maybe_init_expert_routing_tables(),
                shared_experts=layer.shared_experts,
785
            )
786

787
788
    def process_weights_after_loading(self, layer: Module) -> None:
        # Allow for accessing weights and scales in standard way.
789
790
791
792
        w13 = layer.w13_weight
        w2 = layer.w2_weight
        w13_scale = getattr(layer, f"w13_{self.weight_scale_name}")
        w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
793
794
        w13_input_scale = layer.w13_input_scale
        w2_input_scale = layer.w2_input_scale
795
796
797

        # MI300x and MI325x use FNUZ format for FP8. Convert if needed.
        if current_platform.is_fp8_fnuz():
798
799
800
801
            w13, w13_scale, w13_input_scale = normalize_e4m3fn_to_e4m3fnuz(
                w13,
                w13_scale,
                w13_input_scale,
802
            )
803
804
805
806
            w2, w2_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz(
                w2,
                w2_scale,
                w2_input_scale,
807
808
809
810
811
            )

        # Per tensor kernels require single activation scale. Use the max.
        if self.quant_config.activation_scheme == "static":
            assert not self.block_quant
812
            assert w13_input_scale is not None and w2_input_scale is not None
813
814
815
816
817
            w13_input_scale, w2_input_scale = process_fp8_input_tensor_strategy_moe(
                w13_input_scale, w2_input_scale
            )
            replace_parameter(layer, "w13_input_scale", w13_input_scale)
            replace_parameter(layer, "w2_input_scale", w2_input_scale)
818
819
820
821
822

        # Per tensor kernels require single weight scale for w13 per expert, but
        # on disk there is a scale for w1 and w3. Use the max to requantize.
        if not self.block_quant:
            shard_size = layer.intermediate_size_per_partition
823
824
825
            w13, w13_scale = process_fp8_weight_tensor_strategy_moe(
                w13, w13_scale, shard_size, layer.local_num_experts
            )
826

827
828
829
        # Shuffle weights to runtime format and setup kernel.
        self._setup_kernel(
            layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale
830
831
        )

832
833
834
    def maybe_make_prepare_finalize(
        self,
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
835
    ) -> mk.FusedMoEPrepareAndFinalizeModular | None:
836
837
838
        raise ValueError(
            f"{self.__class__.__name__} uses the new modular kernel initialization "
            "logic. This function should not be called."
839
        )
840

841
    def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
842
843
844
845
846
        w1_scale = getattr(layer, f"w13_{self.weight_scale_name}")
        w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
        a1_scale = layer.w13_input_scale
        a2_scale = layer.w2_input_scale

847
        quant_config = make_fp8_moe_quant_config(
848
            fp8_backend=self.fp8_backend,
849
850
851
852
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
853
            block_shape=self.weight_block_size,
854
855
        )

856
857
858
859
860
861
862
863
864
865
866
867
        # Inject biases into the quant config if the model has them
        # (e.g. GPT-OSS biased MoE)
        if quant_config is not None and self.moe.has_bias:
            w13_bias = getattr(layer, "w13_bias", None)
            w2_bias = getattr(layer, "w2_bias", None)
            if w13_bias is not None:
                quant_config._w1.bias = w13_bias
            if w2_bias is not None:
                quant_config._w2.bias = w2_bias

        return quant_config

868
869
870
871
    @property
    def supports_eplb(self) -> bool:
        return True

872
    def apply_monolithic(
873
        self,
874
        layer: FusedMoE,
875
876
        x: torch.Tensor,
        router_logits: torch.Tensor,
877
    ) -> torch.Tensor:
878
        assert self.is_monolithic
879
880
881
882
883
884
885
886
887
888
889
890
891
892
        assert self.moe_kernel is not None
        return self.moe_kernel.apply_monolithic(
            x,
            layer.w13_weight,
            layer.w2_weight,
            router_logits,
            activation=layer.activation,
            global_num_experts=layer.global_num_experts,
            expert_map=layer.expert_map,
            apply_router_weight_on_input=layer.apply_router_weight_on_input,
            num_expert_group=layer.num_expert_group,
            topk_group=layer.topk_group,
            e_score_correction_bias=layer.e_score_correction_bias,
            routed_scaling_factor=layer.routed_scaling_factor,
893
        )
894

895
896
897
898
899
900
    def apply(
        self,
        layer: FusedMoE,
        x: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
901
        shared_experts_input: torch.Tensor | None,
902
    ) -> torch.Tensor:
903
        assert not self.is_monolithic
904
905
        assert self.moe_kernel is not None
        return self.moe_kernel.apply(
906
907
908
909
910
911
912
913
914
            x,
            layer.w13_weight,
            layer.w2_weight,
            topk_weights,
            topk_ids,
            activation=layer.activation,
            global_num_experts=layer.global_num_experts,
            expert_map=layer.expert_map,
            apply_router_weight_on_input=layer.apply_router_weight_on_input,
915
            shared_experts_input=shared_experts_input,
916
        )
917

918

919
920
# TODO(future PR): remove this class in favor of
# online/fp8.py::Fp8PerTensorOnlineMoEMethod
921
922
923
924
925
926
927
928
929
930
class Fp8OnlineMoEMethod(Fp8MoEMethod):
    """MoE method for online FP8 quantization.
    Supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

931
932
    uses_meta_device: bool = True

933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(quant_config, layer)
        assert not quant_config.is_checkpoint_fp8_serialized
        assert quant_config.activation_scheme == "dynamic"
        assert quant_config.weight_block_size is None

    def create_weights(
        self,
        layer: Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

        # WEIGHTS
        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
958
                device="meta",
959
960
961
962
963
964
965
966
967
968
969
970
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
971
                device="meta",  # materialized and processed during loading
972
973
974
975
976
977
978
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

979
980
981
982
983
984
        # BIASES (for models like GPT-OSS that have biased MoE)
        if self.moe.has_bias:
            w13_bias = torch.nn.Parameter(
                torch.zeros(
                    num_experts,
                    2 * intermediate_size_per_partition,
985
                    device="meta",  # materialized and processed during loading
986
987
988
989
990
                    dtype=layer.orig_dtype,
                ),
                requires_grad=False,
            )
            layer.register_parameter("w13_bias", w13_bias)
991
992
            set_weight_attrs(w13_bias, extra_weight_attrs)

993
            w2_bias = torch.nn.Parameter(
994
995
996
997
998
999
                torch.zeros(
                    num_experts,
                    hidden_size,
                    device="meta",  # materialized and processed during loading
                    dtype=layer.orig_dtype,
                ),
1000
1001
1002
                requires_grad=False,
            )
            layer.register_parameter("w2_bias", w2_bias)
1003
            set_weight_attrs(w2_bias, extra_weight_attrs)
1004

1005
        initialize_online_processing(layer)
1006
1007

    def process_weights_after_loading(self, layer: Module) -> None:
1008
        # TODO(@ksayers): inplace fp8 quant kernel, initialize scales with ones
1009
1010
1011
1012
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

        fp8_dtype = current_platform.fp8_dtype()
1013
1014
        w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
        w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
1015
1016
1017
1018
        w13_scale = torch.ones(
            layer.num_experts, device=w13.device, dtype=torch.float32
        )
        w2_scale = torch.ones(layer.num_experts, device=w2.device, dtype=torch.float32)
1019
1020
        layer.w13_input_scale = None
        layer.w2_input_scale = None
1021
1022

        for expert in range(layer.local_num_experts):
1023
1024
            w13[expert, :, :], w13_scale[expert] = ops.scaled_fp8_quant(
                layer.w13_weight[expert, :, :]
1025
            )
1026
1027
            w2[expert, :, :], w2_scale[expert] = ops.scaled_fp8_quant(
                layer.w2_weight[expert, :, :]
1028
1029
            )

1030
1031
1032
1033
        if current_platform.is_xpu():
            w13.data = w13.transpose(-1, -2).contiguous()
            w2.data = w2.transpose(-1, -2).contiguous()

1034
1035
1036
1037
1038
1039
1040
        # Shuffle weights to runtime format and setup kernel.
        self._setup_kernel(
            layer,
            w13,
            w2,
            w13_scale,
            w2_scale,
1041
1042
            w13_input_scale=layer.w13_input_scale,
            w2_input_scale=layer.w2_input_scale,
1043
        )
1044

1045
1046
1047
        # Prevent duplicate processing (e.g., during weight reload)
        layer._already_called_process_weights_after_loading = True

1048

1049
1050
1051
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
1052
1053
1054
    """

    def __init__(self, quant_config: Fp8Config):
1055
        super().__init__(quant_config)