test_lora_adapters.py 9.35 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import asyncio
import json
import shutil
from contextlib import suppress

import openai  # use the official client for correctness check
import pytest
import pytest_asyncio

from ...utils import RemoteOpenAIServer

# any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
# technically this needs Mistral-7B-v0.1 as base, but we're not testing
# generation quality here

20
21
22
BADREQUEST_CASES = [
    (
        "test_rank",
23
        {"r": 1024},
24
25
26
27
        "is greater than max_lora_rank",
    ),
    (
        "test_bias",
28
        {"bias": "all"},
29
30
        "Adapter bias cannot be used without bias_enabled",
    ),
31
    ("test_dora", {"use_dora": True}, "does not yet support DoRA"),
32
33
    (
        "test_modules_to_save",
34
        {"modules_to_save": ["lm_head"]},
35
36
37
38
        "only supports modules_to_save being None",
    ),
]

39
40

@pytest.fixture(scope="module")
41
42
def monkeypatch_module():
    from _pytest.monkeypatch import MonkeyPatch
43

44
45
46
47
48
    mpatch = MonkeyPatch()
    yield mpatch
    mpatch.undo()


49
@pytest.fixture(scope="module", params=[True])
50
def server_with_lora_modules_json(request, monkeypatch_module, zephyr_lora_files):
51
    use_v1 = request.param
52
    assert use_v1
53
    monkeypatch_module.setenv("VLLM_USE_V1", "1")
54

55
56
57
58
    # Define the json format LoRA module configurations
    lora_module_1 = {
        "name": "zephyr-lora",
        "path": zephyr_lora_files,
59
        "base_model_name": MODEL_NAME,
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
    }

    args = [
        # use half precision for speed and memory savings in CI environment
        "--dtype",
        "bfloat16",
        "--max-model-len",
        "8192",
        "--enforce-eager",
        # lora config below
        "--enable-lora",
        "--lora-modules",
        json.dumps(lora_module_1),
        "--max-lora-rank",
        "64",
        "--max-cpu-loras",
        "2",
        "--max-num-seqs",
        "64",
    ]

    # Enable the /v1/load_lora_adapter endpoint
    envs = {"VLLM_ALLOW_RUNTIME_LORA_UPDATING": "True"}

    with RemoteOpenAIServer(MODEL_NAME, args, env_dict=envs) as remote_server:
        yield remote_server


@pytest_asyncio.fixture
async def client(server_with_lora_modules_json):
90
    async with server_with_lora_modules_json.get_async_client() as async_client:
91
92
93
94
        yield async_client


@pytest.mark.asyncio
95
async def test_static_lora_lineage(client: openai.AsyncOpenAI, zephyr_lora_files):
96
97
98
99
100
101
102
    models = await client.models.list()
    models = models.data
    served_model = models[0]
    lora_models = models[1:]
    assert served_model.id == MODEL_NAME
    assert served_model.root == MODEL_NAME
    assert served_model.parent is None
103
    assert all(lora_model.root == zephyr_lora_files for lora_model in lora_models)
104
105
106
107
108
    assert all(lora_model.parent == MODEL_NAME for lora_model in lora_models)
    assert lora_models[0].id == "zephyr-lora"


@pytest.mark.asyncio
109
110
111
112
113
114
async def test_dynamic_lora_lineage(client: openai.AsyncOpenAI, zephyr_lora_files):
    response = await client.post(
        "load_lora_adapter",
        cast_to=str,
        body={"lora_name": "zephyr-lora-3", "lora_path": zephyr_lora_files},
    )
115
116
117
118
119
120
121
122
123
124
125
126
127
128
    # Ensure adapter loads before querying /models
    assert "success" in response

    models = await client.models.list()
    models = models.data
    dynamic_lora_model = models[-1]
    assert dynamic_lora_model.root == zephyr_lora_files
    assert dynamic_lora_model.parent == MODEL_NAME
    assert dynamic_lora_model.id == "zephyr-lora-3"


@pytest.mark.asyncio
async def test_dynamic_lora_not_found(client: openai.AsyncOpenAI):
    with pytest.raises(openai.NotFoundError):
129
130
131
132
133
        await client.post(
            "load_lora_adapter",
            cast_to=str,
            body={"lora_name": "notfound", "lora_path": "/not/an/adapter"},
        )
134
135
136


@pytest.mark.asyncio
137
async def test_dynamic_lora_invalid_files(client: openai.AsyncOpenAI, tmp_path):
138
139
140
141
142
    invalid_files = tmp_path / "invalid_files"
    invalid_files.mkdir()
    (invalid_files / "adapter_config.json").write_text("this is not json")

    with pytest.raises(openai.BadRequestError):
143
144
145
146
147
        await client.post(
            "load_lora_adapter",
            cast_to=str,
            body={"lora_name": "invalid-json", "lora_path": str(invalid_files)},
        )
148
149
150


@pytest.mark.asyncio
151
152
153
154
155
156
157
158
159
@pytest.mark.parametrize("test_name,config_change,expected_error", BADREQUEST_CASES)
async def test_dynamic_lora_badrequests(
    client: openai.AsyncOpenAI,
    tmp_path,
    zephyr_lora_files,
    test_name: str,
    config_change: dict,
    expected_error: str,
):
160
161
162
163
164
165
166
167
168
    # Create test directory
    test_dir = tmp_path / test_name

    # Copy adapter files
    shutil.copytree(zephyr_lora_files, test_dir)

    # Load and modify configuration
    config_path = test_dir / "adapter_config.json"
    with open(config_path) as f:
169
        adapter_config = json.load(f)
170
171
    # Apply configuration changes
    adapter_config.update(config_change)
172

173
174
    # Save modified configuration
    with open(config_path, "w") as f:
175
176
        json.dump(adapter_config, f)

177
178
    # Test loading the adapter
    with pytest.raises(openai.BadRequestError, match=expected_error):
179
180
181
182
183
        await client.post(
            "load_lora_adapter",
            cast_to=str,
            body={"lora_name": test_name, "lora_path": str(test_dir)},
        )
184
185
186


@pytest.mark.asyncio
187
188
189
async def test_multiple_lora_adapters(
    client: openai.AsyncOpenAI, tmp_path, zephyr_lora_files
):
190
    """Validate that many loras can be dynamically registered and inferenced
191
192
193
194
195
    with concurrently"""

    # This test file configures the server with --max-cpu-loras=2 and this test
    # will concurrently load 10 adapters, so it should flex the LRU cache
    async def load_and_run_adapter(adapter_name: str):
196
197
198
199
200
        await client.post(
            "load_lora_adapter",
            cast_to=str,
            body={"lora_name": adapter_name, "lora_path": str(zephyr_lora_files)},
        )
201
202
203
204
205
206
207
208
209
        for _ in range(3):
            await client.completions.create(
                model=adapter_name,
                prompt=["Hello there", "Foo bar bazz buzz"],
                max_tokens=5,
            )

    lora_tasks = []
    for i in range(10):
210
        lora_tasks.append(asyncio.create_task(load_and_run_adapter(f"adapter_{i}")))
211
212
213
214
215
216
217
218
219

    results, _ = await asyncio.wait(lora_tasks)

    for r in results:
        assert not isinstance(r, Exception), f"Got exception {r}"


@pytest.mark.asyncio
async def test_loading_invalid_adapters_does_not_break_others(
220
221
    client: openai.AsyncOpenAI, tmp_path, zephyr_lora_files
):
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
    invalid_files = tmp_path / "invalid_files"
    invalid_files.mkdir()
    (invalid_files / "adapter_config.json").write_text("this is not json")

    stop_good_requests_event = asyncio.Event()

    async def run_good_requests(client):
        # Run chat completions requests until event set

        results = []

        while not stop_good_requests_event.is_set():
            try:
                batch = await client.completions.create(
                    model="zephyr-lora",
                    prompt=["Hello there", "Foo bar bazz buzz"],
                    max_tokens=5,
                )
                results.append(batch)
            except Exception as e:
                results.append(e)

        return results

    # Create task to run good requests
    good_task = asyncio.create_task(run_good_requests(client))

    # Run a bunch of bad adapter loads
    for _ in range(25):
        with suppress(openai.NotFoundError):
252
253
254
255
256
            await client.post(
                "load_lora_adapter",
                cast_to=str,
                body={"lora_name": "notfound", "lora_path": "/not/an/adapter"},
            )
257
258
    for _ in range(25):
        with suppress(openai.BadRequestError):
259
260
261
262
263
            await client.post(
                "load_lora_adapter",
                cast_to=str,
                body={"lora_name": "invalid", "lora_path": str(invalid_files)},
            )
264
265
266
267
268
269
270
271

    # Ensure all the running requests with lora adapters succeeded
    stop_good_requests_event.set()
    results = await good_task
    for r in results:
        assert not isinstance(r, Exception), f"Got exception {r}"

    # Ensure we can load another adapter and run it
272
273
274
275
276
    await client.post(
        "load_lora_adapter",
        cast_to=str,
        body={"lora_name": "valid", "lora_path": zephyr_lora_files},
    )
277
278
279
280
281
    await client.completions.create(
        model="valid",
        prompt=["Hello there", "Foo bar bazz buzz"],
        max_tokens=5,
    )
282
283
284
285
286
287
288
289
290
291
292


@pytest.mark.asyncio
async def test_beam_search_with_lora_adapters(
    client: openai.AsyncOpenAI,
    tmp_path,
    zephyr_lora_files,
):
    """Validate that async beam search can be used with lora."""

    async def load_and_run_adapter(adapter_name: str):
293
294
295
296
297
        await client.post(
            "load_lora_adapter",
            cast_to=str,
            body={"lora_name": adapter_name, "lora_path": str(zephyr_lora_files)},
        )
298
299
300
301
302
303
304
305
306
307
        for _ in range(3):
            await client.completions.create(
                model=adapter_name,
                prompt=["Hello there", "Foo bar bazz buzz"],
                max_tokens=5,
                extra_body=dict(use_beam_search=True),
            )

    lora_tasks = []
    for i in range(3):
308
        lora_tasks.append(asyncio.create_task(load_and_run_adapter(f"adapter_{i}")))
309
310
311
312
313

    results, _ = await asyncio.wait(lora_tasks)

    for r in results:
        assert not isinstance(r, Exception), f"Got exception {r}"