test_quant_model.py 5.66 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
# Adapted from
# https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/tests/lora/test_llama.py
from dataclasses import dataclass
from typing import List

import pytest

import vllm
from vllm.lora.request import LoRARequest

from .conftest import cleanup


@dataclass
class ModelWithQuantization:
    model_path: str
    quantization: str


MODELS: List[ModelWithQuantization] = [
    ModelWithQuantization(model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ",
                          quantization="AWQ"),
    ModelWithQuantization(model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
                          quantization="GPTQ"),
]


def do_sample(llm, lora_path: str, lora_id: int, max_tokens=256):
    raw_prompts = [
        "Give me an orange-ish brown color",
        "Give me a neon pink color",
    ]

    def format_prompt_tuples(prompt):
        return f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"

    prompts = [format_prompt_tuples(p) for p in raw_prompts]

    sampling_params = vllm.SamplingParams(temperature=0,
                                          max_tokens=max_tokens,
                                          stop=["<|im_end|>"])
    outputs = llm.generate(
        prompts,
        sampling_params,
        lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
        if lora_id else None)
    # Print the outputs.
    generated_texts = []
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        generated_texts.append(generated_text)
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
    return generated_texts


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("tp_size", [1])
def test_quant_model_lora(tinyllama_lora_files, model, tp_size):
    # Cannot use as it will initialize torch.cuda too early...
    # if torch.cuda.device_count() < tp_size:
    #     pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")

    llm = vllm.LLM(model=model.model_path,
                   enable_lora=True,
                   max_num_seqs=16,
                   max_loras=4,
                   max_model_len=400,
                   tensor_parallel_size=tp_size,
                   quantization=model.quantization,
                   trust_remote_code=True)

    if model.quantization is None:
        expected_no_lora_output = [
            "Here are some examples of orange-brown colors",
            "I'm sorry, I don't have"
        ]
        expected_lora_output = [
            "#ff8050",
            "#ff8080",
        ]
    elif model.quantization == "AWQ":
        expected_no_lora_output = [
            "I'm sorry, I don't understand",
            "I'm sorry, I don't understand",
        ]
        expected_lora_output = [
            "#f07700: A v",
            "#f00000: A v",
        ]
    elif model.quantization == "GPTQ":
        expected_no_lora_output = [
            "I'm sorry, I don't have",
            "I'm sorry, I don't have",
        ]
        expected_lora_output = [
            "#f08800: This is",
            "#f07788 \n#",
        ]

    def expect_match(output, expected_output):
        # HACK: GPTQ lora outputs are just incredibly unstable.
        # Assert that the outputs changed.
        if (model.quantization == "GPTQ"
                and expected_output is expected_lora_output):
            assert output != expected_no_lora_output
            for i, o in enumerate(output):
                assert o.startswith(
                    '#'), f"Expected example {i} to start with # but got {o}"
            return
        assert output == expected_output

    max_tokens = 10

    print("lora adapter created")
    output = do_sample(llm,
                       tinyllama_lora_files,
                       lora_id=0,
                       max_tokens=max_tokens)
    expect_match(output, expected_no_lora_output)

    print("lora 1")
    output = do_sample(llm,
                       tinyllama_lora_files,
                       lora_id=1,
                       max_tokens=max_tokens)
    expect_match(output, expected_lora_output)

    print("no lora")
    output = do_sample(llm,
                       tinyllama_lora_files,
                       lora_id=0,
                       max_tokens=max_tokens)
    expect_match(output, expected_no_lora_output)

    print("lora 2")
    output = do_sample(llm,
                       tinyllama_lora_files,
                       lora_id=2,
                       max_tokens=max_tokens)
    expect_match(output, expected_lora_output)

    print("removing lora")

    del llm
    cleanup()


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.skip("Requires multiple GPUs")
def test_quant_model_tp_equality(tinyllama_lora_files, model):
    # Cannot use as it will initialize torch.cuda too early...
    # if torch.cuda.device_count() < 2:
    #     pytest.skip(f"Not enough GPUs for tensor parallelism {2}")

    llm_tp1 = vllm.LLM(model=model.model_path,
                       enable_lora=True,
                       max_num_seqs=16,
                       max_loras=4,
                       tensor_parallel_size=1,
                       quantization=model.quantization,
                       trust_remote_code=True)
    output_tp1 = do_sample(llm_tp1, tinyllama_lora_files, lora_id=1)

    del llm_tp1
    cleanup()

    llm_tp2 = vllm.LLM(model=model.model_path,
                       enable_lora=True,
                       max_num_seqs=16,
                       max_loras=4,
                       tensor_parallel_size=2,
                       quantization=model.quantization)
    output_tp2 = do_sample(llm_tp2, tinyllama_lora_files, lora_id=1)

    del llm_tp2
    cleanup()

    assert output_tp1 == output_tp2