test_autogptq_marlin_configs.py 1.82 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
"""Tests whether Marlin models can be loaded from the autogptq config.

Run `pytest tests/quantization/test_autogptq_marlin_configs.py --forked`.
"""

from dataclasses import dataclass

import pytest

from vllm.config import ModelConfig


@dataclass
class ModelPair:
    model_marlin: str
    model_gptq: str


# Model Id // Expected Kernel
MODELS_QUANT_TYPE = [
    # compat: autogptq <=0.7.1 is_marlin_format: bool
    ("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", "marlin"),
    ("TheBloke/Llama-2-7B-Chat-GPTQ", "gptq"),
    # compat: autogptq >=0.8.0 use checkpoint_format: str
    ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", "marlin"),
    ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "gptq")
]


@pytest.mark.parametrize("model_quant_type", MODELS_QUANT_TYPE)
def test_auto_gptq(model_quant_type: str, ) -> None:
    model_path, quant_type = model_quant_type

    model_config_no_quant_arg = ModelConfig(
        model_path,
        model_path,
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
        revision=None,
        quantization=None  # case 1
    )

    model_config_quant_arg = ModelConfig(
        model_path,
        model_path,
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
        revision=None,
        quantization="gptq"  # case 2
    )

    assert model_config_no_quant_arg.quantization == quant_type, (
        f"Expected quant_type == {quant_type} for {model_path}, "
        f"but found {model_config_no_quant_arg.quantization} "
        "for no --quantization None case")

    assert model_config_quant_arg.quantization == quant_type, (
        f"Expected quant_type == {quant_type} for {model_path}, "
        f"but found {model_config_quant_arg.quantization} "
        "for --quantization gptq case")