test_register_quantization_config.py 4.69 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
"""Tests register custom quantization config.

See https://github.com/vllm-project/vllm/issues/11926 for more details.

Run `pytest tests/quantization/test_register_quantization_config.py`.
"""
9

10
import logging
11
from typing import Any
12
13
14
15
16

import pytest
import torch
import torch.nn.functional as F

17
18
19
20
from vllm.model_executor.layers.linear import (
    LinearBase,  # noqa: E501
    UnquantizedLinearMethod,
)
21
from vllm.model_executor.layers.quantization import (
22
23
24
25
    QuantizationMethods,
    get_quantization_config,
    register_quantization_config,
)
26
from vllm.model_executor.layers.quantization.base_config import (  # noqa: E501
27
28
    QuantizationConfig,
)
29
30
31
32
33
34
35
36
37
38


class FakeQuantLinearMethod(UnquantizedLinearMethod):
    """Fake quantization linear method for per-token dynamic quantization."""

    def __init__(self, num_bits: int = 8) -> None:
        """Initialize the quantization method."""
        super().__init__()
        self.num_bits = num_bits

39
40
    def apply(
        self,
41
42
43
44
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
45
46
47
48
49
50
51
52
        """Perform fake quantization before the linear layer."""

        # Calculate the scales dynamically
        max_val = torch.amax(x, dim=(0, -1), keepdims=True)
        min_val = torch.amin(x, dim=(0, -1), keepdims=True)
        scales = (max_val - min_val) / (2**self.num_bits - 1)

        # Fake quantize the input
53
54
55
56
57
        quant_x = torch.clamp(
            torch.round(x / scales),
            -(2 ** (self.num_bits - 1)),
            2 ** (self.num_bits - 1) - 1,
        )
58
59
60
61
62
63
64
65
66
67
68
        dequant_x = quant_x * scales

        return F.linear(dequant_x, layer.weight, bias)


@register_quantization_config("custom_quant")
class CustomQuantConfig(QuantizationConfig):
    """Custom quantization config for per-token dynamic fake quantization."""

    def __init__(self, num_bits: int = 8) -> None:
        """Initialize the quantization config."""
69
        super().__init__()
70
71
        self.num_bits = num_bits

72
    def get_name(self) -> QuantizationMethods:
73
74
75
        """Name of the quantization method."""
        return "custom_quant"

76
    def get_supported_act_dtypes(self) -> list[torch.dtype]:
77
78
79
80
81
82
83
84
85
        """List of supported activation dtypes."""
        return [torch.float16, torch.bfloat16]

    @classmethod
    def get_min_capability(cls) -> int:
        """Minimum GPU capability to support the quantization method."""
        return -1

    @staticmethod
86
    def get_config_filenames() -> list[str]:
87
88
89
90
        """List of filenames to search for in the model directory."""
        return []

    @classmethod
91
    def from_config(cls, config: dict[str, Any]) -> "CustomQuantConfig":
92
93
94
        """Create a config class from the model's quantization config."""
        return CustomQuantConfig(num_bits=config.get("num_bits", 8))

95
    def get_quant_method(
96
97
        self, layer: torch.nn.Module, prefix: str
    ) -> FakeQuantLinearMethod | None:
98
99
100
101
102
103
        """Get the quantize method to use for the quantized layer."""
        if isinstance(layer, LinearBase):
            return FakeQuantLinearMethod(num_bits=self.num_bits)
        return None


104
def test_register_quantization_config(caplog_vllm):
105
106
107
108
109
110
    """Test register custom quantization config."""

    # The quantization method `custom_quant` should be registered.
    assert get_quantization_config("custom_quant") == CustomQuantConfig

    # The quantization method `custom_quant` is already exists,
111
112
    # should raise a warning when re-registering it.
    with caplog_vllm.at_level(logging.WARNING):
113
114
        register_quantization_config("custom_quant")(CustomQuantConfig)

115
116
117
118
119
    assert any(
        "The quantization method 'custom_quant' already exists" in message
        for message in caplog_vllm.messages
    ), "Expected a warning when re-registering custom_quant"

120

121
122
123
124
125
126
@pytest.mark.parametrize(
    argnames="model",
    argvalues=[
        "meta-llama/Llama-3.2-1B-Instruct",
    ],
)
127
def test_custom_quant(vllm_runner, model, monkeypatch):
128
    """Test infer with the custom quantization method."""
129
130
131
    # `LLM.apply_model` requires pickling a function.
    monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")

132
133
134
    with vllm_runner(
        model_name=model, quantization="custom_quant", enforce_eager=True
    ) as llm:
135

136
137
138
139
140
141
        def check_model(model):
            layer = model.model.layers[0]
            qkv_proj = layer.self_attn.qkv_proj

            # Check the quantization method is FakeQuantLinearMethod
            assert isinstance(qkv_proj.quant_method, FakeQuantLinearMethod)
142

143
        llm.apply_model(check_model)
144
145
146

        output = llm.generate_greedy("Hello my name is", max_tokens=20)
        assert output