test_configs.py 3.79 KB
Newer Older
1
2
3
4
5
6
"""Tests whether Marlin models can be loaded from the autogptq config.

Run `pytest tests/quantization/test_configs.py --forked`.
"""

from dataclasses import dataclass
7
from typing import Tuple
8
9

import pytest
10
import os
11
12

from vllm.config import ModelConfig
13
from ..utils import models_path_prefix
14
15
16
17
18
19
20
21
22
23
24
25
26


@dataclass
class ModelPair:
    model_marlin: str
    model_gptq: str


# Model Id // Quantization Arg // Expected Type
MODEL_ARG_EXPTYPES = [
    # AUTOGPTQ
    # compat: autogptq <=0.7.1 is_marlin_format: bool
    # Model Serialized in Marlin Format should always use Marlin kernel.
27
28
29
    # (os.path.join(models_path_prefix, "neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin"), None, "marlin"),
    # (os.path.join(models_path_prefix, "neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin"), "marlin", "marlin"),
    # (os.path.join(models_path_prefix, "neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin"), "gptq", "marlin"),
30
    (os.path.join(models_path_prefix, "neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin"), "awq", "ERROR"),
31
    # Model Serialized in Exllama Format.
32
33
34
    # (os.path.join(models_path_prefix, "TheBloke/Llama-2-7B-Chat-GPTQ"), None, "gptq_marlin"),
    # (os.path.join(models_path_prefix, "TheBloke/Llama-2-7B-Chat-GPTQ"), "marlin", "gptq_marlin"),
    # (os.path.join(models_path_prefix, "TheBloke/Llama-2-7B-Chat-GPTQ"), "gptq", "gptq"),
35
    (os.path.join(models_path_prefix, "TheBloke/Llama-2-7B-Chat-GPTQ"), "awq", "ERROR"),
36
37
    # compat: autogptq >=0.8.0 use checkpoint_format: str
    # Model Serialized in Marlin Format should always use Marlin kernel.
38
39
40
    # (os.path.join(models_path_prefix, "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit"), None, "marlin"),
    # (os.path.join(models_path_prefix, "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit"), "marlin", "marlin"),
    # (os.path.join(models_path_prefix, "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit"), "gptq", "marlin"),
41
    (os.path.join(models_path_prefix, "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit"), "awq", "ERROR"),
42
    # Model Serialized in Exllama Format.
43
44
    # (os.path.join(models_path_prefix, "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"), None, "gptq_marlin"),
    # (os.path.join(models_path_prefix, "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"), "marlin", "gptq_marlin"),
45
46
    (os.path.join(models_path_prefix, "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"), "gptq", "gptq"),
    (os.path.join(models_path_prefix, "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"), "awq", "ERROR"),
47
48

    # AUTOAWQ
49
    # (os.path.join(models_path_prefix, "TheBloke/OpenHermes-2.5-Mistral-7B-AWQ"), None, "awq_marlin"),
50
    (os.path.join(models_path_prefix, "TheBloke/OpenHermes-2.5-Mistral-7B-AWQ"), "awq", "awq"),
51
    # (os.path.join(models_path_prefix, "TheBloke/OpenHermes-2.5-Mistral-7B-AWQ"), "marlin", "awq_marlin"),
52
    (os.path.join(models_path_prefix, "TheBloke/OpenHermes-2.5-Mistral-7B-AWQ"), "gptq", "ERROR"),
53
54
55
56
]


@pytest.mark.parametrize("model_arg_exptype", MODEL_ARG_EXPTYPES)
57
def test_auto_gptq(model_arg_exptype: Tuple[str, None, str]) -> None:
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
    model_path, quantization_arg, expected_type = model_arg_exptype

    try:
        model_config = ModelConfig(model_path,
                                   model_path,
                                   tokenizer_mode="auto",
                                   trust_remote_code=False,
                                   seed=0,
                                   dtype="float16",
                                   revision=None,
                                   quantization=quantization_arg)
        found_quantization_type = model_config.quantization
    except ValueError:
        found_quantization_type = "ERROR"

    assert found_quantization_type == expected_type, (
        f"Expected quant_type == {expected_type} for {model_path}, "
        f"but found {found_quantization_type} "
        f"for no --quantization {quantization_arg} case")