test_lora.py 3.99 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import pytest
4
from torch_xla._internal import tpu
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19

import vllm
from vllm.lora.request import LoRARequest

# This file contains tests to ensure that LoRA works correctly on the TPU
# backend. We use a series of custom trained adapters for Qwen2.5-3B-Instruct
# for this. The adapters are:
# Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter, where x ranges
# from 1 to 4.

# These adapters are trained using a standard huggingface peft training script,
# where all the inputs are "What is 1+1? \n" and all the outputs are "x". We run
# 100 training iterations with a training batch size of 100.


20
def setup_vllm(num_loras: int, tp: int) -> vllm.LLM:
21
22
23
24
25
26
27
28
29
    return vllm.LLM(
        model="Qwen/Qwen2.5-3B-Instruct",
        max_model_len=256,
        max_num_seqs=8,
        tensor_parallel_size=tp,
        enable_lora=True,
        max_loras=num_loras,
        max_lora_rank=8,
    )
30
31


32
33
34
TPU_TENSOR_PARALLEL_SIZES = (
    [1, tpu.num_available_chips()] if tpu.num_available_chips() > 1 else [1]
)
35
36
37
38


@pytest.mark.parametrize("tp", TPU_TENSOR_PARALLEL_SIZES)
def test_single_lora(tp: int):
39
40
41
42
43
44
    """
    This test ensures we can run a single LoRA adapter on the TPU backend.
    We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_1_adapter" which
    will force Qwen2.5-3B-Instruct to claim 1+1=1.
    """

45
    llm = setup_vllm(1, tp)
46
47
48
49

    prompt = "What is 1+1? \n"

    lora_request = LoRARequest(
50
51
52
53
54
55
56
57
58
59
60
61
62
        "lora_adapter_1",
        1,
        "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_1_adapter",
    )
    output = (
        llm.generate(
            prompt,
            sampling_params=vllm.SamplingParams(max_tokens=256, temperature=0),
            lora_request=lora_request,
        )[0]
        .outputs[0]
        .text
    )
63
64
65
66
67
68
69

    answer = output.strip()[0]

    assert answer.isdigit()
    assert int(answer) == 1


70
71
@pytest.mark.parametrize("tp", TPU_TENSOR_PARALLEL_SIZES)
def test_lora_hotswapping(tp: int):
72
73
74
    """
    This test ensures we can run multiple LoRA adapters on the TPU backend, even
    if we only have space to store 1.
75

76
77
78
79
    We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which
    will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x.
    """

80
    lora_name_template = "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter"
81
82
83
84
85
    lora_requests = [
        LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i))
        for i in range(1, 5)
    ]

86
    llm = setup_vllm(1, tp)
87
88
89
90

    prompt = "What is 1+1? \n"

    for i, req in enumerate(lora_requests):
91
92
93
94
95
96
97
98
99
        output = (
            llm.generate(
                prompt,
                sampling_params=vllm.SamplingParams(max_tokens=256, temperature=0),
                lora_request=req,
            )[0]
            .outputs[0]
            .text
        )
100
101
102
103
104
105
        answer = output.strip()[0]

        assert answer.isdigit()
        assert int(answer) == i + 1


106
107
@pytest.mark.parametrize("tp", TPU_TENSOR_PARALLEL_SIZES)
def test_multi_lora(tp: int):
108
109
110
    """
    This test ensures we can run multiple LoRA adapters on the TPU backend, when
    we have enough space to store all of them.
111

112
113
114
    We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which
    will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x.
    """
115
    lora_name_template = "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter"
116
117
118
119
120
    lora_requests = [
        LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i))
        for i in range(1, 5)
    ]

121
    llm = setup_vllm(4, tp)
122
123
124
125

    prompt = "What is 1+1? \n"

    for i, req in enumerate(lora_requests):
126
127
128
129
130
131
132
133
134
        output = (
            llm.generate(
                prompt,
                sampling_params=vllm.SamplingParams(max_tokens=256, temperature=0),
                lora_request=req,
            )[0]
            .outputs[0]
            .text
        )
135
136
137
138
139

        answer = output.strip()[0]

        assert answer.isdigit()
        assert int(output.strip()[0]) == i + 1