"vscode:/vscode.git/clone" did not exist on "3e2c63a7ba8c52edd26d6c6838e1b4c0ea6cf412"
test_configs.py 3.01 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
"""Tests whether Marlin models can be loaded from the autogptq config.

Run `pytest tests/quantization/test_configs.py --forked`.
"""

from dataclasses import dataclass

import pytest
11
import os
12
13

from vllm.config import ModelConfig
14
from vllm.platforms import current_platform
15
from tests.utils import models_path_prefix
16
17
18
19
20
21
22
23
24
25
26
27
28


@dataclass
class ModelPair:
    model_marlin: str
    model_gptq: str


# Model Id // Quantization Arg // Expected Type
MODEL_ARG_EXPTYPES = [
    # AUTOGPTQ
    # compat: autogptq <=0.7.1 is_marlin_format: bool
    # Model Serialized in Exllama Format.
29
    (
30
        os.path.join(models_path_prefix, "TheBloke/Llama-2-7B-Chat-GPTQ"),
31
32
33
34
        None,
        "gptq_marlin" if current_platform.is_cuda() else "gptq",
    ),
    (
35
        os.path.join(models_path_prefix, "TheBloke/Llama-2-7B-Chat-GPTQ"),
36
37
38
        "marlin",
        "gptq_marlin" if current_platform.is_cuda() else "ERROR",
    ),
39
    (os.path.join(models_path_prefix, "TheBloke/Llama-2-7B-Chat-GPTQ"), "gptq", "gptq"),
40
    (os.path.join(models_path_prefix, "TheBloke/Llama-2-7B-Chat-GPTQ"), "awq", "ERROR"),
41
    # compat: autogptq >=0.8.0 use checkpoint_format: str
42

43
    # Model Serialized in Exllama Format.
44
    (
45
        os.path.join(models_path_prefix, "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"),
46
47
48
49
        None,
        "gptq_marlin" if current_platform.is_cuda() else "gptq",
    ),
    (
50
        os.path.join(models_path_prefix, "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"),
51
52
53
        "marlin",
        "gptq_marlin" if current_platform.is_cuda() else "ERROR",
    ),
54
55
    (os.path.join(models_path_prefix, "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"), "gptq", "gptq"),
    (os.path.join(models_path_prefix, "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"), "awq", "ERROR"),
56
    # AUTOAWQ
57
    (
58
        os.path.join(models_path_prefix, "TheBloke/OpenHermes-2.5-Mistral-7B-AWQ"),
59
60
61
        None,
        "awq_marlin" if current_platform.is_cuda() else "awq",
    ),
62
    (os.path.join(models_path_prefix, "TheBloke/OpenHermes-2.5-Mistral-7B-AWQ"), "awq", "awq"),
63
    (
64
        os.path.join(models_path_prefix, "TheBloke/OpenHermes-2.5-Mistral-7B-AWQ"),
65
66
67
        "marlin",
        "awq_marlin" if current_platform.is_cuda() else "ERROR",
    ),
68
    (os.path.join(models_path_prefix, "TheBloke/OpenHermes-2.5-Mistral-7B-AWQ"), "gptq", "ERROR"),
69
70
71
72
]


@pytest.mark.parametrize("model_arg_exptype", MODEL_ARG_EXPTYPES)
73
def test_auto_gptq(model_arg_exptype: tuple[str, None, str]) -> None:
74
75
76
    model_path, quantization_arg, expected_type = model_arg_exptype

    try:
77
        model_config = ModelConfig(model_path, quantization=quantization_arg)
78
79
80
81
82
83
84
        found_quantization_type = model_config.quantization
    except ValueError:
        found_quantization_type = "ERROR"

    assert found_quantization_type == expected_type, (
        f"Expected quant_type == {expected_type} for {model_path}, "
        f"but found {found_quantization_type} "
85
86
        f"for no --quantization {quantization_arg} case"
    )