test_bitsandbytes.py 9.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Tests whether bitsandbytes computation is enabled correctly.
4
5

Run `pytest tests/quantization/test_bitsandbytes.py`.
6
"""
7

8
import pytest
Marc Sun's avatar
Marc Sun committed
9
from packaging.version import Version
10
from transformers import BitsAndBytesConfig
Marc Sun's avatar
Marc Sun committed
11
from transformers import __version__ as TRANSFORMERS_VERSION
12

13
from tests.quantization.utils import is_quant_method_supported
14
from vllm.platforms import current_platform
15

16
from ...utils import compare_two_settings, multi_gpu_test
17
from ..utils import check_embeddings_close, check_logprobs_close
youkaichao's avatar
youkaichao committed
18

19
20
21
22
23
24
25
if current_platform.is_rocm():
    from vllm.platforms.rocm import on_gfx9

    pytestmark = pytest.mark.skipif(
        on_gfx9(),
        reason="bitsandbytes not supported on gfx9 (warp size 64 limitation)",
    )
26

27
models_4bit_to_test = [
28
    ("facebook/opt-125m", "quantize opt model inflight"),
29
30
31
32
    (
        "mistralai/Mistral-7B-Instruct-v0.3",
        "quantize inflight model with both HF and Mistral format weights",
    ),
33
34
]

35
36
37
38
models_4bit_to_embedding_test = [
    ("intfloat/e5-mistral-7b-instruct", "quantize embedding model inflight"),
]

39
40
41
42
models_4bit_to_moe_test = [
    ("allenai/OLMoE-1B-7B-0125-Instruct", "quantize moe model inflight"),
]

43
models_pre_qaunt_4bit_to_test = [
44
45
46
47
48
    (
        "PrunaAI/Einstein-v6.1-Llama3-8B-bnb-4bit-smashed",
        "read pre-quantized 4-bit FP4 model",
    ),
    ("poedator/opt-125m-bnb-4bit", "read pre-quantized 4-bit NF4 opt model"),
49
50
51
]

models_pre_quant_8bit_to_test = [
52
    ("meta-llama/Llama-Guard-3-8B-INT8", "read pre-quantized llama 8-bit model"),
53
    ("yec019/fbopt-350m-8bit", "read pre-quantized 8-bit opt model"),
54
55
56
]


57
58
59
60
@pytest.mark.skipif(
    not is_quant_method_supported("bitsandbytes"),
    reason="bitsandbytes is not supported on this GPU type.",
)
61
@pytest.mark.parametrize("model_name, description", models_4bit_to_test)
62
63
64
65
66
67
68
def test_load_4bit_bnb_model(
    hf_runner, vllm_runner, example_prompts, model_name, description
) -> None:
    hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig(load_in_4bit=True))
    validate_generated_texts(
        hf_runner, vllm_runner, example_prompts[:1], model_name, False, hf_model_kwargs
    )
69
70


71
72
73
74
75
76
77
78
79
80
81
@pytest.mark.skipif(
    not is_quant_method_supported("bitsandbytes"),
    reason="bitsandbytes is not supported on this GPU type.",
)
@pytest.mark.parametrize("model_name, description", models_pre_qaunt_4bit_to_test)
def test_load_pre_quant_4bit_bnb_model(
    hf_runner, vllm_runner, example_prompts, model_name, description
) -> None:
    validate_generated_texts(
        hf_runner, vllm_runner, example_prompts[:1], model_name, True
    )
82

83

84
85
86
87
88
89
90
91
92
93
94
@pytest.mark.skipif(
    not is_quant_method_supported("bitsandbytes"),
    reason="bitsandbytes is not supported on this GPU type.",
)
@pytest.mark.parametrize("model_name, description", models_pre_quant_8bit_to_test)
def test_load_8bit_bnb_model(
    hf_runner, vllm_runner, example_prompts, model_name, description
) -> None:
    validate_generated_texts(
        hf_runner, vllm_runner, example_prompts[:1], model_name, True
    )
95
96


97
98
99
100
@pytest.mark.skipif(
    not is_quant_method_supported("bitsandbytes"),
    reason="bitsandbytes is not supported on this GPU type.",
)
101
@pytest.mark.parametrize("model_name, description", models_4bit_to_test)
102
@multi_gpu_test(num_gpus=2)
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
def test_load_tp_4bit_bnb_model(
    hf_runner, vllm_runner, example_prompts, model_name, description
) -> None:
    hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig(load_in_4bit=True))
    validate_generated_texts(
        hf_runner,
        vllm_runner,
        example_prompts[:1],
        model_name,
        False,
        hf_model_kwargs,
        vllm_tp_size=2,
    )


@pytest.mark.skipif(
    not is_quant_method_supported("bitsandbytes"),
    reason="bitsandbytes is not supported on this GPU type.",
)
122
@pytest.mark.parametrize("model_name, description", models_4bit_to_test)
123
@multi_gpu_test(num_gpus=2)
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
def test_load_pp_4bit_bnb_model(model_name, description) -> None:
    common_args = [
        "--disable-log-stats",
        "--dtype",
        "bfloat16",
        "--enable-prefix-caching",
        "--quantization",
        "bitsandbytes",
        "--gpu-memory-utilization",
        "0.7",
    ]
    pp_args = [
        *common_args,
        "--pipeline-parallel-size",
        "2",
    ]
    compare_two_settings(model_name, common_args, pp_args)


Marc Sun's avatar
Marc Sun committed
143
144
145
146
147
148
@pytest.mark.skipif(
    Version(TRANSFORMERS_VERSION) >= Version("5.0.0"),
    reason="Need to add support for quantizing MoE experts with bnb"
    " in transformers v5. See"
    " https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1849",
)
149
150
151
152
@pytest.mark.skipif(
    not is_quant_method_supported("bitsandbytes"),
    reason="bitsandbytes is not supported on this GPU type.",
)
153
@pytest.mark.parametrize("model_name, description", models_4bit_to_moe_test)
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
def test_4bit_bnb_moe_model(
    hf_runner, vllm_runner, example_prompts, model_name, description
) -> None:
    hf_model_kwargs = dict(
        quantization_config=BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=True,
        )
    )
    with vllm_runner(
        model_name,
        quantization="bitsandbytes",
        enforce_eager=False,
        default_torch_num_threads=1,
    ) as llm:
        vllm_outputs = llm.generate_greedy_logprobs(
            example_prompts, max_tokens=32, num_logprobs=5
        )

    with hf_runner(
        model_name, model_kwargs=hf_model_kwargs, default_torch_num_threads=1
    ) as llm:
177
        transformers_outputs = llm.generate_greedy_logprobs_limit(
178
179
            example_prompts, max_tokens=32, num_logprobs=5
        )
180
181
182
183
184
185
186
187
    check_logprobs_close(
        outputs_0_lst=transformers_outputs,
        outputs_1_lst=vllm_outputs,
        name_0="transformers",
        name_1="vllm",
    )


188
189
190
191
192
@pytest.mark.skipif(
    not is_quant_method_supported("bitsandbytes"),
    reason="bitsandbytes is not supported on this GPU type.",
)
@pytest.mark.parametrize("model_name, description", models_4bit_to_embedding_test)
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
@pytest.mark.parametrize("dtype", ["half"])
def test_4bit_bnb_embedding_model(
    model_name,
    description,
    hf_runner,
    vllm_runner,
    example_prompts,
    dtype: str,
) -> None:
    # The example_prompts has ending "\n", for example:
    # "Write a short story about a robot that dreams for the first time.\n"
    # sentence_transformers will strip the input texts, see:
    # https://github.com/UKPLab/sentence-transformers/blob/v3.1.1/sentence_transformers/models/Transformer.py#L159
    # This makes the input_ids different between hf_model and vllm_model.
    # So we need to strip the input texts to avoid test failing.
    example_prompts = [str(s).strip() for s in example_prompts]

    # Inflight 4bit quantization
211
212
213
214
215
216
217
218
    with vllm_runner(
        model_name,
        runner="pooling",
        dtype=dtype,
        gpu_memory_utilization=0.5,
        quantization="bitsandbytes",
        default_torch_num_threads=1,
    ) as vllm_model:
219
220
        vllm_outputs = vllm_model.embed(example_prompts)

221
    hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig(load_in_4bit=True))
222
    with hf_runner(
223
224
225
226
227
        model_name,
        dtype=dtype,
        model_kwargs=hf_model_kwargs,
        is_sentence_transformer=True,
        default_torch_num_threads=1,
228
229
230
231
232
233
234
235
236
237
238
239
    ) as hf_model:
        hf_outputs = hf_model.encode(example_prompts)

    check_embeddings_close(
        embeddings_0_lst=hf_outputs,
        embeddings_1_lst=vllm_outputs,
        name_0="hf",
        name_1="vllm",
        tol=5e-2,
    )


240
241
242
243
244
245
246
247
248
249
250
251
def log_generated_texts(prompts, outputs, runner_name):
    logged_texts = []
    for i, (_, generated_text) in enumerate(outputs):
        log_entry = {
            "prompt": prompts[i],
            "runner_name": runner_name,
            "generated_text": generated_text,
        }
        logged_texts.append(log_entry)
    return logged_texts


252
253
254
255
256
257
258
259
260
261
def validate_generated_texts(
    hf_runner,
    vllm_runner,
    prompts,
    model_name,
    pre_quant=False,
    hf_model_kwargs=None,
    vllm_tp_size=1,
    max_tokens=8,
):
youkaichao's avatar
youkaichao committed
262
263
    # NOTE: run vLLM first, as it requires a clean process
    # when using distributed inference
264
265
266
267
268
269
    with vllm_runner(
        model_name,
        quantization=None if pre_quant else "bitsandbytes",
        tensor_parallel_size=vllm_tp_size,
        enforce_eager=False,
        default_torch_num_threads=1,
270
271
272
        tokenizer_mode="hf",
        load_format="hf",
        config_format="hf",
273
    ) as llm:
274
        vllm_outputs = llm.generate_greedy(prompts, max_tokens)
275
276
        vllm_logs = log_generated_texts(prompts, vllm_outputs, "VllmRunner")

youkaichao's avatar
youkaichao committed
277
278
279
280
    if hf_model_kwargs is None:
        hf_model_kwargs = {}

    # Run with HF runner
281
282
283
    with hf_runner(
        model_name, model_kwargs=hf_model_kwargs, default_torch_num_threads=1
    ) as llm:
284
        hf_outputs = llm.generate_greedy(prompts, max_tokens)
youkaichao's avatar
youkaichao committed
285
286
        hf_logs = log_generated_texts(prompts, hf_outputs, "HfRunner")

287
288
289
290
291
    # Compare the generated strings
    for hf_log, vllm_log in zip(hf_logs, vllm_logs):
        hf_str = hf_log["generated_text"]
        vllm_str = vllm_log["generated_text"]
        prompt = hf_log["prompt"]
292
293
294
295
296
297
298
        assert hf_str == vllm_str, (
            f"Model: {model_name}"
            f"Mismatch between HF and vLLM outputs:\n"
            f"Prompt: {prompt}\n"
            f"HF Output: '{hf_str}'\n"
            f"vLLM Output: '{vllm_str}'"
        )