fp8.py 49.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

bnellnm's avatar
bnellnm committed
4
from typing import TYPE_CHECKING, Any, Callable, Optional
5
6

import torch
7
import torch.nn.functional as F
8
9
10
from torch.nn import Module
from torch.nn.parameter import Parameter

11
import vllm.envs as envs
12
from vllm import _custom_ops as ops
13
from vllm.distributed import get_tensor_model_parallel_world_size
14
from vllm.logger import init_logger
bnellnm's avatar
bnellnm committed
15
from vllm.model_executor.layers.fused_moe import (
16
17
18
    FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase,
    FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize,
    FusedMoeWeightScaleSupported)
19
20
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
                                               UnquantizedLinearMethod)
21
from vllm.model_executor.layers.quantization import QuantizationMethods
22
from vllm.model_executor.layers.quantization.base_config import (
23
    QuantizationConfig, QuantizeMethodBase)
24
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
25
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
26
27
    apply_flashinfer_per_tensor_scale_fp8, register_moe_scaling_factors,
    rotate_flashinfer_fp8_moe_weights, swap_w13_to_w31)
28
29
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
    get_col_major_tma_aligned_tensor, requant_weight_ue8m0_inplace)
30
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
31
32
    apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin,
    prepare_moe_fp8_layer_for_marlin)
33
from vllm.model_executor.layers.quantization.utils.quant_utils import (
34
    GroupShape, is_layer_skipped)
35
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
36
37
38
39
    Fp8LinearOp, all_close_1d, cutlass_block_fp8_supported,
    cutlass_fp8_supported, maybe_create_device_identity,
    normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
    requantize_with_max_scale)
40
41
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
                                           ModelWeightParameter,
42
                                           PerTensorScaleParameter)
43
from vllm.model_executor.utils import set_weight_attrs
44
from vllm.platforms import current_platform
45
from vllm.scalar_type import scalar_types
46
from vllm.utils import has_deep_gemm
47
48
from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_e8m0_used,
                                  is_deep_gemm_supported)
49
from vllm.utils.flashinfer import has_flashinfer_moe
50

51
52
53
if TYPE_CHECKING:
    from vllm.model_executor.models.utils import WeightsMapper

54
55
56
57
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

58
59
60
61
62
63

def _is_col_major(x: torch.Tensor) -> bool:
    assert x.dim() == 3
    b, m, n = x.shape
    return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m

64

65
class Fp8Config(QuantizationConfig):
66
67
    """Config class for FP8."""

68
69
    def __init__(
        self,
70
        is_checkpoint_fp8_serialized: bool = False,
71
        activation_scheme: str = "dynamic",
72
73
        ignored_layers: Optional[list[str]] = None,
        weight_block_size: Optional[list[int]] = None,
74
    ) -> None:
75
        super().__init__()
76

77
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
78

79
80
81
        if activation_scheme not in ACTIVATION_SCHEMES:
            raise ValueError(
                f"Unsupported activation scheme {activation_scheme}")
82
        self.activation_scheme = activation_scheme
83
        self.ignored_layers = ignored_layers or []
84
85
86
87
88
89
90
91
92
93
94
95
96
97
        if weight_block_size is not None:
            if not is_checkpoint_fp8_serialized:
                raise ValueError(
                    "The block-wise quantization only supports fp8-serialized "
                    "checkpoint for now.")
            if len(weight_block_size) != 2:
                raise ValueError(
                    "The quantization block size of weight must have 2 "
                    f"dimensions, but got {len(weight_block_size)} dimensions")
            if activation_scheme != "dynamic":
                raise ValueError("The block-wise quantization only supports "
                                 "dynamic activation scheme for now, but got "
                                 f"{activation_scheme} activation scheme.")
        self.weight_block_size = weight_block_size
98

99
    @classmethod
100
    def get_name(cls) -> QuantizationMethods:
101
102
103
        return "fp8"

    @classmethod
104
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
105
106
107
108
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
109
        return 80
110
111

    @classmethod
112
    def get_config_filenames(cls) -> list[str]:
113
114
        return []

115
116
117
118
119
    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.ignored_layers is not None:
            self.ignored_layers = hf_to_vllm_mapper.apply_list(
                self.ignored_layers)

120
    @classmethod
121
    def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
122
123
        quant_method = cls.get_from_keys(config, ["quant_method"])
        is_checkpoint_fp8_serialized = ("fp8" in quant_method)
124
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
125
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
126
127
        weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"],
                                                 None)
128
129
130
131
        if not ignored_layers:
            ignored_layers = cls.get_from_keys_or(config,
                                                  ["modules_to_not_convert"],
                                                  None)
132
        return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
133
                   activation_scheme=activation_scheme,
134
135
                   ignored_layers=ignored_layers,
                   weight_block_size=weight_block_size)
136

137
138
    def get_quant_method(self, layer: torch.nn.Module,
                         prefix: str) -> Optional["QuantizeMethodBase"]:
139
140
        from vllm.attention.layer import Attention  # Avoid circular import

141
        if isinstance(layer, LinearBase):
142
143
144
            if is_layer_skipped(prefix=prefix,
                                ignored_layers=self.ignored_layers,
                                fused_mapping=self.packed_modules_mapping):
145
                return UnquantizedLinearMethod()
146
            return Fp8LinearMethod(self)
147
        elif isinstance(layer, FusedMoE):
148
            return Fp8MoEMethod(self, layer.moe_config)
149
        elif isinstance(layer, Attention):
150
            return Fp8KVCacheMethod(self)
151
        return None
152

153
154
155
156
157
158
159
160
161
162
163
164
165
    def get_cache_scale(self, name: str) -> Optional[str]:
        """
        Check whether the param name matches the format for k/v cache scales
        in compressed-tensors. If this is the case, return its equivalent
        param name expected by vLLM

        :param name: param name
        :return: matching param name for KV cache scale in vLLM
        """
        if name.endswith(".output_scale") and ".k_proj" in name:
            return name.replace(".k_proj.output_scale", ".attn.k_scale")
        if name.endswith(".output_scale") and ".v_proj" in name:
            return name.replace(".v_proj.output_scale", ".attn.v_scale")
166
167
168
169
170
        if name.endswith(".output_scale") and ".q_proj" in name:
            return name.replace(".q_proj.output_scale", ".attn.q_scale")
        if name.endswith("self_attn.prob_output_scale"):
            return name.replace(".prob_output_scale", ".attn.prob_scale")
        # If no matches, return None
171
172
        return None

173
174
175

class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
176
177
178
179
180
181
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.
182
183
184
185
186

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
187

188
189
190
191
    Args:
        quant_config: The quantization config.
    """

192
    def __init__(self, quant_config: Fp8Config):
193
        self.quant_config = quant_config
194
        self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
195
        self.out_dtype = torch.get_default_dtype()
196

197
198
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
199
200
        self.use_marlin = (not current_platform.has_device_capability(89)
                           or envs.VLLM_TEST_FORCE_FP8_MARLIN)
201
        # Disable marlin for rocm
202
        if current_platform.is_rocm():
203
            self.use_marlin = False
204

205
206
207
208
209
210
211
        # AITER is only supported on ROCm and only for FP8_FNUZ
        # and at the moment are MI300 series
        self.use_aiter_and_is_supported = (current_platform.is_rocm()
                                           and envs.VLLM_ROCM_USE_AITER
                                           and envs.VLLM_ROCM_USE_AITER_LINEAR
                                           and current_platform.is_fp8_fnuz())

212
        self.block_quant = self.quant_config.weight_block_size is not None
213
214
215
216
217
218
219
        self.act_q_static = self.quant_config.activation_scheme == "static"
        # Use per-token quantization for better perf if dynamic and cutlass
        if not self.act_q_static and cutlass_fp8_supported():
            self.act_q_group_shape = GroupShape.PER_TOKEN
        else:
            self.act_q_group_shape = GroupShape.PER_TENSOR

220
        self.fp8_linear = Fp8LinearOp(
221
222
223
            act_quant_static=self.act_q_static,
            act_quant_group_shape=self.act_q_group_shape,
            cutlass_fp8_supported=cutlass_fp8_supported())
224

225
226
227
228
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
229
        output_partition_sizes: list[int],
230
231
232
233
234
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
235
236
        maybe_create_device_identity()

237
        output_size_per_partition = sum(output_partition_sizes)
238
        weight_loader = extra_weight_attrs.get("weight_loader")
239
240
241
242
243
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None
244

245
246
247
        if self.block_quant:
            tp_size = get_tensor_model_parallel_world_size()
            assert self.quant_config.weight_block_size is not None
248
            layer.weight_block_size = self.quant_config.weight_block_size
249
250
251
252
253
254
255
256
257
258
259
260
261
            block_n, block_k = (
                self.quant_config.weight_block_size[0],
                self.quant_config.weight_block_size[1],
            )
            # Required by row parallel
            if (tp_size > 1
                    and input_size // input_size_per_partition == tp_size
                    and input_size_per_partition % block_k != 0):
                raise ValueError(
                    f"Weight input_size_per_partition = "
                    f"{input_size_per_partition} is not divisible by "
                    f"weight quantization block_k = {block_k}.")
            # Required by column parallel or enabling merged weights
262
263
264
265
266
267
268
269
270
271
            is_tp_split = (tp_size > 1 and
                           output_size // output_size_per_partition == tp_size)
            is_merged_gemm = len(output_partition_sizes) > 1
            if is_tp_split or is_merged_gemm:
                sizes_to_check = output_partition_sizes
                if not is_tp_split and is_merged_gemm:
                    # In case of merged matrices, we allow the last
                    # matrix to not be a multiple of block size
                    sizes_to_check = output_partition_sizes[:-1]
                for output_partition_size in sizes_to_check:
272
273
274
275
276
277
                    if output_partition_size % block_n != 0:
                        raise ValueError(
                            f"Weight output_partition_size = "
                            f"{output_partition_size} is not divisible by "
                            f"weight quantization block_n = {block_n}.")

278
279
280
281
        # WEIGHT
        weight_dtype = (torch.float8_e4m3fn
                        if self.quant_config.is_checkpoint_fp8_serialized else
                        params_dtype)
282
283
284
285
286
287
288
289

        weight = ModelWeightParameter(data=torch.empty(
            output_size_per_partition,
            input_size_per_partition,
            dtype=weight_dtype),
                                      input_dim=1,
                                      output_dim=0,
                                      weight_loader=weight_loader)
290
291
        layer.register_parameter("weight", weight)

292
293
294
295
        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
296
297
298
299
300
301
302
            if not self.block_quant:
                scale = PerTensorScaleParameter(
                    data=torch.empty(len(output_partition_sizes),
                                     dtype=torch.float32),
                    weight_loader=weight_loader,
                )
                scale[:] = torch.finfo(torch.float32).min
303
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
304
305
306
307
308
309
310
311
312
313
314
315
316
317
                layer.register_parameter("weight_scale", scale)
            else:
                assert self.quant_config.activation_scheme == "dynamic"
                scale = BlockQuantScaleParameter(
                    data=torch.empty(
                        (output_size_per_partition + block_n - 1) // block_n,
                        (input_size_per_partition + block_k - 1) // block_k,
                        dtype=torch.float32,
                    ),
                    input_dim=1,
                    output_dim=0,
                    weight_loader=weight_loader,
                )
                scale[:] = torch.finfo(torch.float32).min
318
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
319
320
                # The weight_scale_inv name is intentional for deepseekv3
                layer.register_parameter("weight_scale_inv", scale)
321

322
            # INPUT ACTIVATION SCALE
323
            if self.quant_config.activation_scheme == "static":
324
325
326
327
328
                scale = PerTensorScaleParameter(data=torch.empty(
                    len(output_partition_sizes), dtype=torch.float32),
                                                weight_loader=weight_loader)

                scale[:] = torch.finfo(torch.float32).min
329
                set_weight_attrs(scale, {"scale_type": "input_scale"})
330
                layer.register_parameter("input_scale", scale)
331
332
            else:
                layer.register_parameter("input_scale", None)
333

334
    def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
335
336
337
338
339
340
341
342
343
344
        # Pad the weight tensor. This is an optimization on ROCm platform, which
        # can benefit from tensors located far enough from one another in memory
        if (envs.VLLM_ROCM_FP8_PADDING and current_platform.is_rocm()
                and weight.stride(-1) == 1
                and (weight.stride(-2) * weight.element_size()) % 512 == 0):
            num_pad = 256 // weight.element_size()
            weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
            torch.cuda.empty_cache()
        return weight

345
    def process_weights_after_loading(self, layer: Module) -> None:
346
        size_k_first = True
347
        # TODO(rob): refactor block quant into separate class.
348
        if self.block_quant:
349
            assert self.quant_config.activation_scheme == "dynamic"
350
            size_k_first = False
351
            if current_platform.is_fp8_fnuz():
352
                weight, weight_scale_inv, _ = \
353
354
                    normalize_e4m3fn_to_e4m3fnuz(
                        weight=layer.weight,
355
356
357
358
359
                        weight_scale=layer.weight_scale_inv)
            else:
                weight = layer.weight.data
                weight_scale_inv = layer.weight_scale_inv.data

360
            weight = self._maybe_pad_weight(weight)
361

362
363
364
365
366
            # Torch.compile cannot use Parameter subclasses.
            layer.weight = Parameter(weight, requires_grad=False)
            layer.weight_scale_inv = Parameter(weight_scale_inv,
                                               requires_grad=False)

367
        # If checkpoint not serialized fp8, quantize the weights.
368
        elif not self.quant_config.is_checkpoint_fp8_serialized:
369
370
            qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
                                                         scale=None)
371
372

            # Update the layer with the new values.
373
374
            layer.weight = Parameter(qweight.t(), requires_grad=False)
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
375
            layer.input_scale = None
376

377
378
        # If checkpoint is fp8, handle that there are N scales for N
        # shards in a fused module
379
        else:
380
381
382
383
384
            layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
                                                    requires_grad=False)
            if self.quant_config.activation_scheme == "static":
                layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
                                                       requires_grad=False)
385
386
387

            weight = layer.weight
            weight_scale = layer.weight_scale
388
389
390

            # If using w8a8, torch._scaled_mm needs per tensor, so
            # requantize the logical shards as a single weight.
391
            if not self.use_marlin:
392
                # Dequant -> Quant with max scale so we can run per tensor.
393
                if current_platform.is_fp8_fnuz():
394
395
396
397
398
399
400
401
402
                    weight, weight_scale, input_scale = \
                        normalize_e4m3fn_to_e4m3fnuz(
                            weight=weight,
                            weight_scale=weight_scale,
                            input_scale=layer.input_scale)
                    if input_scale is not None:
                        layer.input_scale = Parameter(input_scale,
                                                      requires_grad=False)

403
                weight_scale, weight = requantize_with_max_scale(
404
405
                    weight=weight,
                    weight_scale=weight_scale,
406
407
                    logical_widths=layer.logical_widths,
                )
408

409
            weight = self._maybe_pad_weight(weight)
410
            # Update layer with new values.
411
            layer.weight = Parameter(weight.t(), requires_grad=False)
412
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
413
            if self.quant_config.activation_scheme == "static":
414
415
                layer.input_scale = Parameter(layer.input_scale.max(),
                                              requires_grad=False)
416

417
        if self.use_marlin:
418
            prepare_fp8_layer_for_marlin(layer, size_k_first)
419
420
            # Activations not quantized for marlin.
            del layer.input_scale
421

422
        # On B200, if E8M0 for DeepGemm is used, we need to
423
424
        # requantize the weight and input to the specific scale
        # at the same time.
425
        if is_blackwell_deep_gemm_e8m0_used():
426
427
428
429
430
431
432
433
434
            assert layer.weight_block_size is not None
            block_sz = tuple(layer.weight_block_size)
            requant_weight_ue8m0_inplace(
                layer.weight.data,
                layer.weight_scale_inv.data if hasattr(
                    layer, "weight_scale_inv") else layer.weight_scale.data,
                block_sz,
            )

435
436
437
438
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
439

440
        if self.use_marlin:
441
442
443
444
            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
445
446
447
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
448
                bias=bias)
449

450
451
        if self.block_quant:
            assert self.quant_config.weight_block_size is not None
452

453
            return torch.ops.vllm.apply_w8a8_block_fp8_linear(
454
455
456
457
458
459
                input=x,
                weight=layer.weight,
                block_size=self.quant_config.weight_block_size,
                weight_scale=layer.weight_scale_inv,
                input_scale=layer.input_scale,
                bias=bias,
460
                cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
461
                use_aiter_and_is_supported=self.use_aiter_and_is_supported,
462
463
            )

464
465
466
        return self.fp8_linear.apply(input=x,
                                     weight=layer.weight,
                                     weight_scale=layer.weight_scale,
467
                                     out_dtype=self.out_dtype,
468
469
                                     input_scale=layer.input_scale,
                                     bias=bias)
470
471


472
473
474
475
476
477
478
479
480
481
482
483
484
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

485
486
    def __init__(self, quant_config: Fp8Config, moe: FusedMoEConfig):
        super().__init__(moe)
487
        self.quant_config = quant_config
488
        self.block_quant = self.quant_config.weight_block_size is not None
489

490
491
492
493
494
        self.flashinfer_moe_enabled = False
        if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
            logger.info_once(
                "Using FlashInfer MoE FP8 kernels for Fp8MoEMethod.")
            self.flashinfer_moe_enabled = True
495
496
497
498
499
500
501
502
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
        self.use_marlin = (not current_platform.has_device_capability(89)
                           or envs.VLLM_TEST_FORCE_FP8_MARLIN)
        # Disable marlin for rocm
        if current_platform.is_rocm():
            self.use_marlin = False

503
504
505
        # Check for DeepGemm support.
        self.allow_deep_gemm = False
        if envs.VLLM_USE_DEEP_GEMM:
506
            if not has_deep_gemm():
507
                logger.warning_once("Failed to import DeepGemm kernels.")
508
509
            elif not self.block_quant:
                logger.warning_once("Model is not block quantized. Not using "
510
                                    "DeepGemm kernels")
511
            elif (is_deep_gemm_supported()):
512
513
514
515
516
517
                logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.")
                self.allow_deep_gemm = True
            else:
                logger.warning_once(
                    "DeepGemm not supported on the current platform.")

518
519
520
        # Check for CutlassBlockScaledGroupedGemm support.
        self.allow_cutlass_block_scaled_grouped_gemm = False
        if not self.block_quant:
521
522
            logger.debug_once("Model is not block quantized. Not using "
                              "CutlassBlockScaledGroupedGemm kernels")
523
        elif (current_platform.is_cuda()
524
              and current_platform.is_device_capability(100)):
525
526
527
528
529
530
531
532
533
            logger.info_once(
                "Using CutlassBlockScaledGroupedGemm kernels for Fp8MoEMethod."
            )
            self.allow_cutlass_block_scaled_grouped_gemm = True
        else:
            logger.warning_once(
                "CutlassBlockScaledGroupedGemm not supported on the current "
                "platform.")

534
    def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
535
536
                       intermediate_size_per_partition: int,
                       params_dtype: torch.dtype, **extra_weight_attrs):
537

538
539
540
541
542
543
        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

544
545
        if self.quant_config.is_checkpoint_fp8_serialized:
            params_dtype = torch.float8_e4m3fn
546
547
        if self.block_quant:
            assert self.quant_config.weight_block_size is not None
548
            layer.weight_block_size = self.quant_config.weight_block_size
549
550
551
552
553
554
555
556
557
            tp_size = get_tensor_model_parallel_world_size()
            block_n, block_k = (
                self.quant_config.weight_block_size[0],
                self.quant_config.weight_block_size[1],
            )
            # NOTE: To ensure proper alignment of the block-wise quantization
            # scales, the output_size of the weights for both the gate and up
            # layers must be divisible by block_n.
            # Required by column parallel or enabling merged weights
558
            if intermediate_size_per_partition % block_n != 0:
559
560
                raise ValueError(
                    f"The output_size of gate's and up's weight = "
561
                    f"{intermediate_size_per_partition} is not divisible by "
562
                    f"weight quantization block_n = {block_n}.")
563
564
            if (tp_size > 1
                    and intermediate_size_per_partition % block_k != 0):
565
                # Required by row parallel
566
567
568
569
                raise ValueError(
                    f"The input_size of down's weight = "
                    f"{intermediate_size_per_partition} is not divisible by "
                    f"weight quantization block_k = {block_k}.")
570
571

        # WEIGHTS
572
573
574
575
576
        w13_weight = torch.nn.Parameter(torch.empty(
            num_experts,
            2 * intermediate_size_per_partition,
            hidden_size,
            dtype=params_dtype),
577
578
579
580
                                        requires_grad=False)
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

581
582
583
584
585
        w2_weight = torch.nn.Parameter(torch.empty(
            num_experts,
            hidden_size,
            intermediate_size_per_partition,
            dtype=params_dtype),
586
587
588
589
590
                                       requires_grad=False)
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
        if not self.block_quant:
            # Allocate 2 scales for w1 and w3 respectively.
            # They will be combined to a single scale after weight loading.
            w13_weight_scale = torch.nn.Parameter(torch.ones(
                num_experts, 2, dtype=torch.float32),
                                                  requires_grad=False)
            w2_weight_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                 requires_grad=False)
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)
        else:
            w13_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
606
607
                    2 * ((intermediate_size_per_partition + block_n - 1) //
                         block_n),
608
609
610
611
612
613
614
615
616
                    (hidden_size + block_k - 1) // block_k,
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            w2_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
                    (hidden_size + block_n - 1) // block_n,
617
                    (intermediate_size_per_partition + block_k - 1) // block_k,
618
619
620
621
622
623
624
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
            layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
            assert self.quant_config.activation_scheme == "dynamic"
625

626
627
628
        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
629
630
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.
             value} if self.block_quant else
631
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
632
633
634
635
        # If loading fp8 checkpoint, pass the weight loaders.
        # If loading an fp16 checkpoint, do not (we will quantize in
        #   process_weights_after_loading()
        if self.quant_config.is_checkpoint_fp8_serialized:
636
637
            set_weight_attrs(w13_weight_scale, extra_weight_attrs)
            set_weight_attrs(w2_weight_scale, extra_weight_attrs)
638
639
640
641
642
643
644
645

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
            if not self.quant_config.is_checkpoint_fp8_serialized:
                raise ValueError(
                    "Found static activation scheme for checkpoint that "
                    "was not serialized fp8.")

646
647
648
649
            w13_input_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                 requires_grad=False)
            layer.register_parameter("w13_input_scale", w13_input_scale)
650
            set_weight_attrs(w13_input_scale, extra_weight_attrs)
651
652
653
654
655

            w2_input_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                requires_grad=False)
            layer.register_parameter("w2_input_scale", w2_input_scale)
656
657
            set_weight_attrs(w2_input_scale, extra_weight_attrs)

658
        else:
659
660
            layer.w13_input_scale = None
            layer.w2_input_scale = None
661
662

    def process_weights_after_loading(self, layer: Module) -> None:
663
664
        # Lazy import to avoid importing triton too early.
        from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
665
            is_rocm_aiter_moe_enabled, shuffle_weights)
666

667
668
        self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()

669
        # TODO (rob): refactor block quant into separate class.
670
        if self.block_quant:
671
            assert self.quant_config.activation_scheme == "dynamic"
672
            if current_platform.is_fp8_fnuz():
673
674
675
676
677
678
679
680
                w13_weight, w13_weight_scale_inv, w13_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w13_weight, layer.w13_weight_scale_inv,
                        layer.w13_input_scale)
                w2_weight, w2_weight_scale_inv, w2_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w2_weight, layer.w2_weight_scale_inv,
                        layer.w2_input_scale)
681
682
683
            elif self.flashinfer_moe_enabled:
                # NOTE: weights have to be swapped since the activation is
                # applied on different half for flashinfer vs vllm
684
685
                w13_weight = swap_w13_to_w31(layer.w13_weight.data)
                w13_weight_scale_inv = swap_w13_to_w31(
686
687
688
                    layer.w13_weight_scale_inv.data)
                w2_weight = layer.w2_weight.data
                w2_weight_scale_inv = layer.w2_weight_scale_inv.data
689
                if not self.block_quant:
690
                    register_moe_scaling_factors(layer)
691
                    rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
692
693
694
695
696
697
698
699
700
701
702
703
704
            else:
                w13_weight = layer.w13_weight.data
                w13_weight_scale_inv = layer.w13_weight_scale_inv.data
                w2_weight = layer.w2_weight
                w2_weight_scale_inv = layer.w2_weight_scale_inv

            # torch.compile() cannot use Parameter subclasses.
            layer.w13_weight = Parameter(w13_weight, requires_grad=False)
            layer.w13_weight_scale_inv = Parameter(w13_weight_scale_inv,
                                                   requires_grad=False)
            layer.w2_weight = Parameter(w2_weight, requires_grad=False)
            layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv,
                                                  requires_grad=False)
705
            if self.rocm_aiter_moe_enabled:
706
707
                # reshaping weights is required for aiter moe kernel.
                shuffled_w13, shuffled_w2 = shuffle_weights(
708
                    layer.w13_weight.data, layer.w2_weight.data)
709
710
711
712
713

                layer.w13_weight = torch.nn.Parameter(shuffled_w13,
                                                      requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2,
                                                     requires_grad=False)
714
715
716

            # DeepGemm scales need to be transposed and aligned.  We try to do
            # it ahead of time for performance reasons.
717
            if self.allow_deep_gemm and not is_blackwell_deep_gemm_e8m0_used():
718
719
720
                # Lazy import to avoid CUDA initialization problems.
                if _is_col_major(layer.w13_weight_scale_inv):
                    layer.w13_weight_scale_inv = \
721
                        get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv).contiguous()
722
723
                if _is_col_major(layer.w2_weight_scale_inv):
                    layer.w2_weight_scale_inv = \
724
                        get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous()
725

726
        # If checkpoint is fp16, quantize in place.
727
        elif not self.quant_config.is_checkpoint_fp8_serialized:
728
            fp8_dtype = current_platform.fp8_dtype()
729
            w13_weight = torch.empty_like(layer.w13_weight.data,
730
731
                                          dtype=fp8_dtype)
            w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
732
733
734

            # Re-initialize w13_scale because we directly quantize
            # merged w13 weights and generate a single scaling factor.
735
            layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
736
                layer.local_num_experts,
737
738
                dtype=torch.float32,
                device=w13_weight.device),
739
                                                        requires_grad=False)
740
            for expert in range(layer.local_num_experts):
741
                w13_weight[expert, :, :], layer.w13_weight_scale[
742
743
                    expert] = ops.scaled_fp8_quant(
                        layer.w13_weight.data[expert, :, :])
744
                w2_weight[expert, :, :], layer.w2_weight_scale[
745
746
747
748
749
750
                    expert] = ops.scaled_fp8_quant(
                        layer.w2_weight.data[expert, :, :])
            layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                  requires_grad=False)
            layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                 requires_grad=False)
751
            if self.rocm_aiter_moe_enabled:
752
                # reshaping weights is required for aiter moe kernel.
753
754
                shuffled_w13, shuffled_w2 = shuffle_weights(
                    layer.w13_weight, layer.w2_weight)
755
756
757
758
759

                layer.w13_weight = torch.nn.Parameter(shuffled_w13,
                                                      requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2,
                                                     requires_grad=False)
760
761
762
763
764
765
766
        # If checkpoint is fp8, we need to handle that the
        # MoE kernels require single activation scale and single weight
        # scale for w13 per expert.
        else:
            # Fp8 moe kernels require a single activation scale.
            # We take the max of all the scales in case they differ.
            if self.quant_config.activation_scheme == "static":
767
768
                if (layer.w13_input_scale is None
                        or layer.w2_input_scale is None):
769
770
771
                    raise ValueError(
                        "QuantConfig has static quantization, but found "
                        "activation scales are None.")
772
773
                if (not all_close_1d(layer.w13_input_scale)
                        or not all_close_1d(layer.w2_input_scale)):
774
                    logger.warning_once(
775
776
                        "Found input_scales that are not equal for "
                        "fp8 MoE layer. Using the maximum across experts "
777
                        "for each layer.")
778
779
780
781
                layer.w13_input_scale = torch.nn.Parameter(
                    layer.w13_input_scale.max(), requires_grad=False)
                layer.w2_input_scale = torch.nn.Parameter(
                    layer.w2_input_scale.max(), requires_grad=False)
782
            if current_platform.is_fp8_fnuz():
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
                # Normalize the weights and scales
                w13_weight, w13_weight_scale, w13_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w13_weight, layer.w13_weight_scale,
                        layer.w13_input_scale)
                w2_weight, w2_weight_scale, w2_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w2_weight, layer.w2_weight_scale,
                        layer.w2_input_scale)
                # Reset the parameter
                layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                      requires_grad=False)
                layer.w13_weight_scale = torch.nn.Parameter(
                    w13_weight_scale, requires_grad=False)
                if w13_input_scale is not None:
                    layer.w13_input_scale = torch.nn.Parameter(
                        w13_input_scale, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                     requires_grad=False)
                layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
                                                           requires_grad=False)
                if w2_input_scale is not None:
                    layer.w2_input_scale = torch.nn.Parameter(
                        w2_input_scale, requires_grad=False)
807
808
809

            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max then dequant and requant each expert.
810
            assert layer.w13_weight_scale is not None
811
            shard_size = layer.intermediate_size_per_partition
812
            max_w13_scales = layer.w13_weight_scale.max(dim=1).values
813
            for expert_id in range(layer.local_num_experts):
814
815
816
817
818
                start = 0
                for shard_id in range(2):
                    dq_weight = per_tensor_dequantize(
                        layer.w13_weight[expert_id][start:start +
                                                    shard_size, :],
819
                        layer.w13_weight_scale[expert_id][shard_id])
820
                    layer.w13_weight[expert_id][
821
                        start:start + shard_size, :], _ = ops.scaled_fp8_quant(
822
823
824
                            dq_weight, max_w13_scales[expert_id])
                    start += shard_size

825
            if self.rocm_aiter_moe_enabled:
826
827
                shuffled_w13, shuffled_w2 = shuffle_weights(
                    layer.w13_weight, layer.w2_weight)
828
829
830
831
832
833

                layer.w13_weight = torch.nn.Parameter(shuffled_w13,
                                                      requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2,
                                                     requires_grad=False)

834
835
            layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
                                                        requires_grad=False)
836
837
838
839
840
841

        if self.use_marlin:
            prepare_moe_fp8_layer_for_marlin(layer, False)
            # Activations not quantized for marlin.
            del layer.w13_input_scale
            del layer.w2_input_scale
842

843
        if is_blackwell_deep_gemm_e8m0_used():
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
            assert layer.weight_block_size is not None
            # Re-quantise the expert weights so their scales are UE8M0.
            block_sz = tuple(layer.weight_block_size)
            requant_weight_ue8m0_inplace(
                layer.w13_weight.data,
                layer.w13_weight_scale_inv.data,
                block_sz,
            )
            requant_weight_ue8m0_inplace(
                layer.w2_weight.data,
                layer.w2_weight_scale_inv.data,
                block_sz,
            )

            # Ensure column-major TMA alignment expected by DeepGEMM.
            if _is_col_major(layer.w13_weight_scale_inv):
                layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(
                    layer.w13_weight_scale_inv).contiguous()
            if _is_col_major(layer.w2_weight_scale_inv):
                layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
                    layer.w2_weight_scale_inv).contiguous()

bnellnm's avatar
bnellnm committed
866
867
868
869
870
    def select_gemm_impl(
        self,
        prepare_finalize: FusedMoEPrepareAndFinalize,
        moe: FusedMoEConfig,
    ) -> FusedMoEPermuteExpertsUnpermute:
871
872
873
        from vllm.model_executor.layers.fused_moe import (
            BatchedTritonOrDeepGemmExperts, TritonOrDeepGemmExperts)

874
875
        assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
            "Marlin and ROCm AITER are not supported with all2all yet.")
876

bnellnm's avatar
bnellnm committed
877
878
879
880
881
882
883
884
885
886
887
        if (prepare_finalize.activation_format ==
                FusedMoEActivationFormat.BatchedExperts):
            max_num_tokens_per_rank = (
                prepare_finalize.max_num_tokens_per_rank())
            assert max_num_tokens_per_rank is not None
            logger.debug(
                "BatchedTritonOrDeepGemmExperts(%s): "
                "max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
                self.__class__.__name__, max_num_tokens_per_rank,
                self.quant_config.weight_block_size, False)
            return BatchedTritonOrDeepGemmExperts(
888
                max_num_tokens=max_num_tokens_per_rank,
889
                num_dispatchers=prepare_finalize.num_dispatchers(),
890
                use_fp8_w8a8=True,
891
                block_shape=self.quant_config.weight_block_size,
bnellnm's avatar
bnellnm committed
892
                per_act_token_quant=False,
893
                allow_deep_gemm=self.allow_deep_gemm,
894
895
            )
        else:
bnellnm's avatar
bnellnm committed
896
897
898
899
900
            logger.debug(
                "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
                self.__class__.__name__, self.quant_config.weight_block_size,
                False)
            return TritonOrDeepGemmExperts(
901
902
903
904
905
                use_fp8_w8a8=True,
                block_shape=self.quant_config.weight_block_size,
                allow_deep_gemm=self.allow_deep_gemm,
            )

906
907
908
909
910
911
912
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
913
        use_grouped_topk: bool = False,
914
915
        topk_group: Optional[int] = None,
        num_expert_group: Optional[int] = None,
916
917
        global_num_experts: int = -1,
        expert_map: Optional[torch.Tensor] = None,
918
        custom_routing_function: Optional[Callable] = None,
Simon Mo's avatar
Simon Mo committed
919
920
        scoring_func: str = "softmax",
        e_score_correction_bias: Optional[torch.Tensor] = None,
921
        apply_router_weight_on_input: bool = False,
Michael Goin's avatar
Michael Goin committed
922
        activation: str = "silu",
923
924
925
926
        enable_eplb: bool = False,
        expert_load_view: Optional[torch.Tensor] = None,
        logical_to_physical_map: Optional[torch.Tensor] = None,
        logical_replica_count: Optional[torch.Tensor] = None,
927
    ) -> torch.Tensor:
928
929
930
931
932
        if enable_eplb:
            assert expert_load_view is not None
            assert logical_to_physical_map is not None
            assert logical_replica_count is not None
            assert isinstance(layer, FusedMoE)
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
        if not self.flashinfer_moe_enabled:
            topk_weights, topk_ids = FusedMoE.select_experts(
                hidden_states=x,
                router_logits=router_logits,
                use_grouped_topk=use_grouped_topk,
                top_k=top_k,
                renormalize=renormalize,
                topk_group=topk_group,
                num_expert_group=num_expert_group,
                custom_routing_function=custom_routing_function,
                scoring_func=scoring_func,
                e_score_correction_bias=e_score_correction_bias,
                indices_type=self.topk_indices_dtype,
                enable_eplb=enable_eplb,
                expert_map=expert_map,
                expert_load_view=expert_load_view,
                logical_to_physical_map=logical_to_physical_map,
                logical_replica_count=logical_replica_count,
            )
952

953
        if self.rocm_aiter_moe_enabled:
954
955
            from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (  # noqa: E501
                rocm_aiter_fused_experts)
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
            return rocm_aiter_fused_experts(
                x,
                layer.w13_weight,
                layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                activation=activation,
                use_fp8_w8a8=True,
                apply_router_weight_on_input=apply_router_weight_on_input,
                w1_scale=(layer.w13_weight_scale_inv
                          if self.block_quant else layer.w13_weight_scale),
                w2_scale=(layer.w2_weight_scale_inv
                          if self.block_quant else layer.w2_weight_scale),
                a1_scale=layer.w13_input_scale,
                a2_scale=layer.w2_input_scale,
971
972
                block_shape=self.quant_config.weight_block_size,
                expert_map=expert_map)
973
        elif self.use_marlin:
974
975
            assert activation == "silu", (
                f"{activation} not supported for Marlin MoE.")
976
977
978
979
            return torch.ops.vllm.fused_marlin_moe(
                x,
                layer.w13_weight,
                layer.w2_weight,
980
981
                None,
                None,
982
983
984
985
986
987
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                quant_type_id=scalar_types.float8_e4m3fn.id,
988
                apply_router_weight_on_input=apply_router_weight_on_input,
989
990
                global_num_experts=global_num_experts,
                expert_map=expert_map)
991
        elif self.flashinfer_moe_enabled:
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
            assert activation == 'silu'
            assert scoring_func == 'sigmoid'
            if self.block_quant:
                assert (renormalize and use_grouped_topk
                        and custom_routing_function is None)

                return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
                    routing_logits=router_logits.to(torch.float32),
                    routing_bias=e_score_correction_bias,
                    x=x,
                    w13_weight=layer.w13_weight,
                    w13_weight_scale_inv=layer.w13_weight_scale_inv,
                    w2_weight=layer.w2_weight,
                    w2_weight_scale_inv=layer.w2_weight_scale_inv,
                    global_num_experts=global_num_experts,
                    top_k=top_k,
                    num_expert_group=num_expert_group,
                    topk_group=topk_group,
                    intermediate_size=layer.intermediate_size_per_partition,
                    expert_offset=layer.ep_rank * layer.local_num_experts,
                    local_num_experts=layer.local_num_experts,
                    block_shape=self.quant_config.weight_block_size,
                    routed_scaling=1.0,
                )
            else:
                assert (not renormalize
                        and custom_routing_function is not None)
                return apply_flashinfer_per_tensor_scale_fp8(
                    layer=layer,
                    hidden_states=x,
                    router_logits=router_logits,
                    routing_bias=e_score_correction_bias,
                    global_num_experts=global_num_experts,
                    top_k=top_k,
                    num_expert_group=num_expert_group,
                    topk_group=topk_group,
                    apply_router_weight_on_input=apply_router_weight_on_input)
1029
        elif self.fused_experts is not None:
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
            return self.fused_experts(
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=True,
                activation=activation,
                global_num_experts=global_num_experts,
                apply_router_weight_on_input=apply_router_weight_on_input,
                expert_map=expert_map,
                w1_scale=(layer.w13_weight_scale_inv
                          if self.block_quant else layer.w13_weight_scale),
                w2_scale=(layer.w2_weight_scale_inv
                          if self.block_quant else layer.w2_weight_scale),
                a1_scale=layer.w13_input_scale,
                a2_scale=layer.w2_input_scale,
            )
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
        else:
            from vllm.model_executor.layers.fused_moe import fused_experts
            return fused_experts(
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=True,
                activation=activation,
                global_num_experts=global_num_experts,
                apply_router_weight_on_input=apply_router_weight_on_input,
                expert_map=expert_map,
                w1_scale=(layer.w13_weight_scale_inv
                          if self.block_quant else layer.w13_weight_scale),
                w2_scale=(layer.w2_weight_scale_inv
                          if self.block_quant else layer.w2_weight_scale),
                a1_scale=layer.w13_input_scale,
                a2_scale=layer.w2_input_scale,
                use_fp8_w8a8=True,
                block_shape=self.quant_config.weight_block_size,
                allow_deep_gemm=self.allow_deep_gemm,
                allow_cutlass_block_scaled_grouped_gemm=(
                    self.allow_cutlass_block_scaled_grouped_gemm))
1072
1073


1074
1075
1076
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
1077
1078
1079
    """

    def __init__(self, quant_config: Fp8Config):
1080
        super().__init__(quant_config)