petit.py 11.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py

from typing import Any, Optional

import regex as re
import torch
from torch.nn.parameter import Parameter

from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
                                               UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.petit_utils import (
    apply_petit_nvfp4_linear, prepare_nvfp4_layer_for_petit,
    verify_petit_nvfp4_supported)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    is_layer_skipped)
from vllm.model_executor.parameter import (ModelWeightParameter,
                                           PerTensorScaleParameter)
from vllm.platforms import current_platform

# Initialize logger for the module
logger = init_logger(__name__)


# Configuration class to support the NVFP4 quantized model
# generated by the ModelOpt quantization tool
class PetitNvFp4Config(QuantizationConfig):
    """Config class for Petit FP4."""

    def __init__(
        self,
        is_checkpoint_nvfp4_serialized: bool = False,
        kv_cache_quant_algo: Optional[str] = None,
        group_size: Optional[int] = None,
        exclude_modules: Optional[list[str]] = None,
    ) -> None:
        self._check_hardware_support()
        self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
        if is_checkpoint_nvfp4_serialized:
            logger.warning("Detected nvfp4 checkpoint. Please note that the "
                           "format is experimental and subject to change.")
        self.group_size = group_size
        self.kv_cache_quant_algo = kv_cache_quant_algo
        self.exclude_modules = exclude_modules

    def _check_hardware_support(self) -> None:
        """
        Verifies that the current hardware is supported by the Petit backend.
        This backend is specifically designed for AMD GPUs and is not
        supported on the CUDA platform.
        """
        # This check ensures the code is NOT running on an NVIDIA GPU.
        if current_platform.is_cuda():
            raise ValueError(
                "The 'petit' quantization backend is designed for AMD GPUs "
                "and is not supported on the CUDA platform. For NVIDIA GPUs, "
                "please use a different quantization method such as FP8, AWQ, "
                "or GPTQ.")

    @classmethod
    def get_name(cls) -> QuantizationMethods:
        return "petit_nvfp4"

    @classmethod
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
        # Petit supports the gfx90a and gfx942 GPUs
        return 90

    @classmethod
    def get_config_filenames(cls) -> list[str]:
        return ["hf_quant_config.json"]

    @classmethod
    def from_config(cls, config: dict[str, Any]) -> "PetitNvFp4Config":
        qc = cls.get_from_keys(config, ["quantization"])

        quant_method_raw = qc.get("quant_algo")
        if not isinstance(quant_method_raw, str) or not quant_method_raw:
            raise ValueError(
                "Missing or invalid 'quant_algo' in quantization config.")
        quant_method = quant_method_raw.upper()

        group_size_raw = qc.get("group_size")
        if not isinstance(group_size_raw, int):
            raise ValueError(
                "Missing or invalid 'group_size' (int) in hf_quant_config.json."
            )
        group_size = group_size_raw

        verify_petit_nvfp4_supported(quant_method, group_size)

        kv_cache_quant_algo_raw = qc.get("kv_cache_quant_algo") or "auto"
        if not isinstance(kv_cache_quant_algo_raw, str):
            raise ValueError(
                "'kv_cache_quant_algo' must be a string if provided.")
        kv_cache_quant_algo = kv_cache_quant_algo_raw

        exclude_raw = qc.get("exclude_modules", [])
        if exclude_raw is None:
            exclude_modules: list[str] = []
        elif isinstance(exclude_raw, list) and all(
                isinstance(x, str) for x in exclude_raw):
            exclude_modules = exclude_raw
        else:
            raise ValueError(
                "'exclude_modules' must be a list[str] (or omitted).")

        is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method

        return cls(
            is_checkpoint_nvfp4_serialized=is_checkpoint_nvfp4_serialized,
            kv_cache_quant_algo=kv_cache_quant_algo,
            group_size=group_size,
            exclude_modules=exclude_modules,
        )

    @classmethod
    def override_quantization_method(
            cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
        if not current_platform.is_rocm():
            return None

        qc = hf_quant_cfg.get("quantization", hf_quant_cfg)
        algo = (qc.get("quant_algo") or qc.get("quant_method") or "").upper()
        if algo in ("NVFP4", "MODELOPT_FP4", "MODELOPT"):
            return cls.get_name()  # "petit_nvfp4"
        return None

    @classmethod
    def is_petit_nvfp4_compatible(cls, quant_config: dict[str, Any]) -> bool:
        qc = quant_config.get("quantization", quant_config)
        algo = (qc.get("quant_algo") or qc.get("quant_method") or "").upper()
        return algo == "NVFP4"

    def is_layer_excluded(self, prefix: str,
                          exclude_modules: list[str]) -> bool:
        for pattern in exclude_modules:
            regex_str = pattern.replace(".", r"\.").replace("*", r".*")
            if re.fullmatch(regex_str, prefix):
                return True
        return False

    def get_quant_method(self, layer: torch.nn.Module,
                         prefix: str) -> Optional["QuantizeMethodBase"]:
        from vllm.attention.layer import Attention  # Avoid circular import

        exclude = self.require_exclude_modules()

        if isinstance(layer, LinearBase):
            if is_layer_skipped(prefix, exclude) or self.is_layer_excluded(
                    prefix, exclude):
                return UnquantizedLinearMethod()
            return PetitNvFp4LinearMethod(self)
        elif isinstance(layer, Attention):
            return PetitFp8KVCacheMethod(self)
        return None

    def get_scaled_act_names(self) -> list[str]:
        return []

    def require_group_size(self) -> int:
        if self.group_size is None:
            logger.warning("group_size not set; defaulting to 16 for NVFP4.")
            return 16
        return self.group_size

    def require_kv_cache_quant_algo(self) -> str:
        return self.kv_cache_quant_algo or "auto"

    def require_exclude_modules(self) -> list[str]:
        return list(self.exclude_modules or [])


class PetitFp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
    """

    def __init__(self, quant_config: PetitNvFp4Config):
        super().__init__(quant_config)


class PetitNvFp4LinearMethod(LinearMethodBase):
    """Linear method for NVFP4.
    Supports loading NVFP4 checkpoints with the following structure:

    |Tensor Name           | datatype      |  shape      |
    |----------------------------------------------------|
    |input_scale           | torch.float32 | scalar      |
    |weight                | NVFP4(SE2M1)  | [1, X, y/2] |
    |weight_scale          | FP8-E4M3      | [X, Y]      |
    |weight_scale_2        | torch.float32 | scalar      |

    The weights are quantized per block of 16 elements.
    Args: quant_config: The ModelOpt quantization config.
    """

    def __init__(self, quant_config: PetitNvFp4Config):
        self.quant_config = quant_config

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        if not self.quant_config.is_checkpoint_nvfp4_serialized:
            raise ValueError("NVFP4 quantization was selected, "
                             " dynamic quantization is not supported.")

        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")

        layer.logical_widths = output_partition_sizes

        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        if input_size_per_partition % 16 != 0:
            raise ValueError("Unsupported model when in features size is "
                             "not multiple of 16")

        weight_dtype = (torch.float8_e4m3fn
                        if self.quant_config.is_checkpoint_nvfp4_serialized
                        else params_dtype)

        weight = ModelWeightParameter(
            data=torch.empty(
                # 2 fp4 data is packed in one uint8 in the input dimension
                output_size_per_partition,
                input_size_per_partition // 2,
                dtype=torch.uint8,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
        layer.register_parameter("weight", weight)

        input_scale = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )

        layer.register_parameter("input_scale", input_scale)

        weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
        layer.register_parameter("weight_scale_2", weight_scale_2)

        group_size = self.quant_config.require_group_size()
        weight_scale = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition // group_size,
                dtype=weight_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )

        layer.register_parameter("weight_scale", weight_scale)

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        input_scale_2 = layer.input_scale.max().to(torch.float32)
        weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
        layer.input_scale = Parameter(input_scale_2, requires_grad=False)
        layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)
        layer.alpha = Parameter(layer.input_scale * layer.weight_scale_2,
                                requires_grad=False)

        prepare_nvfp4_layer_for_petit(layer)
        del layer.input_scale

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        return apply_petit_nvfp4_linear(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            weight_scale_2=layer.weight_scale_2,
            size_n=layer.output_size_per_partition,
            size_k=layer.input_size_per_partition,
            bias=bias,
        )