fp8.py 38.7 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import importlib.util
4
from typing import Any, Callable, Dict, List, Optional
5
6

import torch
7
import torch.nn.functional as F
8
9
10
from torch.nn import Module
from torch.nn.parameter import Parameter

11
import vllm.envs as envs
12
from vllm import _custom_ops as ops
13
from vllm.distributed import get_tensor_model_parallel_world_size
14
from vllm.logger import init_logger
15
16
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
                                                  FusedMoeWeightScaleSupported)
17
18
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
                                               UnquantizedLinearMethod)
19
from vllm.model_executor.layers.quantization import QuantizationMethods
20
from vllm.model_executor.layers.quantization.base_config import (
21
    QuantizationConfig, QuantizeMethodBase)
22
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
23
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
24
25
    apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin,
    prepare_moe_fp8_layer_for_marlin)
26
27
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    is_layer_skipped)
28
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
29
30
31
32
    Fp8LinearOp, all_close_1d, cutlass_block_fp8_supported,
    cutlass_fp8_supported, maybe_create_device_identity,
    normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
    requantize_with_max_scale)
33
34
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
                                           ModelWeightParameter,
35
                                           PerTensorScaleParameter)
36
from vllm.model_executor.utils import set_weight_attrs
37
from vllm.platforms import current_platform
38
from vllm.scalar_type import scalar_types
39

40
41
42
43
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

44
45
46
47
48
49
50
51
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None


def _is_col_major(x: torch.Tensor) -> bool:
    assert x.dim() == 3
    b, m, n = x.shape
    return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m

52

53
class Fp8Config(QuantizationConfig):
54
55
    """Config class for FP8."""

56
57
    def __init__(
        self,
58
        is_checkpoint_fp8_serialized: bool = False,
59
        activation_scheme: str = "dynamic",
60
        ignored_layers: Optional[List[str]] = None,
61
        weight_block_size: Optional[List[int]] = None,
62
    ) -> None:
63
        super().__init__()
64
65
66
67
68
69
70
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
        if is_checkpoint_fp8_serialized:
            logger.warning("Detected fp8 checkpoint. Please note that the "
                           "format is experimental and subject to change.")
        if activation_scheme not in ACTIVATION_SCHEMES:
            raise ValueError(
                f"Unsupported activation scheme {activation_scheme}")
71
        self.activation_scheme = activation_scheme
72
        self.ignored_layers = ignored_layers or []
73
74
75
76
77
78
79
80
81
82
83
84
85
86
        if weight_block_size is not None:
            if not is_checkpoint_fp8_serialized:
                raise ValueError(
                    "The block-wise quantization only supports fp8-serialized "
                    "checkpoint for now.")
            if len(weight_block_size) != 2:
                raise ValueError(
                    "The quantization block size of weight must have 2 "
                    f"dimensions, but got {len(weight_block_size)} dimensions")
            if activation_scheme != "dynamic":
                raise ValueError("The block-wise quantization only supports "
                                 "dynamic activation scheme for now, but got "
                                 f"{activation_scheme} activation scheme.")
        self.weight_block_size = weight_block_size
87

88
    @classmethod
89
    def get_name(cls) -> QuantizationMethods:
90
91
92
93
94
95
96
97
        return "fp8"

    @classmethod
    def get_supported_act_dtypes(cls) -> List[torch.dtype]:
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
98
        return 80
99
100
101
102
103
104

    @classmethod
    def get_config_filenames(cls) -> List[str]:
        return []

    @classmethod
105
    def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
106
107
        quant_method = cls.get_from_keys(config, ["quant_method"])
        is_checkpoint_fp8_serialized = ("fp8" in quant_method)
108
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
109
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
110
111
        weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"],
                                                 None)
112
        return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
113
                   activation_scheme=activation_scheme,
114
115
                   ignored_layers=ignored_layers,
                   weight_block_size=weight_block_size)
116

117
118
    def get_quant_method(self, layer: torch.nn.Module,
                         prefix: str) -> Optional["QuantizeMethodBase"]:
119
120
        from vllm.attention.layer import Attention  # Avoid circular import

121
        if isinstance(layer, LinearBase):
122
123
124
            if is_layer_skipped(prefix=prefix,
                                ignored_layers=self.ignored_layers,
                                fused_mapping=self.packed_modules_mapping):
125
                return UnquantizedLinearMethod()
126
            return Fp8LinearMethod(self)
127
128
129
        elif isinstance(layer, FusedMoE):
            return Fp8MoEMethod(self)
        elif isinstance(layer, Attention):
130
            return Fp8KVCacheMethod(self)
131
        return None
132

133
134
135
136
137
138
139
140
141
142
143
144
145
    def get_cache_scale(self, name: str) -> Optional[str]:
        """
        Check whether the param name matches the format for k/v cache scales
        in compressed-tensors. If this is the case, return its equivalent
        param name expected by vLLM

        :param name: param name
        :return: matching param name for KV cache scale in vLLM
        """
        if name.endswith(".output_scale") and ".k_proj" in name:
            return name.replace(".k_proj.output_scale", ".attn.k_scale")
        if name.endswith(".output_scale") and ".v_proj" in name:
            return name.replace(".v_proj.output_scale", ".attn.v_scale")
146
147
148
149
150
        if name.endswith(".output_scale") and ".q_proj" in name:
            return name.replace(".q_proj.output_scale", ".attn.q_scale")
        if name.endswith("self_attn.prob_output_scale"):
            return name.replace(".prob_output_scale", ".attn.prob_scale")
        # If no matches, return None
151
152
        return None

153
154
155

class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
156
157
158
159
160
161
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.
162
163
164
165
166

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
167

168
169
170
171
    Args:
        quant_config: The quantization config.
    """

172
    def __init__(self, quant_config: Fp8Config):
173
        self.quant_config = quant_config
174
        self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
175
        self.out_dtype = torch.get_default_dtype()
176

177
178
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
179
180
        self.use_marlin = (not current_platform.has_device_capability(89)
                           or envs.VLLM_TEST_FORCE_FP8_MARLIN)
181
        # Disable marlin for rocm
182
        if current_platform.is_rocm():
183
            self.use_marlin = False
184

185
        self.block_quant = self.quant_config.weight_block_size is not None
186
187
188
189
        self.fp8_linear = Fp8LinearOp(
            # Default to using per_token quantization if cutlass is supported
            use_per_token_if_dynamic=cutlass_fp8_supported())

190
191
192
193
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
194
        output_partition_sizes: List[int],
195
196
197
198
199
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
200
201
        maybe_create_device_identity()

202
        output_size_per_partition = sum(output_partition_sizes)
203
        weight_loader = extra_weight_attrs.get("weight_loader")
204
205
206
207
208
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None
209

210
211
212
        if self.block_quant:
            tp_size = get_tensor_model_parallel_world_size()
            assert self.quant_config.weight_block_size is not None
213
            layer.weight_block_size = self.quant_config.weight_block_size
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
            block_n, block_k = (
                self.quant_config.weight_block_size[0],
                self.quant_config.weight_block_size[1],
            )
            # Required by row parallel
            if (tp_size > 1
                    and input_size // input_size_per_partition == tp_size
                    and input_size_per_partition % block_k != 0):
                raise ValueError(
                    f"Weight input_size_per_partition = "
                    f"{input_size_per_partition} is not divisible by "
                    f"weight quantization block_k = {block_k}.")
            # Required by column parallel or enabling merged weights
            if (tp_size > 1 and output_size // output_size_per_partition
                    == tp_size) or len(output_partition_sizes) > 1:
                for output_partition_size in output_partition_sizes:
                    if output_partition_size % block_n != 0:
                        raise ValueError(
                            f"Weight output_partition_size = "
                            f"{output_partition_size} is not divisible by "
                            f"weight quantization block_n = {block_n}.")

236
237
238
239
        # WEIGHT
        weight_dtype = (torch.float8_e4m3fn
                        if self.quant_config.is_checkpoint_fp8_serialized else
                        params_dtype)
240
241
242
243
244
245
246
247

        weight = ModelWeightParameter(data=torch.empty(
            output_size_per_partition,
            input_size_per_partition,
            dtype=weight_dtype),
                                      input_dim=1,
                                      output_dim=0,
                                      weight_loader=weight_loader)
248
249
        layer.register_parameter("weight", weight)

250
251
252
253
        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
254
255
256
257
258
259
260
            if not self.block_quant:
                scale = PerTensorScaleParameter(
                    data=torch.empty(len(output_partition_sizes),
                                     dtype=torch.float32),
                    weight_loader=weight_loader,
                )
                scale[:] = torch.finfo(torch.float32).min
261
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
262
263
264
265
266
267
268
269
270
271
272
273
274
275
                layer.register_parameter("weight_scale", scale)
            else:
                assert self.quant_config.activation_scheme == "dynamic"
                scale = BlockQuantScaleParameter(
                    data=torch.empty(
                        (output_size_per_partition + block_n - 1) // block_n,
                        (input_size_per_partition + block_k - 1) // block_k,
                        dtype=torch.float32,
                    ),
                    input_dim=1,
                    output_dim=0,
                    weight_loader=weight_loader,
                )
                scale[:] = torch.finfo(torch.float32).min
276
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
277
278
                # The weight_scale_inv name is intentional for deepseekv3
                layer.register_parameter("weight_scale_inv", scale)
279

280
            # INPUT ACTIVATION SCALE
281
            if self.quant_config.activation_scheme == "static":
282
283
284
285
286
                scale = PerTensorScaleParameter(data=torch.empty(
                    len(output_partition_sizes), dtype=torch.float32),
                                                weight_loader=weight_loader)

                scale[:] = torch.finfo(torch.float32).min
287
                set_weight_attrs(scale, {"scale_type": "input_scale"})
288
                layer.register_parameter("input_scale", scale)
289
290
            else:
                layer.register_parameter("input_scale", None)
291

292
    def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
293
294
295
296
297
298
299
300
301
302
        # Pad the weight tensor. This is an optimization on ROCm platform, which
        # can benefit from tensors located far enough from one another in memory
        if (envs.VLLM_ROCM_FP8_PADDING and current_platform.is_rocm()
                and weight.stride(-1) == 1
                and (weight.stride(-2) * weight.element_size()) % 512 == 0):
            num_pad = 256 // weight.element_size()
            weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
            torch.cuda.empty_cache()
        return weight

303
    def process_weights_after_loading(self, layer: Module) -> None:
304
        size_k_first = True
305
        # TODO(rob): refactor block quant into separate class.
306
        if self.block_quant:
307
            assert self.quant_config.activation_scheme == "dynamic"
308
            size_k_first = False
309
            if current_platform.is_fp8_fnuz():
310
                weight, weight_scale_inv, _ = \
311
312
                    normalize_e4m3fn_to_e4m3fnuz(
                        weight=layer.weight,
313
314
315
316
317
                        weight_scale=layer.weight_scale_inv)
            else:
                weight = layer.weight.data
                weight_scale_inv = layer.weight_scale_inv.data

318
            weight = self._maybe_pad_weight(weight)
319

320
321
322
323
324
            # Torch.compile cannot use Parameter subclasses.
            layer.weight = Parameter(weight, requires_grad=False)
            layer.weight_scale_inv = Parameter(weight_scale_inv,
                                               requires_grad=False)

325
        # If checkpoint not serialized fp8, quantize the weights.
326
        elif not self.quant_config.is_checkpoint_fp8_serialized:
327
328
            qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
                                                         scale=None)
329
330

            # Update the layer with the new values.
331
332
            layer.weight = Parameter(qweight.t(), requires_grad=False)
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
333
            layer.input_scale = None
334

335
336
        # If checkpoint is fp8, handle that there are N scales for N
        # shards in a fused module
337
        else:
338
339
340
341
342
            layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
                                                    requires_grad=False)
            if self.quant_config.activation_scheme == "static":
                layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
                                                       requires_grad=False)
343
344
345

            weight = layer.weight
            weight_scale = layer.weight_scale
346
347
348

            # If using w8a8, torch._scaled_mm needs per tensor, so
            # requantize the logical shards as a single weight.
349
            if not self.use_marlin:
350
                # Dequant -> Quant with max scale so we can run per tensor.
351
                if current_platform.is_fp8_fnuz():
352
353
354
355
356
357
358
359
360
                    weight, weight_scale, input_scale = \
                        normalize_e4m3fn_to_e4m3fnuz(
                            weight=weight,
                            weight_scale=weight_scale,
                            input_scale=layer.input_scale)
                    if input_scale is not None:
                        layer.input_scale = Parameter(input_scale,
                                                      requires_grad=False)

361
                weight_scale, weight = requantize_with_max_scale(
362
363
                    weight=weight,
                    weight_scale=weight_scale,
364
365
                    logical_widths=layer.logical_widths,
                )
366

367
            weight = self._maybe_pad_weight(weight)
368
            # Update layer with new values.
369
            layer.weight = Parameter(weight.t(), requires_grad=False)
370
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
371
            if self.quant_config.activation_scheme == "static":
372
373
                layer.input_scale = Parameter(layer.input_scale.max(),
                                              requires_grad=False)
374

375
        if self.use_marlin:
376
            prepare_fp8_layer_for_marlin(layer, size_k_first)
377
378
            # Activations not quantized for marlin.
            del layer.input_scale
379

380
381
382
383
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
384

385
        if self.use_marlin:
386
387
388
389
            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
390
391
392
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
393
                bias=bias)
394

395
396
        if self.block_quant:
            assert self.quant_config.weight_block_size is not None
397
            return torch.ops.vllm.apply_w8a8_block_fp8_linear(
398
399
400
401
402
403
                input=x,
                weight=layer.weight,
                block_size=self.quant_config.weight_block_size,
                weight_scale=layer.weight_scale_inv,
                input_scale=layer.input_scale,
                bias=bias,
404
                cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
405
406
            )

407
408
409
        return self.fp8_linear.apply(input=x,
                                     weight=layer.weight,
                                     weight_scale=layer.weight_scale,
410
                                     out_dtype=self.out_dtype,
411
412
                                     input_scale=layer.input_scale,
                                     bias=bias)
413
414


415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

    def __init__(self, quant_config: Fp8Config):
        self.quant_config = quant_config
430
        self.block_quant = self.quant_config.weight_block_size is not None
431

432
433
434
435
436
437
438
439
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
        self.use_marlin = (not current_platform.has_device_capability(89)
                           or envs.VLLM_TEST_FORCE_FP8_MARLIN)
        # Disable marlin for rocm
        if current_platform.is_rocm():
            self.use_marlin = False

440
441
442
443
444
445
446
447
448
449
450
451
452
        # Check for DeepGemm support.
        self.allow_deep_gemm = False
        if envs.VLLM_USE_DEEP_GEMM:
            if not has_deep_gemm:
                logger.warning_once("Failed to import DeepGemm kernels.")
            elif (current_platform.is_cuda()
                  and current_platform.has_device_capability(90)):
                logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.")
                self.allow_deep_gemm = True
            else:
                logger.warning_once(
                    "DeepGemm not supported on the current platform.")

453
    def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
454
455
                       intermediate_size_per_partition: int,
                       params_dtype: torch.dtype, **extra_weight_attrs):
456

457
458
459
460
461
462
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

463
464
        if self.quant_config.is_checkpoint_fp8_serialized:
            params_dtype = torch.float8_e4m3fn
465
466
        if self.block_quant:
            assert self.quant_config.weight_block_size is not None
467
            layer.weight_block_size = self.quant_config.weight_block_size
468
469
470
471
472
473
474
475
476
            tp_size = get_tensor_model_parallel_world_size()
            block_n, block_k = (
                self.quant_config.weight_block_size[0],
                self.quant_config.weight_block_size[1],
            )
            # NOTE: To ensure proper alignment of the block-wise quantization
            # scales, the output_size of the weights for both the gate and up
            # layers must be divisible by block_n.
            # Required by column parallel or enabling merged weights
477
            if intermediate_size_per_partition % block_n != 0:
478
479
                raise ValueError(
                    f"The output_size of gate's and up's weight = "
480
                    f"{intermediate_size_per_partition} is not divisible by "
481
                    f"weight quantization block_n = {block_n}.")
482
483
            if (tp_size > 1
                    and intermediate_size_per_partition % block_k != 0):
484
                # Required by row parallel
485
486
487
488
                raise ValueError(
                    f"The input_size of down's weight = "
                    f"{intermediate_size_per_partition} is not divisible by "
                    f"weight quantization block_k = {block_k}.")
489
490

        # WEIGHTS
491
492
493
494
495
        w13_weight = torch.nn.Parameter(torch.empty(
            num_experts,
            2 * intermediate_size_per_partition,
            hidden_size,
            dtype=params_dtype),
496
497
498
499
                                        requires_grad=False)
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

500
501
502
503
504
        w2_weight = torch.nn.Parameter(torch.empty(
            num_experts,
            hidden_size,
            intermediate_size_per_partition,
            dtype=params_dtype),
505
506
507
508
509
                                       requires_grad=False)
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
        if not self.block_quant:
            # Allocate 2 scales for w1 and w3 respectively.
            # They will be combined to a single scale after weight loading.
            w13_weight_scale = torch.nn.Parameter(torch.ones(
                num_experts, 2, dtype=torch.float32),
                                                  requires_grad=False)
            w2_weight_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                 requires_grad=False)
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)
        else:
            w13_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
525
526
                    2 * ((intermediate_size_per_partition + block_n - 1) //
                         block_n),
527
528
529
530
531
532
533
534
535
                    (hidden_size + block_k - 1) // block_k,
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            w2_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
                    (hidden_size + block_n - 1) // block_n,
536
                    (intermediate_size_per_partition + block_k - 1) // block_k,
537
538
539
540
541
542
543
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
            layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
            assert self.quant_config.activation_scheme == "dynamic"
544

545
546
547
        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
548
549
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.
             value} if self.block_quant else
550
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
551
552
553
554
        # If loading fp8 checkpoint, pass the weight loaders.
        # If loading an fp16 checkpoint, do not (we will quantize in
        #   process_weights_after_loading()
        if self.quant_config.is_checkpoint_fp8_serialized:
555
556
            set_weight_attrs(w13_weight_scale, extra_weight_attrs)
            set_weight_attrs(w2_weight_scale, extra_weight_attrs)
557
558
559
560
561
562
563
564

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
            if not self.quant_config.is_checkpoint_fp8_serialized:
                raise ValueError(
                    "Found static activation scheme for checkpoint that "
                    "was not serialized fp8.")

565
566
567
568
            w13_input_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                 requires_grad=False)
            layer.register_parameter("w13_input_scale", w13_input_scale)
569
            set_weight_attrs(w13_input_scale, extra_weight_attrs)
570
571
572
573
574

            w2_input_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                requires_grad=False)
            layer.register_parameter("w2_input_scale", w2_input_scale)
575
576
            set_weight_attrs(w2_input_scale, extra_weight_attrs)

577
        else:
578
579
            layer.w13_input_scale = None
            layer.w2_input_scale = None
580
581

    def process_weights_after_loading(self, layer: Module) -> None:
582
583
        # Lazy import to avoid importing triton too early.
        from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
584
            expand_weights, is_rocm_aiter_moe_enabled, shuffle_weights)
585

586
        # TODO (rob): refactor block quant into separate class.
587
        if self.block_quant:
588
            assert self.quant_config.activation_scheme == "dynamic"
589
            if current_platform.is_fp8_fnuz():
590
591
592
593
594
595
596
597
                w13_weight, w13_weight_scale_inv, w13_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w13_weight, layer.w13_weight_scale_inv,
                        layer.w13_input_scale)
                w2_weight, w2_weight_scale_inv, w2_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w2_weight, layer.w2_weight_scale_inv,
                        layer.w2_input_scale)
598
599
600
601
602
603
604
605
606
607
608
609
610
            else:
                w13_weight = layer.w13_weight.data
                w13_weight_scale_inv = layer.w13_weight_scale_inv.data
                w2_weight = layer.w2_weight
                w2_weight_scale_inv = layer.w2_weight_scale_inv

            # torch.compile() cannot use Parameter subclasses.
            layer.w13_weight = Parameter(w13_weight, requires_grad=False)
            layer.w13_weight_scale_inv = Parameter(w13_weight_scale_inv,
                                                   requires_grad=False)
            layer.w2_weight = Parameter(w2_weight, requires_grad=False)
            layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv,
                                                  requires_grad=False)
611
            if is_rocm_aiter_moe_enabled():
612
613
614
615
616
617
618
619
                # reshaping weights is required for aiter moe kernel.
                shuffled_w13, shuffled_w2 = shuffle_weights(
                    layer.w13_weight.data, layer.w2_weight.data)

                layer.w13_weight = torch.nn.Parameter(shuffled_w13,
                                                      requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2,
                                                     requires_grad=False)
620
621
622
623
624
625
626
627
628
629
630
631
632

            # DeepGemm scales need to be transposed and aligned.  We try to do
            # it ahead of time for performance reasons.
            if self.allow_deep_gemm:
                # Lazy import to avoid CUDA initialization problems.
                import deep_gemm as dg
                if _is_col_major(layer.w13_weight_scale_inv):
                    layer.w13_weight_scale_inv = \
                        dg.get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv).contiguous()
                if _is_col_major(layer.w2_weight_scale_inv):
                    layer.w2_weight_scale_inv = \
                        dg.get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous()

633
        # If checkpoint is fp16, quantize in place.
634
        elif not self.quant_config.is_checkpoint_fp8_serialized:
635
            fp8_dtype = current_platform.fp8_dtype()
636
            w13_weight = torch.empty_like(layer.w13_weight.data,
637
638
                                          dtype=fp8_dtype)
            w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
639
640
641

            # Re-initialize w13_scale because we directly quantize
            # merged w13 weights and generate a single scaling factor.
642
            layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
643
                layer.local_num_experts,
644
645
                dtype=torch.float32,
                device=w13_weight.device),
646
                                                        requires_grad=False)
647
            for expert in range(layer.local_num_experts):
648
                w13_weight[expert, :, :], layer.w13_weight_scale[
649
650
                    expert] = ops.scaled_fp8_quant(
                        layer.w13_weight.data[expert, :, :])
651
                w2_weight[expert, :, :], layer.w2_weight_scale[
652
653
654
655
656
657
                    expert] = ops.scaled_fp8_quant(
                        layer.w2_weight.data[expert, :, :])
            layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                  requires_grad=False)
            layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                 requires_grad=False)
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
            if is_rocm_aiter_moe_enabled():
                # reshaping weights is required for aiter moe kernel.
                w13_scales, w2_scales = expand_weights(
                    layer.w13_weight_scale.data,
                    layer.w2_weight_scale.data,
                    expansion_dims=[
                        layer.w13_weight.shape[1], layer.w2_weight.shape[1]
                    ])
                layer.w13_weight_scale = torch.nn.Parameter(
                    w13_scales.contiguous(), requires_grad=False)
                layer.w2_weight_scale = torch.nn.Parameter(
                    w2_scales.contiguous(), requires_grad=False)

                shuffled_w13, shuffled_w2 = shuffle_weights(
                    layer.w13_weight, layer.w2_weight)

                layer.w13_weight = torch.nn.Parameter(shuffled_w13,
                                                      requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2,
                                                     requires_grad=False)
678
679
680
681
682
683
684
        # If checkpoint is fp8, we need to handle that the
        # MoE kernels require single activation scale and single weight
        # scale for w13 per expert.
        else:
            # Fp8 moe kernels require a single activation scale.
            # We take the max of all the scales in case they differ.
            if self.quant_config.activation_scheme == "static":
685
686
                if (layer.w13_input_scale is None
                        or layer.w2_input_scale is None):
687
688
689
                    raise ValueError(
                        "QuantConfig has static quantization, but found "
                        "activation scales are None.")
690
691
                if (not all_close_1d(layer.w13_input_scale)
                        or not all_close_1d(layer.w2_input_scale)):
692
                    logger.warning_once(
693
694
                        "Found input_scales that are not equal for "
                        "fp8 MoE layer. Using the maximum across experts "
695
                        "for each layer.")
696
697
698
699
                layer.w13_input_scale = torch.nn.Parameter(
                    layer.w13_input_scale.max(), requires_grad=False)
                layer.w2_input_scale = torch.nn.Parameter(
                    layer.w2_input_scale.max(), requires_grad=False)
700
            if current_platform.is_fp8_fnuz():
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
                # Normalize the weights and scales
                w13_weight, w13_weight_scale, w13_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w13_weight, layer.w13_weight_scale,
                        layer.w13_input_scale)
                w2_weight, w2_weight_scale, w2_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w2_weight, layer.w2_weight_scale,
                        layer.w2_input_scale)
                # Reset the parameter
                layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                      requires_grad=False)
                layer.w13_weight_scale = torch.nn.Parameter(
                    w13_weight_scale, requires_grad=False)
                if w13_input_scale is not None:
                    layer.w13_input_scale = torch.nn.Parameter(
                        w13_input_scale, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                     requires_grad=False)
                layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
                                                           requires_grad=False)
                if w2_input_scale is not None:
                    layer.w2_input_scale = torch.nn.Parameter(
                        w2_input_scale, requires_grad=False)
725
726
727

            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max then dequant and requant each expert.
728
            assert layer.w13_weight_scale is not None
729
            shard_size = layer.intermediate_size_per_partition
730
            max_w13_scales = layer.w13_weight_scale.max(dim=1).values
731
            for expert_id in range(layer.local_num_experts):
732
733
734
735
736
                start = 0
                for shard_id in range(2):
                    dq_weight = per_tensor_dequantize(
                        layer.w13_weight[expert_id][start:start +
                                                    shard_size, :],
737
                        layer.w13_weight_scale[expert_id][shard_id])
738
                    layer.w13_weight[expert_id][
739
                        start:start + shard_size, :], _ = ops.scaled_fp8_quant(
740
741
742
                            dq_weight, max_w13_scales[expert_id])
                    start += shard_size

743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
            if is_rocm_aiter_moe_enabled():
                # reshaping weights is required for aiter moe kernel.
                expansion_dims = [
                    layer.w13_weight.shape[1], layer.w2_weight.shape[1]
                ]
                max_w13_scales, w2_scales = expand_weights(
                    max_w13_scales,
                    layer.w2_weight_scale.data,
                    expansion_dims=expansion_dims)
                layer.w2_weight_scale = torch.nn.Parameter(
                    w2_scales.contiguous(), requires_grad=False)

                shuffled_w13, shuffled_w2 = shuffle_weights(
                    layer.w13_weight, layer.w2_weight)

                layer.w13_weight = torch.nn.Parameter(shuffled_w13,
                                                      requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2,
                                                     requires_grad=False)

763
764
            layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
                                                        requires_grad=False)
765
766
767
768
769
770

        if self.use_marlin:
            prepare_moe_fp8_layer_for_marlin(layer, False)
            # Activations not quantized for marlin.
            del layer.w13_input_scale
            del layer.w2_input_scale
771

772
773
774
775
776
777
778
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
779
        use_grouped_topk: bool = False,
780
781
        topk_group: Optional[int] = None,
        num_expert_group: Optional[int] = None,
782
783
        global_num_experts: int = -1,
        expert_map: Optional[torch.Tensor] = None,
784
        custom_routing_function: Optional[Callable] = None,
Simon Mo's avatar
Simon Mo committed
785
786
        scoring_func: str = "softmax",
        e_score_correction_bias: Optional[torch.Tensor] = None,
787
        apply_router_weight_on_input: bool = False,
Michael Goin's avatar
Michael Goin committed
788
        activation: str = "silu",
789
    ) -> torch.Tensor:
790
791
792
793
794
795
796
797
798
        from vllm.model_executor.layers.fused_moe import fused_experts

        topk_weights, topk_ids = FusedMoE.select_experts(
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
799
            num_expert_group=num_expert_group,
Simon Mo's avatar
Simon Mo committed
800
801
802
803
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
            e_score_correction_bias=e_score_correction_bias,
        )
804

805
806
807
808
809
810
811
812
813
814
815
816
817
818
        if self.use_marlin:
            return torch.ops.vllm.fused_marlin_moe(
                x,
                layer.w13_weight,
                layer.w2_weight,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                quant_type_id=scalar_types.float8_e4m3fn.id,
                global_num_experts=global_num_experts,
                expert_map=expert_map)

819
820
821
822
823
824
825
        return fused_experts(
            x,
            layer.w13_weight,
            layer.w2_weight,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            inplace=True,
Michael Goin's avatar
Michael Goin committed
826
            activation=activation,
827
            use_fp8_w8a8=True,
828
            global_num_experts=global_num_experts,
829
            apply_router_weight_on_input=apply_router_weight_on_input,
830
            expert_map=expert_map,
831
832
833
834
835
836
837
            w1_scale=(layer.w13_weight_scale_inv
                      if self.block_quant else layer.w13_weight_scale),
            w2_scale=(layer.w2_weight_scale_inv
                      if self.block_quant else layer.w2_weight_scale),
            a1_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale,
            block_shape=self.quant_config.weight_block_size,
838
            allow_deep_gemm=self.allow_deep_gemm,
839
        )
840
841


842
843
844
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
845
846
847
    """

    def __init__(self, quant_config: Fp8Config):
848
        super().__init__(quant_config)