test_lora.py 4.39 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import pytest
4
from torch_xla._internal import tpu
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30

import vllm
from vllm.lora.request import LoRARequest

# This file contains tests to ensure that LoRA works correctly on the TPU
# backend. We use a series of custom trained adapters for Qwen2.5-3B-Instruct
# for this. The adapters are:
# Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter, where x ranges
# from 1 to 4.

# These adapters are trained using a standard huggingface peft training script,
# where all the inputs are "What is 1+1? \n" and all the outputs are "x". We run
# 100 training iterations with a training batch size of 100.


@pytest.fixture(scope="function", autouse=True)
def use_v1_only(monkeypatch: pytest.MonkeyPatch):
    """
    Since Multi-LoRA is only supported on the v1 TPU backend, set VLLM_USE_V1=1
    for all tests in this file
    """
    with monkeypatch.context() as m:
        m.setenv("VLLM_USE_V1", "1")
        yield


31
def setup_vllm(num_loras: int, tp: int) -> vllm.LLM:
32
33
34
    return vllm.LLM(model="Qwen/Qwen2.5-3B-Instruct",
                    max_model_len=256,
                    max_num_seqs=8,
35
                    tensor_parallel_size=tp,
36
37
38
39
40
                    enable_lora=True,
                    max_loras=num_loras,
                    max_lora_rank=8)


41
42
43
44
45
46
TPU_TENSOR_PARALLEL_SIZES = [1, tpu.num_available_chips()
                             ] if tpu.num_available_chips() > 1 else [1]


@pytest.mark.parametrize("tp", TPU_TENSOR_PARALLEL_SIZES)
def test_single_lora(tp: int):
47
48
49
50
51
52
    """
    This test ensures we can run a single LoRA adapter on the TPU backend.
    We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_1_adapter" which
    will force Qwen2.5-3B-Instruct to claim 1+1=1.
    """

53
    llm = setup_vllm(1, tp)
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70

    prompt = "What is 1+1? \n"

    lora_request = LoRARequest(
        "lora_adapter_1", 1,
        "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_1_adapter")
    output = llm.generate(prompt,
                          sampling_params=vllm.SamplingParams(max_tokens=256,
                                                              temperature=0),
                          lora_request=lora_request)[0].outputs[0].text

    answer = output.strip()[0]

    assert answer.isdigit()
    assert int(answer) == 1


71
72
@pytest.mark.parametrize("tp", TPU_TENSOR_PARALLEL_SIZES)
def test_lora_hotswapping(tp: int):
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
    """
    This test ensures we can run multiple LoRA adapters on the TPU backend, even
    if we only have space to store 1.
    
    We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which
    will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x.
    """

    lora_name_template = \
        "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter"
    lora_requests = [
        LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i))
        for i in range(1, 5)
    ]

88
    llm = setup_vllm(1, tp)
89
90
91
92
93
94
95
96
97
98
99
100
101
102

    prompt = "What is 1+1? \n"

    for i, req in enumerate(lora_requests):
        output = llm.generate(prompt,
                              sampling_params=vllm.SamplingParams(
                                  max_tokens=256, temperature=0),
                              lora_request=req)[0].outputs[0].text
        answer = output.strip()[0]

        assert answer.isdigit()
        assert int(answer) == i + 1


103
104
@pytest.mark.parametrize("tp", TPU_TENSOR_PARALLEL_SIZES)
def test_multi_lora(tp: int):
105
106
107
108
109
110
111
112
113
114
115
116
117
118
    """
    This test ensures we can run multiple LoRA adapters on the TPU backend, when
    we have enough space to store all of them.
    
    We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which
    will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x.
    """
    lora_name_template = \
        "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter"
    lora_requests = [
        LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i))
        for i in range(1, 5)
    ]

119
    llm = setup_vllm(4, tp)
120
121
122
123
124
125
126
127
128
129
130
131
132

    prompt = "What is 1+1? \n"

    for i, req in enumerate(lora_requests):
        output = llm.generate(prompt,
                              sampling_params=vllm.SamplingParams(
                                  max_tokens=256, temperature=0),
                              lora_request=req)[0].outputs[0].text

        answer = output.strip()[0]

        assert answer.isdigit()
        assert int(output.strip()[0]) == i + 1