"tests/vscode:/vscode.git/clone" did not exist on "2ac85a4544cf9488037e61bf9ed7a87d0c3696bb"
fp8.py 45.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import TYPE_CHECKING, Any, Optional
5
6
7

import torch
from torch.nn import Module
8
from torch.utils._python_dispatch import TorchDispatchMode
9

10
import vllm.envs as envs
11
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
12
from vllm import _custom_ops as ops
13
from vllm._aiter_ops import rocm_aiter_ops
14
from vllm.attention.layer import Attention
15
from vllm.distributed import get_tensor_model_parallel_world_size
16
from vllm.logger import init_logger
17
from vllm.model_executor.layers.batch_invariant import (
18
    vllm_is_batch_invariant,
19
)
bnellnm's avatar
bnellnm committed
20
from vllm.model_executor.layers.fused_moe import (
21
22
23
24
    FusedMoE,
    FusedMoEMethodBase,
    FusedMoEPermuteExpertsUnpermute,
    FusedMoEPrepareAndFinalize,
25
    FusedMoERouter,
26
27
    FusedMoeWeightScaleSupported,
)
28
from vllm.model_executor.layers.fused_moe.config import (
29
    FusedMoEQuantConfig,
30
    RoutingMethodType,
31
32
)
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
33
34
35
36
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
    Fp8MoeBackend,
    convert_to_fp8_moe_kernel_format,
    make_fp8_moe_kernel,
37
    make_fp8_moe_kernel_for_mkm,
38
39
40
    make_fp8_moe_quant_config,
    select_fp8_moe_backend,
)
41
42
43
44
45
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
46
from vllm.model_executor.layers.quantization import QuantizationMethods
47
from vllm.model_executor.layers.quantization.base_config import (
48
49
50
    QuantizationConfig,
    QuantizeMethodBase,
)
51
52
53
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
    init_fp8_linear_kernel,
)
54
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
55
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
56
    apply_fi_trtllm_fp8_per_tensor_moe,
57
    build_flashinfer_fp8_cutlass_moe_prepare_finalize,
58
)
59
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
60
61
62
63
64
    W8A8BlockFp8LinearOp,
    create_fp8_input_scale,
    create_fp8_scale_parameter,
    create_fp8_weight_parameter,
    maybe_post_process_fp8_weight_block,
65
    process_fp8_input_tensor_strategy_moe,
66
67
    process_fp8_weight_block_strategy,
    process_fp8_weight_tensor_strategy,
68
    process_fp8_weight_tensor_strategy_moe,
69
70
    validate_fp8_block_shape,
)
71
72
73
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
    get_marlin_input_dtype,
)
74
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
75
76
77
    apply_fp8_marlin_linear,
    prepare_fp8_layer_for_marlin,
)
78
from vllm.model_executor.layers.quantization.utils.quant_utils import (
79
80
    GroupShape,
    is_layer_skipped,
81
    kFp8Dynamic128Sym,
82
83
    kFp8DynamicTensorSym,
    kFp8DynamicTokenSym,
84
    kFp8Static128BlockSym,
85
    kFp8StaticTensorSym,
86
)
87
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
88
89
90
91
92
93
94
95
96
    cutlass_block_fp8_supported,
    cutlass_fp8_supported,
    normalize_e4m3fn_to_e4m3fnuz,
)
from vllm.model_executor.parameter import (
    BlockQuantScaleParameter,
    ModelWeightParameter,
    PerTensorScaleParameter,
)
97
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
98
from vllm.platforms import current_platform
99
100
101
from vllm.utils.deep_gemm import (
    is_deep_gemm_supported,
)
102

103
104
105
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

106
107
108
109
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

110

111
class Fp8Config(QuantizationConfig):
112
113
    """Config class for FP8."""

114
115
    def __init__(
        self,
116
        is_checkpoint_fp8_serialized: bool = False,
117
        activation_scheme: str = "dynamic",
118
119
        ignored_layers: list[str] | None = None,
        weight_block_size: list[int] | None = None,
120
    ) -> None:
121
        super().__init__()
122

123
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
124

125
        if activation_scheme not in ACTIVATION_SCHEMES:
126
            raise ValueError(f"Unsupported activation scheme {activation_scheme}")
127
        self.activation_scheme = activation_scheme
128
        self.ignored_layers = ignored_layers or []
129
130
131
132
        if weight_block_size is not None:
            if not is_checkpoint_fp8_serialized:
                raise ValueError(
                    "The block-wise quantization only supports fp8-serialized "
133
134
                    "checkpoint for now."
                )
135
136
137
            if len(weight_block_size) != 2:
                raise ValueError(
                    "The quantization block size of weight must have 2 "
138
139
                    f"dimensions, but got {len(weight_block_size)} dimensions"
                )
140
            if activation_scheme != "dynamic":
141
142
143
144
145
                raise ValueError(
                    "The block-wise quantization only supports "
                    "dynamic activation scheme for now, but got "
                    f"{activation_scheme} activation scheme."
                )
146
        self.weight_block_size = weight_block_size
147

148
    @classmethod
149
    def get_name(cls) -> QuantizationMethods:
150
151
152
        return "fp8"

    @classmethod
153
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
154
155
156
157
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
158
        return 75
159
160

    @classmethod
161
    def get_config_filenames(cls) -> list[str]:
162
163
        return []

164
165
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.ignored_layers is not None:
166
            self.ignored_layers = hf_to_vllm_mapper.apply_list(self.ignored_layers)
167

168
    @classmethod
169
    def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
170
        quant_method = cls.get_from_keys(config, ["quant_method"])
171
        is_checkpoint_fp8_serialized = "fp8" in quant_method
172
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
173
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
174
        weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
175
        if not ignored_layers:
176
177
178
179
180
181
182
183
184
185
186
187
188
            ignored_layers = cls.get_from_keys_or(
                config, ["modules_to_not_convert"], None
            )
        return cls(
            is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
            activation_scheme=activation_scheme,
            ignored_layers=ignored_layers,
            weight_block_size=weight_block_size,
        )

    def get_xpu_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
189
        from vllm.model_executor.layers.quantization.ipex_quant import (
190
191
192
193
            XPUFp8LinearMethod,
            XPUFp8MoEMethod,
        )

194
195
196
197
        fp8_config = Fp8Config(
            is_checkpoint_fp8_serialized=self.is_checkpoint_fp8_serialized,
            activation_scheme=self.activation_scheme,
            ignored_layers=self.ignored_layers,
198
199
            weight_block_size=self.weight_block_size,
        )
200
201

        if isinstance(layer, LinearBase):
202
203
204
205
206
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
207
208
209
                return UnquantizedLinearMethod()
            return XPUFp8LinearMethod(fp8_config)
        elif isinstance(layer, FusedMoE):
210
211
212
213
214
215
216
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
                return UnquantizedFusedMoEMethod(layer.moe_config)

217
218
219
220
221
            return XPUFp8MoEMethod(fp8_config, layer)
        elif isinstance(layer, Attention):
            return Fp8KVCacheMethod(self)
        return None

222
223
224
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
225
226
        if current_platform.is_xpu():
            return self.get_xpu_quant_method(layer, prefix)
227
        if isinstance(layer, LinearBase):
228
229
230
231
232
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
233
                return UnquantizedLinearMethod()
234
235
236
237
238
239
240
241
            if not self.is_checkpoint_fp8_serialized:
                online_method = Fp8OnlineLinearMethod(self)
                online_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
                return online_method
            else:
                offline_method = Fp8LinearMethod(self)
                offline_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
                return offline_method
242
        elif isinstance(layer, FusedMoE):
243
244
245
246
247
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
            ):
XuruiYang's avatar
XuruiYang committed
248
                return UnquantizedFusedMoEMethod(layer.moe_config)
249
250
251
252
            if self.is_checkpoint_fp8_serialized:
                moe_quant_method = Fp8MoEMethod(self, layer)
            else:
                moe_quant_method = Fp8OnlineMoEMethod(self, layer)
253
            return moe_quant_method
254
        elif isinstance(layer, Attention):
255
            return Fp8KVCacheMethod(self)
256
        return None
257

258
    def get_cache_scale(self, name: str) -> str | None:
259
260
261
262
263
264
265
266
267
268
269
270
        """
        Check whether the param name matches the format for k/v cache scales
        in compressed-tensors. If this is the case, return its equivalent
        param name expected by vLLM

        :param name: param name
        :return: matching param name for KV cache scale in vLLM
        """
        if name.endswith(".output_scale") and ".k_proj" in name:
            return name.replace(".k_proj.output_scale", ".attn.k_scale")
        if name.endswith(".output_scale") and ".v_proj" in name:
            return name.replace(".v_proj.output_scale", ".attn.v_scale")
271
272
273
274
275
        if name.endswith(".output_scale") and ".q_proj" in name:
            return name.replace(".q_proj.output_scale", ".attn.q_scale")
        if name.endswith("self_attn.prob_output_scale"):
            return name.replace(".prob_output_scale", ".attn.prob_scale")
        # If no matches, return None
276
277
        return None

278

279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
class CopyNumelCounter(TorchDispatchMode):
    """
    Tracks total number of elements modified with `copy_`. Useful for keeping
    track of weight loading where underlying weights can be arbitrarily
    transformed (such as with `narrow`) before calling copy.
    """

    def __init__(self):
        super().__init__()
        self.copied_numel = 0

    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        out = func(*args, **kwargs)
        if func == torch.ops.aten.copy_.default:
            self.copied_numel += args[0].numel()
        return out


299
300
class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
301
302
303
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

304
    Limitations:
305
    1. Only support float8_e4m3fn data type due to the limitation of
306
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
307

308
309
310
311
    Args:
        quant_config: The quantization config.
    """

312
    def __init__(self, quant_config: Fp8Config):
313
        self.quant_config = quant_config
314
        self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
315
        self.out_dtype = torch.get_default_dtype()
316

317
318
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
319
        self.marlin_input_dtype = None
320
321
322
323
        self.use_marlin = (
            not current_platform.has_device_capability(89)
            or envs.VLLM_TEST_FORCE_FP8_MARLIN
        )
324
        # Disable marlin for rocm
325
        if current_platform.is_rocm():
326
            self.use_marlin = False
327
        if vllm_is_batch_invariant():
328
            self.use_marlin = False
329

330
        self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
331
        self.use_deep_gemm = is_deep_gemm_supported()
332

333
334
        self.weight_block_size = self.quant_config.weight_block_size
        self.block_quant = self.weight_block_size is not None
335
336
        self.act_q_static = self.quant_config.activation_scheme == "static"

337
338
339
340
341
        if self.block_quant:
            assert not self.act_q_static
            assert self.weight_block_size is not None
            self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
                weight_group_shape=GroupShape(*self.weight_block_size),
342
                act_quant_group_shape=GroupShape(1, self.weight_block_size[0]),
343
344
345
346
                cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
                use_aiter_and_is_supported=self.use_aiter_and_is_supported,
            )
        else:
347
348
349
350
351
352
353
354
355
356
357
358
359
            # Use per-token quantization for better perf if dynamic and cutlass
            if self.act_q_static:
                activation_quant_key = kFp8StaticTensorSym
            elif cutlass_fp8_supported():
                activation_quant_key = kFp8DynamicTokenSym
            else:
                activation_quant_key = kFp8DynamicTensorSym

            self.fp8_linear = init_fp8_linear_kernel(
                activation_quant_key=activation_quant_key,
                weight_quant_key=kFp8StaticTensorSym,
                out_dtype=torch.get_default_dtype(),
                module_name=self.__class__.__name__,
360
            )
361

362
363
364
365
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
366
        output_partition_sizes: list[int],
367
368
369
370
371
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
372
        output_size_per_partition = sum(output_partition_sizes)
373
        weight_loader = extra_weight_attrs.get("weight_loader")
374
375
376
377
378
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None
379

380
        if self.block_quant:
381
382
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
383
384
385
386
387
388
389
390
            validate_fp8_block_shape(
                layer,
                input_size,
                output_size,
                input_size_per_partition,
                output_partition_sizes,
                self.weight_block_size,
            )
391

392
393
394
395
396
397
398
399
400
401
402
403
404
        weight = create_fp8_weight_parameter(
            output_size_per_partition, input_size_per_partition, weight_loader
        )
        layer.register_parameter("weight", weight)

        # WEIGHT SCALE
        if not self.block_quant:
            scale = create_fp8_scale_parameter(
                PerTensorScaleParameter,
                output_partition_sizes,
                input_size_per_partition,
                None,
                weight_loader,
405
            )
406
407
            set_weight_attrs(scale, {"scale_type": "weight_scale"})
            layer.register_parameter("weight_scale", scale)
408
        else:
409
410
411
412
413
414
415
416
            assert not self.act_q_static
            assert self.weight_block_size is not None
            scale = create_fp8_scale_parameter(
                BlockQuantScaleParameter,
                output_partition_sizes,
                input_size_per_partition,
                self.weight_block_size,
                weight_loader,
417
            )
418
419
420
            set_weight_attrs(scale, {"scale_type": "weight_scale"})
            # The weight_scale_inv name is intentional for deepseekv3
            layer.register_parameter("weight_scale_inv", scale)
421

422
423
424
425
426
        # INPUT ACTIVATION SCALE
        if self.act_q_static:
            scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
            set_weight_attrs(scale, {"scale_type": "input_scale"})
            layer.register_parameter("input_scale", scale)
427

428
    def process_weights_after_loading(self, layer: Module) -> None:
429
        size_k_first = True
430
        input_scale = None
431
        # TODO(rob): refactor block quant into separate class.
432
        if self.block_quant:
433
            assert not self.act_q_static
434
            size_k_first = False
435

436
            weight, weight_scale_inv = process_fp8_weight_block_strategy(
437
438
                layer.weight, layer.weight_scale_inv
            )
439
440
441
442

            # Update layer with new values
            replace_parameter(layer, "weight", weight.data)
            replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data)
443

444
        # If checkpoint not serialized fp8, quantize the weights.
445
446
447
        else:
            # If checkpoint is fp8 per-tensor, handle that there are N scales for N
            # shards in a fused module
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
            weight = layer.weight
            weight_scale = layer.weight_scale

            # If using w8a8, torch._scaled_mm needs per tensor, so
            # requantize the logical shards as a single weight.
            if not self.use_marlin:
                weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy(
                    weight,
                    weight_scale,
                    layer.logical_widths,
                    getattr(layer, "input_scale", None),
                )
                if self.act_q_static:
                    assert input_scale is not None
                    input_scale = input_scale.max()
            weight = weight.t()
464

465
466
467
468
469
470
            # Update layer with new values.
            replace_parameter(layer, "weight", weight.data)
            replace_parameter(layer, "weight_scale", weight_scale.data)

        if input_scale is not None:
            replace_parameter(layer, "input_scale", input_scale)
471
        else:
472
            layer.input_scale = None
473

474
        if self.use_marlin:
475
476
477
            prepare_fp8_layer_for_marlin(
                layer, size_k_first, input_dtype=self.marlin_input_dtype
            )
478
479
            # Activations not quantized for marlin.
            del layer.input_scale
480
            return
481

482
        if self.block_quant:
483
            maybe_post_process_fp8_weight_block(layer)
484

485
486
487
488
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
489
        bias: torch.Tensor | None = None,
490
    ) -> torch.Tensor:
491
492
        # if batch invariant mode is enabled, prefer DeepGEMM FP8 path
        # we will use BF16 dequant when DeepGEMM is not supported.
493
        if vllm_is_batch_invariant():
494
495
            if self.block_quant:
                assert self.weight_block_size is not None
496
497
498
                return self.w8a8_block_fp8_linear.apply(
                    input=x,
                    weight=layer.weight,
499
                    weight_scale=layer.weight_scale_inv,
500
501
502
                    input_scale=layer.input_scale,
                    bias=bias,
                )
503
            else:
504
505
506
                # per-tensor/channel: dequant to BF16 and run GEMM
                weight_fp8 = layer.weight.to(torch.bfloat16)
                weight_scale = layer.weight_scale.to(torch.bfloat16)
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
                if weight_scale.numel() == 1:
                    # Per-tensor: simple scalar multiplication
                    weight_bf16 = weight_fp8 * weight_scale
                else:
                    # Multiple scales (fused modules like QKV)
                    # Try to infer correct broadcasting
                    # weight is [K, N], scale could be [num_logical_weights]
                    # Need to figure out how to broadcast - for now just try
                    # direct multiplication
                    if (
                        weight_scale.dim() == 1
                        and weight_scale.shape[0] == weight_fp8.shape[0]
                    ):
                        # Per-row scaling
                        weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1)
                    else:
                        # Fallback
                        weight_bf16 = weight_fp8 * weight_scale
525
                return torch.nn.functional.linear(x, weight_bf16.t(), bias)
526

527
        if self.use_marlin:
528
529
530
531
532
            if self.block_quant:
                weight_scale = layer.weight_scale_inv
            else:
                weight_scale = layer.weight_scale

533
534
535
            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
536
                weight_scale=weight_scale,
537
538
539
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
540
                input_dtype=self.marlin_input_dtype,
541
542
                bias=bias,
            )
543

544
        if self.block_quant:
545
546
547
            assert self.weight_block_size is not None

            return self.w8a8_block_fp8_linear.apply(
548
                input=x,
549
                weight=layer.weight,
550
                weight_scale=layer.weight_scale_inv,
551
                input_scale=layer.input_scale,
552
                bias=bias,
553
            )
554

555
        return self.fp8_linear.apply_weights(layer, x, bias)
556
557


558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
class Fp8OnlineLinearMethod(Fp8LinearMethod):
    """Online version of Fp8LinearMethod, loads the fp16/bf16 checkpoint
    and quantized the weights during loading."""

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

        # WEIGHT
        def patched_weight_loader(param, loaded_weight, *args, **kwargs):
            # track how many elements we have updated
            if not hasattr(layer, "_loaded_numel"):
                layer._loaded_numel = 0

            # load the current weight chunk
            copy_numel_counter = CopyNumelCounter()
            with copy_numel_counter:
                res = weight_loader(param, loaded_weight, *args, **kwargs)  # type: ignore[misc]
            layer._loaded_numel += copy_numel_counter.copied_numel

            # if we have loaded all of the elements, call
            # process_weights_after_loading
            target_loaded_numel = layer.weight.numel()
            if layer._loaded_numel == target_loaded_numel:
                self.process_weights_after_loading(layer)

                # Delete the bookkeeping
                del layer._loaded_numel
                # Prevent the usual `process_weights_after_loading` call from doing
                # anything
                layer._already_called_process_weights_after_loading = True

            return res

        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition,
                dtype=params_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=patched_weight_loader,
        )
        layer.register_parameter("weight", weight)

    def process_weights_after_loading(self, layer: Module) -> None:
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

        # TODO(future): support block_quant in online quant path
        assert not self.block_quant

        layer.input_scale = None
        qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
        weight = qweight.t()

        # Update layer with new values.
        replace_parameter(layer, "weight", weight.data)
        replace_parameter(layer, "weight_scale", weight_scale.data)

        if self.use_marlin:
            size_k_first = True
            prepare_fp8_layer_for_marlin(
                layer, size_k_first, input_dtype=self.marlin_input_dtype
            )
            # Activations not quantized for marlin.


641
642
643
644
645
646
647
648
649
650
651
652
653
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

654
655
    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(layer.moe_config)
656
        self.quant_config = quant_config
657
        self.weight_block_size = self.quant_config.weight_block_size
658
        self.block_quant: bool = self.weight_block_size is not None
659
660
661
        self.weight_scale_name = (
            "weight_scale_inv" if self.block_quant else "weight_scale"
        )
662

663
664
665
666
667
668
669
670
671
672
        # Set weight key and activation key for kernel compatibility
        if self.block_quant:
            weight_key = kFp8Static128BlockSym
            activation_key = kFp8Dynamic128Sym
        else:
            weight_key = kFp8StaticTensorSym
            activation_key = (
                kFp8StaticTensorSym
                if self.quant_config.activation_scheme == "static"
                else kFp8DynamicTensorSym
673
            )
674

675
676
677
678
679
680
681
682
683
        # Select Fp8 MoE backend
        self.fp8_backend, self.experts_cls = select_fp8_moe_backend(
            config=self.moe,
            weight_key=weight_key,
            activation_key=activation_key,
            allow_vllm_cutlass=False,
        )

        # Delay creation of the kernel until after process-weights.
684
685
        self.kernel: mk.FusedMoEModularKernel | None = None

686
687
688
689
690
691
    @property
    def topk_indices_dtype(self) -> torch.dtype | None:
        if self.kernel is not None:
            return self.kernel.prepare_finalize.topk_indices_dtype()
        return None

692
693
694
695
696
697
698
699
700
    def create_weights(
        self,
        layer: Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
701
702
703
704
705
706
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

707
708
709
        assert self.quant_config.is_checkpoint_fp8_serialized
        params_dtype = torch.float8_e4m3fn

710
        if self.block_quant:
711
712
            assert self.weight_block_size is not None
            layer.weight_block_size = self.weight_block_size
713
714
            tp_size = get_tensor_model_parallel_world_size()
            block_n, block_k = (
715
716
                self.weight_block_size[0],
                self.weight_block_size[1],
717
718
719
720
721
            )
            # NOTE: To ensure proper alignment of the block-wise quantization
            # scales, the output_size of the weights for both the gate and up
            # layers must be divisible by block_n.
            # Required by column parallel or enabling merged weights
722
            if intermediate_size_per_partition % block_n != 0:
723
724
                raise ValueError(
                    f"The output_size of gate's and up's weight = "
725
                    f"{intermediate_size_per_partition} is not divisible by "
726
727
728
                    f"weight quantization block_n = {block_n}."
                )
            if tp_size > 1 and intermediate_size_per_partition % block_k != 0:
729
                # Required by row parallel
730
731
732
                raise ValueError(
                    f"The input_size of down's weight = "
                    f"{intermediate_size_per_partition} is not divisible by "
733
734
                    f"weight quantization block_k = {block_k}."
                )
735
736

        # WEIGHTS
737
738
739
740
741
742
743
744
745
        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
746
747
748
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

749
750
751
752
753
754
755
756
757
        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
758
759
760
761
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
762
        if not self.block_quant:
763
764
765
            # For per-tensor quant, the scales are per expert and weight.
            w13_scale_data = torch.ones(num_experts, 2, dtype=torch.float32)
            w2_scale_data = torch.ones(num_experts, dtype=torch.float32)
766
        else:
767
768
769
770
771
772
            # For block quant, the scales are per block (typically 128x128).
            w13_scale_data = torch.ones(
                num_experts,
                2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
                (hidden_size + block_k - 1) // block_k,
                dtype=torch.float32,
773
            )
774
775
776
777
778
            w2_scale_data = torch.ones(
                num_experts,
                (hidden_size + block_n - 1) // block_n,
                (intermediate_size_per_partition + block_k - 1) // block_k,
                dtype=torch.float32,
779
            )
780
781
782
783
784
        w13_weight_scale = torch.nn.Parameter(w13_scale_data, requires_grad=False)
        w2_weight_scale = torch.nn.Parameter(w2_scale_data, requires_grad=False)
        # Note: name is weight_scale for tensor, weight_scale_inv for block.
        layer.register_parameter(f"w13_{self.weight_scale_name}", w13_weight_scale)
        layer.register_parameter(f"w2_{self.weight_scale_name}", w2_weight_scale)
785

786
787
788
        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
789
790
791
792
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
            if self.block_quant
            else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
        )
793
794
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)
795
796
797

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
798
            assert not self.block_quant
799
800
801
            w13_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
802
            layer.register_parameter("w13_input_scale", w13_input_scale)
803
            set_weight_attrs(w13_input_scale, extra_weight_attrs)
804

805
806
807
            w2_input_scale = torch.nn.Parameter(
                torch.ones(num_experts, dtype=torch.float32), requires_grad=False
            )
808
            layer.register_parameter("w2_input_scale", w2_input_scale)
809
810
            set_weight_attrs(w2_input_scale, extra_weight_attrs)

811
        else:
812
813
            layer.w13_input_scale = None
            layer.w2_input_scale = None
814

815
    def _setup_kernel(
816
817
        self,
        layer: Module,
818
819
820
821
        w13: torch.Tensor,
        w2: torch.Tensor,
        w13_scale: torch.Tensor,
        w2_scale: torch.Tensor,
822
823
        w13_input_scale: torch.Tensor | None,
        w2_input_scale: torch.Tensor | None,
824
    ) -> None:
825
826
827
828
829
830
831
832
833
834
835
        # Shuffle weights to runtime format.
        w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format(
            fp8_backend=self.fp8_backend,
            layer=layer,
            w13=w13,
            w2=w2,
            w13_scale=w13_scale,
            w2_scale=w2_scale,
            w13_input_scale=w13_input_scale,
            w2_input_scale=w2_input_scale,
        )
836

837
838
        # Replace parameters with updated versions. Note that this helper
        # function ensures the replacement is compatible with RL weight reloads.
839
840
841
842
        replace_parameter(layer, "w13_weight", w13)
        replace_parameter(layer, "w2_weight", w2)
        replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_scale)
        replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_scale)
843

844
845
846
847
        # Setup modular kernel for TP case and naive DP/EP case.
        # In non-naive DP/EP case, we will create a ModularKernelMethod.
        # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
        # in both cases.
848
        self.moe_quant_config = self.get_fused_moe_quant_config(layer)
849
850
851
852
853
        if self.moe_quant_config and (
            (not self.moe.moe_parallel_config.use_all2all_kernels)
            or self.moe.moe_parallel_config.use_naive_all2all_kernels
        ):
            assert self.experts_cls is not None
854
855
856
857
            self.kernel, self.use_inplace = make_fp8_moe_kernel(
                moe_quant_config=self.moe_quant_config,
                moe_config=self.moe,
                fp8_backend=self.fp8_backend,
858
                experts_cls=self.experts_cls,
859
            )
860

861
862
863
864
865
    def process_weights_after_loading(self, layer: Module) -> None:
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

        # Allow for accessing weights and scales in standard way.
866
867
868
869
        w13 = layer.w13_weight
        w2 = layer.w2_weight
        w13_scale = getattr(layer, f"w13_{self.weight_scale_name}")
        w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
870
871
        w13_input_scale = layer.w13_input_scale
        w2_input_scale = layer.w2_input_scale
872
873
874

        # MI300x and MI325x use FNUZ format for FP8. Convert if needed.
        if current_platform.is_fp8_fnuz():
875
876
877
878
            w13, w13_scale, w13_input_scale = normalize_e4m3fn_to_e4m3fnuz(
                w13,
                w13_scale,
                w13_input_scale,
879
            )
880
881
882
883
            w2, w2_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz(
                w2,
                w2_scale,
                w2_input_scale,
884
885
886
887
888
            )

        # Per tensor kernels require single activation scale. Use the max.
        if self.quant_config.activation_scheme == "static":
            assert not self.block_quant
889
            assert w13_input_scale is not None and w2_input_scale is not None
890
891
892
893
894
            w13_input_scale, w2_input_scale = process_fp8_input_tensor_strategy_moe(
                w13_input_scale, w2_input_scale
            )
            replace_parameter(layer, "w13_input_scale", w13_input_scale)
            replace_parameter(layer, "w2_input_scale", w2_input_scale)
895
896
897
898
899

        # Per tensor kernels require single weight scale for w13 per expert, but
        # on disk there is a scale for w1 and w3. Use the max to requantize.
        if not self.block_quant:
            shard_size = layer.intermediate_size_per_partition
900
901
902
            w13, w13_scale = process_fp8_weight_tensor_strategy_moe(
                w13, w13_scale, shard_size, layer.local_num_experts
            )
903

904
905
906
        # Shuffle weights to runtime format and setup kernel.
        self._setup_kernel(
            layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale
907
908
        )

909
910
911
912
    def maybe_make_prepare_finalize(
        self,
        routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
    ) -> mk.FusedMoEPrepareAndFinalize | None:
913
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
914
            return None
915
        elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
916
917
918
919
            # For no-EP case, don't use the MKM framework.
            if not self.moe.moe_parallel_config.use_all2all_kernels:
                return None

920
921
922
923
924
925
            prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
                self.moe,
                use_deepseek_fp8_block_scale=self.block_quant,
            )
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
926
        return super().maybe_make_prepare_finalize(routing_tables)
927

bnellnm's avatar
bnellnm committed
928
929
930
    def select_gemm_impl(
        self,
        prepare_finalize: FusedMoEPrepareAndFinalize,
931
        layer: torch.nn.Module,
bnellnm's avatar
bnellnm committed
932
    ) -> FusedMoEPermuteExpertsUnpermute:
933
        assert self.moe_quant_config is not None
934
935
936
937
938
939
940
        assert self.experts_cls is not None
        return make_fp8_moe_kernel_for_mkm(
            moe_config=self.moe,
            quant_config=self.moe_quant_config,
            experts_cls=self.experts_cls,
            prepare_finalize=prepare_finalize,
        )
941

942
    def get_fused_moe_quant_config(
943
        self, layer: torch.nn.Module
944
    ) -> FusedMoEQuantConfig | None:
945
946
947
948
949
950
951
952
953
        # TRTLLM does not use Modular Kernel.
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
            return None

        w1_scale = getattr(layer, f"w13_{self.weight_scale_name}")
        w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
        a1_scale = layer.w13_input_scale
        a2_scale = layer.w2_input_scale

954
955
        return make_fp8_moe_quant_config(
            fp8_backend=self.fp8_backend,
956
957
958
959
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
960
            block_shape=self.weight_block_size,
961
962
        )

963
964
965
966
967
968
969
970
    @property
    def supports_eplb(self) -> bool:
        return True

    @property
    def allow_inplace(self) -> bool:
        return True

971
972
    def apply(
        self,
973
        layer: FusedMoE,
974
        router: FusedMoERouter,
975
976
        x: torch.Tensor,
        router_logits: torch.Tensor,
977
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
978
        if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
979
            # TODO(rob): convert this to MK.
980
981
982
983
            if layer.enable_eplb:
                raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.")
            assert layer.activation == "silu", (
                f"Expected 'silu' activation but got {layer.activation}"
984
            )
985

986
            if self.block_quant:
987
                import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe  # noqa: E501, F401
988
989

                e_score_correction_bias = (
990
991
                    layer.e_score_correction_bias.to(x.dtype)
                    if layer.e_score_correction_bias is not None
992
993
                    else None
                )
994
                routing_method_type = layer.routing_method_type
995
                return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
996
997
998
                    routing_logits=router_logits.to(torch.float32)
                    if routing_method_type == RoutingMethodType.DeepSeekV3
                    else router_logits,
999
1000
1001
1002
1003
1004
                    routing_bias=e_score_correction_bias,
                    x=x,
                    w13_weight=layer.w13_weight,
                    w13_weight_scale_inv=layer.w13_weight_scale_inv,
                    w2_weight=layer.w2_weight,
                    w2_weight_scale_inv=layer.w2_weight_scale_inv,
1005
1006
1007
1008
                    global_num_experts=layer.global_num_experts,
                    top_k=layer.top_k,
                    num_expert_group=layer.num_expert_group,
                    topk_group=layer.topk_group,
1009
1010
1011
                    intermediate_size=layer.intermediate_size_per_partition,
                    expert_offset=layer.ep_rank * layer.local_num_experts,
                    local_num_experts=layer.local_num_experts,
1012
                    block_shape=self.weight_block_size,
1013
                    routing_method_type=routing_method_type,
1014
                    routed_scaling=layer.routed_scaling_factor,
1015
1016
                )
            else:
1017
                return apply_fi_trtllm_fp8_per_tensor_moe(
1018
1019
1020
                    layer=layer,
                    hidden_states=x,
                    router_logits=router_logits,
1021
1022
1023
1024
1025
1026
                    routing_bias=layer.e_score_correction_bias,
                    global_num_experts=layer.global_num_experts,
                    top_k=layer.top_k,
                    num_expert_group=layer.num_expert_group,
                    topk_group=layer.topk_group,
                    apply_router_weight_on_input=layer.apply_router_weight_on_input,
1027
                )
1028

1029
        topk_weights, topk_ids = router.select_experts(
1030
1031
1032
            hidden_states=x,
            router_logits=router_logits,
        )
1033
1034

        assert self.kernel is not None
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
        result = self.kernel(
            x,
            layer.w13_weight,
            layer.w2_weight,
            topk_weights,
            topk_ids,
            inplace=self.use_inplace,
            activation=layer.activation,
            global_num_experts=layer.global_num_experts,
            expert_map=layer.expert_map,
            apply_router_weight_on_input=layer.apply_router_weight_on_input,
        )
1047

1048
        return result
1049
1050


1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
class Fp8OnlineMoEMethod(Fp8MoEMethod):
    """MoE method for online FP8 quantization.
    Supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(quant_config, layer)
        assert not quant_config.is_checkpoint_fp8_serialized
        assert quant_config.activation_scheme == "dynamic"
        assert quant_config.weight_block_size is None

    def create_weights(
        self,
        layer: Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

        # We are doing online quantization, patch the weight loaded
        # to call `process_weights_after_loading` in a streaming fashion
        # as soon as the last weight chunk is loaded.
        weight_loader = extra_weight_attrs["weight_loader"]
        # create a new holder to prevent modifying behavior of any other
        # objects which might depend on the old one
        new_extra_weight_attrs = extra_weight_attrs

        def patched_weight_loader(param, loaded_weight, *args, **kwargs):
            # add a counter to track how many elements we have updated
            if not hasattr(layer, "_loaded_numel"):
                layer._loaded_numel = 0
1094
1095
1096
1097
1098
1099

            # load the current weight chunk
            copy_numel_counter = CopyNumelCounter()
            with copy_numel_counter:
                res = weight_loader(param, loaded_weight, *args, **kwargs)  # type: ignore[misc]
            layer._loaded_numel += copy_numel_counter.copied_numel
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153

            # if we have loaded all of the elements, call
            # process_weights_after_loading
            target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel()
            if layer._loaded_numel == target_loaded_numel:
                self.process_weights_after_loading(layer)

                # Delete the bookkeeping
                del layer._loaded_numel
                # Prevent the usual `process_weights_after_loading` call
                # from doing anything
                layer._already_called_process_weights_after_loading = True

            return res

        new_extra_weight_attrs["weight_loader"] = patched_weight_loader
        extra_weight_attrs = new_extra_weight_attrs

        # WEIGHTS
        w13_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

        w2_weight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                intermediate_size_per_partition,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
        # Allocate 2 scales for w1 and w3 respectively.
        # They will be combined to a single scale after weight loading.
        w13_weight_scale = torch.nn.Parameter(
            torch.ones(num_experts, dtype=torch.float32), requires_grad=False
        )
        w2_weight_scale = torch.nn.Parameter(
            torch.ones(num_experts, dtype=torch.float32), requires_grad=False
        )
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        layer.register_parameter("w2_weight_scale", w2_weight_scale)
1154
1155
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165

        layer.w13_input_scale = None
        layer.w2_input_scale = None

    def process_weights_after_loading(self, layer: Module) -> None:
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

        # If checkpoint is fp16, quantize in place.
        fp8_dtype = current_platform.fp8_dtype()
1166
1167
1168
1169
        w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
        w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
        w13_scale = layer.w13_weight_scale
        w2_scale = layer.w2_weight_scale
1170
1171

        for expert in range(layer.local_num_experts):
1172
1173
            w13[expert, :, :], w13_scale[expert] = ops.scaled_fp8_quant(
                layer.w13_weight[expert, :, :]
1174
            )
1175
1176
            w2[expert, :, :], w2_scale[expert] = ops.scaled_fp8_quant(
                layer.w2_weight[expert, :, :]
1177
1178
            )

1179
1180
1181
1182
1183
1184
1185
1186
1187
        # Shuffle weights to runtime format and setup kernel.
        self._setup_kernel(
            layer,
            w13,
            w2,
            w13_scale,
            w2_scale,
            layer.w13_input_scale,
            layer.w2_input_scale,
1188
        )
1189
1190


1191
1192
1193
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
1194
1195
1196
    """

    def __init__(self, quant_config: Fp8Config):
1197
        super().__init__(quant_config)