"vscode:/vscode.git/clone" did not exist on "8f4b313c3790844d2d6ec9aeaa6dd0825c94752e"
fp8.py 21.7 KB
Newer Older
1
from typing import Any, Dict, List, Optional
2
3
4
5
6

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

7
import vllm.envs as envs
8
from vllm import _custom_ops as ops
9
from vllm.logger import init_logger
10
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
11
12
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
                                               UnquantizedLinearMethod)
13
from vllm.model_executor.layers.quantization.base_config import (
14
    QuantizationConfig, QuantizeMethodBase)
15
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
16
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
17
    apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
18
19
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    is_layer_skipped)
20
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
21
22
    all_close_1d, apply_fp8_linear, convert_to_channelwise,
    create_per_tensor_scale_param, cutlass_fp8_supported,
23
24
    normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
    requantize_with_max_scale)
25
from vllm.model_executor.utils import set_weight_attrs
26
from vllm.platforms import current_platform
27
from vllm.utils import is_hip, print_warning_once
28

29
30
31
32
ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)

33

34
class Fp8Config(QuantizationConfig):
35
36
    """Config class for FP8."""

37
38
    def __init__(
        self,
39
        is_checkpoint_fp8_serialized: bool = False,
40
        activation_scheme: str = "dynamic",
41
        ignored_layers: Optional[List[str]] = None,
42
    ) -> None:
43
44
45
46
47
48
49
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
        if is_checkpoint_fp8_serialized:
            logger.warning("Detected fp8 checkpoint. Please note that the "
                           "format is experimental and subject to change.")
        if activation_scheme not in ACTIVATION_SCHEMES:
            raise ValueError(
                f"Unsupported activation scheme {activation_scheme}")
50
        self.activation_scheme = activation_scheme
51
        self.ignored_layers = ignored_layers or []
52

53
54
55
56
57
58
59
60
61
62
    @classmethod
    def get_name(cls) -> str:
        return "fp8"

    @classmethod
    def get_supported_act_dtypes(cls) -> List[torch.dtype]:
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
63
        return 80
64
65
66
67
68
69

    @classmethod
    def get_config_filenames(cls) -> List[str]:
        return []

    @classmethod
70
    def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
71
72
        quant_method = cls.get_from_keys(config, ["quant_method"])
        is_checkpoint_fp8_serialized = ("fp8" in quant_method)
73
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
74
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
75
        return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
76
77
                   activation_scheme=activation_scheme,
                   ignored_layers=ignored_layers)
78

79
80
    def get_quant_method(self, layer: torch.nn.Module,
                         prefix: str) -> Optional["QuantizeMethodBase"]:
81
82
        from vllm.attention.layer import Attention  # Avoid circular import

83
        if isinstance(layer, LinearBase):
84
85
            if is_layer_skipped(prefix, self.ignored_layers):
                return UnquantizedLinearMethod()
86
            return Fp8LinearMethod(self)
87
88
89
        elif isinstance(layer, FusedMoE):
            return Fp8MoEMethod(self)
        elif isinstance(layer, Attention):
90
            return Fp8KVCacheMethod(self)
91
        return None
92
93
94
95
96
97
98

    def get_scaled_act_names(self) -> List[str]:
        return []


class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
99
100
101
102
103
104
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.
105
106
107
108
109

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
110

111
112
113
114
    Args:
        quant_config: The quantization config.
    """

115
    def __init__(self, quant_config: Fp8Config):
116
        self.quant_config = quant_config
117
        self.cutlass_fp8_supported = cutlass_fp8_supported()
118

119
120
121
122
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
        capability = current_platform.get_device_capability()
        capability = capability[0] * 10 + capability[1]
123
        self.use_marlin = capability < 89 or envs.VLLM_TEST_FORCE_FP8_MARLIN
124
125
126
        # Disable marlin for rocm
        if is_hip():
            self.use_marlin = False
127

128
129
130
131
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
132
        output_partition_sizes: List[int],
133
134
135
136
137
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
138
        del input_size, output_size
139
        output_size_per_partition = sum(output_partition_sizes)
140
141
142

        layer.logical_widths = output_partition_sizes

143
144
145
146
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype

147
148
149
150
        # WEIGHT
        weight_dtype = (torch.float8_e4m3fn
                        if self.quant_config.is_checkpoint_fp8_serialized else
                        params_dtype)
151
152
        weight = Parameter(torch.empty(output_size_per_partition,
                                       input_size_per_partition,
153
                                       dtype=weight_dtype),
154
155
                           requires_grad=False)
        layer.register_parameter("weight", weight)
156
157
158
159
160
        set_weight_attrs(weight, {
            **extra_weight_attrs,
            "input_dim": 1,
            "output_dim": 0,
        })
161

162
163
164
165
        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
166
167
168
            scale = create_per_tensor_scale_param(output_partition_sizes,
                                                  **extra_weight_attrs)
            layer.register_parameter("weight_scale", scale)
169

170
            # INPUT ACTIVATION SCALE
171
            if self.quant_config.activation_scheme == "static":
172
173
174
                scale = create_per_tensor_scale_param(output_partition_sizes,
                                                      **extra_weight_attrs)
                layer.register_parameter("input_scale", scale)
175
176
            else:
                layer.register_parameter("input_scale", None)
177

178
    def process_weights_after_loading(self, layer: Module) -> None:
179
        # If checkpoint not serialized fp8, quantize the weights.
180
181
182
        if not self.quant_config.is_checkpoint_fp8_serialized:
            qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
                                                         scale=None)
183

184
185
186
187
188
189
190
191
            # If using marlin (w8a16), kernel uses channelwise weights,
            # so extend the weight scales to be channelwise.
            if self.use_marlin:
                assert weight_scale.numel() == 1
                weight_scale = convert_to_channelwise(
                    weight_scale.expand(len(layer.logical_widths)),
                    layer.logical_widths)

192
            # Update the layer with the new values.
193
194
            layer.weight = Parameter(qweight.t(), requires_grad=False)
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
195
            layer.input_scale = None
196

197
198
        # If checkpoint is fp8, handle that there are N scales for N
        # shards in a fused module
199
        else:
200
201
202
203
204
205
206
207
208
209
210
            # If using marlin (w8a16), kernel uses channelwise weights,
            # so extend the weight scales to be channelwise.
            if self.use_marlin:
                weight = layer.weight
                weight_scale = convert_to_channelwise(layer.weight_scale,
                                                      layer.logical_widths)

            # If using w8a8, torch._scaled_mm needs per tensor, so
            # requantize the logical shards as a single weight.
            else:
                # Dequant -> Quant with max scale so we can run per tensor.
211
212
213
214
215
216
217
218
219
220
221
222
223
224
                weight = layer.weight
                weight_scale = layer.weight_scale

                # If rocm, use float8_e4m3fnuz.
                if is_hip():
                    weight, weight_scale, input_scale = \
                        normalize_e4m3fn_to_e4m3fnuz(
                            weight=weight,
                            weight_scale=weight_scale,
                            input_scale=layer.input_scale)
                    if input_scale is not None:
                        layer.input_scale = Parameter(input_scale,
                                                      requires_grad=False)

225
                weight_scale, weight = requantize_with_max_scale(
226
227
                    weight=weight,
                    weight_scale=weight_scale,
228
229
                    logical_widths=layer.logical_widths,
                )
230

231
            # Update layer with new values.
232
            layer.weight = Parameter(weight.t(), requires_grad=False)
233
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
234
            if self.quant_config.activation_scheme == "static":
235
236
                layer.input_scale = Parameter(layer.input_scale.max(),
                                              requires_grad=False)
237

238
239
240
241
        if self.use_marlin:
            prepare_fp8_layer_for_marlin(layer)
            # Activations not quantized for marlin.
            del layer.input_scale
242

243
244
245
246
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
247

248
        if self.use_marlin:
249
250
251
252
            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
253
254
255
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
256
                bias=bias)
257

258
259
260
261
262
263
        return apply_fp8_linear(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            input_scale=layer.input_scale,
            bias=bias,
264
265
            cutlass_fp8_supported=self.cutlass_fp8_supported,
            use_per_token_if_dynamic=False)
266
267


268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

    def __init__(self, quant_config: Fp8Config):
        self.quant_config = quant_config

    def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
                       intermediate_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):

        if self.quant_config.is_checkpoint_fp8_serialized:
            params_dtype = torch.float8_e4m3fn

        # WEIGHTS
        w13_weight = torch.nn.Parameter(torch.empty(num_experts,
                                                    2 * intermediate_size,
                                                    hidden_size,
                                                    dtype=params_dtype),
                                        requires_grad=False)
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

        w2_weight = torch.nn.Parameter(torch.empty(num_experts,
                                                   hidden_size,
                                                   intermediate_size,
                                                   dtype=params_dtype),
                                       requires_grad=False)
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
        # Allocate 2 scales for w1 and w3 respectively.
        # They will be combined to a single scale after weight loading.
311
312
313
314
315
        w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
                                                         2,
                                                         dtype=torch.float32),
                                              requires_grad=False)
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
316

317
318
319
320
        w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
                                                        dtype=torch.float32),
                                             requires_grad=False)
        layer.register_parameter("w2_weight_scale", w2_weight_scale)
321
322
323
324
325

        # If loading fp8 checkpoint, pass the weight loaders.
        # If loading an fp16 checkpoint, do not (we will quantize in
        #   process_weights_after_loading()
        if self.quant_config.is_checkpoint_fp8_serialized:
326
327
328
329
330
331
332
333
            set_weight_attrs(w13_weight_scale, {
                "is_fp8_scale": True,
                **extra_weight_attrs
            })
            set_weight_attrs(w2_weight_scale, {
                "is_fp8_scale": True,
                **extra_weight_attrs
            })
334
335
336
337
338
339
340
341

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
            if not self.quant_config.is_checkpoint_fp8_serialized:
                raise ValueError(
                    "Found static activation scheme for checkpoint that "
                    "was not serialized fp8.")

342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
            w13_input_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                 requires_grad=False)
            layer.register_parameter("w13_input_scale", w13_input_scale)
            set_weight_attrs(w13_input_scale, {
                "is_fp8_scale": True,
                **extra_weight_attrs
            })

            w2_input_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                requires_grad=False)
            layer.register_parameter("w2_input_scale", w2_input_scale)
            set_weight_attrs(w2_input_scale, {
                "is_fp8_scale": True,
                **extra_weight_attrs
            })
359
        else:
360
361
            layer.w13_input_scale = None
            layer.w2_input_scale = None
362
363
364
365
366

    def process_weights_after_loading(self, layer: Module) -> None:

        # If checkpoint is fp16, quantize in place.
        if not self.quant_config.is_checkpoint_fp8_serialized:
367
368
369
            # If rocm, use float8_e4m3fnuz as dtype
            fp8_dtype = torch.float8_e4m3fnuz \
                        if is_hip() else torch.float8_e4m3fn
370
            w13_weight = torch.empty_like(layer.w13_weight.data,
371
372
                                          dtype=fp8_dtype)
            w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
373
374
375

            # Re-initialize w13_scale because we directly quantize
            # merged w13 weights and generate a single scaling factor.
376
            layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
377
378
379
                layer.num_experts,
                dtype=torch.float32,
                device=w13_weight.device),
380
                                                        requires_grad=False)
381
            for expert in range(layer.num_experts):
382
                w13_weight[expert, :, :], layer.w13_weight_scale[
383
384
                    expert] = ops.scaled_fp8_quant(
                        layer.w13_weight.data[expert, :, :])
385
                w2_weight[expert, :, :], layer.w2_weight_scale[
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
                    expert] = ops.scaled_fp8_quant(
                        layer.w2_weight.data[expert, :, :])
            layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                  requires_grad=False)
            layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                 requires_grad=False)
            return

        # If checkpoint is fp8, we need to handle that the
        # MoE kernels require single activation scale and single weight
        # scale for w13 per expert.
        else:
            # Fp8 moe kernels require a single activation scale.
            # We take the max of all the scales in case they differ.
            if self.quant_config.activation_scheme == "static":
401
402
                if (layer.w13_input_scale is None
                        or layer.w2_input_scale is None):
403
404
405
                    raise ValueError(
                        "QuantConfig has static quantization, but found "
                        "activation scales are None.")
406
407
                if (not all_close_1d(layer.w13_input_scale)
                        or not all_close_1d(layer.w2_input_scale)):
408
409
410
411
                    print_warning_once(
                        "Found input_scales that are not equal for "
                        "fp8 MoE layer. Using the maximum across experts "
                        "for each layer. ")
412
413
414
415
                layer.w13_input_scale = torch.nn.Parameter(
                    layer.w13_input_scale.max(), requires_grad=False)
                layer.w2_input_scale = torch.nn.Parameter(
                    layer.w2_input_scale.max(), requires_grad=False)
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
            # If rocm, normalize the weights and scales to e4m3fnuz
            if is_hip():
                # Normalize the weights and scales
                w13_weight, w13_weight_scale, w13_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w13_weight, layer.w13_weight_scale,
                        layer.w13_input_scale)
                w2_weight, w2_weight_scale, w2_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w2_weight, layer.w2_weight_scale,
                        layer.w2_input_scale)
                # Reset the parameter
                layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                      requires_grad=False)
                layer.w13_weight_scale = torch.nn.Parameter(
                    w13_weight_scale, requires_grad=False)
                if w13_input_scale is not None:
                    layer.w13_input_scale = torch.nn.Parameter(
                        w13_input_scale, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                     requires_grad=False)
                layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
                                                           requires_grad=False)
                if w2_input_scale is not None:
                    layer.w2_input_scale = torch.nn.Parameter(
                        w2_input_scale, requires_grad=False)
442
443
444

            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max then dequant and requant each expert.
445
            assert layer.w13_weight_scale is not None
446
            shard_size = layer.intermediate_size_per_partition
447
            max_w13_scales = layer.w13_weight_scale.max(dim=1).values
448
449
450
451
452
453
            for expert_id in range(layer.num_experts):
                start = 0
                for shard_id in range(2):
                    dq_weight = per_tensor_dequantize(
                        layer.w13_weight[expert_id][start:start +
                                                    shard_size, :],
454
                        layer.w13_weight_scale[expert_id][shard_id])
455
                    layer.w13_weight[expert_id][
456
                        start:start + shard_size, :], _ = ops.scaled_fp8_quant(
457
458
459
                            dq_weight, max_w13_scales[expert_id])
                    start += shard_size

460
461
            layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
                                                        requires_grad=False)
462
463
464
465
466
467
468
            return

    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              router_logits: torch.Tensor,
              top_k: int,
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
              renormalize: bool,
              use_grouped_topk: bool,
              topk_group: Optional[int] = None,
              num_expert_group: Optional[int] = None) -> torch.Tensor:

        from vllm.model_executor.layers.fused_moe import fused_experts

        topk_weights, topk_ids = FusedMoE.select_experts(
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group)

        return fused_experts(x,
                             layer.w13_weight,
                             layer.w2_weight,
                             topk_weights=topk_weights,
                             topk_ids=topk_ids,
                             inplace=True,
                             use_fp8=True,
                             w1_scale=layer.w13_weight_scale,
                             w2_scale=layer.w2_weight_scale,
                             a1_scale=layer.w13_input_scale,
                             a2_scale=layer.w2_input_scale)
496
497


498
499
500
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
501
502
503
    """

    def __init__(self, quant_config: Fp8Config):
504
        super().__init__(quant_config)