test_gguf.py 5.42 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
5
6
7
"""
Tests gguf models against unquantized models generations
Note: To pass the test, quantization higher than Q4 should be used
"""

import os
8
from typing import NamedTuple
9
10
11

import pytest
from huggingface_hub import hf_hub_download
12
from pytest import MarkDecorator
13
from transformers import AutoTokenizer
14
15
16

from tests.quantization.utils import is_quant_method_supported

17
from ....conftest import VllmRunner
18
from ....utils import multi_gpu_test
19
from ...utils import check_logprobs_close
20
21
22
23
24
25

os.environ["TOKENIZERS_PARALLELISM"] = "true"

MAX_MODEL_LEN = 1024


26
27
28
29
class GGUFTestConfig(NamedTuple):
    original_model: str
    gguf_repo: str
    gguf_filename: str
30
    marks: list[MarkDecorator] = []
31
32
33
34
35
36
37
38
39
40

    @property
    def gguf_model(self):
        return hf_hub_download(self.gguf_repo, filename=self.gguf_filename)


LLAMA_CONFIG = GGUFTestConfig(
    original_model="meta-llama/Llama-3.2-1B-Instruct",
    gguf_repo="bartowski/Llama-3.2-1B-Instruct-GGUF",
    gguf_filename="Llama-3.2-1B-Instruct-IQ4_XS.gguf",
41
    marks=[pytest.mark.quant_model],
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
)

QWEN2_CONFIG = GGUFTestConfig(
    original_model="Qwen/Qwen2.5-1.5B-Instruct",
    gguf_repo="Qwen/Qwen2.5-1.5B-Instruct-GGUF",
    gguf_filename="qwen2.5-1.5b-instruct-q6_k.gguf",
)

PHI3_CONFIG = GGUFTestConfig(
    original_model="microsoft/Phi-3.5-mini-instruct",
    gguf_repo="bartowski/Phi-3.5-mini-instruct-GGUF",
    gguf_filename="Phi-3.5-mini-instruct-IQ4_XS.gguf",
)

GPT2_CONFIG = GGUFTestConfig(
    original_model="openai-community/gpt2-large",
    gguf_repo="QuantFactory/gpt2-large-GGUF",
    gguf_filename="gpt2-large.Q4_K_M.gguf",
)

STABLELM_CONFIG = GGUFTestConfig(
    original_model="stabilityai/stablelm-3b-4e1t",
    gguf_repo="afrideva/stablelm-3b-4e1t-GGUF",
    gguf_filename="stablelm-3b-4e1t.q4_k_m.gguf",
)

STARCODER_CONFIG = GGUFTestConfig(
    original_model="bigcode/starcoder2-3b",
    gguf_repo="QuantFactory/starcoder2-3b-GGUF",
    gguf_filename="starcoder2-3b.Q6_K.gguf",
)

74
75
76
77
78
79
80
DOLPHIN_CONFIG = GGUFTestConfig(
    # Test VocabParallelEmbedding sharding issue.
    original_model="cognitivecomputations/TinyDolphin-2.8-1.1b",
    gguf_repo="tsunemoto/TinyDolphin-2.8-1.1b-GGUF",
    gguf_filename="tinydolphin-2.8-1.1b.Q6_K.gguf",
)

81
MODELS = [
82
    LLAMA_CONFIG, QWEN2_CONFIG, PHI3_CONFIG, GPT2_CONFIG, STABLELM_CONFIG,
83
    DOLPHIN_CONFIG
84
85
86
87
    # STARCODER_CONFIG, # broken
]


88
def check_model_outputs(
89
    vllm_runner: type[VllmRunner],
90
    prompts: list[str],
91
    model: GGUFTestConfig,
92
93
94
    dtype: str,
    max_tokens: int,
    num_logprobs: int,
95
    tp_size: int,
96
):
97
98
99
100
101
    tokenizer = AutoTokenizer.from_pretrained(model.original_model)
    if tokenizer.chat_template is not None:
        messages = [[{
            'role': 'user',
            'content': prompt
102
103
104
105
        }] for prompt in prompts]
        prompts = tokenizer.apply_chat_template(messages,
                                                tokenize=False,
                                                add_generation_prompt=True)
106

107
    # Run gguf model.
108
    with vllm_runner(model_name=model.gguf_model,
109
                     enforce_eager=True,
110
                     tokenizer_name=model.original_model,
111
112
                     dtype=dtype,
                     max_model_len=MAX_MODEL_LEN,
113
                     tensor_parallel_size=tp_size) as gguf_model:
114
        gguf_outputs = gguf_model.generate_greedy_logprobs(
115
            prompts[:-1], max_tokens, num_logprobs)
116

117
    # Run unquantized model.
118
119
    # Should run with tp=1, otherwise the test will stuck at
    # nccl initialization.
120
121
122
123
124
    with vllm_runner(
            model_name=model.original_model,
            enforce_eager=True,  # faster tests
            dtype=dtype,
            max_model_len=MAX_MODEL_LEN,
125
            tensor_parallel_size=1) as original_model:
126
        original_outputs = original_model.generate_greedy_logprobs(
127
            prompts[:-1], max_tokens, num_logprobs)
128

129
130
131
132
133
134
    check_logprobs_close(
        outputs_0_lst=original_outputs,
        outputs_1_lst=gguf_outputs,
        name_0="original",
        name_1="gguf",
    )
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178


@pytest.mark.skipif(not is_quant_method_supported("gguf"),
                    reason="gguf is not supported on this GPU type.")
@pytest.mark.parametrize("model", [
    pytest.param(test_config, marks=test_config.marks)
    for test_config in MODELS
])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("tp_size", [1])
def test_models(
    vllm_runner: type[VllmRunner],
    example_prompts: list[str],
    model: GGUFTestConfig,
    dtype: str,
    max_tokens: int,
    num_logprobs: int,
    tp_size: int,
) -> None:
    check_model_outputs(vllm_runner, example_prompts, model, dtype, max_tokens,
                        num_logprobs, tp_size)


@pytest.mark.skipif(not is_quant_method_supported("gguf"),
                    reason="gguf is not supported on this GPU type.")
@pytest.mark.parametrize("model", [LLAMA_CONFIG])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [8])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("tp_size", [2])
@multi_gpu_test(num_gpus=2)
def test_distributed(
    vllm_runner: type[VllmRunner],
    example_prompts: list[str],
    model: GGUFTestConfig,
    dtype: str,
    max_tokens: int,
    num_logprobs: int,
    tp_size: int,
) -> None:
    check_model_outputs(vllm_runner, example_prompts, model, dtype, max_tokens,
                        num_logprobs, tp_size)