test_quant_model.py 5.05 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
# Adapted from
# https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/tests/lora/test_llama.py
from dataclasses import dataclass

import pytest
9
import os
10
11

import vllm
12
from vllm.distributed import cleanup_dist_env_and_memory
13
14
from vllm.lora.request import LoRARequest

15
from vllm.platforms import current_platform
16
from ..utils import models_path_prefix
17
18
19
20
21
22
23
24


@dataclass
class ModelWithQuantization:
    model_path: str
    quantization: str


25
MODELS: list[ModelWithQuantization]
26
# AWQ quantization is currently not supported in ROCm.
27
if current_platform.is_rocm():
28
29
    MODELS = [
        ModelWithQuantization(
30
            model_path=os.path.join(models_path_prefix, "TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ"), quantization="gptq"
31
        ),
32
33
34
35
    ]
else:
    MODELS = [
        ModelWithQuantization(
36
            model_path=os.path.join(models_path_prefix, "TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ"), quantization="awq"
37
        ),
38
        ModelWithQuantization(
39
            model_path=os.path.join(models_path_prefix, "TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ"), quantization="gptq"
40
        ),
41
    ]
42
43


44
45
46
def do_sample(
    llm: vllm.LLM, lora_path: str, lora_id: int, max_tokens: int = 256
) -> list[str]:
47
48
49
50
51
52
53
54
55
56
    raw_prompts = [
        "Give me an orange-ish brown color",
        "Give me a neon pink color",
    ]

    def format_prompt_tuples(prompt):
        return f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"

    prompts = [format_prompt_tuples(p) for p in raw_prompts]

57
58
59
    sampling_params = vllm.SamplingParams(
        temperature=0, max_tokens=max_tokens, stop=["<|im_end|>"]
    )
60
61
62
    outputs = llm.generate(
        prompts,
        sampling_params,
63
64
        lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None,
    )
65
    # Print the outputs.
66
    generated_texts: list[str] = []
67
68
69
70
71
72
73
74
75
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        generated_texts.append(generated_text)
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
    return generated_texts


@pytest.mark.parametrize("model", MODELS)
76
def test_quant_model_lora(tinyllama_lora_files, model):
77
78
79
80
81
82
    llm = vllm.LLM(
        model=model.model_path,
        enable_lora=True,
        max_num_seqs=16,
        max_loras=4,
        max_model_len=400,
83
        gpu_memory_utilization=0.2,  # avoid OOM
84
        quantization=model.quantization,
85
        trust_remote_code=True,
86
        enable_chunked_prefill=True,
87
88
        tokenizer=tinyllama_lora_files,
    )
89
90
91
92
93
94

    if model.quantization is None:
        expected_lora_output = [
            "#ff8050",
            "#ff8080",
        ]
95
    elif model.quantization == "awq":
96
97
98
99
        expected_lora_output = [
            "#f07700: A v",
            "#f00000: A v",
        ]
100
    elif model.quantization == "gptq":
101
102
103
104
105
106
107
108
        expected_lora_output = [
            "#f08800: This is",
            "#f07788 \n#",
        ]

    def expect_match(output, expected_output):
        # HACK: GPTQ lora outputs are just incredibly unstable.
        # Assert that the outputs changed.
109
        if model.quantization == "gptq" and expected_output is expected_lora_output:
110
            for i, o in enumerate(output):
111
112
113
                assert o.startswith("#"), (
                    f"Expected example {i} to start with # but got {o}"
                )
114
115
116
117
118
119
120
            return
        assert output == expected_output

    max_tokens = 10

    print("lora adapter created")
    print("lora 1")
121
    output = do_sample(llm, tinyllama_lora_files, lora_id=1, max_tokens=max_tokens)
122
123
124
    expect_match(output, expected_lora_output)

    print("lora 2")
125
    output = do_sample(llm, tinyllama_lora_files, lora_id=2, max_tokens=max_tokens)
126
127
128
129
130
    expect_match(output, expected_lora_output)

    print("removing lora")

    del llm
131
    cleanup_dist_env_and_memory()
132
133
134


@pytest.mark.parametrize("model", MODELS)
135
def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available, model):
136
137
    if num_gpus_available < 2:
        pytest.skip(f"Not enough GPUs for tensor parallelism {2}")
138
    if model.quantization == "gptq":
139
        pytest.skip("GPTQ lora outputs are just incredibly unstable")
140
141
142
143
144
    llm_tp1 = vllm.LLM(
        model=model.model_path,
        enable_lora=True,
        max_num_seqs=16,
        max_loras=4,
145
        gpu_memory_utilization=0.2,  # avoid OOM
146
        quantization=model.quantization,
147
        trust_remote_code=True,
148
149
        enable_chunked_prefill=True,
    )
150
151
152
    output_tp1 = do_sample(llm_tp1, tinyllama_lora_files, lora_id=1)

    del llm_tp1
153
    cleanup_dist_env_and_memory()
154

155
156
157
158
159
160
    llm_tp2 = vllm.LLM(
        model=model.model_path,
        enable_lora=True,
        max_num_seqs=16,
        max_loras=4,
        tensor_parallel_size=2,
161
        gpu_memory_utilization=0.2,  # avoid OOM
162
        quantization=model.quantization,
163
164
        enable_chunked_prefill=True,
    )
165
166
167
    output_tp2 = do_sample(llm_tp2, tinyllama_lora_files, lora_id=1)

    del llm_tp2
168
    cleanup_dist_env_and_memory()
169
170

    assert output_tp1 == output_tp2