test_lora_adapters.py 11.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
import asyncio
import json
import shutil
from contextlib import suppress

import openai  # use the official client for correctness check
import pytest
import pytest_asyncio

from ...utils import RemoteOpenAIServer

# any model with a chat template should work here
16
17
MODEL_NAME = "Qwen/Qwen3-0.6B"

18

19
20
21
BADREQUEST_CASES = [
    (
        "test_rank",
22
        {"r": 1024},
23
24
        "is greater than max_lora_rank",
    ),
25
    ("test_dora", {"use_dora": True}, "does not yet support DoRA"),
26
27
    (
        "test_modules_to_save",
28
        {"modules_to_save": ["lm_head"]},
29
30
31
32
        "only supports modules_to_save being None",
    ),
]

33

34
@pytest.fixture(scope="module", params=[True])
35
def server_with_lora_modules_json(request, qwen3_lora_files):
36
37
    # Define the json format LoRA module configurations
    lora_module_1 = {
38
39
        "name": "qwen3-lora",
        "path": qwen3_lora_files,
40
        "base_model_name": MODEL_NAME,
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    }

    args = [
        # use half precision for speed and memory savings in CI environment
        "--dtype",
        "bfloat16",
        "--max-model-len",
        "8192",
        "--enforce-eager",
        # lora config below
        "--enable-lora",
        "--lora-modules",
        json.dumps(lora_module_1),
        "--max-lora-rank",
        "64",
        "--max-cpu-loras",
        "2",
        "--max-num-seqs",
        "64",
    ]

    # Enable the /v1/load_lora_adapter endpoint
    envs = {"VLLM_ALLOW_RUNTIME_LORA_UPDATING": "True"}

    with RemoteOpenAIServer(MODEL_NAME, args, env_dict=envs) as remote_server:
        yield remote_server


@pytest_asyncio.fixture
async def client(server_with_lora_modules_json):
71
    async with server_with_lora_modules_json.get_async_client() as async_client:
72
73
74
75
        yield async_client


@pytest.mark.asyncio
76
async def test_static_lora_lineage(client: openai.AsyncOpenAI, qwen3_lora_files):
77
78
79
80
81
82
83
    models = await client.models.list()
    models = models.data
    served_model = models[0]
    lora_models = models[1:]
    assert served_model.id == MODEL_NAME
    assert served_model.root == MODEL_NAME
    assert served_model.parent is None
84
    assert all(lora_model.root == qwen3_lora_files for lora_model in lora_models)
85
    assert all(lora_model.parent == MODEL_NAME for lora_model in lora_models)
86
    assert lora_models[0].id == "qwen3-lora"
87
88
89


@pytest.mark.asyncio
90
async def test_dynamic_lora_lineage(client: openai.AsyncOpenAI, qwen3_lora_files):
91
92
93
    response = await client.post(
        "load_lora_adapter",
        cast_to=str,
94
        body={"lora_name": "qwen3-lora-3", "lora_path": qwen3_lora_files},
95
    )
96
97
98
99
100
101
    # Ensure adapter loads before querying /models
    assert "success" in response

    models = await client.models.list()
    models = models.data
    dynamic_lora_model = models[-1]
102
    assert dynamic_lora_model.root == qwen3_lora_files
103
    assert dynamic_lora_model.parent == MODEL_NAME
104
    assert dynamic_lora_model.id == "qwen3-lora-3"
105
106


107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
@pytest.mark.asyncio
async def test_load_lora_adapter_with_same_name_replaces_inplace(
    client: openai.AsyncOpenAI, qwen3_meowing_lora_files, qwen3_woofing_lora_files
):
    """Test that loading a LoRA adapter with the same name replaces it inplace."""
    adapter_name = "replaceable-adapter"
    messages = [
        {"content": "Follow the instructions to make animal noises", "role": "system"},
        {"content": "Make your favorite animal noise.", "role": "user"},
    ]

    # Load LoRA that makes model meow
    response = await client.post(
        "load_lora_adapter",
        cast_to=str,
        body={"lora_name": adapter_name, "lora_path": qwen3_meowing_lora_files},
    )
    assert "success" in response.lower()

    completion = await client.chat.completions.create(
        model=adapter_name,
        messages=messages,
        max_tokens=10,
    )
    assert "Meow Meow Meow" in completion.choices[0].message.content

    # Load LoRA that makes model woof
    response = await client.post(
        "load_lora_adapter",
        cast_to=str,
        body={
            "lora_name": adapter_name,
            "lora_path": qwen3_woofing_lora_files,
            "load_inplace": True,
        },
    )
    assert "success" in response.lower()

    completion = await client.chat.completions.create(
        model=adapter_name,
        messages=messages,
        max_tokens=10,
    )
    assert "Woof Woof Woof" in completion.choices[0].message.content


@pytest.mark.asyncio
async def test_load_lora_adapter_with_load_inplace_false_errors(
    client: openai.AsyncOpenAI, qwen3_meowing_lora_files
):
    """Test that load_inplace=False returns an error when adapter already exists."""
    adapter_name = "test-load-inplace-false"

    # Load LoRA adapter first time (should succeed)
    response = await client.post(
        "load_lora_adapter",
        cast_to=str,
        body={"lora_name": adapter_name, "lora_path": qwen3_meowing_lora_files},
    )
    assert "success" in response.lower()

    # Try to load the same adapter again with load_inplace=False (should fail)
    with pytest.raises(openai.BadRequestError) as exc_info:
        await client.post(
            "load_lora_adapter",
            cast_to=str,
            body={
                "lora_name": adapter_name,
                "lora_path": qwen3_meowing_lora_files,
            },
        )

    # Verify the error message
    assert "already been loaded" in str(exc_info.value)


183
184
185
@pytest.mark.asyncio
async def test_dynamic_lora_not_found(client: openai.AsyncOpenAI):
    with pytest.raises(openai.NotFoundError):
186
187
188
189
190
        await client.post(
            "load_lora_adapter",
            cast_to=str,
            body={"lora_name": "notfound", "lora_path": "/not/an/adapter"},
        )
191
192
193


@pytest.mark.asyncio
194
async def test_dynamic_lora_invalid_files(client: openai.AsyncOpenAI, tmp_path):
195
196
197
198
    invalid_files = tmp_path / "invalid_files"
    invalid_files.mkdir()
    (invalid_files / "adapter_config.json").write_text("this is not json")

199
    with pytest.raises(openai.InternalServerError):
200
201
202
203
204
        await client.post(
            "load_lora_adapter",
            cast_to=str,
            body={"lora_name": "invalid-json", "lora_path": str(invalid_files)},
        )
205
206
207


@pytest.mark.asyncio
208
209
210
211
@pytest.mark.parametrize("test_name,config_change,expected_error", BADREQUEST_CASES)
async def test_dynamic_lora_badrequests(
    client: openai.AsyncOpenAI,
    tmp_path,
212
    qwen3_lora_files,
213
214
215
216
    test_name: str,
    config_change: dict,
    expected_error: str,
):
217
218
219
220
    # Create test directory
    test_dir = tmp_path / test_name

    # Copy adapter files
221
    shutil.copytree(qwen3_lora_files, test_dir)
222
223
224
225

    # Load and modify configuration
    config_path = test_dir / "adapter_config.json"
    with open(config_path) as f:
226
        adapter_config = json.load(f)
227
228
    # Apply configuration changes
    adapter_config.update(config_change)
229

230
231
    # Save modified configuration
    with open(config_path, "w") as f:
232
233
        json.dump(adapter_config, f)

234
    # Test loading the adapter
235
    with pytest.raises(openai.InternalServerError, match=expected_error):
236
237
238
239
240
        await client.post(
            "load_lora_adapter",
            cast_to=str,
            body={"lora_name": test_name, "lora_path": str(test_dir)},
        )
241
242
243


@pytest.mark.asyncio
244
async def test_multiple_lora_adapters(
245
    client: openai.AsyncOpenAI, tmp_path, qwen3_lora_files
246
):
247
    """Validate that many loras can be dynamically registered and inferenced
248
249
250
251
252
    with concurrently"""

    # This test file configures the server with --max-cpu-loras=2 and this test
    # will concurrently load 10 adapters, so it should flex the LRU cache
    async def load_and_run_adapter(adapter_name: str):
253
254
255
        await client.post(
            "load_lora_adapter",
            cast_to=str,
256
            body={"lora_name": adapter_name, "lora_path": str(qwen3_lora_files)},
257
        )
258
259
260
261
262
263
264
265
266
        for _ in range(3):
            await client.completions.create(
                model=adapter_name,
                prompt=["Hello there", "Foo bar bazz buzz"],
                max_tokens=5,
            )

    lora_tasks = []
    for i in range(10):
267
        lora_tasks.append(asyncio.create_task(load_and_run_adapter(f"adapter_{i}")))
268
269
270
271
272
273
274
275
276

    results, _ = await asyncio.wait(lora_tasks)

    for r in results:
        assert not isinstance(r, Exception), f"Got exception {r}"


@pytest.mark.asyncio
async def test_loading_invalid_adapters_does_not_break_others(
277
    client: openai.AsyncOpenAI, tmp_path, qwen3_lora_files
278
):
279
280
281
282
283
284
285
286
287
288
289
290
291
292
    invalid_files = tmp_path / "invalid_files"
    invalid_files.mkdir()
    (invalid_files / "adapter_config.json").write_text("this is not json")

    stop_good_requests_event = asyncio.Event()

    async def run_good_requests(client):
        # Run chat completions requests until event set

        results = []

        while not stop_good_requests_event.is_set():
            try:
                batch = await client.completions.create(
293
                    model="qwen3-lora",
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
                    prompt=["Hello there", "Foo bar bazz buzz"],
                    max_tokens=5,
                )
                results.append(batch)
            except Exception as e:
                results.append(e)

        return results

    # Create task to run good requests
    good_task = asyncio.create_task(run_good_requests(client))

    # Run a bunch of bad adapter loads
    for _ in range(25):
        with suppress(openai.NotFoundError):
309
310
311
312
313
            await client.post(
                "load_lora_adapter",
                cast_to=str,
                body={"lora_name": "notfound", "lora_path": "/not/an/adapter"},
            )
314
    for _ in range(25):
315
        with suppress(openai.InternalServerError):
316
317
318
319
320
            await client.post(
                "load_lora_adapter",
                cast_to=str,
                body={"lora_name": "invalid", "lora_path": str(invalid_files)},
            )
321
322
323
324
325
326
327
328

    # Ensure all the running requests with lora adapters succeeded
    stop_good_requests_event.set()
    results = await good_task
    for r in results:
        assert not isinstance(r, Exception), f"Got exception {r}"

    # Ensure we can load another adapter and run it
329
330
331
    await client.post(
        "load_lora_adapter",
        cast_to=str,
332
        body={"lora_name": "valid", "lora_path": qwen3_lora_files},
333
    )
334
335
336
337
338
    await client.completions.create(
        model="valid",
        prompt=["Hello there", "Foo bar bazz buzz"],
        max_tokens=5,
    )
339
340
341
342
343
344


@pytest.mark.asyncio
async def test_beam_search_with_lora_adapters(
    client: openai.AsyncOpenAI,
    tmp_path,
345
    qwen3_lora_files,
346
347
348
349
):
    """Validate that async beam search can be used with lora."""

    async def load_and_run_adapter(adapter_name: str):
350
351
352
        await client.post(
            "load_lora_adapter",
            cast_to=str,
353
            body={"lora_name": adapter_name, "lora_path": str(qwen3_lora_files)},
354
        )
355
356
357
358
359
360
361
362
363
364
        for _ in range(3):
            await client.completions.create(
                model=adapter_name,
                prompt=["Hello there", "Foo bar bazz buzz"],
                max_tokens=5,
                extra_body=dict(use_beam_search=True),
            )

    lora_tasks = []
    for i in range(3):
365
        lora_tasks.append(asyncio.create_task(load_and_run_adapter(f"adapter_{i}")))
366
367
368
369
370

    results, _ = await asyncio.wait(lora_tasks)

    for r in results:
        assert not isinstance(r, Exception), f"Got exception {r}"