test_configs.py 2.56 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
"""Tests whether Marlin models can be loaded from the autogptq config.

Run `pytest tests/quantization/test_configs.py --forked`.
"""

from dataclasses import dataclass

import pytest

from vllm.config import ModelConfig
13
from vllm.platforms import current_platform
14
15
16
17
18
19
20
21
22
23
24
25
26


@dataclass
class ModelPair:
    model_marlin: str
    model_gptq: str


# Model Id // Quantization Arg // Expected Type
MODEL_ARG_EXPTYPES = [
    # AUTOGPTQ
    # compat: autogptq <=0.7.1 is_marlin_format: bool
    # Model Serialized in Exllama Format.
27
28
29
30
31
32
33
34
35
36
    (
        "TheBloke/Llama-2-7B-Chat-GPTQ",
        None,
        "gptq_marlin" if current_platform.is_cuda() else "gptq",
    ),
    (
        "TheBloke/Llama-2-7B-Chat-GPTQ",
        "marlin",
        "gptq_marlin" if current_platform.is_cuda() else "ERROR",
    ),
37
38
39
40
    ("TheBloke/Llama-2-7B-Chat-GPTQ", "gptq", "gptq"),
    ("TheBloke/Llama-2-7B-Chat-GPTQ", "awq", "ERROR"),
    # compat: autogptq >=0.8.0 use checkpoint_format: str
    # Model Serialized in Exllama Format.
41
42
43
44
45
46
47
48
49
50
    (
        "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit",
        None,
        "gptq_marlin" if current_platform.is_cuda() else "gptq",
    ),
    (
        "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit",
        "marlin",
        "gptq_marlin" if current_platform.is_cuda() else "ERROR",
    ),
51
52
53
    ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "gptq", "gptq"),
    ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "awq", "ERROR"),
    # AUTOAWQ
54
55
56
57
58
    (
        "TheBloke/OpenHermes-2.5-Mistral-7B-AWQ",
        None,
        "awq_marlin" if current_platform.is_cuda() else "awq",
    ),
59
    ("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", "awq", "awq"),
60
61
62
63
64
    (
        "TheBloke/OpenHermes-2.5-Mistral-7B-AWQ",
        "marlin",
        "awq_marlin" if current_platform.is_cuda() else "ERROR",
    ),
65
66
67
68
69
    ("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", "gptq", "ERROR"),
]


@pytest.mark.parametrize("model_arg_exptype", MODEL_ARG_EXPTYPES)
70
def test_auto_gptq(model_arg_exptype: tuple[str, None, str]) -> None:
71
72
73
    model_path, quantization_arg, expected_type = model_arg_exptype

    try:
74
        model_config = ModelConfig(model_path, quantization=quantization_arg)
75
76
77
78
79
80
81
        found_quantization_type = model_config.quantization
    except ValueError:
        found_quantization_type = "ERROR"

    assert found_quantization_type == expected_type, (
        f"Expected quant_type == {expected_type} for {model_path}, "
        f"but found {found_quantization_type} "
82
83
        f"for no --quantization {quantization_arg} case"
    )