test_gptoss_tp.py 4.24 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

4
5
import pytest

6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import vllm
from vllm.lora.request import LoRARequest

from ..utils import multi_gpu_test

MODEL_PATH = "openai/gpt-oss-20b"

PROMPT_TEMPLATE = """<|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI.
Knowledge cutoff: 2024-06
Current date: 2025-10-29

Reasoning: medium

# Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>user<|message|>I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.
"
##Instruction:
farm contains tables such as city, farm, farm_competition, competition_record. Table city has columns such as City_ID, Official_Name, Status, Area_km_2, Population, Census_Ranking. City_ID is the primary key.
Table farm has columns such as Farm_ID, Year, Total_Horses, Working_Horses, Total_Cattle, Oxen, Bulls, Cows, Pigs, Sheep_and_Goats. Farm_ID is the primary key.
Table farm_competition has columns such as Competition_ID, Year, Theme, Host_city_ID, Hosts. Competition_ID is the primary key.
Table competition_record has columns such as Competition_ID, Farm_ID, Rank. Competition_ID is the primary key.
The Host_city_ID of farm_competition is the foreign key of City_ID of city.
The Farm_ID of competition_record is the foreign key of Farm_ID of farm.
The Competition_ID of competition_record is the foreign key of Competition_ID of farm_competition.


###Input:
{context}

###Response:<|end|><|start|>assistant<|channel|>final<|message|>"""  # noqa: E501

EXPECTED_LORA_OUTPUT = [
37
38
39
    "SELECT avg(Working_Horses) FROM farm WHERE Total_Horses  >  5000",
    "SELECT max(Cows) ,  min(Cows) FROM farm",
    "SELECT max(Cows) ,  min(Cows) FROM farm",
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
]


def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None:
    prompts = [
        PROMPT_TEMPLATE.format(
            context="Give the average number of working horses on farms with more than 5000 total horses."  # noqa: E501
        ),  # noqa: E501
        PROMPT_TEMPLATE.format(
            context="What are the maximum and minimum number of cows across all farms."
        ),
        PROMPT_TEMPLATE.format(
            context="Return the maximum and minimum number of cows across all farms."
        ),
    ]
    sampling_params = vllm.SamplingParams(temperature=0, max_tokens=64)
    outputs = llm.generate(
        prompts,
        sampling_params,
        lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None,
    )
    # Print the outputs.
    generated_texts: list[str] = []
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text.strip()
        generated_texts.append(generated_text)
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
    for i in range(len(EXPECTED_LORA_OUTPUT)):
        assert generated_texts[i].startswith(EXPECTED_LORA_OUTPUT[i])


def test_gpt_oss_lora(gptoss20b_lora_files):
    llm = vllm.LLM(
        MODEL_PATH,
        max_model_len=1024,
        enable_lora=True,
        max_loras=4,
        max_lora_rank=8,
79
80
        max_num_seqs=2,
        max_num_batched_tokens=2048,
81
82
83
84
85
86
87
88
89
90
        compilation_config=vllm.config.CompilationConfig(  # Avoid OOM
            cudagraph_specialize_lora=False,
        ),
    )

    generate_and_test(llm, gptoss20b_lora_files, lora_id=1)
    generate_and_test(llm, gptoss20b_lora_files, lora_id=2)


@multi_gpu_test(num_gpus=2)
91
92
@pytest.mark.parametrize("fully_sharded_loras", [False, True])
def test_gpt_oss_lora_tp2(gptoss20b_lora_files, fully_sharded_loras):
93
94
95
96
97
    llm = vllm.LLM(
        MODEL_PATH,
        max_model_len=1024,
        enable_lora=True,
        max_loras=2,
98
99
        max_num_seqs=2,
        max_num_batched_tokens=2048,
100
        tensor_parallel_size=2,
101
        gpu_memory_utilization=0.8,
102
        fully_sharded_loras=fully_sharded_loras,
103
104
105
106
107
108
109
        compilation_config=vllm.config.CompilationConfig(  # Avoid OOM
            cudagraph_specialize_lora=False,
        ),
    )

    generate_and_test(llm, gptoss20b_lora_files, lora_id=1)
    generate_and_test(llm, gptoss20b_lora_files, lora_id=2)