petit.py 11.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py

from typing import Any, Optional

import regex as re
import torch
from torch.nn.parameter import Parameter

11
from vllm.attention.layer import Attention
12
from vllm.logger import init_logger
13
14
15
16
17
from vllm.model_executor.layers.linear import (
    LinearBase,
    LinearMethodBase,
    UnquantizedLinearMethod,
)
18
19
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
20
21
22
    QuantizationConfig,
    QuantizeMethodBase,
)
23
24
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.petit_utils import (
25
26
27
28
29
30
    apply_petit_nvfp4_linear,
    prepare_nvfp4_layer_for_petit,
    verify_petit_nvfp4_supported,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from vllm.platforms import current_platform

# Initialize logger for the module
logger = init_logger(__name__)


# Configuration class to support the NVFP4 quantized model
# generated by the ModelOpt quantization tool
class PetitNvFp4Config(QuantizationConfig):
    """Config class for Petit FP4."""

    def __init__(
        self,
        is_checkpoint_nvfp4_serialized: bool = False,
45
46
47
        kv_cache_quant_algo: str | None = None,
        group_size: int | None = None,
        exclude_modules: list[str] | None = None,
48
49
50
51
    ) -> None:
        self._check_hardware_support()
        self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
        if is_checkpoint_nvfp4_serialized:
52
53
54
55
            logger.warning(
                "Detected nvfp4 checkpoint. Please note that the "
                "format is experimental and subject to change."
            )
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
        self.group_size = group_size
        self.kv_cache_quant_algo = kv_cache_quant_algo
        self.exclude_modules = exclude_modules

    def _check_hardware_support(self) -> None:
        """
        Verifies that the current hardware is supported by the Petit backend.
        This backend is specifically designed for AMD GPUs and is not
        supported on the CUDA platform.
        """
        # This check ensures the code is NOT running on an NVIDIA GPU.
        if current_platform.is_cuda():
            raise ValueError(
                "The 'petit' quantization backend is designed for AMD GPUs "
                "and is not supported on the CUDA platform. For NVIDIA GPUs, "
                "please use a different quantization method such as FP8, AWQ, "
72
73
                "or GPTQ."
            )
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97

    @classmethod
    def get_name(cls) -> QuantizationMethods:
        return "petit_nvfp4"

    @classmethod
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
        # Petit supports the gfx90a and gfx942 GPUs
        return 90

    @classmethod
    def get_config_filenames(cls) -> list[str]:
        return ["hf_quant_config.json"]

    @classmethod
    def from_config(cls, config: dict[str, Any]) -> "PetitNvFp4Config":
        qc = cls.get_from_keys(config, ["quantization"])

        quant_method_raw = qc.get("quant_algo")
        if not isinstance(quant_method_raw, str) or not quant_method_raw:
98
            raise ValueError("Missing or invalid 'quant_algo' in quantization config.")
99
100
101
102
103
104
105
106
107
108
109
110
111
        quant_method = quant_method_raw.upper()

        group_size_raw = qc.get("group_size")
        if not isinstance(group_size_raw, int):
            raise ValueError(
                "Missing or invalid 'group_size' (int) in hf_quant_config.json."
            )
        group_size = group_size_raw

        verify_petit_nvfp4_supported(quant_method, group_size)

        kv_cache_quant_algo_raw = qc.get("kv_cache_quant_algo") or "auto"
        if not isinstance(kv_cache_quant_algo_raw, str):
112
            raise ValueError("'kv_cache_quant_algo' must be a string if provided.")
113
114
115
116
117
118
        kv_cache_quant_algo = kv_cache_quant_algo_raw

        exclude_raw = qc.get("exclude_modules", [])
        if exclude_raw is None:
            exclude_modules: list[str] = []
        elif isinstance(exclude_raw, list) and all(
119
120
            isinstance(x, str) for x in exclude_raw
        ):
121
122
            exclude_modules = exclude_raw
        else:
123
            raise ValueError("'exclude_modules' must be a list[str] (or omitted).")
124
125
126
127
128
129
130
131
132
133
134
135

        is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method

        return cls(
            is_checkpoint_nvfp4_serialized=is_checkpoint_nvfp4_serialized,
            kv_cache_quant_algo=kv_cache_quant_algo,
            group_size=group_size,
            exclude_modules=exclude_modules,
        )

    @classmethod
    def override_quantization_method(
136
        cls, hf_quant_cfg, user_quant
137
    ) -> QuantizationMethods | None:
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
        if not current_platform.is_rocm():
            return None

        qc = hf_quant_cfg.get("quantization", hf_quant_cfg)
        algo = (qc.get("quant_algo") or qc.get("quant_method") or "").upper()
        if algo in ("NVFP4", "MODELOPT_FP4", "MODELOPT"):
            return cls.get_name()  # "petit_nvfp4"
        return None

    @classmethod
    def is_petit_nvfp4_compatible(cls, quant_config: dict[str, Any]) -> bool:
        qc = quant_config.get("quantization", quant_config)
        algo = (qc.get("quant_algo") or qc.get("quant_method") or "").upper()
        return algo == "NVFP4"

153
    def is_layer_excluded(self, prefix: str, exclude_modules: list[str]) -> bool:
154
155
156
157
158
159
        for pattern in exclude_modules:
            regex_str = pattern.replace(".", r"\.").replace("*", r".*")
            if re.fullmatch(regex_str, prefix):
                return True
        return False

160
161
162
    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
163
164
165
166
        exclude = self.require_exclude_modules()

        if isinstance(layer, LinearBase):
            if is_layer_skipped(prefix, exclude) or self.is_layer_excluded(
167
168
                prefix, exclude
            ):
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
                return UnquantizedLinearMethod()
            return PetitNvFp4LinearMethod(self)
        elif isinstance(layer, Attention):
            return PetitFp8KVCacheMethod(self)
        return None

    def get_scaled_act_names(self) -> list[str]:
        return []

    def require_group_size(self) -> int:
        if self.group_size is None:
            logger.warning("group_size not set; defaulting to 16 for NVFP4.")
            return 16
        return self.group_size

    def require_kv_cache_quant_algo(self) -> str:
        return self.kv_cache_quant_algo or "auto"

    def require_exclude_modules(self) -> list[str]:
        return list(self.exclude_modules or [])


class PetitFp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
    """

    def __init__(self, quant_config: PetitNvFp4Config):
        super().__init__(quant_config)


class PetitNvFp4LinearMethod(LinearMethodBase):
    """Linear method for NVFP4.
    Supports loading NVFP4 checkpoints with the following structure:

    |Tensor Name           | datatype      |  shape      |
    |----------------------------------------------------|
    |input_scale           | torch.float32 | scalar      |
    |weight                | NVFP4(SE2M1)  | [1, X, y/2] |
    |weight_scale          | FP8-E4M3      | [X, Y]      |
    |weight_scale_2        | torch.float32 | scalar      |

    The weights are quantized per block of 16 elements.
    Args: quant_config: The ModelOpt quantization config.
    """

    def __init__(self, quant_config: PetitNvFp4Config):
        self.quant_config = quant_config

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        if not self.quant_config.is_checkpoint_nvfp4_serialized:
230
231
232
233
            raise ValueError(
                "NVFP4 quantization was selected, "
                " dynamic quantization is not supported."
            )
234
235
236
237
238
239
240
241
242

        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")

        layer.logical_widths = output_partition_sizes

        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        if input_size_per_partition % 16 != 0:
243
244
245
            raise ValueError(
                "Unsupported model when in features size is not multiple of 16"
            )
246

247
248
249
250
251
        weight_dtype = (
            torch.float8_e4m3fn
            if self.quant_config.is_checkpoint_nvfp4_serialized
            else params_dtype
        )
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297

        weight = ModelWeightParameter(
            data=torch.empty(
                # 2 fp4 data is packed in one uint8 in the input dimension
                output_size_per_partition,
                input_size_per_partition // 2,
                dtype=torch.uint8,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
        layer.register_parameter("weight", weight)

        input_scale = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )

        layer.register_parameter("input_scale", input_scale)

        weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
        layer.register_parameter("weight_scale_2", weight_scale_2)

        group_size = self.quant_config.require_group_size()
        weight_scale = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition // group_size,
                dtype=weight_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )

        layer.register_parameter("weight_scale", weight_scale)

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        input_scale_2 = layer.input_scale.max().to(torch.float32)
        weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
        layer.input_scale = Parameter(input_scale_2, requires_grad=False)
        layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)
298
299
300
        layer.alpha = Parameter(
            layer.input_scale * layer.weight_scale_2, requires_grad=False
        )
301
302
303
304
305
306
307
308

        prepare_nvfp4_layer_for_petit(layer)
        del layer.input_scale

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
309
        bias: torch.Tensor | None = None,
310
311
312
313
314
315
316
317
318
319
    ) -> torch.Tensor:
        return apply_petit_nvfp4_linear(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            weight_scale_2=layer.weight_scale_2,
            size_n=layer.output_size_per_partition,
            size_k=layer.input_size_per_partition,
            bias=bias,
        )