test_lora.py 4.06 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import pytest

import vllm
from vllm.lora.request import LoRARequest

# This file contains tests to ensure that LoRA works correctly on the TPU
# backend. We use a series of custom trained adapters for Qwen2.5-3B-Instruct
# for this. The adapters are:
# Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter, where x ranges
# from 1 to 4.

# These adapters are trained using a standard huggingface peft training script,
# where all the inputs are "What is 1+1? \n" and all the outputs are "x". We run
# 100 training iterations with a training batch size of 100.


@pytest.fixture(scope="function", autouse=True)
def use_v1_only(monkeypatch: pytest.MonkeyPatch):
    """
    Since Multi-LoRA is only supported on the v1 TPU backend, set VLLM_USE_V1=1
    for all tests in this file
    """
    with monkeypatch.context() as m:
        m.setenv("VLLM_USE_V1", "1")
        yield


def setup_vllm(num_loras: int) -> vllm.LLM:
    return vllm.LLM(model="Qwen/Qwen2.5-3B-Instruct",
                    num_scheduler_steps=1,
                    max_model_len=256,
                    max_seq_len_to_capture=256,
                    max_num_seqs=8,
                    enable_lora=True,
                    max_loras=num_loras,
                    max_lora_rank=8)


def test_single_lora():
    """
    This test ensures we can run a single LoRA adapter on the TPU backend.
    We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_1_adapter" which
    will force Qwen2.5-3B-Instruct to claim 1+1=1.
    """

    llm = setup_vllm(1)

    prompt = "What is 1+1? \n"

    lora_request = LoRARequest(
        "lora_adapter_1", 1,
        "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_1_adapter")
    output = llm.generate(prompt,
                          sampling_params=vllm.SamplingParams(max_tokens=256,
                                                              temperature=0),
                          lora_request=lora_request)[0].outputs[0].text

    answer = output.strip()[0]

    assert answer.isdigit()
    assert int(answer) == 1


def test_lora_hotswapping():
    """
    This test ensures we can run multiple LoRA adapters on the TPU backend, even
    if we only have space to store 1.
    
    We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which
    will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x.
    """

    lora_name_template = \
        "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter"
    lora_requests = [
        LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i))
        for i in range(1, 5)
    ]

    llm = setup_vllm(1)

    prompt = "What is 1+1? \n"

    for i, req in enumerate(lora_requests):
        output = llm.generate(prompt,
                              sampling_params=vllm.SamplingParams(
                                  max_tokens=256, temperature=0),
                              lora_request=req)[0].outputs[0].text
        answer = output.strip()[0]

        assert answer.isdigit()
        assert int(answer) == i + 1


def test_multi_lora():
    """
    This test ensures we can run multiple LoRA adapters on the TPU backend, when
    we have enough space to store all of them.
    
    We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which
    will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x.
    """
    lora_name_template = \
        "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter"
    lora_requests = [
        LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i))
        for i in range(1, 5)
    ]

    llm = setup_vllm(4)

    prompt = "What is 1+1? \n"

    for i, req in enumerate(lora_requests):
        output = llm.generate(prompt,
                              sampling_params=vllm.SamplingParams(
                                  max_tokens=256, temperature=0),
                              lora_request=req)[0].outputs[0].text

        answer = output.strip()[0]

        assert answer.isdigit()
        assert int(output.strip()[0]) == i + 1