fp8.py 44.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import TYPE_CHECKING, Any, Optional
5
6
7

import torch
from torch.nn import Module
8
from torch.utils._python_dispatch import TorchDispatchMode
9

10
import vllm.envs as envs
11
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
12
from vllm import _custom_ops as ops
13
from vllm._aiter_ops import rocm_aiter_ops
14
from vllm.attention.layer import Attention
15
from vllm.distributed import get_tensor_model_parallel_world_size
16
from vllm.logger import init_logger
17
from vllm.model_executor.layers.batch_invariant import (
18
    vllm_is_batch_invariant,
19
)
bnellnm's avatar
bnellnm committed
20
from vllm.model_executor.layers.fused_moe import (
21
22
23
24
25
26
    FusedMoE,
    FusedMoEMethodBase,
    FusedMoEPermuteExpertsUnpermute,
    FusedMoEPrepareAndFinalize,
    FusedMoeWeightScaleSupported,
)
27
from vllm.model_executor.layers.fused_moe.config import (
28
29
30
    FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
31
32
33
34
35
36
37
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
    Fp8MoeBackend,
    convert_to_fp8_moe_kernel_format,
    make_fp8_moe_kernel,
    make_fp8_moe_quant_config,
    select_fp8_moe_backend,
)
38
39
40
41
42
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
43
from vllm.model_executor.layers.quantization import QuantizationMethods
44
from vllm.model_executor.layers.quantization.base_config import (
45
46
47
    QuantizationConfig,
    QuantizeMethodBase,
)
48
49
50
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
    init_fp8_linear_kernel,
)
51
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
52
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
53
    apply_fi_trtllm_fp8_per_tensor_moe,
54
)
55
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
56
57
58
59
60
    W8A8BlockFp8LinearOp,
    create_fp8_input_scale,
    create_fp8_scale_parameter,
    create_fp8_weight_parameter,
    maybe_post_process_fp8_weight_block,
61
    process_fp8_input_tensor_strategy_moe,
62
63
    process_fp8_weight_block_strategy,
    process_fp8_weight_tensor_strategy,
64
    process_fp8_weight_tensor_strategy_moe,
65
66
    validate_fp8_block_shape,
)
67
68
69
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
    get_marlin_input_dtype,
)
70
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
71
72
73
    apply_fp8_marlin_linear,
    prepare_fp8_layer_for_marlin,
)
74
from vllm.model_executor.layers.quantization.utils.quant_utils import (
75
76
    GroupShape,
    is_layer_skipped,
77
    kFp8Dynamic128Sym,
78
79
    kFp8DynamicTensorSym,
    kFp8DynamicTokenSym,
80
    kFp8Static128BlockSym,
81
    kFp8StaticTensorSym,
82
)
83
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
84
85
86
87
88
89
90
91
92
    cutlass_block_fp8_supported,
    cutlass_fp8_supported,
    normalize_e4m3fn_to_e4m3fnuz,
)
from vllm.model_executor.parameter import (
    BlockQuantScaleParameter,
    ModelWeightParameter,
    PerTensorScaleParameter,
)
93
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
94
from vllm.platforms import current_platform
95
96
97
from vllm.utils.deep_gemm import (
    is_deep_gemm_supported,
)
98

99
100
101
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

102
103
104
105
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

106

107
class Fp8Config(QuantizationConfig):
108
109
    """Config class for FP8."""

110
111
    def __init__(
        self,
112
        is_checkpoint_fp8_serialized: bool = False,
113
        activation_scheme: str = "dynamic",
114
115
        ignored_layers: list[str] | None = None,
        weight_block_size: list[int] | None = None,
116
    ) -> None:
117
        super().__init__()
118

119
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
120

121
        if activation_scheme not in ACTIVATION_SCHEMES:
122
            raise ValueError(f"Unsupported activation scheme {activation_scheme}")
123
        self.activation_scheme = activation_scheme
124
        self.ignored_layers = ignored_layers or []
125
126
127
128
        if weight_block_size is not None:
            if not is_checkpoint_fp8_serialized:
                raise ValueError(
                    "The block-wise quantization only supports fp8-serialized "
129
130
                    "checkpoint for now."
                )
131
132
133
            if len(weight_block_size) != 2:
                raise ValueError(
                    "The quantization block size of weight must have 2 "
134
135
                    f"dimensions, but got {len(weight_block_size)} dimensions"
                )
136
            if activation_scheme != "dynamic":
137
138
139
140
141
                raise ValueError(
                    "The block-wise quantization only supports "
                    "dynamic activation scheme for now, but got "
                    f"{activation_scheme} activation scheme."
                )
142
        self.weight_block_size = weight_block_size
143

144
    @classmethod
145
    def get_name(cls) -> QuantizationMethods:
146
147
148
        return "fp8"

    @classmethod
149
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
150
151
152
153
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
154
        return 75
155
156

    @classmethod
157
    def get_config_filenames(cls) -> list[str]:
158
159
        return []

160
161
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.ignored_layers is not None:
162
            self.ignored_layers = hf_to_vllm_mapper.apply_list(self.ignored_layers)
163

164
    @classmethod
165
    def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
166
        quant_method = cls.get_from_keys(config, ["quant_method"])
167
        is_checkpoint_fp8_serialized = "fp8" in quant_method
168
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
169
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
170
        weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
171
        if not ignored_layers:
172
173
174
175
176
177
178
179
180
181
182
183
184
            ignored_layers = cls.get_from_keys_or(
                config, ["modules_to_not_convert"], None
            )
        return cls(
            is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
            activation_scheme=activation_scheme,
            ignored_layers=ignored_layers,
            weight_block_size=weight_block_size,
        )

    def get_xpu_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
185
        from vllm.model_executor.layers.quantization.ipex_quant import (
186
187
188
189
            XPUFp8LinearMethod,
            XPUFp8MoEMethod,
        )

190
191
192
193
        fp8_config = Fp8Config(
            is_checkpoint_fp8_serialized=self.is_checkpoint_fp8_serialized,
            activation_scheme=self.activation_scheme,
            ignored_layers=self.ignored_layers,
194
195
            weight_block_size=self.weight_block_size,
        )
196
197

        if isinstance(layer, LinearBase):
198
199
200
201
202
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
203
204
205
                return UnquantizedLinearMethod()
            return XPUFp8LinearMethod(fp8_config)
        elif isinstance(layer, FusedMoE):
206
207
208
209
210
211
212
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
                return UnquantizedFusedMoEMethod(layer.moe_config)

213
214
215
216
217
            return XPUFp8MoEMethod(fp8_config, layer)
        elif isinstance(layer, Attention):
            return Fp8KVCacheMethod(self)
        return None

218
219
220
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
221
222
        if current_platform.is_xpu():
            return self.get_xpu_quant_method(layer, prefix)
223
        if isinstance(layer, LinearBase):
224
225
226
227
228
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
229
                return UnquantizedLinearMethod()
230
231
232
233
234
235
236
237
            if not self.is_checkpoint_fp8_serialized:
                online_method = Fp8OnlineLinearMethod(self)
                online_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
                return online_method
            else:
                offline_method = Fp8LinearMethod(self)
                offline_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
                return offline_method
238
        elif isinstance(layer, FusedMoE):
239
240
241
242
243
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
XuruiYang's avatar
XuruiYang committed
244
                return UnquantizedFusedMoEMethod(layer.moe_config)
245
246
247
248
            if self.is_checkpoint_fp8_serialized:
                moe_quant_method = Fp8MoEMethod(self, layer)
            else:
                moe_quant_method = Fp8OnlineMoEMethod(self, layer)
249
            return moe_quant_method
250
        elif isinstance(layer, Attention):
251
            return Fp8KVCacheMethod(self)
252
        return None
253

254
    def get_cache_scale(self, name: str) -> str | None:
255
256
257
258
259
260
261
262
263
264
265
266
        """
        Check whether the param name matches the format for k/v cache scales
        in compressed-tensors. If this is the case, return its equivalent
        param name expected by vLLM

        :param name: param name
        :return: matching param name for KV cache scale in vLLM
        """
        if name.endswith(".output_scale") and ".k_proj" in name:
            return name.replace(".k_proj.output_scale", ".attn.k_scale")
        if name.endswith(".output_scale") and ".v_proj" in name:
            return name.replace(".v_proj.output_scale", ".attn.v_scale")
267
268
269
270
271
        if name.endswith(".output_scale") and ".q_proj" in name:
            return name.replace(".q_proj.output_scale", ".attn.q_scale")
        if name.endswith("self_attn.prob_output_scale"):
            return name.replace(".prob_output_scale", ".attn.prob_scale")
        # If no matches, return None
272
273
        return None

274

275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
class CopyNumelCounter(TorchDispatchMode):
    """
    Tracks total number of elements modified with `copy_`. Useful for keeping
    track of weight loading where underlying weights can be arbitrarily
    transformed (such as with `narrow`) before calling copy.
    """

    def __init__(self):
        super().__init__()
        self.copied_numel = 0

    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        out = func(*args, **kwargs)
        if func == torch.ops.aten.copy_.default:
            self.copied_numel += args[0].numel()
        return out


295
296
class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
297
298
299
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

300
    Limitations:
301
    1. Only support float8_e4m3fn data type due to the limitation of
302
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
303

304
305
306
307
    Args:
        quant_config: The quantization config.
    """

308
    def __init__(self, quant_config: Fp8Config):
309
        self.quant_config = quant_config
310
        self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
311
        self.out_dtype = torch.get_default_dtype()
312

313
314
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
315
        self.marlin_input_dtype = None
316
317
318
319
        self.use_marlin = (
            not current_platform.has_device_capability(89)
            or envs.VLLM_TEST_FORCE_FP8_MARLIN
        )
320
        # Disable marlin for rocm
321
        if current_platform.is_rocm():
322
            self.use_marlin = False
323
        if vllm_is_batch_invariant():
324
            self.use_marlin = False
325

326
        self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
327
        self.use_deep_gemm = is_deep_gemm_supported()
328

329
330
        self.weight_block_size = self.quant_config.weight_block_size
        self.block_quant = self.weight_block_size is not None
331
332
        self.act_q_static = self.quant_config.activation_scheme == "static"

333
334
335
336
337
        if self.block_quant:
            assert not self.act_q_static
            assert self.weight_block_size is not None
            self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
                weight_group_shape=GroupShape(*self.weight_block_size),
338
                act_quant_group_shape=GroupShape(1, self.weight_block_size[0]),
339
340
341
342
                cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
                use_aiter_and_is_supported=self.use_aiter_and_is_supported,
            )
        else:
343
344
345
346
347
348
349
350
351
352
353
354
355
            # Use per-token quantization for better perf if dynamic and cutlass
            if self.act_q_static:
                activation_quant_key = kFp8StaticTensorSym
            elif cutlass_fp8_supported():
                activation_quant_key = kFp8DynamicTokenSym
            else:
                activation_quant_key = kFp8DynamicTensorSym

            self.fp8_linear = init_fp8_linear_kernel(
                activation_quant_key=activation_quant_key,
                weight_quant_key=kFp8StaticTensorSym,
                out_dtype=torch.get_default_dtype(),
                module_name=self.__class__.__name__,
356
            )
357

358
359
360
361
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
362
        output_partition_sizes: list[int],
363
364
365
366
367
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
368
        output_size_per_partition = sum(output_partition_sizes)
369
        weight_loader = extra_weight_attrs.get("weight_loader")
370
371
372
373
374
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None
375

376
        if self.block_quant:
377
378
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
379
380
381
382
383
384
385
386
            validate_fp8_block_shape(
                layer,
                input_size,
                output_size,
                input_size_per_partition,
                output_partition_sizes,
                self.weight_block_size,
            )
387

388
389
390
391
392
393
394
395
396
397
398
399
400
        weight = create_fp8_weight_parameter(
            output_size_per_partition, input_size_per_partition, weight_loader
        )
        layer.register_parameter("weight", weight)

        # WEIGHT SCALE
        if not self.block_quant:
            scale = create_fp8_scale_parameter(
                PerTensorScaleParameter,
                output_partition_sizes,
                input_size_per_partition,
                None,
                weight_loader,
401
            )
402
403
            set_weight_attrs(scale, {"scale_type": "weight_scale"})
            layer.register_parameter("weight_scale", scale)
404
        else:
405
406
407
408
409
410
411
412
            assert not self.act_q_static
            assert self.weight_block_size is not None
            scale = create_fp8_scale_parameter(
                BlockQuantScaleParameter,
                output_partition_sizes,
                input_size_per_partition,
                self.weight_block_size,
                weight_loader,
413
            )
414
415
416
            set_weight_attrs(scale, {"scale_type": "weight_scale"})
            # The weight_scale_inv name is intentional for deepseekv3
            layer.register_parameter("weight_scale_inv", scale)
417

418
419
420
421
422
        # INPUT ACTIVATION SCALE
        if self.act_q_static:
            scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
            set_weight_attrs(scale, {"scale_type": "input_scale"})
            layer.register_parameter("input_scale", scale)
423

424
    def process_weights_after_loading(self, layer: Module) -> None:
425
        size_k_first = True
426
        input_scale = None
427
        # TODO(rob): refactor block quant into separate class.
428
        if self.block_quant:
429
            assert not self.act_q_static
430
            size_k_first = False
431

432
            weight, weight_scale_inv = process_fp8_weight_block_strategy(
433
434
                layer.weight, layer.weight_scale_inv
            )
435
436
437
438

            # Update layer with new values
            replace_parameter(layer, "weight", weight.data)
            replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data)
439

440
        # If checkpoint not serialized fp8, quantize the weights.
441
442
443
        else:
            # If checkpoint is fp8 per-tensor, handle that there are N scales for N
            # shards in a fused module
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
            weight = layer.weight
            weight_scale = layer.weight_scale

            # If using w8a8, torch._scaled_mm needs per tensor, so
            # requantize the logical shards as a single weight.
            if not self.use_marlin:
                weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy(
                    weight,
                    weight_scale,
                    layer.logical_widths,
                    getattr(layer, "input_scale", None),
                )
                if self.act_q_static:
                    assert input_scale is not None
                    input_scale = input_scale.max()
            weight = weight.t()
460

461
462
463
464
465
466
            # Update layer with new values.
            replace_parameter(layer, "weight", weight.data)
            replace_parameter(layer, "weight_scale", weight_scale.data)

        if input_scale is not None:
            replace_parameter(layer, "input_scale", input_scale)
467
        else:
468
            layer.input_scale = None
469

470
        if self.use_marlin:
471
472
473
            prepare_fp8_layer_for_marlin(
                layer, size_k_first, input_dtype=self.marlin_input_dtype
            )
474
475
            # Activations not quantized for marlin.
            del layer.input_scale
476
            return
477

478
        if self.block_quant:
479
            maybe_post_process_fp8_weight_block(layer)
480

481
482
483
484
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
485
        bias: torch.Tensor | None = None,
486
    ) -> torch.Tensor:
487
488
        # if batch invariant mode is enabled, prefer DeepGEMM FP8 path
        # we will use BF16 dequant when DeepGEMM is not supported.
489
        if vllm_is_batch_invariant():
490
491
            if self.block_quant:
                assert self.weight_block_size is not None
492
493
494
                return self.w8a8_block_fp8_linear.apply(
                    input=x,
                    weight=layer.weight,
495
                    weight_scale=layer.weight_scale_inv,
496
497
498
                    input_scale=layer.input_scale,
                    bias=bias,
                )
499
            else:
500
501
502
                # per-tensor/channel: dequant to BF16 and run GEMM
                weight_fp8 = layer.weight.to(torch.bfloat16)
                weight_scale = layer.weight_scale.to(torch.bfloat16)
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
                if weight_scale.numel() == 1:
                    # Per-tensor: simple scalar multiplication
                    weight_bf16 = weight_fp8 * weight_scale
                else:
                    # Multiple scales (fused modules like QKV)
                    # Try to infer correct broadcasting
                    # weight is [K, N], scale could be [num_logical_weights]
                    # Need to figure out how to broadcast - for now just try
                    # direct multiplication
                    if (
                        weight_scale.dim() == 1
                        and weight_scale.shape[0] == weight_fp8.shape[0]
                    ):
                        # Per-row scaling
                        weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1)
                    else:
                        # Fallback
                        weight_bf16 = weight_fp8 * weight_scale
521
                return torch.nn.functional.linear(x, weight_bf16.t(), bias)
522

523
        if self.use_marlin:
524
525
526
527
528
            if self.block_quant:
                weight_scale = layer.weight_scale_inv
            else:
                weight_scale = layer.weight_scale

529
530
531
            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
532
                weight_scale=weight_scale,
533
534
535
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
536
                input_dtype=self.marlin_input_dtype,
537
538
                bias=bias,
            )
539

540
        if self.block_quant:
541
542
543
            assert self.weight_block_size is not None

            return self.w8a8_block_fp8_linear.apply(
544
                input=x,
545
                weight=layer.weight,
546
                weight_scale=layer.weight_scale_inv,
547
                input_scale=layer.input_scale,
548
                bias=bias,
549
            )
550

551
        return self.fp8_linear.apply_weights(layer, x, bias)
552
553


554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
class Fp8OnlineLinearMethod(Fp8LinearMethod):
    """Online version of Fp8LinearMethod, loads the fp16/bf16 checkpoint
    and quantized the weights during loading."""

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

        # WEIGHT
        def patched_weight_loader(param, loaded_weight, *args, **kwargs):
            # track how many elements we have updated
            if not hasattr(layer, "_loaded_numel"):
                layer._loaded_numel = 0

            # load the current weight chunk
            copy_numel_counter = CopyNumelCounter()
            with copy_numel_counter:
                res = weight_loader(param, loaded_weight, *args, **kwargs)  # type: ignore[misc]
            layer._loaded_numel += copy_numel_counter.copied_numel

            # if we have loaded all of the elements, call
            # process_weights_after_loading
            target_loaded_numel = layer.weight.numel()
            if layer._loaded_numel == target_loaded_numel:
                self.process_weights_after_loading(layer)

                # Delete the bookkeeping
                del layer._loaded_numel
                # Prevent the usual `process_weights_after_loading` call from doing
                # anything
                layer._already_called_process_weights_after_loading = True

            return res

        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition,
                dtype=params_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=patched_weight_loader,
        )
        layer.register_parameter("weight", weight)

    def process_weights_after_loading(self, layer: Module) -> None:
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

        # TODO(future): support block_quant in online quant path
        assert not self.block_quant

        layer.input_scale = None
        qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
        weight = qweight.t()

        # Update layer with new values.
        replace_parameter(layer, "weight", weight.data)
        replace_parameter(layer, "weight_scale", weight_scale.data)

        if self.use_marlin:
            size_k_first = True
            prepare_fp8_layer_for_marlin(
                layer, size_k_first, input_dtype=self.marlin_input_dtype
            )
            # Activations not quantized for marlin.


637
638
639
640
641
642
643
644
645
646
647
648
649
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

650
651
    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(layer.moe_config)
652
        self.quant_config = quant_config
653
        self.weight_block_size = self.quant_config.weight_block_size
654
        self.block_quant: bool = self.weight_block_size is not None
655
656
657
        self.weight_scale_name = (
            "weight_scale_inv" if self.block_quant else "weight_scale"
        )
658

659
660
661
662
663
664
665
666
667
668
        # Set weight key and activation key for kernel compatibility
        if self.block_quant:
            weight_key = kFp8Static128BlockSym
            activation_key = kFp8Dynamic128Sym
        else:
            weight_key = kFp8StaticTensorSym
            activation_key = (
                kFp8StaticTensorSym
                if self.quant_config.activation_scheme == "static"
                else kFp8DynamicTensorSym
669
            )
670

671
672
673
674
675
676
677
678
        # Select Fp8 MoE backend
        self.fp8_backend, self.experts_cls = select_fp8_moe_backend(
            config=self.moe,
            weight_key=weight_key,
            activation_key=activation_key,
            allow_vllm_cutlass=False,
        )

679
680
681
682
683
684
685
686
687
    def create_weights(
        self,
        layer: Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
688
689
690
691
692
693
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

694
695
696
        assert self.quant_config.is_checkpoint_fp8_serialized
        params_dtype = torch.float8_e4m3fn

697
        if self.block_quant:
698
699
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
700
701
            tp_size = get_tensor_model_parallel_world_size()
            block_n, block_k = (
702
703
                self.weight_block_size[0],
                self.weight_block_size[1],
704
705
706
707
708
            )
            # NOTE: To ensure proper alignment of the block-wise quantization
            # scales, the output_size of the weights for both the gate and up
            # layers must be divisible by block_n.
            # Required by column parallel or enabling merged weights
709
            if intermediate_size_per_partition % block_n != 0:
710
711
                raise ValueError(
                    f"The output_size of gate's and up's weight = "
712
                    f"{intermediate_size_per_partition} is not divisible by "
713
714
715
                    f"weight quantization block_n = {block_n}."
                )
            if tp_size > 1 and intermediate_size_per_partition % block_k != 0:
716
                # Required by row parallel
717
718
719
                raise ValueError(
                    f"The input_size of down's weight = "
                    f"{intermediate_size_per_partition} is not divisible by "
720
721
                    f"weight quantization block_k = {block_k}."
                )
722
723

        # WEIGHTS
724
725
726
727
728
729
730
731
732
        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
733
734
735
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

736
737
738
739
740
741
742
743
744
        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
745
746
747
748
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
749
        if not self.block_quant:
750
751
752
            # For per-tensor quant, the scales are per expert and weight.
            w13_scale_data = torch.ones(num_experts, 2, dtype=torch.float32)
            w2_scale_data = torch.ones(num_experts, dtype=torch.float32)
753
        else:
754
755
756
757
758
759
            # For block quant, the scales are per block (typically 128x128).
            w13_scale_data = torch.ones(
                num_experts,
                2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
                (hidden_size + block_k - 1) // block_k,
                dtype=torch.float32,
760
            )
761
762
763
764
765
            w2_scale_data = torch.ones(
                num_experts,
                (hidden_size + block_n - 1) // block_n,
                (intermediate_size_per_partition + block_k - 1) // block_k,
                dtype=torch.float32,
766
            )
767
768
769
770
771
        w13_weight_scale = torch.nn.Parameter(w13_scale_data, requires_grad=False)
        w2_weight_scale = torch.nn.Parameter(w2_scale_data, requires_grad=False)
        # Note: name is weight_scale for tensor, weight_scale_inv for block.
        layer.register_parameter(f"w13_{self.weight_scale_name}", w13_weight_scale)
        layer.register_parameter(f"w2_{self.weight_scale_name}", w2_weight_scale)
772

773
774
775
        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
776
777
778
779
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
            if self.block_quant
            else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
780
781
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)
782
783
784

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
785
            assert not self.block_quant
786
787
788
            w13_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
789
            layer.register_parameter("w13_input_scale", w13_input_scale)
790
            set_weight_attrs(w13_input_scale, extra_weight_attrs)
791

792
793
794
            w2_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
795
            layer.register_parameter("w2_input_scale", w2_input_scale)
796
797
            set_weight_attrs(w2_input_scale, extra_weight_attrs)

798
        else:
799
800
            layer.w13_input_scale = None
            layer.w2_input_scale = None
801

802
    def _setup_kernel(
803
        self,
804
        layer: FusedMoE,
805
806
807
808
        w13: torch.Tensor,
        w2: torch.Tensor,
        w13_scale: torch.Tensor,
        w2_scale: torch.Tensor,
809
810
        w13_input_scale: torch.Tensor | None,
        w2_input_scale: torch.Tensor | None,
811
    ) -> None:
812
813
814
815
816
817
818
819
820
821
822
        # Shuffle weights to runtime format.
        w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format(
            fp8_backend=self.fp8_backend,
            layer=layer,
            w13=w13,
            w2=w2,
            w13_scale=w13_scale,
            w2_scale=w2_scale,
            w13_input_scale=w13_input_scale,
            w2_input_scale=w2_input_scale,
        )
823

824
825
        # Replace parameters with updated versions. Note that this helper
        # function ensures the replacement is compatible with RL weight reloads.
826
827
828
829
        replace_parameter(layer, "w13_weight", w13)
        replace_parameter(layer, "w2_weight", w2)
        replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_scale)
        replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_scale)
830

831
832
833
834
        # Setup modular kernel for TP case and naive DP/EP case.
        # In non-naive DP/EP case, we will create a ModularKernelMethod.
        # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
        # in both cases.
835
        self.moe_quant_config = self.get_fused_moe_quant_config(layer)
836
        if self.moe_quant_config:
837
            assert self.experts_cls is not None
838
            self.moe_mk, self.use_inplace = make_fp8_moe_kernel(
839
840
841
                moe_quant_config=self.moe_quant_config,
                moe_config=self.moe,
                fp8_backend=self.fp8_backend,
842
                experts_cls=self.experts_cls,
843
844
                routing_tables=layer._maybe_init_expert_routing_tables(),
                shared_experts=layer.shared_experts,
845
            )
846

847
848
849
850
851
    def process_weights_after_loading(self, layer: Module) -> None:
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

        # Allow for accessing weights and scales in standard way.
852
853
854
855
        w13 = layer.w13_weight
        w2 = layer.w2_weight
        w13_scale = getattr(layer, f"w13_{self.weight_scale_name}")
        w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
856
857
        w13_input_scale = layer.w13_input_scale
        w2_input_scale = layer.w2_input_scale
858
859
860

        # MI300x and MI325x use FNUZ format for FP8. Convert if needed.
        if current_platform.is_fp8_fnuz():
861
862
863
864
            w13, w13_scale, w13_input_scale = normalize_e4m3fn_to_e4m3fnuz(
                w13,
                w13_scale,
                w13_input_scale,
865
            )
866
867
868
869
            w2, w2_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz(
                w2,
                w2_scale,
                w2_input_scale,
870
871
872
873
874
            )

        # Per tensor kernels require single activation scale. Use the max.
        if self.quant_config.activation_scheme == "static":
            assert not self.block_quant
875
            assert w13_input_scale is not None and w2_input_scale is not None
876
877
878
879
880
            w13_input_scale, w2_input_scale = process_fp8_input_tensor_strategy_moe(
                w13_input_scale, w2_input_scale
            )
            replace_parameter(layer, "w13_input_scale", w13_input_scale)
            replace_parameter(layer, "w2_input_scale", w2_input_scale)
881
882
883
884
885

        # Per tensor kernels require single weight scale for w13 per expert, but
        # on disk there is a scale for w1 and w3. Use the max to requantize.
        if not self.block_quant:
            shard_size = layer.intermediate_size_per_partition
886
887
888
            w13, w13_scale = process_fp8_weight_tensor_strategy_moe(
                w13, w13_scale, shard_size, layer.local_num_experts
            )
889

890
891
892
        # Shuffle weights to runtime format and setup kernel.
        self._setup_kernel(
            layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale
893
894
        )

895
896
897
898
    def maybe_make_prepare_finalize(
        self,
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
    ) -> mk.FusedMoEPrepareAndFinalize | None:
899
900
901
902
        raise ValueError(
            f"{self.__class__.__name__} uses the new modular kernel initialization "
            "logic. This function should not be called."
        )
903

bnellnm's avatar
bnellnm committed
904
905
906
    def select_gemm_impl(
        self,
        prepare_finalize: FusedMoEPrepareAndFinalize,
907
        layer: torch.nn.Module,
bnellnm's avatar
bnellnm committed
908
    ) -> FusedMoEPermuteExpertsUnpermute:
909
910
911
        raise ValueError(
            f"{self.__class__.__name__} uses the new modular kernel initialization "
            "logic. This function should not be called."
912
        )
913

914
    def get_fused_moe_quant_config(
915
        self, layer: torch.nn.Module
916
    ) -> FusedMoEQuantConfig | None:
917
918
919
920
921
922
923
924
925
        # TRTLLM does not use Modular Kernel.
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
            return None

        w1_scale = getattr(layer, f"w13_{self.weight_scale_name}")
        w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
        a1_scale = layer.w13_input_scale
        a2_scale = layer.w2_input_scale

926
927
        return make_fp8_moe_quant_config(
            fp8_backend=self.fp8_backend,
928
929
930
931
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
932
            block_shape=self.weight_block_size,
933
934
        )

935
936
937
938
939
940
941
942
    @property
    def supports_eplb(self) -> bool:
        return True

    @property
    def allow_inplace(self) -> bool:
        return True

943
944
945
946
947
    @property
    def is_monolithic(self) -> bool:
        return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM

    def apply_monolithic(
948
        self,
949
        layer: FusedMoE,
950
951
        x: torch.Tensor,
        router_logits: torch.Tensor,
952
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
953
954
955
956
957
958
959
960
961
        assert self.is_monolithic
        assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM

        # TODO(rob): convert this to MK.
        if layer.enable_eplb:
            raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.")
        assert layer.activation == "silu", (
            f"Expected 'silu' activation but got {layer.activation}"
        )
962

963
964
        if self.block_quant:
            import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe  # noqa: E501, F401
965

966
            return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
Michael Goin's avatar
Michael Goin committed
967
968
                routing_logits=router_logits,
                routing_bias=layer.e_score_correction_bias,
969
970
971
972
973
974
975
976
977
978
979
980
981
                x=x,
                w13_weight=layer.w13_weight,
                w13_weight_scale_inv=layer.w13_weight_scale_inv,
                w2_weight=layer.w2_weight,
                w2_weight_scale_inv=layer.w2_weight_scale_inv,
                global_num_experts=layer.global_num_experts,
                top_k=layer.top_k,
                num_expert_group=layer.num_expert_group,
                topk_group=layer.topk_group,
                intermediate_size=layer.intermediate_size_per_partition,
                expert_offset=layer.ep_rank * layer.local_num_experts,
                local_num_experts=layer.local_num_experts,
                block_shape=self.weight_block_size,
Michael Goin's avatar
Michael Goin committed
982
                routing_method_type=layer.routing_method_type,
983
984
985
986
987
988
989
990
991
992
993
994
995
996
                routed_scaling=layer.routed_scaling_factor,
            )
        else:
            return apply_fi_trtllm_fp8_per_tensor_moe(
                layer=layer,
                hidden_states=x,
                router_logits=router_logits,
                routing_bias=layer.e_score_correction_bias,
                global_num_experts=layer.global_num_experts,
                top_k=layer.top_k,
                num_expert_group=layer.num_expert_group,
                topk_group=layer.topk_group,
                apply_router_weight_on_input=layer.apply_router_weight_on_input,
            )
997

998
999
1000
1001
1002
1003
1004
    def apply(
        self,
        layer: FusedMoE,
        x: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1005
        assert self.moe_mk is not None
1006
        assert not self.is_monolithic
1007
        return self.moe_mk(
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
            x,
            layer.w13_weight,
            layer.w2_weight,
            topk_weights,
            topk_ids,
            inplace=self.use_inplace,
            activation=layer.activation,
            global_num_experts=layer.global_num_experts,
            expert_map=layer.expert_map,
            apply_router_weight_on_input=layer.apply_router_weight_on_input,
        )
1019

1020

1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
class Fp8OnlineMoEMethod(Fp8MoEMethod):
    """MoE method for online FP8 quantization.
    Supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(quant_config, layer)
        assert not quant_config.is_checkpoint_fp8_serialized
        assert quant_config.activation_scheme == "dynamic"
        assert quant_config.weight_block_size is None

    def create_weights(
        self,
        layer: Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

        # We are doing online quantization, patch the weight loaded
        # to call `process_weights_after_loading` in a streaming fashion
        # as soon as the last weight chunk is loaded.
        weight_loader = extra_weight_attrs["weight_loader"]
        # create a new holder to prevent modifying behavior of any other
        # objects which might depend on the old one
        new_extra_weight_attrs = extra_weight_attrs

        def patched_weight_loader(param, loaded_weight, *args, **kwargs):
            # add a counter to track how many elements we have updated
            if not hasattr(layer, "_loaded_numel"):
                layer._loaded_numel = 0
1064
1065
1066
1067
1068
1069

            # load the current weight chunk
            copy_numel_counter = CopyNumelCounter()
            with copy_numel_counter:
                res = weight_loader(param, loaded_weight, *args, **kwargs)  # type: ignore[misc]
            layer._loaded_numel += copy_numel_counter.copied_numel
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123

            # if we have loaded all of the elements, call
            # process_weights_after_loading
            target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel()
            if layer._loaded_numel == target_loaded_numel:
                self.process_weights_after_loading(layer)

                # Delete the bookkeeping
                del layer._loaded_numel
                # Prevent the usual `process_weights_after_loading` call
                # from doing anything
                layer._already_called_process_weights_after_loading = True

            return res

        new_extra_weight_attrs["weight_loader"] = patched_weight_loader
        extra_weight_attrs = new_extra_weight_attrs

        # WEIGHTS
        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
        # Allocate 2 scales for w1 and w3 respectively.
        # They will be combined to a single scale after weight loading.
        w13_weight_scale = torch.nn.Parameter(
            torch.ones(num_experts, dtype=torch.float32), requires_grad=False
        )
        w2_weight_scale = torch.nn.Parameter(
            torch.ones(num_experts, dtype=torch.float32), requires_grad=False
        )
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        layer.register_parameter("w2_weight_scale", w2_weight_scale)
1124
1125
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135

        layer.w13_input_scale = None
        layer.w2_input_scale = None

    def process_weights_after_loading(self, layer: Module) -> None:
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

        # If checkpoint is fp16, quantize in place.
        fp8_dtype = current_platform.fp8_dtype()
1136
1137
1138
1139
        w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
        w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
        w13_scale = layer.w13_weight_scale
        w2_scale = layer.w2_weight_scale
1140
1141

        for expert in range(layer.local_num_experts):
1142
1143
            w13[expert, :, :], w13_scale[expert] = ops.scaled_fp8_quant(
                layer.w13_weight[expert, :, :]
1144
            )
1145
1146
            w2[expert, :, :], w2_scale[expert] = ops.scaled_fp8_quant(
                layer.w2_weight[expert, :, :]
1147
1148
            )

1149
1150
1151
1152
1153
1154
1155
1156
1157
        # Shuffle weights to runtime format and setup kernel.
        self._setup_kernel(
            layer,
            w13,
            w2,
            w13_scale,
            w2_scale,
            layer.w13_input_scale,
            layer.w2_input_scale,
1158
        )
1159
1160


1161
1162
1163
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
1164
1165
1166
    """

    def __init__(self, quant_config: Fp8Config):
1167
        super().__init__(quant_config)