"vscode:/vscode.git/clone" did not exist on "8b5a09f6a90db8160cb2d17d7034ff6e63725a1a"
test_lora.py 3.99 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# SPDX-License-Identifier: Apache-2.0
import pytest

import vllm
from vllm.lora.request import LoRARequest

# This file contains tests to ensure that LoRA works correctly on the TPU
# backend. We use a series of custom trained adapters for Qwen2.5-3B-Instruct
# for this. The adapters are:
# Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter, where x ranges
# from 1 to 4.

# These adapters are trained using a standard huggingface peft training script,
# where all the inputs are "What is 1+1? \n" and all the outputs are "x". We run
# 100 training iterations with a training batch size of 100.


@pytest.fixture(scope="function", autouse=True)
def use_v1_only(monkeypatch: pytest.MonkeyPatch):
    """
    Since Multi-LoRA is only supported on the v1 TPU backend, set VLLM_USE_V1=1
    for all tests in this file
    """
    with monkeypatch.context() as m:
        m.setenv("VLLM_USE_V1", "1")
        yield


def setup_vllm(num_loras: int) -> vllm.LLM:
    return vllm.LLM(model="Qwen/Qwen2.5-3B-Instruct",
                    num_scheduler_steps=1,
                    max_model_len=256,
                    max_seq_len_to_capture=256,
                    max_num_seqs=8,
                    enable_lora=True,
                    max_loras=num_loras,
                    max_lora_rank=8)


def test_single_lora():
    """
    This test ensures we can run a single LoRA adapter on the TPU backend.
    We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_1_adapter" which
    will force Qwen2.5-3B-Instruct to claim 1+1=1.
    """

    llm = setup_vllm(1)

    prompt = "What is 1+1? \n"

    lora_request = LoRARequest(
        "lora_adapter_1", 1,
        "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_1_adapter")
    output = llm.generate(prompt,
                          sampling_params=vllm.SamplingParams(max_tokens=256,
                                                              temperature=0),
                          lora_request=lora_request)[0].outputs[0].text

    answer = output.strip()[0]

    assert answer.isdigit()
    assert int(answer) == 1


def test_lora_hotswapping():
    """
    This test ensures we can run multiple LoRA adapters on the TPU backend, even
    if we only have space to store 1.
    
    We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which
    will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x.
    """

    lora_name_template = \
        "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter"
    lora_requests = [
        LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i))
        for i in range(1, 5)
    ]

    llm = setup_vllm(1)

    prompt = "What is 1+1? \n"

    for i, req in enumerate(lora_requests):
        output = llm.generate(prompt,
                              sampling_params=vllm.SamplingParams(
                                  max_tokens=256, temperature=0),
                              lora_request=req)[0].outputs[0].text
        answer = output.strip()[0]

        assert answer.isdigit()
        assert int(answer) == i + 1


def test_multi_lora():
    """
    This test ensures we can run multiple LoRA adapters on the TPU backend, when
    we have enough space to store all of them.
    
    We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which
    will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x.
    """
    lora_name_template = \
        "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter"
    lora_requests = [
        LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i))
        for i in range(1, 5)
    ]

    llm = setup_vllm(4)

    prompt = "What is 1+1? \n"

    for i, req in enumerate(lora_requests):
        output = llm.generate(prompt,
                              sampling_params=vllm.SamplingParams(
                                  max_tokens=256, temperature=0),
                              lora_request=req)[0].outputs[0].text

        answer = output.strip()[0]

        assert answer.isdigit()
        assert int(output.strip()[0]) == i + 1