test_llama_tp.py 12 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
import subprocess
import sys
from typing import Union
6

7
8
import pytest

9
import vllm
10
from vllm import LLM
11
from vllm.lora.request import LoRARequest
12
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
13

14
from ..utils import VLLM_PATH, create_new_process_for_each_test, multi_gpu_test
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35

MODEL_PATH = "meta-llama/Llama-2-7b-hf"

EXPECTED_NO_LORA_OUTPUT = [
    "\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_75 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_76 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_77 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_78 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user]",  # noqa: E501
    " Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? ",  # noqa: E501
    "\n\n answer: 1\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_96 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_97 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_98 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one m",  # noqa: E501
    " Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. ",  # noqa: E501
    " Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? ",  # noqa: E501
    "\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE",  # noqa: E501
]
EXPECTED_LORA_OUTPUT = [
    "  SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ",  # noqa: E501
    "  SELECT nationality FROM table_name_11 WHERE elector = 'anchero pantaleone' ",  # noqa: E501
    "  SELECT one_mora FROM table_name_95 WHERE gloss = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] AND accented_mora = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] ",  # noqa: E501
    "  SELECT sex FROM people WHERE people_id IN (SELECT people_id FROM candidate GROUP BY sex ORDER BY COUNT(people_id) DESC LIMIT 1) ",  # noqa: E501
    "  SELECT pick FROM table_name_60 WHERE former_wnba_team = 'Minnesota Lynx' ",  # noqa: E501
    "  SELECT womens_doubles FROM table_28138035_4 WHERE mens_singles = 'Werner Schlager' "  # noqa: E501
]


36
37
38
39
def do_sample(llm: vllm.LLM,
              lora_path: str,
              lora_id: int,
              tensorizer_config_dict: Union[dict, None] = None) -> list[str]:
40
41
42
43
44
45
46
47
    prompts = [
        "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]",  # noqa: E501
        "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]",  # noqa: E501
        "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_95 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a low tone mora with a gloss of /˩okiru/ [òkìɽɯ́]? [/user] [assistant]",  # noqa: E501
        "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]",  # noqa: E501
        "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? [/user] [assistant]",  # noqa: E501
        "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]"  # noqa: E501
    ]
48

49
50
    sampling_params = vllm.SamplingParams(temperature=0,
                                          max_tokens=256,
51
                                          skip_special_tokens=False,
52
                                          stop=["[/assistant]"])
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69

    if tensorizer_config_dict is not None:
        outputs = llm.generate(
            prompts,
            sampling_params,
            lora_request=LoRARequest(
                str(lora_id),
                lora_id,
                lora_path,
                tensorizer_config_dict=tensorizer_config_dict)
            if lora_id else None)
    else:
        outputs = llm.generate(
            prompts,
            sampling_params,
            lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
            if lora_id else None)
70
    # Print the outputs.
71
    generated_texts: list[str] = []
72
73
74
75
76
77
78
79
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        generated_texts.append(generated_text)
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
    return generated_texts


80
81
82
def generate_and_test(llm,
                      sql_lora_files,
                      tensorizer_config_dict: Union[dict, None] = None):
83
    print("lora adapter created")
84
85
86
87
    assert do_sample(llm,
                     sql_lora_files,
                     tensorizer_config_dict=tensorizer_config_dict,
                     lora_id=0) == EXPECTED_NO_LORA_OUTPUT
88
89

    print("lora 1")
90
91
92
93
    assert do_sample(llm,
                     sql_lora_files,
                     tensorizer_config_dict=tensorizer_config_dict,
                     lora_id=1) == EXPECTED_LORA_OUTPUT
94
95

    print("no lora")
96
97
98
99
    assert do_sample(llm,
                     sql_lora_files,
                     tensorizer_config_dict=tensorizer_config_dict,
                     lora_id=0) == EXPECTED_NO_LORA_OUTPUT
100
101

    print("lora 2")
102
103
104
105
    assert do_sample(llm,
                     sql_lora_files,
                     tensorizer_config_dict=tensorizer_config_dict,
                     lora_id=2) == EXPECTED_LORA_OUTPUT
106
107
108
109

    print("removing lora")


110
@create_new_process_for_each_test()
111
112
def test_llama_lora(sql_lora_files):

113
114
115
116
117
118
119
    llm = vllm.LLM(
        MODEL_PATH,
        enable_lora=True,
        # also test odd max_num_seqs
        max_num_seqs=13,
        max_loras=4,
        enable_chunked_prefill=True)
120
121
122
    generate_and_test(llm, sql_lora_files)


123
@multi_gpu_test(num_gpus=4)
124
@create_new_process_for_each_test()
125
126
127
128
129
130
131
132
def test_llama_lora_tp4(sql_lora_files):

    llm = vllm.LLM(
        MODEL_PATH,
        enable_lora=True,
        max_num_seqs=16,
        max_loras=4,
        tensor_parallel_size=4,
133
        enable_chunked_prefill=True,
134
    )
135
    generate_and_test(llm, sql_lora_files)
136
137
138


@multi_gpu_test(num_gpus=4)
139
@create_new_process_for_each_test()
140
141
142
143
144
145
146
147
148
def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files):

    llm = vllm.LLM(
        MODEL_PATH,
        enable_lora=True,
        max_num_seqs=16,
        max_loras=4,
        tensor_parallel_size=4,
        fully_sharded_loras=True,
149
        enable_chunked_prefill=True,
150
    )
151
    generate_and_test(llm, sql_lora_files)
152
153


154
155
@pytest.mark.skip(reason=("Skipping this test as tensorizer is not "
                          "working with LoRA as of #19619"))
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
@multi_gpu_test(num_gpus=2)
@create_new_process_for_each_test()
def test_tp2_serialize_and_deserialize_lora(tmp_path, sql_lora_files,
                                            sql_lora_huggingface_id):

    # Run the tensorizing of the LoRA adapter and the model in a subprocess
    # to guarantee cleanup

    tp_size = 2
    model_name = "model-rank-%03d.tensors"

    model_ref = MODEL_PATH
    lora_path = sql_lora_huggingface_id
    suffix = "test"
    try:
        result = subprocess.run([
            sys.executable,
173
            f"{VLLM_PATH}/examples/others/tensorize_vllm_model.py", "--model",
174
175
            MODEL_PATH, "--lora-path", lora_path, "--tensor-parallel-size",
            str(tp_size), "serialize", "--serialized-directory",
176
177
            str(tmp_path), "--suffix", suffix, "--serialization-kwargs",
            '{"limit_cpu_concurrency": 4}'
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
        ],
                                check=True,
                                capture_output=True,
                                text=True)
    except subprocess.CalledProcessError as e:
        print("Tensorizing failed.")
        print("STDOUT:\n", e.stdout)
        print("STDERR:\n", e.stderr)
        raise

    print("STDOUT:\n", result.stdout)

    model_uri = tmp_path / "vllm" / model_ref / suffix / model_name
    tensorizer_config = TensorizerConfig(tensorizer_uri=str(model_uri))
    tensorizer_config.lora_dir = tensorizer_config.tensorizer_dir

    loaded_vllm_model = LLM(model=model_ref,
                            load_format="tensorizer",
                            enable_lora=True,
                            enforce_eager=True,
                            model_loader_extra_config=tensorizer_config,
                            max_num_seqs=13,
                            tensor_parallel_size=2,
                            max_loras=2)

203
    tensorizer_config_dict = tensorizer_config.to_serializable()
204
205
206
207
208
209
210
211
212
213
214
215

    print("lora adapter created")
    assert do_sample(loaded_vllm_model,
                     sql_lora_files,
                     tensorizer_config_dict=tensorizer_config_dict,
                     lora_id=0) == EXPECTED_NO_LORA_OUTPUT

    print("lora 1")
    assert do_sample(loaded_vllm_model,
                     sql_lora_files,
                     tensorizer_config_dict=tensorizer_config_dict,
                     lora_id=1) == EXPECTED_LORA_OUTPUT