fp8.py 22.7 KB
Newer Older
1
from typing import Any, Callable, Dict, List, Optional
2
3
4
5
6

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

7
import vllm.envs as envs
8
from vllm import _custom_ops as ops
9
from vllm.logger import init_logger
10
11
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
                                                  FusedMoeWeightScaleSupported)
12
13
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
                                               UnquantizedLinearMethod)
14
from vllm.model_executor.layers.quantization.base_config import (
15
    QuantizationConfig, QuantizeMethodBase)
16
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
17
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
18
    apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
19
20
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    is_layer_skipped)
21
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
22
    all_close_1d, apply_fp8_linear, convert_to_channelwise,
23
    cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
24
    requantize_with_max_scale)
25
26
from vllm.model_executor.parameter import (ModelWeightParameter,
                                           PerTensorScaleParameter)
27
from vllm.model_executor.utils import set_weight_attrs
28
from vllm.platforms import current_platform
29
from vllm.utils import is_hip, print_warning_once
30

31
32
33
34
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

35

36
class Fp8Config(QuantizationConfig):
37
38
    """Config class for FP8."""

39
40
    def __init__(
        self,
41
        is_checkpoint_fp8_serialized: bool = False,
42
        activation_scheme: str = "dynamic",
43
        ignored_layers: Optional[List[str]] = None,
44
    ) -> None:
45
46
47
48
49
50
51
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
        if is_checkpoint_fp8_serialized:
            logger.warning("Detected fp8 checkpoint. Please note that the "
                           "format is experimental and subject to change.")
        if activation_scheme not in ACTIVATION_SCHEMES:
            raise ValueError(
                f"Unsupported activation scheme {activation_scheme}")
52
        self.activation_scheme = activation_scheme
53
        self.ignored_layers = ignored_layers or []
54

55
56
57
58
59
60
61
62
63
64
    @classmethod
    def get_name(cls) -> str:
        return "fp8"

    @classmethod
    def get_supported_act_dtypes(cls) -> List[torch.dtype]:
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
65
        return 80
66
67
68
69
70
71

    @classmethod
    def get_config_filenames(cls) -> List[str]:
        return []

    @classmethod
72
    def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
73
74
        quant_method = cls.get_from_keys(config, ["quant_method"])
        is_checkpoint_fp8_serialized = ("fp8" in quant_method)
75
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
76
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
77
        return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
78
79
                   activation_scheme=activation_scheme,
                   ignored_layers=ignored_layers)
80

81
82
    def get_quant_method(self, layer: torch.nn.Module,
                         prefix: str) -> Optional["QuantizeMethodBase"]:
83
84
        from vllm.attention.layer import Attention  # Avoid circular import

85
        if isinstance(layer, LinearBase):
86
87
            if is_layer_skipped(prefix, self.ignored_layers):
                return UnquantizedLinearMethod()
88
            return Fp8LinearMethod(self)
89
90
91
        elif isinstance(layer, FusedMoE):
            return Fp8MoEMethod(self)
        elif isinstance(layer, Attention):
92
            return Fp8KVCacheMethod(self)
93
        return None
94
95
96
97
98
99
100

    def get_scaled_act_names(self) -> List[str]:
        return []


class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
101
102
103
104
105
106
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.
107
108
109
110
111

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
112

113
114
115
116
    Args:
        quant_config: The quantization config.
    """

117
    def __init__(self, quant_config: Fp8Config):
118
        self.quant_config = quant_config
119
        self.cutlass_fp8_supported = cutlass_fp8_supported()
120

121
122
123
124
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
        capability = current_platform.get_device_capability()
        capability = capability[0] * 10 + capability[1]
125
        self.use_marlin = capability < 89 or envs.VLLM_TEST_FORCE_FP8_MARLIN
126
127
128
        # Disable marlin for rocm
        if is_hip():
            self.use_marlin = False
129

130
131
132
133
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
134
        output_partition_sizes: List[int],
135
136
137
138
139
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
140
        del input_size, output_size
141
        output_size_per_partition = sum(output_partition_sizes)
142
        weight_loader = extra_weight_attrs.get("weight_loader")
143
144
145

        layer.logical_widths = output_partition_sizes

146
147
148
149
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype

150
151
152
153
        # WEIGHT
        weight_dtype = (torch.float8_e4m3fn
                        if self.quant_config.is_checkpoint_fp8_serialized else
                        params_dtype)
154
155
156
157
158
159
160
161

        weight = ModelWeightParameter(data=torch.empty(
            output_size_per_partition,
            input_size_per_partition,
            dtype=weight_dtype),
                                      input_dim=1,
                                      output_dim=0,
                                      weight_loader=weight_loader)
162
163
        layer.register_parameter("weight", weight)

164
165
166
167
        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
168
169
170
171
172
            scale = PerTensorScaleParameter(data=torch.empty(
                len(output_partition_sizes), dtype=torch.float32),
                                            weight_loader=weight_loader)

            scale[:] = torch.finfo(torch.float32).min
173
            layer.register_parameter("weight_scale", scale)
174

175
            # INPUT ACTIVATION SCALE
176
            if self.quant_config.activation_scheme == "static":
177
178
179
180
181
                scale = PerTensorScaleParameter(data=torch.empty(
                    len(output_partition_sizes), dtype=torch.float32),
                                                weight_loader=weight_loader)

                scale[:] = torch.finfo(torch.float32).min
182
                layer.register_parameter("input_scale", scale)
183
184
            else:
                layer.register_parameter("input_scale", None)
185

186
    def process_weights_after_loading(self, layer: Module) -> None:
187
188
        layer.weight = torch.nn.Parameter(layer.weight.data,
                                          requires_grad=False)
189
        # If checkpoint not serialized fp8, quantize the weights.
190
191
192
        if not self.quant_config.is_checkpoint_fp8_serialized:
            qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
                                                         scale=None)
193

194
195
196
197
198
199
200
201
            # If using marlin (w8a16), kernel uses channelwise weights,
            # so extend the weight scales to be channelwise.
            if self.use_marlin:
                assert weight_scale.numel() == 1
                weight_scale = convert_to_channelwise(
                    weight_scale.expand(len(layer.logical_widths)),
                    layer.logical_widths)

202
            # Update the layer with the new values.
203
204
            layer.weight = Parameter(qweight.t(), requires_grad=False)
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
205
            layer.input_scale = None
206

207
208
        # If checkpoint is fp8, handle that there are N scales for N
        # shards in a fused module
209
        else:
210
211
212
213
214
            layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
                                                    requires_grad=False)
            if self.quant_config.activation_scheme == "static":
                layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
                                                       requires_grad=False)
215
216
217
218
219
220
221
222
223
224
225
            # If using marlin (w8a16), kernel uses channelwise weights,
            # so extend the weight scales to be channelwise.
            if self.use_marlin:
                weight = layer.weight
                weight_scale = convert_to_channelwise(layer.weight_scale,
                                                      layer.logical_widths)

            # If using w8a8, torch._scaled_mm needs per tensor, so
            # requantize the logical shards as a single weight.
            else:
                # Dequant -> Quant with max scale so we can run per tensor.
226
227
228
229
230
231
232
233
234
235
236
237
238
239
                weight = layer.weight
                weight_scale = layer.weight_scale

                # If rocm, use float8_e4m3fnuz.
                if is_hip():
                    weight, weight_scale, input_scale = \
                        normalize_e4m3fn_to_e4m3fnuz(
                            weight=weight,
                            weight_scale=weight_scale,
                            input_scale=layer.input_scale)
                    if input_scale is not None:
                        layer.input_scale = Parameter(input_scale,
                                                      requires_grad=False)

240
                weight_scale, weight = requantize_with_max_scale(
241
242
                    weight=weight,
                    weight_scale=weight_scale,
243
244
                    logical_widths=layer.logical_widths,
                )
245

246
            # Update layer with new values.
247
            layer.weight = Parameter(weight.t(), requires_grad=False)
248
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
249
            if self.quant_config.activation_scheme == "static":
250
251
                layer.input_scale = Parameter(layer.input_scale.max(),
                                              requires_grad=False)
252

253
254
255
256
        if self.use_marlin:
            prepare_fp8_layer_for_marlin(layer)
            # Activations not quantized for marlin.
            del layer.input_scale
257

258
259
260
261
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
262

263
        if self.use_marlin:
264
265
266
267
            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
268
269
270
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
271
                bias=bias)
272

273
274
275
276
277
278
        return apply_fp8_linear(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            input_scale=layer.input_scale,
            bias=bias,
279
280
            cutlass_fp8_supported=self.cutlass_fp8_supported,
            use_per_token_if_dynamic=False)
281
282


283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

    def __init__(self, quant_config: Fp8Config):
        self.quant_config = quant_config

    def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
                       intermediate_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):

        if self.quant_config.is_checkpoint_fp8_serialized:
            params_dtype = torch.float8_e4m3fn

        # WEIGHTS
        w13_weight = torch.nn.Parameter(torch.empty(num_experts,
                                                    2 * intermediate_size,
                                                    hidden_size,
                                                    dtype=params_dtype),
                                        requires_grad=False)
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

        w2_weight = torch.nn.Parameter(torch.empty(num_experts,
                                                   hidden_size,
                                                   intermediate_size,
                                                   dtype=params_dtype),
                                       requires_grad=False)
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
        # Allocate 2 scales for w1 and w3 respectively.
        # They will be combined to a single scale after weight loading.
326
327
328
329
330
        w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
                                                         2,
                                                         dtype=torch.float32),
                                              requires_grad=False)
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
331

332
333
334
335
        w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
                                                        dtype=torch.float32),
                                             requires_grad=False)
        layer.register_parameter("w2_weight_scale", w2_weight_scale)
336
337
338
339
        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
340
341
342
343
        # If loading fp8 checkpoint, pass the weight loaders.
        # If loading an fp16 checkpoint, do not (we will quantize in
        #   process_weights_after_loading()
        if self.quant_config.is_checkpoint_fp8_serialized:
344
345
            set_weight_attrs(w13_weight_scale, extra_weight_attrs)
            set_weight_attrs(w2_weight_scale, extra_weight_attrs)
346
347
348
349
350
351
352
353

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
            if not self.quant_config.is_checkpoint_fp8_serialized:
                raise ValueError(
                    "Found static activation scheme for checkpoint that "
                    "was not serialized fp8.")

354
355
356
357
            w13_input_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                 requires_grad=False)
            layer.register_parameter("w13_input_scale", w13_input_scale)
358
            set_weight_attrs(w13_input_scale, extra_weight_attrs)
359
360
361
362
363

            w2_input_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                requires_grad=False)
            layer.register_parameter("w2_input_scale", w2_input_scale)
364
365
            set_weight_attrs(w2_input_scale, extra_weight_attrs)

366
        else:
367
368
            layer.w13_input_scale = None
            layer.w2_input_scale = None
369
370
371
372
373

    def process_weights_after_loading(self, layer: Module) -> None:

        # If checkpoint is fp16, quantize in place.
        if not self.quant_config.is_checkpoint_fp8_serialized:
374
375
376
            # If rocm, use float8_e4m3fnuz as dtype
            fp8_dtype = torch.float8_e4m3fnuz \
                        if is_hip() else torch.float8_e4m3fn
377
            w13_weight = torch.empty_like(layer.w13_weight.data,
378
379
                                          dtype=fp8_dtype)
            w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
380
381
382

            # Re-initialize w13_scale because we directly quantize
            # merged w13 weights and generate a single scaling factor.
383
            layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
384
385
386
                layer.num_experts,
                dtype=torch.float32,
                device=w13_weight.device),
387
                                                        requires_grad=False)
388
            for expert in range(layer.num_experts):
389
                w13_weight[expert, :, :], layer.w13_weight_scale[
390
391
                    expert] = ops.scaled_fp8_quant(
                        layer.w13_weight.data[expert, :, :])
392
                w2_weight[expert, :, :], layer.w2_weight_scale[
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
                    expert] = ops.scaled_fp8_quant(
                        layer.w2_weight.data[expert, :, :])
            layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                  requires_grad=False)
            layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                 requires_grad=False)
            return

        # If checkpoint is fp8, we need to handle that the
        # MoE kernels require single activation scale and single weight
        # scale for w13 per expert.
        else:
            # Fp8 moe kernels require a single activation scale.
            # We take the max of all the scales in case they differ.
            if self.quant_config.activation_scheme == "static":
408
409
                if (layer.w13_input_scale is None
                        or layer.w2_input_scale is None):
410
411
412
                    raise ValueError(
                        "QuantConfig has static quantization, but found "
                        "activation scales are None.")
413
414
                if (not all_close_1d(layer.w13_input_scale)
                        or not all_close_1d(layer.w2_input_scale)):
415
416
417
418
                    print_warning_once(
                        "Found input_scales that are not equal for "
                        "fp8 MoE layer. Using the maximum across experts "
                        "for each layer. ")
419
420
421
422
                layer.w13_input_scale = torch.nn.Parameter(
                    layer.w13_input_scale.max(), requires_grad=False)
                layer.w2_input_scale = torch.nn.Parameter(
                    layer.w2_input_scale.max(), requires_grad=False)
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
            # If rocm, normalize the weights and scales to e4m3fnuz
            if is_hip():
                # Normalize the weights and scales
                w13_weight, w13_weight_scale, w13_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w13_weight, layer.w13_weight_scale,
                        layer.w13_input_scale)
                w2_weight, w2_weight_scale, w2_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w2_weight, layer.w2_weight_scale,
                        layer.w2_input_scale)
                # Reset the parameter
                layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                      requires_grad=False)
                layer.w13_weight_scale = torch.nn.Parameter(
                    w13_weight_scale, requires_grad=False)
                if w13_input_scale is not None:
                    layer.w13_input_scale = torch.nn.Parameter(
                        w13_input_scale, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                     requires_grad=False)
                layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
                                                           requires_grad=False)
                if w2_input_scale is not None:
                    layer.w2_input_scale = torch.nn.Parameter(
                        w2_input_scale, requires_grad=False)
449
450
451

            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max then dequant and requant each expert.
452
            assert layer.w13_weight_scale is not None
453
            shard_size = layer.intermediate_size_per_partition
454
            max_w13_scales = layer.w13_weight_scale.max(dim=1).values
455
456
457
458
459
460
            for expert_id in range(layer.num_experts):
                start = 0
                for shard_id in range(2):
                    dq_weight = per_tensor_dequantize(
                        layer.w13_weight[expert_id][start:start +
                                                    shard_size, :],
461
                        layer.w13_weight_scale[expert_id][shard_id])
462
                    layer.w13_weight[expert_id][
463
                        start:start + shard_size, :], _ = ops.scaled_fp8_quant(
464
465
466
                            dq_weight, max_w13_scales[expert_id])
                    start += shard_size

467
468
            layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
                                                        requires_grad=False)
469
470
            return

471
472
473
474
475
476
477
478
479
480
481
482
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool,
        topk_group: Optional[int] = None,
        num_expert_group: Optional[int] = None,
        custom_routing_function: Optional[Callable] = None,
    ) -> torch.Tensor:
483
484
485
486
487
488
489
490
491
492

        from vllm.model_executor.layers.fused_moe import fused_experts

        topk_weights, topk_ids = FusedMoE.select_experts(
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
493
494
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function)
495
496
497
498
499
500
501

        return fused_experts(x,
                             layer.w13_weight,
                             layer.w2_weight,
                             topk_weights=topk_weights,
                             topk_ids=topk_ids,
                             inplace=True,
502
                             use_fp8_w8a8=True,
503
504
505
506
                             w1_scale=layer.w13_weight_scale,
                             w2_scale=layer.w2_weight_scale,
                             a1_scale=layer.w13_input_scale,
                             a2_scale=layer.w2_input_scale)
507
508


509
510
511
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
512
513
514
    """

    def __init__(self, quant_config: Fp8Config):
515
        super().__init__(quant_config)