test_chatglm3_tp.py 5.11 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
import vllm
6
import vllm.config
7
from vllm.lora.request import LoRARequest
8
from ..utils import models_path_prefix
9

10
from ..utils import create_new_process_for_each_test, multi_gpu_test
11

12
MODEL_PATH = os.path.join(models_path_prefix, "zai-org/chatglm3-6b")
13
14
15

PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:"""  # noqa: E501

16
17
EXPECTED_LORA_OUTPUT = [
    "SELECT count(*) FROM singer",
18
    "SELECT avg(age) ,  min(age) ,  max(age) FROM singer WHERE country  =  'France'",
19
20
21
    "SELECT name ,  country ,  age FROM singer ORDER BY age",
]

22

23
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]:
24
25
26
    prompts = [
        PROMPT_TEMPLATE.format(query="How many singers do we have?"),
        PROMPT_TEMPLATE.format(
27
28
29
30
            query=(
                "What is the average, minimum, and maximum "
                "age of all singers from France?"
            )
31
32
        ),
        PROMPT_TEMPLATE.format(
33
34
35
36
            query=(
                "Show name, country, age for all singers ordered "
                "by age from the oldest to the youngest."
            )
37
38
39
40
41
42
        ),
    ]
    sampling_params = vllm.SamplingParams(temperature=0, max_tokens=32)
    outputs = llm.generate(
        prompts,
        sampling_params,
43
44
        lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None,
    )
45
    # Print the outputs.
46
    generated_texts: list[str] = []
47
48
49
50
51
52
53
54
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text.strip()
        generated_texts.append(generated_text)
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
    return generated_texts


55
@create_new_process_for_each_test()
56
def test_chatglm3_lora(chatglm3_lora_files):
57
58
    llm = vllm.LLM(
        MODEL_PATH,
59
        max_model_len=512,
60
        enable_lora=True,
61
62
        max_loras=2,
        max_num_seqs=16,
63
64
65
        max_lora_rank=64,
        trust_remote_code=True,
    )
66

67
68
69
70
71
72
    output1 = do_sample(llm, chatglm3_lora_files, lora_id=1)
    for i in range(len(EXPECTED_LORA_OUTPUT)):
        assert output1[i] == EXPECTED_LORA_OUTPUT[i]
    output2 = do_sample(llm, chatglm3_lora_files, lora_id=2)
    for i in range(len(EXPECTED_LORA_OUTPUT)):
        assert output2[i] == EXPECTED_LORA_OUTPUT[i]
73
74


75
76
@multi_gpu_test(num_gpus=4)
def test_chatglm3_lora_tp4(chatglm3_lora_files):
77
78
    llm = vllm.LLM(
        MODEL_PATH,
79
        max_model_len=512,
80
        enable_lora=True,
81
        max_loras=2,
82
        max_lora_rank=64,
83
        max_num_seqs=16,
84
85
86
        tensor_parallel_size=4,
        trust_remote_code=True,
        fully_sharded_loras=False,
87
88
89
        compilation_config=vllm.config.CompilationConfig(  # Avoid OOM
            cudagraph_specialize_lora=False,
        ),
90
    )
91
92
93
94
95
96
97
98
99
100
101

    output1 = do_sample(llm, chatglm3_lora_files, lora_id=1)
    for i in range(len(EXPECTED_LORA_OUTPUT)):
        assert output1[i] == EXPECTED_LORA_OUTPUT[i]
    output2 = do_sample(llm, chatglm3_lora_files, lora_id=2)
    for i in range(len(EXPECTED_LORA_OUTPUT)):
        assert output2[i] == EXPECTED_LORA_OUTPUT[i]


@multi_gpu_test(num_gpus=4)
def test_chatglm3_lora_tp4_fully_sharded_loras(chatglm3_lora_files):
Huy Do's avatar
Huy Do committed
102
103
104
    # https://github.com/NVIDIA/nccl/issues/1790, set a lower value for
    # gpu_memory_utilization here because NCCL >= 2.26.3 seems to use
    # more GPU memory causing vLLM to OOM
105
106
    llm = vllm.LLM(
        MODEL_PATH,
107
        max_model_len=512,
108
        enable_lora=True,
109
        max_loras=2,
110
111
112
113
        max_lora_rank=64,
        tensor_parallel_size=4,
        trust_remote_code=True,
        fully_sharded_loras=True,
114
        gpu_memory_utilization=0.8,
115
116
117
        compilation_config=vllm.config.CompilationConfig(  # Avoid OOM
            cudagraph_specialize_lora=False,
        ),
118
    )
119
    output1 = do_sample(llm, chatglm3_lora_files, lora_id=1)
120
121
    for i in range(len(EXPECTED_LORA_OUTPUT)):
        assert output1[i] == EXPECTED_LORA_OUTPUT[i]
122
    output2 = do_sample(llm, chatglm3_lora_files, lora_id=2)
123
124
    for i in range(len(EXPECTED_LORA_OUTPUT)):
        assert output2[i] == EXPECTED_LORA_OUTPUT[i]