test_lora.py 4.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import pytest
4
from torch_xla._internal import tpu
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30

import vllm
from vllm.lora.request import LoRARequest

# This file contains tests to ensure that LoRA works correctly on the TPU
# backend. We use a series of custom trained adapters for Qwen2.5-3B-Instruct
# for this. The adapters are:
# Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter, where x ranges
# from 1 to 4.

# These adapters are trained using a standard huggingface peft training script,
# where all the inputs are "What is 1+1? \n" and all the outputs are "x". We run
# 100 training iterations with a training batch size of 100.


@pytest.fixture(scope="function", autouse=True)
def use_v1_only(monkeypatch: pytest.MonkeyPatch):
    """
    Since Multi-LoRA is only supported on the v1 TPU backend, set VLLM_USE_V1=1
    for all tests in this file
    """
    with monkeypatch.context() as m:
        m.setenv("VLLM_USE_V1", "1")
        yield


31
def setup_vllm(num_loras: int, tp: int) -> vllm.LLM:
32
33
34
35
36
37
38
39
40
    return vllm.LLM(
        model="Qwen/Qwen2.5-3B-Instruct",
        max_model_len=256,
        max_num_seqs=8,
        tensor_parallel_size=tp,
        enable_lora=True,
        max_loras=num_loras,
        max_lora_rank=8,
    )
41
42


43
44
45
TPU_TENSOR_PARALLEL_SIZES = (
    [1, tpu.num_available_chips()] if tpu.num_available_chips() > 1 else [1]
)
46
47
48
49


@pytest.mark.parametrize("tp", TPU_TENSOR_PARALLEL_SIZES)
def test_single_lora(tp: int):
50
51
52
53
54
55
    """
    This test ensures we can run a single LoRA adapter on the TPU backend.
    We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_1_adapter" which
    will force Qwen2.5-3B-Instruct to claim 1+1=1.
    """

56
    llm = setup_vllm(1, tp)
57
58
59
60

    prompt = "What is 1+1? \n"

    lora_request = LoRARequest(
61
62
63
64
65
66
67
68
69
70
71
72
73
        "lora_adapter_1",
        1,
        "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_1_adapter",
    )
    output = (
        llm.generate(
            prompt,
            sampling_params=vllm.SamplingParams(max_tokens=256, temperature=0),
            lora_request=lora_request,
        )[0]
        .outputs[0]
        .text
    )
74
75
76
77
78
79
80

    answer = output.strip()[0]

    assert answer.isdigit()
    assert int(answer) == 1


81
82
@pytest.mark.parametrize("tp", TPU_TENSOR_PARALLEL_SIZES)
def test_lora_hotswapping(tp: int):
83
84
85
    """
    This test ensures we can run multiple LoRA adapters on the TPU backend, even
    if we only have space to store 1.
86

87
88
89
90
    We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which
    will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x.
    """

91
    lora_name_template = "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter"
92
93
94
95
96
    lora_requests = [
        LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i))
        for i in range(1, 5)
    ]

97
    llm = setup_vllm(1, tp)
98
99
100
101

    prompt = "What is 1+1? \n"

    for i, req in enumerate(lora_requests):
102
103
104
105
106
107
108
109
110
        output = (
            llm.generate(
                prompt,
                sampling_params=vllm.SamplingParams(max_tokens=256, temperature=0),
                lora_request=req,
            )[0]
            .outputs[0]
            .text
        )
111
112
113
114
115
116
        answer = output.strip()[0]

        assert answer.isdigit()
        assert int(answer) == i + 1


117
118
@pytest.mark.parametrize("tp", TPU_TENSOR_PARALLEL_SIZES)
def test_multi_lora(tp: int):
119
120
121
    """
    This test ensures we can run multiple LoRA adapters on the TPU backend, when
    we have enough space to store all of them.
122

123
124
125
    We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which
    will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x.
    """
126
    lora_name_template = "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter"
127
128
129
130
131
    lora_requests = [
        LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i))
        for i in range(1, 5)
    ]

132
    llm = setup_vllm(4, tp)
133
134
135
136

    prompt = "What is 1+1? \n"

    for i, req in enumerate(lora_requests):
137
138
139
140
141
142
143
144
145
        output = (
            llm.generate(
                prompt,
                sampling_params=vllm.SamplingParams(max_tokens=256, temperature=0),
                lora_request=req,
            )[0]
            .outputs[0]
            .text
        )
146
147
148
149
150

        answer = output.strip()[0]

        assert answer.isdigit()
        assert int(output.strip()[0]) == i + 1