test_bitsandbytes.py 8.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
5
'''Tests whether bitsandbytes computation is enabled correctly.

Run `pytest tests/quantization/test_bitsandbytes.py`.
'''
6
7
8

import gc

9
10
import pytest
import torch
11
from transformers import BitsAndBytesConfig
12

13
from tests.quantization.utils import is_quant_method_supported
14

15
from ..models.utils import check_embeddings_close
16
from ..utils import compare_two_settings, create_new_process_for_each_test
youkaichao's avatar
youkaichao committed
17

18
models_4bit_to_test = [
19
    ("facebook/opt-125m", "quantize opt model inflight"),
20
21
    ("mistralai/Mistral-7B-Instruct-v0.3",
     "quantize inflight model with both HF and Mistral format weights")
22
23
]

24
25
26
27
models_4bit_to_embedding_test = [
    ("intfloat/e5-mistral-7b-instruct", "quantize embedding model inflight"),
]

28
29
30
models_pre_qaunt_4bit_to_test = [
    ('PrunaAI/Einstein-v6.1-Llama3-8B-bnb-4bit-smashed',
     'read pre-quantized 4-bit FP4 model'),
31
    ('poedator/opt-125m-bnb-4bit', 'read pre-quantized 4-bit NF4 opt model'),
32
33
34
]

models_pre_quant_8bit_to_test = [
35
36
37
    ('meta-llama/Llama-Guard-3-8B-INT8',
     'read pre-quantized llama 8-bit model'),
    ("yec019/fbopt-350m-8bit", "read pre-quantized 8-bit opt model"),
38
39
]

40
41
42
43
44
45
models_pre_quant_8bit_to_test = [
    ('meta-llama/Llama-Guard-3-8B-INT8',
     'read pre-quantized llama 8-bit model'),
    ("yec019/fbopt-350m-8bit", "read pre-quantized 8-bit opt model"),
]

46
47
48
49

@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
                    reason='bitsandbytes is not supported on this GPU type.')
@pytest.mark.parametrize("model_name, description", models_4bit_to_test)
50
@create_new_process_for_each_test()
51
52
53
def test_load_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
                             model_name, description) -> None:

54
55
    hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig(
        load_in_4bit=True))
56
    validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1],
57
                             model_name, False, hf_model_kwargs)
58
59
60
61
62
63


@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
                    reason='bitsandbytes is not supported on this GPU type.')
@pytest.mark.parametrize("model_name, description",
                         models_pre_qaunt_4bit_to_test)
64
@create_new_process_for_each_test()
65
66
67
68
def test_load_pre_quant_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
                                       model_name, description) -> None:

    validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1],
69
                             model_name, True)
70

71

72
73
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
                    reason='bitsandbytes is not supported on this GPU type.')
74
75
@pytest.mark.parametrize("model_name, description",
                         models_pre_quant_8bit_to_test)
76
@create_new_process_for_each_test()
77
78
79
80
def test_load_8bit_bnb_model(hf_runner, vllm_runner, example_prompts,
                             model_name, description) -> None:

    validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1],
81
                             model_name, True)
82
83


84
85
86
87
88
@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason='Test requires at least 2 GPUs.')
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
                    reason='bitsandbytes is not supported on this GPU type.')
@pytest.mark.parametrize("model_name, description", models_4bit_to_test)
89
@create_new_process_for_each_test()
90
91
92
def test_load_tp_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
                                model_name, description) -> None:

93
94
    hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig(
        load_in_4bit=True))
95
96
97
98
    validate_generated_texts(hf_runner,
                             vllm_runner,
                             example_prompts[:1],
                             model_name,
99
                             False,
100
101
102
103
                             hf_model_kwargs,
                             vllm_tp_size=2)


104
105
106
107
108
@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason='Test requires at least 2 GPUs.')
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
                    reason='bitsandbytes is not supported on this GPU type.')
@pytest.mark.parametrize("model_name, description", models_4bit_to_test)
109
@create_new_process_for_each_test()
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
def test_load_pp_4bit_bnb_model(model_name, description) -> None:
    common_args = [
        "--disable-log-stats",
        "--disable-log-requests",
        "--dtype",
        "bfloat16",
        "--enable-prefix-caching",
        "--quantization",
        "bitsandbytes",
        "--gpu-memory-utilization",
        "0.7",
    ]
    pp_args = [
        *common_args,
        "--pipeline-parallel-size",
        "2",
    ]
    compare_two_settings(model_name, common_args, pp_args)


130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
                    reason='bitsandbytes is not supported on this GPU type.')
@pytest.mark.parametrize("model_name, description",
                         models_4bit_to_embedding_test)
@pytest.mark.parametrize("dtype", ["half"])
@create_new_process_for_each_test()
def test_4bit_bnb_embedding_model(
    model_name,
    description,
    hf_runner,
    vllm_runner,
    example_prompts,
    dtype: str,
) -> None:

    # The example_prompts has ending "\n", for example:
    # "Write a short story about a robot that dreams for the first time.\n"
    # sentence_transformers will strip the input texts, see:
    # https://github.com/UKPLab/sentence-transformers/blob/v3.1.1/sentence_transformers/models/Transformer.py#L159
    # This makes the input_ids different between hf_model and vllm_model.
    # So we need to strip the input texts to avoid test failing.
    example_prompts = [str(s).strip() for s in example_prompts]

    # Inflight 4bit quantization
    hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig(
        load_in_4bit=True))
    with hf_runner(
            model_name,
            dtype=dtype,
            model_kwargs=hf_model_kwargs,
            is_sentence_transformer=True,
    ) as hf_model:
        hf_outputs = hf_model.encode(example_prompts)

    with vllm_runner(model_name,
                     task="embed",
                     dtype=dtype,
                     quantization="bitsandbytes") as vllm_model:
        vllm_outputs = vllm_model.encode(example_prompts)
    check_embeddings_close(
        embeddings_0_lst=hf_outputs,
        embeddings_1_lst=vllm_outputs,
        name_0="hf",
        name_1="vllm",
        tol=5e-2,
    )


178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
def log_generated_texts(prompts, outputs, runner_name):
    logged_texts = []
    for i, (_, generated_text) in enumerate(outputs):
        log_entry = {
            "prompt": prompts[i],
            "runner_name": runner_name,
            "generated_text": generated_text,
        }
        logged_texts.append(log_entry)
    return logged_texts


def validate_generated_texts(hf_runner,
                             vllm_runner,
                             prompts,
                             model_name,
194
                             pre_quant=False,
195
196
                             hf_model_kwargs=None,
                             vllm_tp_size=1):
197

youkaichao's avatar
youkaichao committed
198
199
    # NOTE: run vLLM first, as it requires a clean process
    # when using distributed inference
200
    with vllm_runner(model_name,
201
                     quantization=None if pre_quant else 'bitsandbytes',
202
                     tensor_parallel_size=vllm_tp_size,
203
                     enforce_eager=False) as llm:
204
205
206
207
208
209
210
        vllm_outputs = llm.generate_greedy(prompts, 8)
        vllm_logs = log_generated_texts(prompts, vllm_outputs, "VllmRunner")

    # Clean up the GPU memory for the next test
    gc.collect()
    torch.cuda.empty_cache()

youkaichao's avatar
youkaichao committed
211
212
213
214
215
216
217
218
219
220
221
222
    if hf_model_kwargs is None:
        hf_model_kwargs = {}

    # Run with HF runner
    with hf_runner(model_name, model_kwargs=hf_model_kwargs) as llm:
        hf_outputs = llm.generate_greedy(prompts, 8)
        hf_logs = log_generated_texts(prompts, hf_outputs, "HfRunner")

    # Clean up the GPU memory for the next test
    gc.collect()
    torch.cuda.empty_cache()

223
224
225
226
227
    # Compare the generated strings
    for hf_log, vllm_log in zip(hf_logs, vllm_logs):
        hf_str = hf_log["generated_text"]
        vllm_str = vllm_log["generated_text"]
        prompt = hf_log["prompt"]
228

229
230
231
232
233
        assert hf_str == vllm_str, (f"Model: {model_name}"
                                    f"Mismatch between HF and vLLM outputs:\n"
                                    f"Prompt: {prompt}\n"
                                    f"HF Output: '{hf_str}'\n"
                                    f"vLLM Output: '{vllm_str}'")