test_gguf.py 5.39 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
5
6
7
"""
Tests gguf models against unquantized models generations
Note: To pass the test, quantization higher than Q4 should be used
"""

import os
8
from typing import NamedTuple
9
10
11

import pytest
from huggingface_hub import hf_hub_download
12
from pytest import MarkDecorator
13
from transformers import AutoTokenizer
14
15
16

from tests.quantization.utils import is_quant_method_supported

17
18
19
from ...conftest import VllmRunner
from ...utils import multi_gpu_test
from ..utils import check_logprobs_close
20
21
22
23
24
25

os.environ["TOKENIZERS_PARALLELISM"] = "true"

MAX_MODEL_LEN = 1024


26
27
28
29
class GGUFTestConfig(NamedTuple):
    original_model: str
    gguf_repo: str
    gguf_filename: str
30
    marks: list[MarkDecorator] = []
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72

    @property
    def gguf_model(self):
        return hf_hub_download(self.gguf_repo, filename=self.gguf_filename)


LLAMA_CONFIG = GGUFTestConfig(
    original_model="meta-llama/Llama-3.2-1B-Instruct",
    gguf_repo="bartowski/Llama-3.2-1B-Instruct-GGUF",
    gguf_filename="Llama-3.2-1B-Instruct-IQ4_XS.gguf",
)

QWEN2_CONFIG = GGUFTestConfig(
    original_model="Qwen/Qwen2.5-1.5B-Instruct",
    gguf_repo="Qwen/Qwen2.5-1.5B-Instruct-GGUF",
    gguf_filename="qwen2.5-1.5b-instruct-q6_k.gguf",
)

PHI3_CONFIG = GGUFTestConfig(
    original_model="microsoft/Phi-3.5-mini-instruct",
    gguf_repo="bartowski/Phi-3.5-mini-instruct-GGUF",
    gguf_filename="Phi-3.5-mini-instruct-IQ4_XS.gguf",
)

GPT2_CONFIG = GGUFTestConfig(
    original_model="openai-community/gpt2-large",
    gguf_repo="QuantFactory/gpt2-large-GGUF",
    gguf_filename="gpt2-large.Q4_K_M.gguf",
)

STABLELM_CONFIG = GGUFTestConfig(
    original_model="stabilityai/stablelm-3b-4e1t",
    gguf_repo="afrideva/stablelm-3b-4e1t-GGUF",
    gguf_filename="stablelm-3b-4e1t.q4_k_m.gguf",
)

STARCODER_CONFIG = GGUFTestConfig(
    original_model="bigcode/starcoder2-3b",
    gguf_repo="QuantFactory/starcoder2-3b-GGUF",
    gguf_filename="starcoder2-3b.Q6_K.gguf",
)

73
74
75
76
77
78
79
DOLPHIN_CONFIG = GGUFTestConfig(
    # Test VocabParallelEmbedding sharding issue.
    original_model="cognitivecomputations/TinyDolphin-2.8-1.1b",
    gguf_repo="tsunemoto/TinyDolphin-2.8-1.1b-GGUF",
    gguf_filename="tinydolphin-2.8-1.1b.Q6_K.gguf",
)

80
MODELS = [
81
    LLAMA_CONFIG, QWEN2_CONFIG, PHI3_CONFIG, GPT2_CONFIG, STABLELM_CONFIG,
82
    DOLPHIN_CONFIG
83
84
85
86
    # STARCODER_CONFIG, # broken
]


87
def check_model_outputs(
88
    vllm_runner: type[VllmRunner],
89
    prompts: list[str],
90
    model: GGUFTestConfig,
91
92
93
    dtype: str,
    max_tokens: int,
    num_logprobs: int,
94
    tp_size: int,
95
):
96
97
98
99
100
    tokenizer = AutoTokenizer.from_pretrained(model.original_model)
    if tokenizer.chat_template is not None:
        messages = [[{
            'role': 'user',
            'content': prompt
101
102
103
104
        }] for prompt in prompts]
        prompts = tokenizer.apply_chat_template(messages,
                                                tokenize=False,
                                                add_generation_prompt=True)
105

106
    # Run gguf model.
107
    with vllm_runner(model_name=model.gguf_model,
108
                     enforce_eager=True,
109
                     tokenizer_name=model.original_model,
110
111
                     dtype=dtype,
                     max_model_len=MAX_MODEL_LEN,
112
                     tensor_parallel_size=tp_size) as gguf_model:
113
        gguf_outputs = gguf_model.generate_greedy_logprobs(
114
            prompts[:-1], max_tokens, num_logprobs)
115

116
    # Run unquantized model.
117
118
    # Should run with tp=1, otherwise the test will stuck at
    # nccl initialization.
119
120
121
122
123
    with vllm_runner(
            model_name=model.original_model,
            enforce_eager=True,  # faster tests
            dtype=dtype,
            max_model_len=MAX_MODEL_LEN,
124
            tensor_parallel_size=1) as original_model:
125
        original_outputs = original_model.generate_greedy_logprobs(
126
            prompts[:-1], max_tokens, num_logprobs)
127

128
129
130
131
132
133
    check_logprobs_close(
        outputs_0_lst=original_outputs,
        outputs_1_lst=gguf_outputs,
        name_0="original",
        name_1="gguf",
    )
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177


@pytest.mark.skipif(not is_quant_method_supported("gguf"),
                    reason="gguf is not supported on this GPU type.")
@pytest.mark.parametrize("model", [
    pytest.param(test_config, marks=test_config.marks)
    for test_config in MODELS
])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("tp_size", [1])
def test_models(
    vllm_runner: type[VllmRunner],
    example_prompts: list[str],
    model: GGUFTestConfig,
    dtype: str,
    max_tokens: int,
    num_logprobs: int,
    tp_size: int,
) -> None:
    check_model_outputs(vllm_runner, example_prompts, model, dtype, max_tokens,
                        num_logprobs, tp_size)


@pytest.mark.skipif(not is_quant_method_supported("gguf"),
                    reason="gguf is not supported on this GPU type.")
@pytest.mark.parametrize("model", [LLAMA_CONFIG])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [8])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("tp_size", [2])
@multi_gpu_test(num_gpus=2)
def test_distributed(
    vllm_runner: type[VllmRunner],
    example_prompts: list[str],
    model: GGUFTestConfig,
    dtype: str,
    max_tokens: int,
    num_logprobs: int,
    tp_size: int,
) -> None:
    check_model_outputs(vllm_runner, example_prompts, model, dtype, max_tokens,
                        num_logprobs, tp_size)