test_gguf.py 5.51 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
"""
Tests gguf models against unquantized models generations
Note: To pass the test, quantization higher than Q4 should be used
"""

import os
9
from typing import NamedTuple
10
11
12

import pytest
from huggingface_hub import hf_hub_download
13
from pytest import MarkDecorator
14
from transformers import AutoTokenizer
15
16
17

from tests.quantization.utils import is_quant_method_supported

18
19
20
from ...conftest import VllmRunner
from ...utils import multi_gpu_test
from ..utils import check_logprobs_close
21
22
23
24
25
26

os.environ["TOKENIZERS_PARALLELISM"] = "true"

MAX_MODEL_LEN = 1024


27
28
29
30
class GGUFTestConfig(NamedTuple):
    original_model: str
    gguf_repo: str
    gguf_filename: str
31
    marks: list[MarkDecorator] = []
32
33
34
35
36
37
38
39
40

    @property
    def gguf_model(self):
        return hf_hub_download(self.gguf_repo, filename=self.gguf_filename)


LLAMA_CONFIG = GGUFTestConfig(
    original_model="meta-llama/Llama-3.2-1B-Instruct",
    gguf_repo="bartowski/Llama-3.2-1B-Instruct-GGUF",
41
    gguf_filename="Llama-3.2-1B-Instruct-Q6_K.gguf",
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
)

QWEN2_CONFIG = GGUFTestConfig(
    original_model="Qwen/Qwen2.5-1.5B-Instruct",
    gguf_repo="Qwen/Qwen2.5-1.5B-Instruct-GGUF",
    gguf_filename="qwen2.5-1.5b-instruct-q6_k.gguf",
)

PHI3_CONFIG = GGUFTestConfig(
    original_model="microsoft/Phi-3.5-mini-instruct",
    gguf_repo="bartowski/Phi-3.5-mini-instruct-GGUF",
    gguf_filename="Phi-3.5-mini-instruct-IQ4_XS.gguf",
)

GPT2_CONFIG = GGUFTestConfig(
    original_model="openai-community/gpt2-large",
    gguf_repo="QuantFactory/gpt2-large-GGUF",
    gguf_filename="gpt2-large.Q4_K_M.gguf",
)

STABLELM_CONFIG = GGUFTestConfig(
    original_model="stabilityai/stablelm-3b-4e1t",
    gguf_repo="afrideva/stablelm-3b-4e1t-GGUF",
    gguf_filename="stablelm-3b-4e1t.q4_k_m.gguf",
)

STARCODER_CONFIG = GGUFTestConfig(
    original_model="bigcode/starcoder2-3b",
    gguf_repo="QuantFactory/starcoder2-3b-GGUF",
    gguf_filename="starcoder2-3b.Q6_K.gguf",
)

74
75
76
77
78
79
80
DOLPHIN_CONFIG = GGUFTestConfig(
    # Test VocabParallelEmbedding sharding issue.
    original_model="cognitivecomputations/TinyDolphin-2.8-1.1b",
    gguf_repo="tsunemoto/TinyDolphin-2.8-1.1b-GGUF",
    gguf_filename="tinydolphin-2.8-1.1b.Q6_K.gguf",
)

81
MODELS = [
82
83
84
85
86
87
    LLAMA_CONFIG,
    QWEN2_CONFIG,
    PHI3_CONFIG,
    GPT2_CONFIG,
    # STABLELM_CONFIG,  # enable this when v1 support head_size=80
    DOLPHIN_CONFIG,
88
89
90
91
    # STARCODER_CONFIG, # broken
]


92
def check_model_outputs(
93
    vllm_runner: type[VllmRunner],
94
    prompts: list[str],
95
    model: GGUFTestConfig,
96
97
98
    dtype: str,
    max_tokens: int,
    num_logprobs: int,
99
    tp_size: int,
100
):
101
102
103
104
105
    tokenizer = AutoTokenizer.from_pretrained(model.original_model)
    if tokenizer.chat_template is not None:
        messages = [[{
            'role': 'user',
            'content': prompt
106
107
108
109
        }] for prompt in prompts]
        prompts = tokenizer.apply_chat_template(messages,
                                                tokenize=False,
                                                add_generation_prompt=True)
110

111
    # Run gguf model.
112
    with vllm_runner(model_name=model.gguf_model,
113
                     enforce_eager=True,
114
                     tokenizer_name=model.original_model,
115
116
                     dtype=dtype,
                     max_model_len=MAX_MODEL_LEN,
117
                     tensor_parallel_size=tp_size) as gguf_model:
118
        gguf_outputs = gguf_model.generate_greedy_logprobs(
119
            prompts[:-1], max_tokens, num_logprobs)
120

121
    # Run unquantized model.
122
123
    # Should run with tp=1, otherwise the test will stuck at
    # nccl initialization.
124
125
126
127
128
    with vllm_runner(
            model_name=model.original_model,
            enforce_eager=True,  # faster tests
            dtype=dtype,
            max_model_len=MAX_MODEL_LEN,
129
            tensor_parallel_size=1) as original_model:
130
        original_outputs = original_model.generate_greedy_logprobs(
131
            prompts[:-1], max_tokens, num_logprobs)
132

133
134
135
136
137
138
    check_logprobs_close(
        outputs_0_lst=original_outputs,
        outputs_1_lst=gguf_outputs,
        name_0="original",
        name_1="gguf",
    )
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182


@pytest.mark.skipif(not is_quant_method_supported("gguf"),
                    reason="gguf is not supported on this GPU type.")
@pytest.mark.parametrize("model", [
    pytest.param(test_config, marks=test_config.marks)
    for test_config in MODELS
])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("tp_size", [1])
def test_models(
    vllm_runner: type[VllmRunner],
    example_prompts: list[str],
    model: GGUFTestConfig,
    dtype: str,
    max_tokens: int,
    num_logprobs: int,
    tp_size: int,
) -> None:
    check_model_outputs(vllm_runner, example_prompts, model, dtype, max_tokens,
                        num_logprobs, tp_size)


@pytest.mark.skipif(not is_quant_method_supported("gguf"),
                    reason="gguf is not supported on this GPU type.")
@pytest.mark.parametrize("model", [LLAMA_CONFIG])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [8])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("tp_size", [2])
@multi_gpu_test(num_gpus=2)
def test_distributed(
    vllm_runner: type[VllmRunner],
    example_prompts: list[str],
    model: GGUFTestConfig,
    dtype: str,
    max_tokens: int,
    num_logprobs: int,
    tp_size: int,
) -> None:
    check_model_outputs(vllm_runner, example_prompts, model, dtype, max_tokens,
                        num_logprobs, tp_size)