fp8.py 54.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
5
6

import torch
7
import torch.nn.functional as F
8
9
10
from torch.nn import Module
from torch.nn.parameter import Parameter

11
import vllm.envs as envs
12
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
13
from vllm import _custom_ops as ops
14
from vllm.distributed import get_tensor_model_parallel_world_size
15
from vllm.logger import init_logger
bnellnm's avatar
bnellnm committed
16
from vllm.model_executor.layers.fused_moe import (
17
    FusedMoE, FusedMoEActivationFormat, FusedMoEMethodBase,
18
19
    FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize,
    FusedMoeWeightScaleSupported)
20
21
from vllm.model_executor.layers.fused_moe.config import (
    FusedMoEQuantConfig, fp8_w8a8_moe_quant_config)
22
23
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
                                               UnquantizedLinearMethod)
24
from vllm.model_executor.layers.quantization import QuantizationMethods
25
from vllm.model_executor.layers.quantization.base_config import (
26
    QuantizationConfig, QuantizeMethodBase)
27
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
28
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
29
30
31
32
33
    FlashinferMoeBackend, apply_flashinfer_per_tensor_scale_fp8,
    build_flashinfer_fp8_cutlass_moe_prepare_finalize,
    flashinfer_cutlass_moe_fp8, get_flashinfer_moe_backend,
    register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights,
    select_cutlass_fp8_gemm_impl, swap_w13_to_w31)
34
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
35
36
    get_col_major_tma_aligned_tensor, requant_weight_ue8m0_inplace,
    should_use_deepgemm_for_fp8_linear)
37
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
38
39
    apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin,
    prepare_moe_fp8_layer_for_marlin)
40
from vllm.model_executor.layers.quantization.utils.quant_utils import (
41
    GroupShape, is_layer_skipped)
42
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
43
44
45
46
    Fp8LinearOp, all_close_1d, cutlass_block_fp8_supported,
    cutlass_fp8_supported, maybe_create_device_identity,
    normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
    requantize_with_max_scale)
47
48
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
                                           ModelWeightParameter,
49
                                           PerTensorScaleParameter)
50
from vllm.model_executor.utils import set_weight_attrs
51
from vllm.platforms import current_platform
52
from vllm.scalar_type import scalar_types
53
from vllm.utils import has_deep_gemm
54
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used, is_deep_gemm_supported
55
from vllm.utils.flashinfer import has_flashinfer_moe
56

57
58
59
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

60
61
62
63
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

64
65
66
67
68
69

def _is_col_major(x: torch.Tensor) -> bool:
    assert x.dim() == 3
    b, m, n = x.shape
    return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m

70

71
class Fp8Config(QuantizationConfig):
72
73
    """Config class for FP8."""

74
75
    def __init__(
        self,
76
        is_checkpoint_fp8_serialized: bool = False,
77
        activation_scheme: str = "dynamic",
78
79
        ignored_layers: Optional[list[str]] = None,
        weight_block_size: Optional[list[int]] = None,
80
    ) -> None:
81
        super().__init__()
82

83
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
84

85
86
87
        if activation_scheme not in ACTIVATION_SCHEMES:
            raise ValueError(
                f"Unsupported activation scheme {activation_scheme}")
88
        self.activation_scheme = activation_scheme
89
        self.ignored_layers = ignored_layers or []
90
91
92
93
94
95
96
97
98
99
100
101
102
103
        if weight_block_size is not None:
            if not is_checkpoint_fp8_serialized:
                raise ValueError(
                    "The block-wise quantization only supports fp8-serialized "
                    "checkpoint for now.")
            if len(weight_block_size) != 2:
                raise ValueError(
                    "The quantization block size of weight must have 2 "
                    f"dimensions, but got {len(weight_block_size)} dimensions")
            if activation_scheme != "dynamic":
                raise ValueError("The block-wise quantization only supports "
                                 "dynamic activation scheme for now, but got "
                                 f"{activation_scheme} activation scheme.")
        self.weight_block_size = weight_block_size
104

105
    @classmethod
106
    def get_name(cls) -> QuantizationMethods:
107
108
109
        return "fp8"

    @classmethod
110
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
111
112
113
114
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
115
        return 80
116
117

    @classmethod
118
    def get_config_filenames(cls) -> list[str]:
119
120
        return []

121
122
123
124
125
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.ignored_layers is not None:
            self.ignored_layers = hf_to_vllm_mapper.apply_list(
                self.ignored_layers)

126
    @classmethod
127
    def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
128
129
        quant_method = cls.get_from_keys(config, ["quant_method"])
        is_checkpoint_fp8_serialized = ("fp8" in quant_method)
130
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
131
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
132
133
        weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"],
                                                 None)
134
135
136
137
        if not ignored_layers:
            ignored_layers = cls.get_from_keys_or(config,
                                                  ["modules_to_not_convert"],
                                                  None)
138
        return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
139
                   activation_scheme=activation_scheme,
140
141
                   ignored_layers=ignored_layers,
                   weight_block_size=weight_block_size)
142

143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
    def get_xpu_quant_method(self, layer: torch.nn.Module,
                             prefix: str) -> Optional["QuantizeMethodBase"]:
        from vllm.attention.layer import Attention
        from vllm.model_executor.layers.quantization.ipex_quant import (
            XPUFp8LinearMethod, XPUFp8MoEMethod)
        fp8_config = Fp8Config(
            is_checkpoint_fp8_serialized=self.is_checkpoint_fp8_serialized,
            activation_scheme=self.activation_scheme,
            ignored_layers=self.ignored_layers,
            weight_block_size=self.weight_block_size)

        if isinstance(layer, LinearBase):
            if is_layer_skipped(prefix=prefix,
                                ignored_layers=self.ignored_layers,
                                fused_mapping=self.packed_modules_mapping):
                return UnquantizedLinearMethod()
            return XPUFp8LinearMethod(fp8_config)
        elif isinstance(layer, FusedMoE):
            return XPUFp8MoEMethod(fp8_config, layer)
        elif isinstance(layer, Attention):
            return Fp8KVCacheMethod(self)
        return None

166
167
    def get_quant_method(self, layer: torch.nn.Module,
                         prefix: str) -> Optional["QuantizeMethodBase"]:
168
169
        from vllm.attention.layer import Attention  # Avoid circular import

170
171
        if current_platform.is_xpu():
            return self.get_xpu_quant_method(layer, prefix)
172
        if isinstance(layer, LinearBase):
173
174
175
            if is_layer_skipped(prefix=prefix,
                                ignored_layers=self.ignored_layers,
                                fused_mapping=self.packed_modules_mapping):
176
                return UnquantizedLinearMethod()
177
            return Fp8LinearMethod(self)
178
        elif isinstance(layer, FusedMoE):
179
            return Fp8MoEMethod(self, layer)
180
        elif isinstance(layer, Attention):
181
            return Fp8KVCacheMethod(self)
182
        return None
183

184
185
186
187
188
189
190
191
192
193
194
195
196
    def get_cache_scale(self, name: str) -> Optional[str]:
        """
        Check whether the param name matches the format for k/v cache scales
        in compressed-tensors. If this is the case, return its equivalent
        param name expected by vLLM

        :param name: param name
        :return: matching param name for KV cache scale in vLLM
        """
        if name.endswith(".output_scale") and ".k_proj" in name:
            return name.replace(".k_proj.output_scale", ".attn.k_scale")
        if name.endswith(".output_scale") and ".v_proj" in name:
            return name.replace(".v_proj.output_scale", ".attn.v_scale")
197
198
199
200
201
        if name.endswith(".output_scale") and ".q_proj" in name:
            return name.replace(".q_proj.output_scale", ".attn.q_scale")
        if name.endswith("self_attn.prob_output_scale"):
            return name.replace(".prob_output_scale", ".attn.prob_scale")
        # If no matches, return None
202
203
        return None

204
205
206

class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
207
208
209
210
211
212
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.
213
214
215
216
217

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
218

219
220
221
222
    Args:
        quant_config: The quantization config.
    """

223
    def __init__(self, quant_config: Fp8Config):
224
        self.quant_config = quant_config
225
        self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
226
        self.out_dtype = torch.get_default_dtype()
227

228
229
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
230
231
        self.use_marlin = (not current_platform.has_device_capability(89)
                           or envs.VLLM_TEST_FORCE_FP8_MARLIN)
232
        # Disable marlin for rocm
233
        if current_platform.is_rocm():
234
            self.use_marlin = False
235

236
237
238
239
240
241
242
        # AITER is only supported on ROCm and only for FP8_FNUZ
        # and at the moment are MI300 series
        self.use_aiter_and_is_supported = (current_platform.is_rocm()
                                           and envs.VLLM_ROCM_USE_AITER
                                           and envs.VLLM_ROCM_USE_AITER_LINEAR
                                           and current_platform.is_fp8_fnuz())

243
        self.block_quant = self.quant_config.weight_block_size is not None
244
245
246
247
248
249
250
        self.act_q_static = self.quant_config.activation_scheme == "static"
        # Use per-token quantization for better perf if dynamic and cutlass
        if not self.act_q_static and cutlass_fp8_supported():
            self.act_q_group_shape = GroupShape.PER_TOKEN
        else:
            self.act_q_group_shape = GroupShape.PER_TENSOR

251
        self.fp8_linear = Fp8LinearOp(
252
            act_quant_static=self.act_q_static,
253
            act_quant_group_shape=self.act_q_group_shape)
254

255
256
257
258
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
259
        output_partition_sizes: list[int],
260
261
262
263
264
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
265
266
        maybe_create_device_identity()

267
        output_size_per_partition = sum(output_partition_sizes)
268
        weight_loader = extra_weight_attrs.get("weight_loader")
269
270
271
272
273
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None
274

275
        if self.block_quant:
276
277
            tp_size = getattr(layer, "tp_size",
                              get_tensor_model_parallel_world_size())
278
            assert self.quant_config.weight_block_size is not None
279
            layer.weight_block_size = self.quant_config.weight_block_size
280
281
282
283
284
285
286
287
288
289
290
291
292
            block_n, block_k = (
                self.quant_config.weight_block_size[0],
                self.quant_config.weight_block_size[1],
            )
            # Required by row parallel
            if (tp_size > 1
                    and input_size // input_size_per_partition == tp_size
                    and input_size_per_partition % block_k != 0):
                raise ValueError(
                    f"Weight input_size_per_partition = "
                    f"{input_size_per_partition} is not divisible by "
                    f"weight quantization block_k = {block_k}.")
            # Required by column parallel or enabling merged weights
293
294
295
296
297
298
299
300
301
302
            is_tp_split = (tp_size > 1 and
                           output_size // output_size_per_partition == tp_size)
            is_merged_gemm = len(output_partition_sizes) > 1
            if is_tp_split or is_merged_gemm:
                sizes_to_check = output_partition_sizes
                if not is_tp_split and is_merged_gemm:
                    # In case of merged matrices, we allow the last
                    # matrix to not be a multiple of block size
                    sizes_to_check = output_partition_sizes[:-1]
                for output_partition_size in sizes_to_check:
303
304
305
306
307
308
                    if output_partition_size % block_n != 0:
                        raise ValueError(
                            f"Weight output_partition_size = "
                            f"{output_partition_size} is not divisible by "
                            f"weight quantization block_n = {block_n}.")

309
310
311
312
        # WEIGHT
        weight_dtype = (torch.float8_e4m3fn
                        if self.quant_config.is_checkpoint_fp8_serialized else
                        params_dtype)
313
314
315
316
317
318
319
320

        weight = ModelWeightParameter(data=torch.empty(
            output_size_per_partition,
            input_size_per_partition,
            dtype=weight_dtype),
                                      input_dim=1,
                                      output_dim=0,
                                      weight_loader=weight_loader)
321
322
        layer.register_parameter("weight", weight)

323
324
325
326
        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
327
328
329
330
331
332
333
            if not self.block_quant:
                scale = PerTensorScaleParameter(
                    data=torch.empty(len(output_partition_sizes),
                                     dtype=torch.float32),
                    weight_loader=weight_loader,
                )
                scale[:] = torch.finfo(torch.float32).min
334
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
335
336
337
338
339
340
341
342
343
344
345
346
347
348
                layer.register_parameter("weight_scale", scale)
            else:
                assert self.quant_config.activation_scheme == "dynamic"
                scale = BlockQuantScaleParameter(
                    data=torch.empty(
                        (output_size_per_partition + block_n - 1) // block_n,
                        (input_size_per_partition + block_k - 1) // block_k,
                        dtype=torch.float32,
                    ),
                    input_dim=1,
                    output_dim=0,
                    weight_loader=weight_loader,
                )
                scale[:] = torch.finfo(torch.float32).min
349
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
350
351
                # The weight_scale_inv name is intentional for deepseekv3
                layer.register_parameter("weight_scale_inv", scale)
352

353
            # INPUT ACTIVATION SCALE
354
            if self.quant_config.activation_scheme == "static":
355
356
357
358
359
                scale = PerTensorScaleParameter(data=torch.empty(
                    len(output_partition_sizes), dtype=torch.float32),
                                                weight_loader=weight_loader)

                scale[:] = torch.finfo(torch.float32).min
360
                set_weight_attrs(scale, {"scale_type": "input_scale"})
361
                layer.register_parameter("input_scale", scale)
362
363
            else:
                layer.register_parameter("input_scale", None)
364

365
    def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
366
367
368
369
370
371
372
373
374
375
        # Pad the weight tensor. This is an optimization on ROCm platform, which
        # can benefit from tensors located far enough from one another in memory
        if (envs.VLLM_ROCM_FP8_PADDING and current_platform.is_rocm()
                and weight.stride(-1) == 1
                and (weight.stride(-2) * weight.element_size()) % 512 == 0):
            num_pad = 256 // weight.element_size()
            weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
            torch.cuda.empty_cache()
        return weight

376
    def process_weights_after_loading(self, layer: Module) -> None:
377
        size_k_first = True
378
        # TODO(rob): refactor block quant into separate class.
379
        if self.block_quant:
380
            assert self.quant_config.activation_scheme == "dynamic"
381
            size_k_first = False
382
            if current_platform.is_fp8_fnuz():
383
                weight, weight_scale_inv, _ = \
384
385
                    normalize_e4m3fn_to_e4m3fnuz(
                        weight=layer.weight,
386
387
388
389
390
                        weight_scale=layer.weight_scale_inv)
            else:
                weight = layer.weight.data
                weight_scale_inv = layer.weight_scale_inv.data

391
            weight = self._maybe_pad_weight(weight)
392

393
394
395
396
397
            # Torch.compile cannot use Parameter subclasses.
            layer.weight = Parameter(weight, requires_grad=False)
            layer.weight_scale_inv = Parameter(weight_scale_inv,
                                               requires_grad=False)

398
        # If checkpoint not serialized fp8, quantize the weights.
399
        elif not self.quant_config.is_checkpoint_fp8_serialized:
400
401
            qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
                                                         scale=None)
402
403

            # Update the layer with the new values.
404
405
            layer.weight = Parameter(qweight.t(), requires_grad=False)
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
406
407
            # layer.input_scale is None indicates dynamic quant and scale is
            # computed from input.
408
            layer.input_scale = None
409

410
411
        # If checkpoint is fp8, handle that there are N scales for N
        # shards in a fused module
412
        else:
413
414
415
416
417
            layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
                                                    requires_grad=False)
            if self.quant_config.activation_scheme == "static":
                layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
                                                       requires_grad=False)
418
419
420

            weight = layer.weight
            weight_scale = layer.weight_scale
421
422
423

            # If using w8a8, torch._scaled_mm needs per tensor, so
            # requantize the logical shards as a single weight.
424
            if not self.use_marlin:
425
                # Dequant -> Quant with max scale so we can run per tensor.
426
                if current_platform.is_fp8_fnuz():
427
428
429
430
431
432
433
434
435
                    weight, weight_scale, input_scale = \
                        normalize_e4m3fn_to_e4m3fnuz(
                            weight=weight,
                            weight_scale=weight_scale,
                            input_scale=layer.input_scale)
                    if input_scale is not None:
                        layer.input_scale = Parameter(input_scale,
                                                      requires_grad=False)

436
                weight_scale, weight = requantize_with_max_scale(
437
438
                    weight=weight,
                    weight_scale=weight_scale,
439
440
                    logical_widths=layer.logical_widths,
                )
441

442
            weight = self._maybe_pad_weight(weight)
443
            # Update layer with new values.
444
            layer.weight = Parameter(weight.t(), requires_grad=False)
445
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
446
            if self.quant_config.activation_scheme == "static":
447
448
                layer.input_scale = Parameter(layer.input_scale.max(),
                                              requires_grad=False)
449

450
        if self.use_marlin:
451
            prepare_fp8_layer_for_marlin(layer, size_k_first)
452
453
            # Activations not quantized for marlin.
            del layer.input_scale
454

455
        # On Blackwell or Hopper, if E8M0 for DeepGemm is used, we need to
456
457
        # requantize the weight and input to the specific scale
        # at the same time.
458
        if is_deep_gemm_e8m0_used() and self.block_quant:
459
460
461
462
463
464
465
466
467
            assert layer.weight_block_size is not None
            block_sz = tuple(layer.weight_block_size)
            requant_weight_ue8m0_inplace(
                layer.weight.data,
                layer.weight_scale_inv.data if hasattr(
                    layer, "weight_scale_inv") else layer.weight_scale.data,
                block_sz,
            )

468
469
470
471
472
473
474
475
476
        # SM90 Block FP8 CUTLASS requires row-major weight scales
        if (self.block_quant and current_platform.is_device_capability(90)
                and self.cutlass_block_fp8_supported
                and not should_use_deepgemm_for_fp8_linear(
                    torch.bfloat16, layer.weight)):
            layer.weight_scale_inv = Parameter(
                layer.weight_scale_inv.data.T.contiguous(),
                requires_grad=False)

477
478
479
480
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
481

482
        if self.use_marlin:
483
484
485
486
            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
487
488
489
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
490
                bias=bias)
491

492
493
        if self.block_quant:
            assert self.quant_config.weight_block_size is not None
494

495
            return torch.ops.vllm.apply_w8a8_block_fp8_linear(
496
497
498
499
500
501
                input=x,
                weight=layer.weight,
                block_size=self.quant_config.weight_block_size,
                weight_scale=layer.weight_scale_inv,
                input_scale=layer.input_scale,
                bias=bias,
502
                cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
503
                use_aiter_and_is_supported=self.use_aiter_and_is_supported,
504
505
            )

506
507
508
        return self.fp8_linear.apply(input=x,
                                     weight=layer.weight,
                                     weight_scale=layer.weight_scale,
509
                                     out_dtype=self.out_dtype,
510
511
                                     input_scale=layer.input_scale,
                                     bias=bias)
512
513


514
515
516
517
518
519
520
521
522
523
524
525
526
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

527
528
529
    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(layer.moe_config)
        self.layer = layer
530
        self.quant_config = quant_config
531
        self.block_quant = self.quant_config.weight_block_size is not None
532

533
534
535
        self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None
        self.fused_experts: Optional[
            mk.FusedMoEModularKernel] = None  # type: ignore
536
        if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
537
            self.flashinfer_moe_backend = get_flashinfer_moe_backend()
538
            logger.info_once(
539
540
                f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
            )
541
542
543
544
545
546
547
548
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
        self.use_marlin = (not current_platform.has_device_capability(89)
                           or envs.VLLM_TEST_FORCE_FP8_MARLIN)
        # Disable marlin for rocm
        if current_platform.is_rocm():
            self.use_marlin = False

549
550
551
        # Check for DeepGemm support.
        self.allow_deep_gemm = False
        if envs.VLLM_USE_DEEP_GEMM:
552
            if not has_deep_gemm():
553
                logger.warning_once("Failed to import DeepGemm kernels.")
554
555
            elif not self.block_quant:
                logger.warning_once("Model is not block quantized. Not using "
556
                                    "DeepGemm kernels")
557
            elif (is_deep_gemm_supported()):
558
559
560
561
562
563
                logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.")
                self.allow_deep_gemm = True
            else:
                logger.warning_once(
                    "DeepGemm not supported on the current platform.")

564
565
566
        # Check for CutlassBlockScaledGroupedGemm support.
        self.allow_cutlass_block_scaled_grouped_gemm = False
        if not self.block_quant:
567
568
            logger.debug_once("Model is not block quantized. Not using "
                              "CutlassBlockScaledGroupedGemm kernels")
569
        elif (current_platform.is_cuda()
570
              and current_platform.is_device_capability(100)):
571
572
573
574
575
576
577
578
579
            logger.info_once(
                "Using CutlassBlockScaledGroupedGemm kernels for Fp8MoEMethod."
            )
            self.allow_cutlass_block_scaled_grouped_gemm = True
        else:
            logger.warning_once(
                "CutlassBlockScaledGroupedGemm not supported on the current "
                "platform.")

580
    def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
581
582
                       intermediate_size_per_partition: int,
                       params_dtype: torch.dtype, **extra_weight_attrs):
583

584
585
586
587
588
589
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

590
591
        if self.quant_config.is_checkpoint_fp8_serialized:
            params_dtype = torch.float8_e4m3fn
592
593
        if self.block_quant:
            assert self.quant_config.weight_block_size is not None
594
            layer.weight_block_size = self.quant_config.weight_block_size
595
596
597
598
599
600
601
602
603
            tp_size = get_tensor_model_parallel_world_size()
            block_n, block_k = (
                self.quant_config.weight_block_size[0],
                self.quant_config.weight_block_size[1],
            )
            # NOTE: To ensure proper alignment of the block-wise quantization
            # scales, the output_size of the weights for both the gate and up
            # layers must be divisible by block_n.
            # Required by column parallel or enabling merged weights
604
            if intermediate_size_per_partition % block_n != 0:
605
606
                raise ValueError(
                    f"The output_size of gate's and up's weight = "
607
                    f"{intermediate_size_per_partition} is not divisible by "
608
                    f"weight quantization block_n = {block_n}.")
609
610
            if (tp_size > 1
                    and intermediate_size_per_partition % block_k != 0):
611
                # Required by row parallel
612
613
614
615
                raise ValueError(
                    f"The input_size of down's weight = "
                    f"{intermediate_size_per_partition} is not divisible by "
                    f"weight quantization block_k = {block_k}.")
616
617

        # WEIGHTS
618
619
620
621
622
        w13_weight = torch.nn.Parameter(torch.empty(
            num_experts,
            2 * intermediate_size_per_partition,
            hidden_size,
            dtype=params_dtype),
623
624
625
626
                                        requires_grad=False)
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

627
628
629
630
631
        w2_weight = torch.nn.Parameter(torch.empty(
            num_experts,
            hidden_size,
            intermediate_size_per_partition,
            dtype=params_dtype),
632
633
634
635
636
                                       requires_grad=False)
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
        if not self.block_quant:
            # Allocate 2 scales for w1 and w3 respectively.
            # They will be combined to a single scale after weight loading.
            w13_weight_scale = torch.nn.Parameter(torch.ones(
                num_experts, 2, dtype=torch.float32),
                                                  requires_grad=False)
            w2_weight_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                 requires_grad=False)
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)
        else:
            w13_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
652
653
                    2 * ((intermediate_size_per_partition + block_n - 1) //
                         block_n),
654
655
656
657
658
659
660
661
662
                    (hidden_size + block_k - 1) // block_k,
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            w2_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
                    (hidden_size + block_n - 1) // block_n,
663
                    (intermediate_size_per_partition + block_k - 1) // block_k,
664
665
666
667
668
669
670
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
            layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
            assert self.quant_config.activation_scheme == "dynamic"
671

672
673
674
        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
675
676
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.
             value} if self.block_quant else
677
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
678
679
680
681
        # If loading fp8 checkpoint, pass the weight loaders.
        # If loading an fp16 checkpoint, do not (we will quantize in
        #   process_weights_after_loading()
        if self.quant_config.is_checkpoint_fp8_serialized:
682
683
            set_weight_attrs(w13_weight_scale, extra_weight_attrs)
            set_weight_attrs(w2_weight_scale, extra_weight_attrs)
684
685
686
687
688
689
690
691

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
            if not self.quant_config.is_checkpoint_fp8_serialized:
                raise ValueError(
                    "Found static activation scheme for checkpoint that "
                    "was not serialized fp8.")

692
693
694
695
            w13_input_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                 requires_grad=False)
            layer.register_parameter("w13_input_scale", w13_input_scale)
696
            set_weight_attrs(w13_input_scale, extra_weight_attrs)
697
698
699
700
701

            w2_input_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                requires_grad=False)
            layer.register_parameter("w2_input_scale", w2_input_scale)
702
703
            set_weight_attrs(w2_input_scale, extra_weight_attrs)

704
        else:
705
706
            layer.w13_input_scale = None
            layer.w2_input_scale = None
707
708

    def process_weights_after_loading(self, layer: Module) -> None:
709
710
        # Lazy import to avoid importing triton too early.
        from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
711
            is_rocm_aiter_moe_enabled, shuffle_weights)
712

713
714
        self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()

715
        # TODO (rob): refactor block quant into separate class.
716
        if self.block_quant:
717
            assert self.quant_config.activation_scheme == "dynamic"
718
            if current_platform.is_fp8_fnuz():
719
720
721
722
723
724
725
726
                w13_weight, w13_weight_scale_inv, w13_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w13_weight, layer.w13_weight_scale_inv,
                        layer.w13_input_scale)
                w2_weight, w2_weight_scale_inv, w2_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w2_weight, layer.w2_weight_scale_inv,
                        layer.w2_input_scale)
727
            elif self.flashinfer_moe_backend is not None:
728
729
                # NOTE: weights have to be swapped since the activation is
                # applied on different half for flashinfer vs vllm
730
731
                w13_weight = swap_w13_to_w31(layer.w13_weight.data)
                w13_weight_scale_inv = swap_w13_to_w31(
732
733
734
                    layer.w13_weight_scale_inv.data)
                w2_weight = layer.w2_weight.data
                w2_weight_scale_inv = layer.w2_weight_scale_inv.data
735
736
737
738
739
740
741
742
743
744
745
746
747
            else:
                w13_weight = layer.w13_weight.data
                w13_weight_scale_inv = layer.w13_weight_scale_inv.data
                w2_weight = layer.w2_weight
                w2_weight_scale_inv = layer.w2_weight_scale_inv

            # torch.compile() cannot use Parameter subclasses.
            layer.w13_weight = Parameter(w13_weight, requires_grad=False)
            layer.w13_weight_scale_inv = Parameter(w13_weight_scale_inv,
                                                   requires_grad=False)
            layer.w2_weight = Parameter(w2_weight, requires_grad=False)
            layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv,
                                                  requires_grad=False)
748
            if self.rocm_aiter_moe_enabled:
749
750
                # reshaping weights is required for aiter moe kernel.
                shuffled_w13, shuffled_w2 = shuffle_weights(
751
                    layer.w13_weight.data, layer.w2_weight.data)
752
753
754
755
756

                layer.w13_weight = torch.nn.Parameter(shuffled_w13,
                                                      requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2,
                                                     requires_grad=False)
757

758
            # DeepGemm scales need to be transposed and aligned. We try to do
759
            # it ahead of time for performance reasons.
760
            if self.allow_deep_gemm and not is_deep_gemm_e8m0_used():
761
762
                if _is_col_major(layer.w13_weight_scale_inv):
                    layer.w13_weight_scale_inv = \
763
                        get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv)
764
765
                if _is_col_major(layer.w2_weight_scale_inv):
                    layer.w2_weight_scale_inv = \
766
                        get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv)
767

768
        # If checkpoint is fp16, quantize in place.
769
        elif not self.quant_config.is_checkpoint_fp8_serialized:
770
            fp8_dtype = current_platform.fp8_dtype()
771
            w13_weight = torch.empty_like(layer.w13_weight.data,
772
773
                                          dtype=fp8_dtype)
            w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
774
775
776

            # Re-initialize w13_scale because we directly quantize
            # merged w13 weights and generate a single scaling factor.
777
            layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
778
                layer.local_num_experts,
779
780
                dtype=torch.float32,
                device=w13_weight.device),
781
                                                        requires_grad=False)
782
            for expert in range(layer.local_num_experts):
783
                w13_weight[expert, :, :], layer.w13_weight_scale[
784
785
                    expert] = ops.scaled_fp8_quant(
                        layer.w13_weight.data[expert, :, :])
786
                w2_weight[expert, :, :], layer.w2_weight_scale[
787
788
789
790
791
792
                    expert] = ops.scaled_fp8_quant(
                        layer.w2_weight.data[expert, :, :])
            layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                  requires_grad=False)
            layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                 requires_grad=False)
793
            if self.rocm_aiter_moe_enabled:
794
                # reshaping weights is required for aiter moe kernel.
795
796
                shuffled_w13, shuffled_w2 = shuffle_weights(
                    layer.w13_weight, layer.w2_weight)
797
798
799
800
801

                layer.w13_weight = torch.nn.Parameter(shuffled_w13,
                                                      requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2,
                                                     requires_grad=False)
802
803
804
805
806
807
808
        # If checkpoint is fp8, we need to handle that the
        # MoE kernels require single activation scale and single weight
        # scale for w13 per expert.
        else:
            # Fp8 moe kernels require a single activation scale.
            # We take the max of all the scales in case they differ.
            if self.quant_config.activation_scheme == "static":
809
810
                if (layer.w13_input_scale is None
                        or layer.w2_input_scale is None):
811
812
813
                    raise ValueError(
                        "QuantConfig has static quantization, but found "
                        "activation scales are None.")
814
815
                if (not all_close_1d(layer.w13_input_scale)
                        or not all_close_1d(layer.w2_input_scale)):
816
                    logger.warning_once(
817
818
                        "Found input_scales that are not equal for "
                        "fp8 MoE layer. Using the maximum across experts "
819
                        "for each layer.")
820
821
822
823
                layer.w13_input_scale = torch.nn.Parameter(
                    layer.w13_input_scale.max(), requires_grad=False)
                layer.w2_input_scale = torch.nn.Parameter(
                    layer.w2_input_scale.max(), requires_grad=False)
824
            if current_platform.is_fp8_fnuz():
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
                # Normalize the weights and scales
                w13_weight, w13_weight_scale, w13_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w13_weight, layer.w13_weight_scale,
                        layer.w13_input_scale)
                w2_weight, w2_weight_scale, w2_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w2_weight, layer.w2_weight_scale,
                        layer.w2_input_scale)
                # Reset the parameter
                layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                      requires_grad=False)
                layer.w13_weight_scale = torch.nn.Parameter(
                    w13_weight_scale, requires_grad=False)
                if w13_input_scale is not None:
                    layer.w13_input_scale = torch.nn.Parameter(
                        w13_input_scale, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                     requires_grad=False)
                layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
                                                           requires_grad=False)
                if w2_input_scale is not None:
                    layer.w2_input_scale = torch.nn.Parameter(
                        w2_input_scale, requires_grad=False)
849
850
851

            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max then dequant and requant each expert.
852
            assert layer.w13_weight_scale is not None
853
            shard_size = layer.intermediate_size_per_partition
854
            max_w13_scales = layer.w13_weight_scale.max(dim=1).values
855
            for expert_id in range(layer.local_num_experts):
856
857
858
859
860
                start = 0
                for shard_id in range(2):
                    dq_weight = per_tensor_dequantize(
                        layer.w13_weight[expert_id][start:start +
                                                    shard_size, :],
861
                        layer.w13_weight_scale[expert_id][shard_id])
862
                    layer.w13_weight[expert_id][
863
                        start:start + shard_size, :], _ = ops.scaled_fp8_quant(
864
865
866
                            dq_weight, max_w13_scales[expert_id])
                    start += shard_size

867
            if self.rocm_aiter_moe_enabled:
868
869
                shuffled_w13, shuffled_w2 = shuffle_weights(
                    layer.w13_weight, layer.w2_weight)
870
871
872
873
874
875

                layer.w13_weight = torch.nn.Parameter(shuffled_w13,
                                                      requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2,
                                                     requires_grad=False)

876
877
            layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
                                                        requires_grad=False)
878

879
880
881
882
883
884
885
886
887
888
889
            if self.flashinfer_moe_backend is not None:
                # NOTE: weights have to be swapped since the activation is
                # applied on different half for flashinfer vs vllm
                assert not self.block_quant
                register_moe_scaling_factors(layer)
                w13_weight = swap_w13_to_w31(layer.w13_weight.data)
                if self.flashinfer_moe_backend == \
                    FlashinferMoeBackend.TENSORRT_LLM:
                    rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
                layer.w13_weight.data = w13_weight.data

890
891
892
893
894
        if self.use_marlin:
            prepare_moe_fp8_layer_for_marlin(layer, False)
            # Activations not quantized for marlin.
            del layer.w13_input_scale
            del layer.w2_input_scale
895

896
        if is_deep_gemm_e8m0_used() and self.block_quant:
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
            assert layer.weight_block_size is not None
            # Re-quantise the expert weights so their scales are UE8M0.
            block_sz = tuple(layer.weight_block_size)
            requant_weight_ue8m0_inplace(
                layer.w13_weight.data,
                layer.w13_weight_scale_inv.data,
                block_sz,
            )
            requant_weight_ue8m0_inplace(
                layer.w2_weight.data,
                layer.w2_weight_scale_inv.data,
                block_sz,
            )

            # Ensure column-major TMA alignment expected by DeepGEMM.
            if _is_col_major(layer.w13_weight_scale_inv):
                layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(
914
                    layer.w13_weight_scale_inv)
915
916
            if _is_col_major(layer.w2_weight_scale_inv):
                layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
917
                    layer.w2_weight_scale_inv)
918

919
920
921
922
923
924
925
926
927
928
929
930
931
932
    def maybe_make_prepare_finalize(
            self) -> Optional[mk.FusedMoEPrepareAndFinalize]:
        if (self.rocm_aiter_moe_enabled or self.use_marlin
                or self.flashinfer_moe_backend
                == FlashinferMoeBackend.TENSORRT_LLM):
            return None
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
            prepare_finalize = (
                build_flashinfer_fp8_cutlass_moe_prepare_finalize(self.moe))
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize
        else:
            return super().maybe_make_prepare_finalize()

bnellnm's avatar
bnellnm committed
933
934
935
    def select_gemm_impl(
        self,
        prepare_finalize: FusedMoEPrepareAndFinalize,
936
        layer: torch.nn.Module,
bnellnm's avatar
bnellnm committed
937
    ) -> FusedMoEPermuteExpertsUnpermute:
938
939
940
        from vllm.model_executor.layers.fused_moe import (
            BatchedTritonOrDeepGemmExperts, TritonOrDeepGemmExperts)

941
942
        assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
            "Marlin and ROCm AITER are not supported with all2all yet.")
943

944
945
        assert self.moe_quant_config is not None

bnellnm's avatar
bnellnm committed
946
947
948
949
950
951
952
953
954
955
956
        if (prepare_finalize.activation_format ==
                FusedMoEActivationFormat.BatchedExperts):
            max_num_tokens_per_rank = (
                prepare_finalize.max_num_tokens_per_rank())
            assert max_num_tokens_per_rank is not None
            logger.debug(
                "BatchedTritonOrDeepGemmExperts(%s): "
                "max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
                self.__class__.__name__, max_num_tokens_per_rank,
                self.quant_config.weight_block_size, False)
            return BatchedTritonOrDeepGemmExperts(
957
                max_num_tokens=max_num_tokens_per_rank,
958
                num_dispatchers=prepare_finalize.num_dispatchers(),
959
                quant_config=self.moe_quant_config,
960
                allow_deep_gemm=self.allow_deep_gemm,
961
            )
962
963
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
            experts = select_cutlass_fp8_gemm_impl(
964
965
                self.moe,
                self.moe_quant_config,
966
967
968
            )
            logger.debug_once("Using %s", experts.__class__.__name__)
            return experts
969
        else:
bnellnm's avatar
bnellnm committed
970
971
972
973
974
            logger.debug(
                "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
                self.__class__.__name__, self.quant_config.weight_block_size,
                False)
            return TritonOrDeepGemmExperts(
975
                quant_config=self.moe_quant_config,
976
977
978
                allow_deep_gemm=self.allow_deep_gemm,
            )

979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
    def get_fused_moe_quant_config(
            self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
        if self.use_marlin:
            return None

        return fp8_w8a8_moe_quant_config(
            w1_scale=(layer.w13_weight_scale_inv
                      if self.block_quant else layer.w13_weight_scale),
            w2_scale=(layer.w2_weight_scale_inv
                      if self.block_quant else layer.w2_weight_scale),
            a1_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale,
            block_shape=self.quant_config.weight_block_size,
        )

994
995
996
997
998
999
1000
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
1001
        use_grouped_topk: bool = False,
1002
1003
        topk_group: Optional[int] = None,
        num_expert_group: Optional[int] = None,
1004
1005
        global_num_experts: int = -1,
        expert_map: Optional[torch.Tensor] = None,
1006
        custom_routing_function: Optional[Callable] = None,
Simon Mo's avatar
Simon Mo committed
1007
        scoring_func: str = "softmax",
1008
        routed_scaling_factor: float = 1.0,
Simon Mo's avatar
Simon Mo committed
1009
        e_score_correction_bias: Optional[torch.Tensor] = None,
1010
        apply_router_weight_on_input: bool = False,
Michael Goin's avatar
Michael Goin committed
1011
        activation: str = "silu",
1012
1013
1014
1015
        enable_eplb: bool = False,
        expert_load_view: Optional[torch.Tensor] = None,
        logical_to_physical_map: Optional[torch.Tensor] = None,
        logical_replica_count: Optional[torch.Tensor] = None,
1016
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
1017
1018
1019
1020
1021
        if enable_eplb:
            assert expert_load_view is not None
            assert logical_to_physical_map is not None
            assert logical_replica_count is not None
            assert isinstance(layer, FusedMoE)
1022

1023
1024
        if (self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
                and self.fused_experts is None):
1025
1026
1027
1028
            assert activation == 'silu', (
                f"Expected 'silu' activation but got {activation}")
            assert scoring_func == 'sigmoid', (
                f"Expected 'sigmoid' scoring func but got {scoring_func}")
1029
            if self.block_quant:
1030
                import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe  # noqa: E501, F401
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
                assert (renormalize and use_grouped_topk
                        and custom_routing_function is None)

                return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
                    routing_logits=router_logits.to(torch.float32),
                    routing_bias=e_score_correction_bias,
                    x=x,
                    w13_weight=layer.w13_weight,
                    w13_weight_scale_inv=layer.w13_weight_scale_inv,
                    w2_weight=layer.w2_weight,
                    w2_weight_scale_inv=layer.w2_weight_scale_inv,
                    global_num_experts=global_num_experts,
                    top_k=top_k,
                    num_expert_group=num_expert_group,
                    topk_group=topk_group,
                    intermediate_size=layer.intermediate_size_per_partition,
                    expert_offset=layer.ep_rank * layer.local_num_experts,
                    local_num_experts=layer.local_num_experts,
                    block_shape=self.quant_config.weight_block_size,
1050
                    routed_scaling=routed_scaling_factor,
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
                )
            else:
                assert (not renormalize
                        and custom_routing_function is not None)
                return apply_flashinfer_per_tensor_scale_fp8(
                    layer=layer,
                    hidden_states=x,
                    router_logits=router_logits,
                    routing_bias=e_score_correction_bias,
                    global_num_experts=global_num_experts,
                    top_k=top_k,
                    num_expert_group=num_expert_group,
                    topk_group=topk_group,
                    apply_router_weight_on_input=apply_router_weight_on_input)
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075

        topk_weights, topk_ids = FusedMoE.select_experts(
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
1076
            routed_scaling_factor=routed_scaling_factor,
1077
1078
1079
1080
1081
1082
1083
1084
1085
            e_score_correction_bias=e_score_correction_bias,
            indices_type=self.topk_indices_dtype,
            enable_eplb=enable_eplb,
            expert_map=expert_map,
            expert_load_view=expert_load_view,
            logical_to_physical_map=logical_to_physical_map,
            logical_replica_count=logical_replica_count,
        )

1086
1087
1088
1089
        #
        # Note: the order of checks is important since self.fused_experts
        # can override fused_experts or cutlass but not rocm or marlin.
        #
1090
1091
1092
        if self.rocm_aiter_moe_enabled:
            from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (  # noqa: E501
                rocm_aiter_fused_experts)
1093
            assert self.fused_experts is None
1094
1095
1096
1097
            return rocm_aiter_fused_experts(
                x,
                layer.w13_weight,
                layer.w2_weight,
1098
1099
1100
1101
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                activation=activation,
                apply_router_weight_on_input=apply_router_weight_on_input,
1102
1103
                expert_map=expert_map,
                quant_config=self.moe_quant_config)
1104
1105
1106
        elif self.use_marlin:
            assert activation == "silu", (
                f"{activation} not supported for Marlin MoE.")
1107
            assert self.fused_experts is None
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
            return torch.ops.vllm.fused_marlin_moe(
                x,
                layer.w13_weight,
                layer.w2_weight,
                None,
                None,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                quant_type_id=scalar_types.float8_e4m3fn.id,
                apply_router_weight_on_input=apply_router_weight_on_input,
                global_num_experts=global_num_experts,
1122
1123
                expert_map=expert_map,
                workspace=layer.workspace)
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
        elif self.fused_experts:
            return self.fused_experts(
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=True,
                activation=activation,
                global_num_experts=global_num_experts,
                apply_router_weight_on_input=apply_router_weight_on_input,
                expert_map=expert_map,
            )
1137
1138
1139
1140
1141
1142
1143
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
            assert self.block_quant is None
            assert (not renormalize and custom_routing_function is not None)
            assert activation == 'silu', (
                f"Expected 'silu' activation but got {activation}")
            assert scoring_func == 'sigmoid', (
                f"Expected 'sigmoid' scoring func but got {scoring_func}")
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155

            return flashinfer_cutlass_moe_fp8(
                x,
                layer,
                topk_weights,
                topk_ids,
                inplace=False,
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
1156
        else:
1157
1158
            from vllm.model_executor.layers.fused_moe import fused_experts
            return fused_experts(
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=True,
                activation=activation,
                global_num_experts=global_num_experts,
                apply_router_weight_on_input=apply_router_weight_on_input,
                expert_map=expert_map,
1169
1170
1171
1172
                quant_config=self.moe_quant_config,
                allow_deep_gemm=self.allow_deep_gemm,
                allow_cutlass_block_scaled_grouped_gemm=(
                    self.allow_cutlass_block_scaled_grouped_gemm))
1173
1174


1175
1176
1177
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
1178
1179
1180
    """

    def __init__(self, quant_config: Fp8Config):
1181
        super().__init__(quant_config)