test_lora_manager.py 31.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
import os

import pytest
import torch
from safetensors.torch import load_file
from torch import nn

11
from vllm.config import ModelConfig, VllmConfig
12
from vllm.config.lora import LoRAConfig
13
14
15
16
17
from vllm.lora.layers import (
    ColumnParallelLinearWithLoRA,
    MergedColumnParallelLinearWithLoRA,
    RowParallelLinearWithLoRA,
)
18
from vllm.lora.lora_model import LoRAModel
19
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
20
from vllm.lora.model_manager import (
21
    DEFAULT_LANGUAGE_WRAPPER_KEY,
22
23
24
25
    LoRAMapping,
    LoRAModelManager,
    LRUCacheLoRAModelManager,
)
26
from vllm.lora.peft_helper import PEFTHelper
27
from vllm.lora.request import LoRARequest
28
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager, WorkerLoRAManager
29
from vllm.platforms import current_platform
30

31
32
from .utils import create_peft_lora

Terry's avatar
Terry committed
33
34
35
36
37
EMBEDDING_MODULES = {
    "embed_tokens": "input_embeddings",
    "lm_head": "output_embeddings",
}

38
DEVICE_TYPE = current_platform.device_type
39
DEVICES = (
40
    [f"{DEVICE_TYPE}:{i}" for i in range(min(torch.accelerator.device_count(), 2))]
41
42
43
    if current_platform.is_cuda_alike()
    else ["cpu"]
)
44

45
46
DEFAULT_DTYPE = torch.get_default_dtype()

47

48
@pytest.mark.parametrize("device", DEVICES)
49
50
def test_from_lora_tensors(qwen3_lora_files, device):
    tensors = load_file(os.path.join(qwen3_lora_files, "adapter_model.safetensors"))
51

52
    peft_helper = PEFTHelper.from_local_dir(
53
        qwen3_lora_files, max_position_embeddings=4096
54
    )
Terry's avatar
Terry committed
55
56
57
    lora_model = LoRAModel.from_lora_tensors(
        1,
        tensors,
58
59
        peft_helper=peft_helper,
        device=device,
60
    )
61
62
63
    for module_name, lora in lora_model.loras.items():
        assert lora.module_name == module_name
        assert lora.rank == 8
64
        assert lora.lora_alpha == 32
65
66
        assert lora.lora_a is not None
        assert lora.lora_b is not None
67
68
        assert lora.lora_a.device == torch.device(device)
        assert lora.lora_b.device == torch.device(device)
69
70
71
        assert lora.lora_a.shape[0] == lora.lora_b.shape[1], (
            f"{lora.lora_a.shape=}, {lora.lora_b.shape=}"
        )
72
        assert lora.lora_a.shape[0] == 8
73
74


75
76
77
def create_lora(
    lora_id: int, model: nn.Module, sub_modules: list[str], device: torch.device
) -> LoRAModel:
78
    loras: dict[str, LoRALayerWeights] = {}
79
80
81
82
83
84
    for name in sub_modules:
        w = model.get_submodule(name).weight
        loras[name] = LoRALayerWeights(
            name,
            8,
            16,
85
86
            torch.rand([8, w.shape[1]], device=device),
            torch.rand([w.shape[0], 8], device=device),
87
88
89
90
91
92
93
94
95
        )
    return LoRAModel(lora_id, 8, loras)


def create_packed_lora(
    lora_id: int,
    model: nn.Module,
    module_name,
    replaced_module_names,
96
    device: torch.device,
97
98
99
    empty_replaced_module_name=None,
) -> LoRAModel:
    w = model.get_submodule(module_name).weight
100
    loras: dict[str, LoRALayerWeights] = {}
101
102
103
104
105
106
107
    for replaced_module_name in replaced_module_names:
        if replaced_module_name == empty_replaced_module_name:
            continue
        loras[replaced_module_name] = LoRALayerWeights(
            replaced_module_name,
            8,
            16,
108
            torch.rand([8, w.shape[1]], device=device),
109
            torch.rand([w.shape[0] // len(replaced_module_names), 8], device=device),
110
111
112
113
        )
    return LoRAModel(lora_id, 8, loras)


114
def test_replace_submodules(default_vllm_config, dist_init, dummy_model):
115
    model = dummy_model
Terry's avatar
Terry committed
116
    manager = LoRAModelManager(
117
118
119
120
121
122
123
124
125
        model,
        1,
        1,
        1,
        LoRAConfig(
            max_lora_rank=8, max_cpu_loras=8, max_loras=8, lora_dtype=DEFAULT_DTYPE
        ),
        torch.device(DEVICES[0]),
    )
126
    model = manager.model
127
128
129
130
    assert isinstance(model.get_submodule("dense1"), ColumnParallelLinearWithLoRA)
    assert isinstance(
        model.get_submodule("layer1.dense1"), ColumnParallelLinearWithLoRA
    )
131
    assert isinstance(model.get_submodule("dense2"), RowParallelLinearWithLoRA)
132
    assert isinstance(model.get_submodule("layer1.dense2"), RowParallelLinearWithLoRA)
133
134


135
@pytest.mark.parametrize("device", DEVICES)
136
def test_lora_model_manager(default_vllm_config, dist_init, dummy_model, device):
137
    model = dummy_model
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
    model_lora1 = create_lora(
        1, model, ["layer1.dense1", "dense2", "lm_head"], device=device
    )
    model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"], device=device)
    model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"], device=device)
    manager = LoRAModelManager(
        model,
        2,
        2,
        2,
        LoRAConfig(
            max_lora_rank=8, max_cpu_loras=3, max_loras=2, lora_dtype=DEFAULT_DTYPE
        ),
        device=device,
    )
153
    assert all(x is None for x in manager.lora_index_to_id)
154
155
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
156
    assert manager.lora_index_to_id[0] == 1
157
158
159
160
    assert not manager.add_adapter(model_lora1)
    assert not manager.activate_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
161
162
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
163
164
165
    assert not manager.add_adapter(model_lora2)
    assert not manager.activate_adapter(2)
    assert manager.add_adapter(model_lora3)
166
167
168
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
    with pytest.raises(ValueError):
169
        assert manager.activate_adapter(3)
170
171
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
172
    assert manager.remove_adapter(model_lora2.id)
173
    assert manager.lora_index_to_id[1] is None
174
175
176
177
    assert not manager.remove_adapter(model_lora2.id)
    assert manager.remove_adapter(model_lora1.id)
    assert not manager.remove_adapter(model_lora1.id)
    assert manager.add_adapter(model_lora1)
178
179
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] is None
180
181
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(3)
182
183
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] is None
184
    assert manager.activate_adapter(2)
185
186
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2
187
    assert manager.device == device
188
189
190
191
    assert (
        manager.punica_wrapper_mapping.get(DEFAULT_LANGUAGE_WRAPPER_KEY).device
        == device
    )
192
193
194
195
196
197
198
    assert hasattr(manager, "supported_lora_modules")
    assert sorted(manager.supported_lora_modules) == [
        "dense1",
        "dense2",
        "lm_head",
        "output",
    ]
199

200

201
@pytest.mark.parametrize("device", DEVICES)
202
203
204
def test_lora_lru_cache_model_manager(
    default_vllm_config, dist_init, dummy_model, device
):
205
    model = dummy_model
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
    model_lora1 = create_lora(
        1, model, ["layer1.dense1", "dense2", "lm_head"], device=device
    )
    model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"], device=device)
    model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"], device=device)
    manager = LRUCacheLoRAModelManager(
        model,
        2,
        2,
        2,
        LoRAConfig(
            max_lora_rank=8, max_cpu_loras=3, max_loras=2, lora_dtype=DEFAULT_DTYPE
        ),
        device=device,
    )
221
    assert all(x is None for x in manager.lora_index_to_id)
222
223
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
224
    assert manager.lora_index_to_id[0] == 1
225
226
227
228
    assert not manager.add_adapter(model_lora1)
    assert not manager.activate_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
229
230
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
231
232
233
    assert not manager.add_adapter(model_lora2)
    assert not manager.activate_adapter(2)
    assert manager.add_adapter(model_lora3)
234
235
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
236
    assert manager.activate_adapter(3)
237
238
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2
239
    assert manager.remove_adapter(model_lora2.id)
240
    assert manager.lora_index_to_id[1] is None
241
242
243
244
245
    assert not manager.remove_adapter(model_lora2.id)
    assert manager.remove_adapter(model_lora1.id)
    assert not manager.remove_adapter(model_lora1.id)
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
246
247
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
248
249
    assert manager.add_adapter(model_lora2)
    assert manager.deactivate_adapter(3)
250
251
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 1
252
    assert manager.activate_adapter(2)
253
254
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
255
    assert manager.activate_adapter(3)
256
257
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 3
258
    assert manager.pin_adapter(2)
259
260
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 3
261
    assert manager.activate_adapter(1)
262
263
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
264
    assert manager.deactivate_adapter(2)
265
266
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 1
267
    assert manager.activate_adapter(3)
268
269
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
270
271
    assert manager.pin_adapter(3)
    assert manager.pin_adapter(1)
272
    with pytest.raises(RuntimeError):
273
        assert manager.pin_adapter(2)
274
275
276
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
    with pytest.raises(RuntimeError):
277
        assert manager.activate_adapter(2)
278

279
280
    assert manager.deactivate_adapter(3)
    assert manager.pin_adapter(2)
281
282
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
283
    assert manager.remove_adapter(3)
284
    with pytest.raises(ValueError):
285
        assert manager.pin_adapter(3)
286
287
288
289
    assert (
        manager.punica_wrapper_mapping.get(DEFAULT_LANGUAGE_WRAPPER_KEY).device
        == device
    )
290
291
    assert manager.device == device

292

293
@pytest.mark.parametrize("device", DEVICES)
294
def test_lru_lora_model_manager(default_vllm_config, dist_init, dummy_model, device):
295
296
297
    # This tests just the LRU cache functionality, everything else is
    # tested in test_lora_model_manager
    model = dummy_model
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
    model_lora1 = create_lora(
        1, model, ["layer1.dense1", "dense2", "lm_head"], device=device
    )
    model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"], device=device)
    model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"], device=device)
    model_lora4 = create_lora(4, model, ["dense1", "dense2", "lm_head"], device=device)
    manager = LRUCacheLoRAModelManager(
        model,
        2,
        2,
        2,
        LoRAConfig(
            max_lora_rank=8, max_cpu_loras=2, max_loras=2, lora_dtype=DEFAULT_DTYPE
        ),
        device=device,
    )
314
315
316
    assert all(x is None for x in manager.lora_index_to_id)

    # Add up to capacity
317
318
319
320
    assert manager.add_adapter(model_lora1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(1)
    assert manager.activate_adapter(2)
321

322
    assert set(manager.list_adapters()) == {1, 2}
323
324
325
326
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2

    # Add over capacity
327
328
329
330
    assert manager.add_adapter(model_lora3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(3)
    assert manager.activate_adapter(4)
331

332
    assert set(manager.list_adapters()) == {3, 4}
333
334
335
336
337
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 4

    # Add 3 again to move it to the top and then add 2
    # should return false since it's in already
338
339
340
341
    assert not manager.add_adapter(model_lora3)
    assert not manager.activate_adapter(3)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
342

343
    assert set(manager.list_adapters()) == {3, 2}
344
345
346
347
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2

    # Remove manually
348
349
    assert manager.remove_adapter(3)
    assert not manager.remove_adapter(3)
350

351
    assert set(manager.list_adapters()) == {2}
352
353
354
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 2

355
356
357
358
    assert manager.add_adapter(model_lora3)
    assert manager.activate_adapter(3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(4)
359

360
    assert set(manager.list_adapters()) == {3, 4}
361
362
363
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 4

364
365
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == {4}
366
367
368
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 4

369
370
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == set()
371
372
    assert all(x is None for x in manager.lora_index_to_id)

373
374
    assert not manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == set()
375
376
    assert all(x is None for x in manager.lora_index_to_id)

377
    # pinning
378
379
380
381
382
    assert manager.add_adapter(model_lora3)
    assert manager.activate_adapter(3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(4)
    assert set(manager.list_adapters()) == {3, 4}
383
    with pytest.raises(ValueError):
384
385
        assert manager.pin_adapter(1)
    assert manager.pin_adapter(3)
386
    # Remove manually
387
388
    assert manager.remove_adapter(3)
    assert not manager.remove_adapter(3)
389

390
    assert set(manager.list_adapters()) == {4}
391
392
393
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 4

394
395
396
397
    assert manager.add_adapter(model_lora1)
    assert manager.pin_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
398

399
    assert set(manager.list_adapters()) == {1, 2}
400
401
402
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2

403
404
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == {1}
405
406
407
408
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] is None

    with pytest.raises(RuntimeError):
409
        assert manager.remove_oldest_adapter()
410

411
    assert set(manager.list_adapters()) == {1}
412
413
414
415
    assert (
        manager.punica_wrapper_mapping.get(DEFAULT_LANGUAGE_WRAPPER_KEY).device
        == device
    )
416
    assert manager.device == device
417

418

419
@pytest.mark.parametrize("device", DEVICES)
420
421
422
def test_lru_cache_worker_adapter_manager(
    default_vllm_config, dist_init, dummy_model, device, tmp_path
):
423
424
425
    lora_config = LoRAConfig(
        max_lora_rank=8, max_cpu_loras=4, max_loras=4, lora_dtype=DEFAULT_DTYPE
    )
426
427
428
429
430
431
432
433
434

    dummy_lora_files = f"{tmp_path}/lora_adapter"
    os.makedirs(dummy_lora_files, exist_ok=True)
    create_peft_lora(
        dummy_model,
        save_dir=dummy_lora_files,
        target_modules=["layer1.dense1", "dense2"],
        lora_dtype=DEFAULT_DTYPE,
    )
435
436

    model_config = ModelConfig(max_model_len=16)
437
    vllm_config = VllmConfig(model_config=model_config, lora_config=lora_config)
438
439
440

    vllm_config.scheduler_config.max_num_seqs = 4
    vllm_config.scheduler_config.max_num_batched_tokens = 2
441
    worker_adapter_manager = LRUCacheWorkerLoRAManager(
442
        vllm_config, device, EMBEDDING_MODULES
443
    )
444
445
446
447

    worker_adapter_manager.max_num_seqs = 4
    worker_adapter_manager.max_num_batched_tokens = 2

448
    worker_adapter_manager.create_lora_manager(dummy_model)
449
450

    mapping = LoRAMapping([], [])
451
452
453
454
    worker_adapter_manager.set_active_adapters(
        [LoRARequest("1", 1, dummy_lora_files), LoRARequest("2", 2, dummy_lora_files)],
        mapping,
    )
455
456
457
    assert worker_adapter_manager.list_adapters() == {1, 2}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
458

459
460
461
462
463
464
465
466
    worker_adapter_manager.set_active_adapters(
        [
            LoRARequest("1", 1, dummy_lora_files),
            LoRARequest("3", 3, dummy_lora_files),
            LoRARequest("4", 4, dummy_lora_files),
        ],
        mapping,
    )
467
468
469
470
471
    assert worker_adapter_manager.list_adapters() == {1, 2, 3, 4}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 3
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
472

473
474
475
476
477
478
479
480
    worker_adapter_manager.set_active_adapters(
        [
            LoRARequest("1", 1, dummy_lora_files),
            LoRARequest("2", 2, dummy_lora_files),
            LoRARequest("5", 5, dummy_lora_files),
        ],
        mapping,
    )
481
482
483
484
485
    assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
486

487
488
489
490
491
492
493
494
    worker_adapter_manager.set_active_adapters(
        [
            LoRARequest("1", 1, dummy_lora_files),
            LoRARequest("1", 1, dummy_lora_files),
            LoRARequest("1", 1, dummy_lora_files),
        ],
        mapping,
    )
495
496
497
498
499
    assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
500

501
502
503
504
505
506
507
508
    worker_adapter_manager.set_active_adapters(
        [
            LoRARequest("6", 6, dummy_lora_files),
            LoRARequest("7", 7, dummy_lora_files),
            LoRARequest("8", 8, dummy_lora_files),
        ],
        mapping,
    )
509
510
511
512
513
    assert worker_adapter_manager.list_adapters() == {1, 6, 7, 8}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 7
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 8
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 6
514
515
516

    # Over capacity
    with pytest.raises(RuntimeError):
517
518
519
520
521
522
523
524
525
526
        worker_adapter_manager.set_active_adapters(
            [
                LoRARequest("10", 10, dummy_lora_files),
                LoRARequest("11", 11, dummy_lora_files),
                LoRARequest("12", 12, dummy_lora_files),
                LoRARequest("13", 13, dummy_lora_files),
                LoRARequest("14", 14, dummy_lora_files),
            ],
            mapping,
        )
527

528
    assert worker_adapter_manager.device == device
529
530
531
532
    punica_wrapper = worker_adapter_manager._adapter_manager.punica_wrapper_mapping.get(
        DEFAULT_LANGUAGE_WRAPPER_KEY
    )
    assert punica_wrapper.device == device
533

534

535
@pytest.mark.parametrize("device", DEVICES)
536
537
538
def test_worker_adapter_manager(
    default_vllm_config, dist_init, dummy_model_gate_up, device, tmp_path
):
539
    # Should remove every LoRA not specified in the request.
540
541
542
    lora_config = LoRAConfig(
        max_lora_rank=8, max_cpu_loras=4, max_loras=4, lora_dtype=DEFAULT_DTYPE
    )
543
544

    model_config = ModelConfig(max_model_len=16)
545
    vllm_config = VllmConfig(model_config=model_config, lora_config=lora_config)
546
547
548
549

    vllm_config.scheduler_config.max_num_seqs = 4
    vllm_config.scheduler_config.max_num_batched_tokens = 2

550
    worker_adapter_manager = WorkerLoRAManager(vllm_config, device, EMBEDDING_MODULES)
551
    worker_adapter_manager.vocab_size = dummy_model_gate_up.unpadded_vocab_size
552
553
554
555
556
557
558
559
560
561
    worker_adapter_manager.create_lora_manager(dummy_model_gate_up)

    dummy_lora_files = f"{tmp_path}/lora_adapter"
    os.makedirs(dummy_lora_files, exist_ok=True)
    create_peft_lora(
        dummy_model_gate_up,
        save_dir=dummy_lora_files,
        target_modules=["layer1.dense1", "dense2"],
        lora_dtype=DEFAULT_DTYPE,
    )
562
563

    mapping = LoRAMapping([], [])
564
565
566
567
    worker_adapter_manager.set_active_adapters(
        [LoRARequest("1", 1, dummy_lora_files), LoRARequest("2", 2, dummy_lora_files)],
        mapping,
    )
568
569
570
    assert worker_adapter_manager.list_adapters() == {1, 2}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
571

572
573
574
575
576
577
578
579
    worker_adapter_manager.set_active_adapters(
        [
            LoRARequest("1", 1, dummy_lora_files),
            LoRARequest("3", 3, dummy_lora_files),
            LoRARequest("4", 4, dummy_lora_files),
        ],
        mapping,
    )
580
581
582
583
    assert worker_adapter_manager.list_adapters() == {1, 3, 4}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 3
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 4
584

585
586
587
588
589
590
591
592
    worker_adapter_manager.set_active_adapters(
        [
            LoRARequest("1", 1, dummy_lora_files),
            LoRARequest("2", 2, dummy_lora_files),
            LoRARequest("5", 5, dummy_lora_files),
        ],
        mapping,
    )
593
594
595
596
    assert worker_adapter_manager.list_adapters() == {1, 2, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
597

598
599
600
601
602
603
604
605
    worker_adapter_manager.set_active_adapters(
        [
            LoRARequest("1", 1, dummy_lora_files),
            LoRARequest("1", 1, dummy_lora_files),
            LoRARequest("1", 1, dummy_lora_files),
        ],
        mapping,
    )
606
607
608
609
    assert worker_adapter_manager.list_adapters() == {1}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] is None
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] is None
610

611
612
613
614
615
616
617
618
    worker_adapter_manager.set_active_adapters(
        [
            LoRARequest("6", 6, dummy_lora_files),
            LoRARequest("7", 7, dummy_lora_files),
            LoRARequest("8", 8, dummy_lora_files),
        ],
        mapping,
    )
619
620
621
622
    assert worker_adapter_manager.list_adapters() == {6, 7, 8}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 8
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 6
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 7
623
624
625

    # Over capacity
    with pytest.raises(RuntimeError):
626
627
628
629
630
631
632
633
634
635
        worker_adapter_manager.set_active_adapters(
            [
                LoRARequest("10", 10, dummy_lora_files),
                LoRARequest("11", 11, dummy_lora_files),
                LoRARequest("12", 12, dummy_lora_files),
                LoRARequest("13", 13, dummy_lora_files),
                LoRARequest("14", 14, dummy_lora_files),
            ],
            mapping,
        )
636

637
    assert worker_adapter_manager.device == device
638
639
640
641
    punica_wrapper = worker_adapter_manager._adapter_manager.punica_wrapper_mapping.get(
        DEFAULT_LANGUAGE_WRAPPER_KEY
    )
    assert punica_wrapper.device == device
642

643

644
@pytest.mark.parametrize("device", DEVICES)
645
def test_packed_loras(default_vllm_config, dist_init, dummy_model_gate_up, device):
646
647
648
649
650
    model = dummy_model_gate_up
    model_lora = create_packed_lora(
        1,
        model,
        module_name="gate_up_proj",
651
        replaced_module_names=["gate_proj", "up_proj"],
652
653
        device=device,
    )
654
655
656
657
658
    model_lora1 = create_packed_lora(
        2,
        model,
        module_name="gate_up_proj",
        replaced_module_names=["gate_proj", "up_proj"],
659
        device=device,
660
661
662
        empty_replaced_module_name="gate_proj",
    )

663
664
665
666
667
668
669
670
671
672
    manager = LoRAModelManager(
        model,
        2,
        2,
        2,
        LoRAConfig(
            max_lora_rank=8, max_cpu_loras=2, max_loras=2, lora_dtype=DEFAULT_DTYPE
        ),
        device=device,
    )
673
674
    model = manager.model

675
676
677
    assert isinstance(
        model.get_submodule("gate_up_proj"), MergedColumnParallelLinearWithLoRA
    )
678
679
680
    # Verify packed lora is correct
    model_lora_clone = model_lora.clone(1)
    model_lora_clone1 = model_lora1.clone(1)
681
682
    assert manager.add_adapter(model_lora)
    assert manager.add_adapter(model_lora1)
683

684
685
686
    assert model_lora.get_lora("gate_proj") is None
    assert model_lora.get_lora("up_proj") is None
    assert model_lora1.get_lora("up_proj") is None
687
688
689
    packed_lora = model_lora.get_lora("gate_up_proj")
    assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights)

690
691
692
693
694
695
696
697
698
699
700
701
    torch.testing.assert_close(
        packed_lora.lora_a[0], model_lora_clone.get_lora("gate_proj").lora_a
    )
    torch.testing.assert_close(
        packed_lora.lora_b[0], model_lora_clone.get_lora("gate_proj").lora_b
    )
    torch.testing.assert_close(
        packed_lora.lora_a[1], model_lora_clone.get_lora("up_proj").lora_a
    )
    torch.testing.assert_close(
        packed_lora.lora_b[1], model_lora_clone.get_lora("up_proj").lora_b
    )
702
703
704
705
706
707

    packed_lora1 = model_lora1.get_lora("gate_up_proj")
    assert packed_lora1 and isinstance(packed_lora1, PackedLoRALayerWeights)

    assert packed_lora1.lora_a[0] is None
    assert packed_lora1.lora_b[0] is None
708
709
710
711
712
713
    torch.testing.assert_close(
        packed_lora1.lora_a[1], model_lora_clone1.get_lora("up_proj").lora_a
    )
    torch.testing.assert_close(
        packed_lora1.lora_b[1], model_lora_clone1.get_lora("up_proj").lora_b
    )
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902


def _test_target_modules(
    model,
    target_modules: list[str] | None,
    device: str,
    expected_lora: list[tuple[str, type]],
    expected_no_lora: list[tuple[str, type]],
):
    """Create a LoRAModelManager and assert which modules have LoRA applied."""
    LoRAModelManager(
        model,
        2,
        2,
        2,
        LoRAConfig(
            max_lora_rank=8,
            max_cpu_loras=2,
            max_loras=2,
            lora_dtype=DEFAULT_DTYPE,
            target_modules=target_modules,
        ),
        device=device,
    )
    for module_path, lora_cls in expected_lora:
        assert isinstance(model.get_submodule(module_path), lora_cls)
    for module_path, lora_cls in expected_no_lora:
        assert not isinstance(model.get_submodule(module_path), lora_cls)


@pytest.mark.parametrize("device", DEVICES)
def test_target_modules_config(default_vllm_config, dist_init, dummy_model, device):
    """Test that target_modules config restricts which modules get LoRA applied."""
    _test_target_modules(
        dummy_model,
        ["dense1"],
        device,
        expected_lora=[
            ("dense1", ColumnParallelLinearWithLoRA),
            ("layer1.dense1", ColumnParallelLinearWithLoRA),
        ],
        expected_no_lora=[
            ("dense2", RowParallelLinearWithLoRA),
            ("layer1.dense2", RowParallelLinearWithLoRA),
        ],
    )


@pytest.mark.parametrize("device", DEVICES)
def test_target_modules_multiple(default_vllm_config, dist_init, dummy_model, device):
    """Test that multiple target_modules work correctly."""
    _test_target_modules(
        dummy_model,
        ["dense1", "dense2"],
        device,
        expected_lora=[
            ("dense1", ColumnParallelLinearWithLoRA),
            ("layer1.dense1", ColumnParallelLinearWithLoRA),
            ("dense2", RowParallelLinearWithLoRA),
            ("layer1.dense2", RowParallelLinearWithLoRA),
        ],
        expected_no_lora=[],
    )


@pytest.mark.parametrize("device", DEVICES)
def test_target_modules_none_uses_all(
    default_vllm_config, dist_init, dummy_model, device
):
    """Test that target_modules=None uses all supported modules."""
    _test_target_modules(
        dummy_model,
        None,
        device,
        expected_lora=[
            ("dense1", ColumnParallelLinearWithLoRA),
            ("layer1.dense1", ColumnParallelLinearWithLoRA),
            ("dense2", RowParallelLinearWithLoRA),
            ("layer1.dense2", RowParallelLinearWithLoRA),
        ],
        expected_no_lora=[],
    )


@pytest.mark.parametrize("device", DEVICES)
def test_load_adapter_warns_on_unsupported_modules(
    default_vllm_config, dist_init, dummy_model_gate_up, device, tmp_path
):
    """Test that _load_adapter warns when a LoRA adapter contains modules
    not in the model's supported LoRA target modules."""
    from unittest.mock import patch

    import vllm.lora.worker_manager as wm_module

    lora_config = LoRAConfig(
        max_lora_rank=8, max_cpu_loras=4, max_loras=4, lora_dtype=DEFAULT_DTYPE
    )

    dummy_lora_files = f"{tmp_path}/lora_adapter"
    os.makedirs(dummy_lora_files, exist_ok=True)
    create_peft_lora(
        dummy_model_gate_up,
        save_dir=dummy_lora_files,
        target_modules=["layer1.dense1", "dense2"],
        lora_dtype=DEFAULT_DTYPE,
    )

    model_config = ModelConfig(max_model_len=16)
    vllm_config = VllmConfig(model_config=model_config, lora_config=lora_config)
    vllm_config.scheduler_config.max_num_seqs = 4
    vllm_config.scheduler_config.max_num_batched_tokens = 2

    worker_manager = WorkerLoRAManager(vllm_config, device, EMBEDDING_MODULES)
    worker_manager.vocab_size = dummy_model_gate_up.unpadded_vocab_size
    worker_manager.create_lora_manager(dummy_model_gate_up)

    # Patch from_local_checkpoint to inject an unsupported module
    original_from_checkpoint = LoRAModel.from_local_checkpoint

    def patched_from_checkpoint(*args, **kwargs):
        lora = original_from_checkpoint(*args, **kwargs)
        lora.loras["unsupported_module"] = LoRALayerWeights(
            module_name="unsupported_module",
            rank=8,
            lora_alpha=16,
            lora_a=torch.randn(8, 10),
            lora_b=torch.randn(10, 8),
        )
        return lora

    lora_request = LoRARequest("test", 1, dummy_lora_files)
    with (
        patch.object(LoRAModel, "from_local_checkpoint", patched_from_checkpoint),
        patch.object(wm_module.logger, "warning_once") as mock_warning,
    ):
        worker_manager._load_adapter(lora_request)
        warning_args = mock_warning.call_args_list
        found = any("unsupported_module" in str(call) for call in warning_args)
        assert found, (
            f"Expected warning about 'unsupported_module', got: {warning_args}"
        )


@pytest.mark.parametrize("device", DEVICES)
def test_load_adapter_warns_on_target_modules_restriction(
    default_vllm_config, dist_init, dummy_model_gate_up, device, tmp_path
):
    """Test that _load_adapter warns when a LoRA adapter contains modules
    excluded by the deployment-time target_modules restriction."""
    from unittest.mock import patch

    import vllm.lora.worker_manager as wm_module

    # Restrict to only dense2 — adapter has dense1 which will be excluded
    lora_config = LoRAConfig(
        max_lora_rank=8,
        max_cpu_loras=4,
        max_loras=4,
        lora_dtype=DEFAULT_DTYPE,
        target_modules=["dense2"],
    )

    dummy_lora_files = f"{tmp_path}/lora_adapter"
    os.makedirs(dummy_lora_files, exist_ok=True)
    create_peft_lora(
        dummy_model_gate_up,
        save_dir=dummy_lora_files,
        target_modules=["layer1.dense1", "dense2"],
        lora_dtype=DEFAULT_DTYPE,
    )

    model_config = ModelConfig(max_model_len=16)
    vllm_config = VllmConfig(model_config=model_config, lora_config=lora_config)
    vllm_config.scheduler_config.max_num_seqs = 4
    vllm_config.scheduler_config.max_num_batched_tokens = 2

    worker_manager = WorkerLoRAManager(vllm_config, device, EMBEDDING_MODULES)
    worker_manager.vocab_size = dummy_model_gate_up.unpadded_vocab_size
    worker_manager.create_lora_manager(dummy_model_gate_up)

    lora_request = LoRARequest("test", 1, dummy_lora_files)
    with patch.object(wm_module.logger, "warning_once") as mock_warning:
        worker_manager._load_adapter(lora_request)
        warning_args = mock_warning.call_args_list
        # dense1 is supported by the model but excluded by target_modules
        found = any("target_modules" in str(call) for call in warning_args)
        assert found, (
            f"Expected warning about target_modules restriction, got: {warning_args}"
        )