test_register_quantization_config.py 4.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
5
6
7
"""Tests register custom quantization config.

See https://github.com/vllm-project/vllm/issues/11926 for more details.

Run `pytest tests/quantization/test_register_quantization_config.py`.
"""
8
from typing import Any, Optional
9

10
import os
11
12
13
14
15
16
17
18
19
20
import pytest
import torch
import torch.nn.functional as F

from vllm.model_executor.layers.linear import LinearBase  # noqa: E501
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization import (
    get_quantization_config, register_quantization_config)
from vllm.model_executor.layers.quantization.base_config import (  # noqa: E501
    QuantizationConfig)
21
from ..utils import models_path_prefix
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62


class FakeQuantLinearMethod(UnquantizedLinearMethod):
    """Fake quantization linear method for per-token dynamic quantization."""

    def __init__(self, num_bits: int = 8) -> None:
        """Initialize the quantization method."""
        super().__init__()
        self.num_bits = num_bits

    def apply(self,
              layer: "torch.nn.Module",
              x: "torch.Tensor",
              bias: Optional["torch.Tensor"] = None) -> "torch.Tensor":
        """Perform fake quantization before the linear layer."""

        # Calculate the scales dynamically
        max_val = torch.amax(x, dim=(0, -1), keepdims=True)
        min_val = torch.amin(x, dim=(0, -1), keepdims=True)
        scales = (max_val - min_val) / (2**self.num_bits - 1)

        # Fake quantize the input
        quant_x = torch.clamp(torch.round(x / scales), -2**(self.num_bits - 1),
                              2**(self.num_bits - 1) - 1)
        dequant_x = quant_x * scales

        return F.linear(dequant_x, layer.weight, bias)


@register_quantization_config("custom_quant")
class CustomQuantConfig(QuantizationConfig):
    """Custom quantization config for per-token dynamic fake quantization."""

    def __init__(self, num_bits: int = 8) -> None:
        """Initialize the quantization config."""
        self.num_bits = num_bits

    def get_name(self) -> str:
        """Name of the quantization method."""
        return "custom_quant"

63
    def get_supported_act_dtypes(self) -> list["torch.dtype"]:
64
65
66
67
68
69
70
71
72
        """List of supported activation dtypes."""
        return [torch.float16, torch.bfloat16]

    @classmethod
    def get_min_capability(cls) -> int:
        """Minimum GPU capability to support the quantization method."""
        return -1

    @staticmethod
73
    def get_config_filenames() -> list[str]:
74
75
76
77
        """List of filenames to search for in the model directory."""
        return []

    @classmethod
78
    def from_config(cls, config: dict[str, Any]) -> "CustomQuantConfig":
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
        """Create a config class from the model's quantization config."""
        return CustomQuantConfig(num_bits=config.get("num_bits", 8))

    def get_quant_method(self, layer: "torch.nn.Module",
                         prefix: str) -> Optional["FakeQuantLinearMethod"]:
        """Get the quantize method to use for the quantized layer."""
        if isinstance(layer, LinearBase):
            return FakeQuantLinearMethod(num_bits=self.num_bits)
        return None


def test_register_quantization_config():
    """Test register custom quantization config."""

    # The quantization method `custom_quant` should be registered.
    assert get_quantization_config("custom_quant") == CustomQuantConfig

    # The quantization method `custom_quant` is already exists,
    # should raise an error.
    with pytest.raises(ValueError):
        register_quantization_config("custom_quant")(CustomQuantConfig)


@pytest.mark.parametrize(argnames="model",
                         argvalues=[
104
                             os.path.join(models_path_prefix, "meta-llama/Llama-3.2-1B-Instruct"),
105
                         ])
106
def test_custom_quant(vllm_runner, model, monkeypatch):
107
    """Test infer with the custom quantization method."""
108
109
    # vllm_runner.apply_model() relies on V0 internals.
    monkeypatch.setenv("VLLM_USE_V1", "0")
110
111
112
113
114
115
116
117
118
119
120
121
    with vllm_runner(model_name=model,
                     quantization="custom_quant",
                     enforce_eager=True) as llm:

        model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model  # noqa: E501
        layer = model.model.layers[0]
        qkv_proj = layer.self_attn.qkv_proj

        # Check the quantization method is FakeQuantLinearMethod
        assert isinstance(qkv_proj.quant_method, FakeQuantLinearMethod)

        output = llm.generate_greedy("Hello my name is", max_tokens=20)
122
        assert output