test_lora_functions.py 4.87 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
"""
Script to test add_lora, remove_lora, pin_lora, list_loras functions.
"""
import pytest

from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
9
from vllm.engine.llm_engine import LLMEngine
10
11
from vllm.entrypoints.openai.api_server import (
    build_async_engine_client_from_engine_args)
12
13
14
15
16
17
from vllm.lora.request import LoRARequest

MODEL_PATH = "meta-llama/Llama-2-7b-hf"
LORA_MODULE_PATH = "yard1/llama-2-7b-sql-lora-test"
LORA_RANK = 8

18
19
20
21
22
23
# @pytest.fixture(autouse=True)
# def v1(run_with_both_engines_lora):
#     # Simple autouse wrapper to run both engines for each test
#     # This can be promoted up to conftest.py to run for every
#     # test in a package
#     pass
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44


def make_lora_request(lora_id: int):
    return LoRARequest(lora_name=f"{lora_id}",
                       lora_int_id=lora_id,
                       lora_path=LORA_MODULE_PATH)


def test_lora_functions_sync():

    max_loras = 4
    # Create engine in eager-mode. Due to high max_loras, the CI can
    # OOM during cuda-graph capture.
    engine_args = EngineArgs(model=MODEL_PATH,
                             enable_lora=True,
                             max_loras=max_loras,
                             max_lora_rank=LORA_RANK,
                             max_model_len=128,
                             gpu_memory_utilization=0.8,
                             enforce_eager=True)

45
    llm = LLMEngine.from_engine_args(engine_args)
46

47
    def run_check(fn, args, expected: list):
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
        fn(args)
        assert set(llm.list_loras()) == set(expected)

    run_check(llm.add_lora, make_lora_request(1), [1])
    run_check(llm.add_lora, make_lora_request(2), [1, 2])

    # Pin LoRA 1 and test that it is never removed on subsequent adds.
    run_check(llm.pin_lora, 1, [1, 2])
    run_check(llm.add_lora, make_lora_request(3), [1, 2, 3])
    run_check(llm.add_lora, make_lora_request(4), [1, 2, 3, 4])
    run_check(llm.add_lora, make_lora_request(5), [1, 5, 3, 4])
    run_check(llm.add_lora, make_lora_request(6), [1, 5, 6, 4])
    run_check(llm.add_lora, make_lora_request(7), [1, 5, 6, 7])
    run_check(llm.add_lora, make_lora_request(8), [1, 8, 6, 7])
    run_check(llm.add_lora, make_lora_request(9), [1, 8, 9, 7])
    run_check(llm.add_lora, make_lora_request(10), [1, 8, 9, 10])

    # Remove LoRA 1 and continue adding.
    run_check(llm.remove_lora, 1, [8, 9, 10])
    run_check(llm.add_lora, make_lora_request(11), [8, 9, 10, 11])
    run_check(llm.add_lora, make_lora_request(12), [12, 9, 10, 11])
    run_check(llm.add_lora, make_lora_request(13), [12, 13, 10, 11])

Jee Jee Li's avatar
Jee Jee Li committed
71
    # Remove all LoRAs.
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
    run_check(llm.remove_lora, 13, [12, 10, 11])
    run_check(llm.remove_lora, 12, [10, 11])
    run_check(llm.remove_lora, 11, [10])
    run_check(llm.remove_lora, 10, [])


@pytest.mark.asyncio
async def test_lora_functions_async():

    max_loras = 4
    engine_args = AsyncEngineArgs(model=MODEL_PATH,
                                  enable_lora=True,
                                  max_loras=max_loras,
                                  max_lora_rank=LORA_RANK,
                                  max_model_len=128,
                                  gpu_memory_utilization=0.8,
                                  enforce_eager=True)

90
    async def run_check(fn, args, expected: list):
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
        await fn(args)
        assert set(await llm.list_loras()) == set(expected)

    async with build_async_engine_client_from_engine_args(engine_args) as llm:
        await run_check(llm.add_lora, make_lora_request(1), [1])
        await run_check(llm.add_lora, make_lora_request(2), [1, 2])

        # Pin LoRA 1 and test that it is never removed on subsequent adds.
        await run_check(llm.pin_lora, 1, [1, 2])
        await run_check(llm.add_lora, make_lora_request(3), [1, 2, 3])
        await run_check(llm.add_lora, make_lora_request(4), [1, 2, 3, 4])
        await run_check(llm.add_lora, make_lora_request(5), [1, 5, 3, 4])
        await run_check(llm.add_lora, make_lora_request(6), [1, 5, 6, 4])
        await run_check(llm.add_lora, make_lora_request(7), [1, 5, 6, 7])
        await run_check(llm.add_lora, make_lora_request(8), [1, 8, 6, 7])
        await run_check(llm.add_lora, make_lora_request(9), [1, 8, 9, 7])
        await run_check(llm.add_lora, make_lora_request(10), [1, 8, 9, 10])

        # Remove LoRA 1 and continue adding.
        await run_check(llm.remove_lora, 1, [8, 9, 10])
        await run_check(llm.add_lora, make_lora_request(11), [8, 9, 10, 11])
        await run_check(llm.add_lora, make_lora_request(12), [12, 9, 10, 11])
        await run_check(llm.add_lora, make_lora_request(13), [12, 13, 10, 11])

        # Remove all LoRAs
        await run_check(llm.remove_lora, 13, [12, 10, 11])
        await run_check(llm.remove_lora, 12, [10, 11])
        await run_check(llm.remove_lora, 11, [10])
        await run_check(llm.remove_lora, 10, [])