test_lora_manager.py 36.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
import os

import pytest
import torch
from safetensors.torch import load_file
from torch import nn

11
from vllm.config import ModelConfig, VllmConfig
12
from vllm.config.lora import LoRAConfig
13
14
15
from vllm.lora.layers import (
    ColumnParallelLinearWithLoRA,
    MergedColumnParallelLinearWithLoRA,
16
    ReplicatedLinearWithLoRA,
17
18
    RowParallelLinearWithLoRA,
)
19
from vllm.lora.lora_model import LoRAModel
20
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
21
from vllm.lora.model_manager import (
22
    DEFAULT_LANGUAGE_WRAPPER_KEY,
23
24
25
26
    LoRAMapping,
    LoRAModelManager,
    LRUCacheLoRAModelManager,
)
27
from vllm.lora.peft_helper import PEFTHelper
28
from vllm.lora.request import LoRARequest
29
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager, WorkerLoRAManager
30
from vllm.model_executor.layers.fused_moe import GateLinear
31
from vllm.platforms import current_platform
32

33
34
from .utils import create_peft_lora

Terry's avatar
Terry committed
35
36
37
38
39
EMBEDDING_MODULES = {
    "embed_tokens": "input_embeddings",
    "lm_head": "output_embeddings",
}

40
DEVICE_TYPE = current_platform.device_type
41
DEVICES = (
42
    [f"{DEVICE_TYPE}:{i}" for i in range(min(torch.accelerator.device_count(), 2))]
43
44
45
    if current_platform.is_cuda_alike()
    else ["cpu"]
)
46

47
48
DEFAULT_DTYPE = torch.get_default_dtype()

49

50
@pytest.mark.parametrize("device", DEVICES)
51
52
def test_from_lora_tensors(qwen3_lora_files, device):
    tensors = load_file(os.path.join(qwen3_lora_files, "adapter_model.safetensors"))
53

54
    peft_helper = PEFTHelper.from_local_dir(
55
        qwen3_lora_files, max_position_embeddings=4096
56
    )
Terry's avatar
Terry committed
57
58
59
    lora_model = LoRAModel.from_lora_tensors(
        1,
        tensors,
60
61
        peft_helper=peft_helper,
        device=device,
62
    )
63
64
65
    for module_name, lora in lora_model.loras.items():
        assert lora.module_name == module_name
        assert lora.rank == 8
66
        assert lora.lora_alpha == 32
67
68
        assert lora.lora_a is not None
        assert lora.lora_b is not None
69
70
        assert lora.lora_a.device == torch.device(device)
        assert lora.lora_b.device == torch.device(device)
71
72
73
        assert lora.lora_a.shape[0] == lora.lora_b.shape[1], (
            f"{lora.lora_a.shape=}, {lora.lora_b.shape=}"
        )
74
        assert lora.lora_a.shape[0] == 8
75
76


77
78
79
def create_lora(
    lora_id: int, model: nn.Module, sub_modules: list[str], device: torch.device
) -> LoRAModel:
80
    loras: dict[str, LoRALayerWeights] = {}
81
82
83
84
85
86
    for name in sub_modules:
        w = model.get_submodule(name).weight
        loras[name] = LoRALayerWeights(
            name,
            8,
            16,
87
88
            torch.rand([8, w.shape[1]], device=device),
            torch.rand([w.shape[0], 8], device=device),
89
90
91
92
93
94
95
96
97
        )
    return LoRAModel(lora_id, 8, loras)


def create_packed_lora(
    lora_id: int,
    model: nn.Module,
    module_name,
    replaced_module_names,
98
    device: torch.device,
99
100
101
    empty_replaced_module_name=None,
) -> LoRAModel:
    w = model.get_submodule(module_name).weight
102
    loras: dict[str, LoRALayerWeights] = {}
103
104
105
106
107
108
109
    for replaced_module_name in replaced_module_names:
        if replaced_module_name == empty_replaced_module_name:
            continue
        loras[replaced_module_name] = LoRALayerWeights(
            replaced_module_name,
            8,
            16,
110
            torch.rand([8, w.shape[1]], device=device),
111
            torch.rand([w.shape[0] // len(replaced_module_names), 8], device=device),
112
113
114
115
        )
    return LoRAModel(lora_id, 8, loras)


116
def test_replace_submodules(default_vllm_config, dist_init, dummy_model):
117
    model = dummy_model
Terry's avatar
Terry committed
118
    manager = LoRAModelManager(
119
120
121
122
123
124
125
126
127
        model,
        1,
        1,
        1,
        LoRAConfig(
            max_lora_rank=8, max_cpu_loras=8, max_loras=8, lora_dtype=DEFAULT_DTYPE
        ),
        torch.device(DEVICES[0]),
    )
128
    model = manager.model
129
130
131
132
    assert isinstance(model.get_submodule("dense1"), ColumnParallelLinearWithLoRA)
    assert isinstance(
        model.get_submodule("layer1.dense1"), ColumnParallelLinearWithLoRA
    )
133
    assert isinstance(model.get_submodule("dense2"), RowParallelLinearWithLoRA)
134
    assert isinstance(model.get_submodule("layer1.dense2"), RowParallelLinearWithLoRA)
135
136


137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
def test_wrap_replicated_linear_subclasses(default_vllm_config, dist_init, dummy_model):
    from vllm.model_executor.layers.linear import ReplicatedLinear

    class CustomReplicatedLinear(ReplicatedLinear):
        pass

    model = dummy_model
    model.add_module("custom_gate", CustomReplicatedLinear(10, 10, bias=False))

    manager = LoRAModelManager(
        model,
        1,
        1,
        1,
        LoRAConfig(
            max_lora_rank=8, max_cpu_loras=8, max_loras=8, lora_dtype=DEFAULT_DTYPE
        ),
        torch.device(DEVICES[0]),
    )

    assert isinstance(
        manager.model.get_submodule("custom_gate"), ReplicatedLinearWithLoRA
    )


def test_wrap_gate_linear(default_vllm_config, dist_init, dummy_model):
    model = dummy_model
    model.add_module("router_gate", GateLinear(10, 4, bias=False))

    manager = LoRAModelManager(
        model,
        1,
        1,
        1,
        LoRAConfig(
            max_lora_rank=8, max_cpu_loras=8, max_loras=8, lora_dtype=DEFAULT_DTYPE
        ),
        torch.device(DEVICES[0]),
    )

    assert isinstance(
        manager.model.get_submodule("router_gate"), ReplicatedLinearWithLoRA
    )


def test_skip_unsupported_matched_modules(default_vllm_config, dist_init, dummy_model):
    class UnsupportedContainer(nn.Module):
        def __init__(self):
            super().__init__()
            # This name matches a supported target suffix ("dense1"),
            # but nn.Linear is not currently a LoRA-wrappable layer type.
            self.dense1 = nn.Linear(10, 10, bias=False)

    model = dummy_model
    model.add_module("unsupported", UnsupportedContainer())

    manager = LoRAModelManager(
        model,
        1,
        1,
        1,
        LoRAConfig(
            max_lora_rank=8, max_cpu_loras=8, max_loras=8, lora_dtype=DEFAULT_DTYPE
        ),
        torch.device(DEVICES[0]),
    )

    # Should not crash and should keep unsupported matched modules unchanged.
    assert isinstance(manager.model.get_submodule("unsupported.dense1"), nn.Linear)
    assert "unsupported.dense1" not in manager.modules


def test_target_modules_fail_closed_on_unsupported_matched_modules(
    default_vllm_config, dist_init, dummy_model
):
    class UnsupportedContainer(nn.Module):
        def __init__(self):
            super().__init__()
            self.dense1 = nn.Linear(10, 10, bias=False)

    model = dummy_model
    model.add_module("unsupported", UnsupportedContainer())

    with pytest.raises(ValueError, match="unsupported.dense1"):
        LoRAModelManager(
            model,
            1,
            1,
            1,
            LoRAConfig(
                max_lora_rank=8,
                max_cpu_loras=8,
                max_loras=8,
                lora_dtype=DEFAULT_DTYPE,
                target_modules=["dense1"],
            ),
            torch.device(DEVICES[0]),
        )


def test_get_dummy_lora_warmup_rank_for_fully_sharded_moe():
    manager = LoRAModelManager.__new__(LoRAModelManager)
    manager.lora_config = LoRAConfig(
        max_lora_rank=64,
        max_cpu_loras=1,
        max_loras=1,
        lora_dtype=DEFAULT_DTYPE,
        fully_sharded_loras=True,
    )

    class DummyModule:
        def __init__(self, tp_size: int, fully_sharded: bool):
            self.tp_size = tp_size
            self.fully_sharded = fully_sharded

    manager.modules = {
        "model.layers.0.self_attn.q_proj": DummyModule(
            tp_size=32,
            fully_sharded=True,
        ),
        "model.layers.0.mlp.experts": DummyModule(
            tp_size=32,
            fully_sharded=True,
        ),
    }

    assert manager.get_dummy_lora_warmup_rank(8) == 32


266
@pytest.mark.parametrize("device", DEVICES)
267
def test_lora_model_manager(default_vllm_config, dist_init, dummy_model, device):
268
    model = dummy_model
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
    model_lora1 = create_lora(
        1, model, ["layer1.dense1", "dense2", "lm_head"], device=device
    )
    model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"], device=device)
    model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"], device=device)
    manager = LoRAModelManager(
        model,
        2,
        2,
        2,
        LoRAConfig(
            max_lora_rank=8, max_cpu_loras=3, max_loras=2, lora_dtype=DEFAULT_DTYPE
        ),
        device=device,
    )
284
    assert all(x is None for x in manager.lora_index_to_id)
285
286
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
287
    assert manager.lora_index_to_id[0] == 1
288
289
290
291
    assert not manager.add_adapter(model_lora1)
    assert not manager.activate_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
292
293
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
294
295
296
    assert not manager.add_adapter(model_lora2)
    assert not manager.activate_adapter(2)
    assert manager.add_adapter(model_lora3)
297
298
299
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
    with pytest.raises(ValueError):
300
        assert manager.activate_adapter(3)
301
302
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
303
    assert manager.remove_adapter(model_lora2.id)
304
    assert manager.lora_index_to_id[1] is None
305
306
307
308
    assert not manager.remove_adapter(model_lora2.id)
    assert manager.remove_adapter(model_lora1.id)
    assert not manager.remove_adapter(model_lora1.id)
    assert manager.add_adapter(model_lora1)
309
310
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] is None
311
312
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(3)
313
314
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] is None
315
    assert manager.activate_adapter(2)
316
317
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2
318
    assert manager.device == device
319
320
321
322
    assert (
        manager.punica_wrapper_mapping.get(DEFAULT_LANGUAGE_WRAPPER_KEY).device
        == device
    )
323
324
325
326
327
328
329
    assert hasattr(manager, "supported_lora_modules")
    assert sorted(manager.supported_lora_modules) == [
        "dense1",
        "dense2",
        "lm_head",
        "output",
    ]
330

331

332
@pytest.mark.parametrize("device", DEVICES)
333
334
335
def test_lora_lru_cache_model_manager(
    default_vllm_config, dist_init, dummy_model, device
):
336
    model = dummy_model
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
    model_lora1 = create_lora(
        1, model, ["layer1.dense1", "dense2", "lm_head"], device=device
    )
    model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"], device=device)
    model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"], device=device)
    manager = LRUCacheLoRAModelManager(
        model,
        2,
        2,
        2,
        LoRAConfig(
            max_lora_rank=8, max_cpu_loras=3, max_loras=2, lora_dtype=DEFAULT_DTYPE
        ),
        device=device,
    )
352
    assert all(x is None for x in manager.lora_index_to_id)
353
354
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
355
    assert manager.lora_index_to_id[0] == 1
356
357
358
359
    assert not manager.add_adapter(model_lora1)
    assert not manager.activate_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
360
361
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
362
363
364
    assert not manager.add_adapter(model_lora2)
    assert not manager.activate_adapter(2)
    assert manager.add_adapter(model_lora3)
365
366
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
367
    assert manager.activate_adapter(3)
368
369
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2
370
    assert manager.remove_adapter(model_lora2.id)
371
    assert manager.lora_index_to_id[1] is None
372
373
374
375
376
    assert not manager.remove_adapter(model_lora2.id)
    assert manager.remove_adapter(model_lora1.id)
    assert not manager.remove_adapter(model_lora1.id)
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
377
378
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
379
380
    assert manager.add_adapter(model_lora2)
    assert manager.deactivate_adapter(3)
381
382
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 1
383
    assert manager.activate_adapter(2)
384
385
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
386
    assert manager.activate_adapter(3)
387
388
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 3
389
    assert manager.pin_adapter(2)
390
391
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 3
392
    assert manager.activate_adapter(1)
393
394
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
395
    assert manager.deactivate_adapter(2)
396
397
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 1
398
    assert manager.activate_adapter(3)
399
400
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
401
402
    assert manager.pin_adapter(3)
    assert manager.pin_adapter(1)
403
    with pytest.raises(RuntimeError):
404
        assert manager.pin_adapter(2)
405
406
407
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
    with pytest.raises(RuntimeError):
408
        assert manager.activate_adapter(2)
409

410
411
    assert manager.deactivate_adapter(3)
    assert manager.pin_adapter(2)
412
413
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
414
    assert manager.remove_adapter(3)
415
    with pytest.raises(ValueError):
416
        assert manager.pin_adapter(3)
417
418
419
420
    assert (
        manager.punica_wrapper_mapping.get(DEFAULT_LANGUAGE_WRAPPER_KEY).device
        == device
    )
421
422
    assert manager.device == device

423

424
@pytest.mark.parametrize("device", DEVICES)
425
def test_lru_lora_model_manager(default_vllm_config, dist_init, dummy_model, device):
426
427
428
    # This tests just the LRU cache functionality, everything else is
    # tested in test_lora_model_manager
    model = dummy_model
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
    model_lora1 = create_lora(
        1, model, ["layer1.dense1", "dense2", "lm_head"], device=device
    )
    model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"], device=device)
    model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"], device=device)
    model_lora4 = create_lora(4, model, ["dense1", "dense2", "lm_head"], device=device)
    manager = LRUCacheLoRAModelManager(
        model,
        2,
        2,
        2,
        LoRAConfig(
            max_lora_rank=8, max_cpu_loras=2, max_loras=2, lora_dtype=DEFAULT_DTYPE
        ),
        device=device,
    )
445
446
447
    assert all(x is None for x in manager.lora_index_to_id)

    # Add up to capacity
448
449
450
451
    assert manager.add_adapter(model_lora1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(1)
    assert manager.activate_adapter(2)
452

453
    assert set(manager.list_adapters()) == {1, 2}
454
455
456
457
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2

    # Add over capacity
458
459
460
461
    assert manager.add_adapter(model_lora3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(3)
    assert manager.activate_adapter(4)
462

463
    assert set(manager.list_adapters()) == {3, 4}
464
465
466
467
468
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 4

    # Add 3 again to move it to the top and then add 2
    # should return false since it's in already
469
470
471
472
    assert not manager.add_adapter(model_lora3)
    assert not manager.activate_adapter(3)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
473

474
    assert set(manager.list_adapters()) == {3, 2}
475
476
477
478
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2

    # Remove manually
479
480
    assert manager.remove_adapter(3)
    assert not manager.remove_adapter(3)
481

482
    assert set(manager.list_adapters()) == {2}
483
484
485
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 2

486
487
488
489
    assert manager.add_adapter(model_lora3)
    assert manager.activate_adapter(3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(4)
490

491
    assert set(manager.list_adapters()) == {3, 4}
492
493
494
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 4

495
496
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == {4}
497
498
499
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 4

500
501
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == set()
502
503
    assert all(x is None for x in manager.lora_index_to_id)

504
505
    assert not manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == set()
506
507
    assert all(x is None for x in manager.lora_index_to_id)

508
    # pinning
509
510
511
512
513
    assert manager.add_adapter(model_lora3)
    assert manager.activate_adapter(3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(4)
    assert set(manager.list_adapters()) == {3, 4}
514
    with pytest.raises(ValueError):
515
516
        assert manager.pin_adapter(1)
    assert manager.pin_adapter(3)
517
    # Remove manually
518
519
    assert manager.remove_adapter(3)
    assert not manager.remove_adapter(3)
520

521
    assert set(manager.list_adapters()) == {4}
522
523
524
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 4

525
526
527
528
    assert manager.add_adapter(model_lora1)
    assert manager.pin_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
529

530
    assert set(manager.list_adapters()) == {1, 2}
531
532
533
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2

534
535
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == {1}
536
537
538
539
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] is None

    with pytest.raises(RuntimeError):
540
        assert manager.remove_oldest_adapter()
541

542
    assert set(manager.list_adapters()) == {1}
543
544
545
546
    assert (
        manager.punica_wrapper_mapping.get(DEFAULT_LANGUAGE_WRAPPER_KEY).device
        == device
    )
547
    assert manager.device == device
548

549

550
@pytest.mark.parametrize("device", DEVICES)
551
552
553
def test_lru_cache_worker_adapter_manager(
    default_vllm_config, dist_init, dummy_model, device, tmp_path
):
554
555
556
    lora_config = LoRAConfig(
        max_lora_rank=8, max_cpu_loras=4, max_loras=4, lora_dtype=DEFAULT_DTYPE
    )
557
558
559
560
561
562
563
564
565

    dummy_lora_files = f"{tmp_path}/lora_adapter"
    os.makedirs(dummy_lora_files, exist_ok=True)
    create_peft_lora(
        dummy_model,
        save_dir=dummy_lora_files,
        target_modules=["layer1.dense1", "dense2"],
        lora_dtype=DEFAULT_DTYPE,
    )
566
567

    model_config = ModelConfig(max_model_len=16)
568
    vllm_config = VllmConfig(model_config=model_config, lora_config=lora_config)
569
570
571

    vllm_config.scheduler_config.max_num_seqs = 4
    vllm_config.scheduler_config.max_num_batched_tokens = 2
572
    worker_adapter_manager = LRUCacheWorkerLoRAManager(
573
        vllm_config, device, EMBEDDING_MODULES
574
    )
575
576
577
578

    worker_adapter_manager.max_num_seqs = 4
    worker_adapter_manager.max_num_batched_tokens = 2

579
    worker_adapter_manager.create_lora_manager(dummy_model)
580
581

    mapping = LoRAMapping([], [])
582
583
584
585
    worker_adapter_manager.set_active_adapters(
        [LoRARequest("1", 1, dummy_lora_files), LoRARequest("2", 2, dummy_lora_files)],
        mapping,
    )
586
587
588
    assert worker_adapter_manager.list_adapters() == {1, 2}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
589

590
591
592
593
594
595
596
597
    worker_adapter_manager.set_active_adapters(
        [
            LoRARequest("1", 1, dummy_lora_files),
            LoRARequest("3", 3, dummy_lora_files),
            LoRARequest("4", 4, dummy_lora_files),
        ],
        mapping,
    )
598
599
600
601
602
    assert worker_adapter_manager.list_adapters() == {1, 2, 3, 4}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 3
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
603

604
605
606
607
608
609
610
611
    worker_adapter_manager.set_active_adapters(
        [
            LoRARequest("1", 1, dummy_lora_files),
            LoRARequest("2", 2, dummy_lora_files),
            LoRARequest("5", 5, dummy_lora_files),
        ],
        mapping,
    )
612
613
614
615
616
    assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
617

618
619
620
621
622
623
624
625
    worker_adapter_manager.set_active_adapters(
        [
            LoRARequest("1", 1, dummy_lora_files),
            LoRARequest("1", 1, dummy_lora_files),
            LoRARequest("1", 1, dummy_lora_files),
        ],
        mapping,
    )
626
627
628
629
630
    assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
631

632
633
634
635
636
637
638
639
    worker_adapter_manager.set_active_adapters(
        [
            LoRARequest("6", 6, dummy_lora_files),
            LoRARequest("7", 7, dummy_lora_files),
            LoRARequest("8", 8, dummy_lora_files),
        ],
        mapping,
    )
640
641
642
643
644
    assert worker_adapter_manager.list_adapters() == {1, 6, 7, 8}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 7
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 8
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 6
645
646
647

    # Over capacity
    with pytest.raises(RuntimeError):
648
649
650
651
652
653
654
655
656
657
        worker_adapter_manager.set_active_adapters(
            [
                LoRARequest("10", 10, dummy_lora_files),
                LoRARequest("11", 11, dummy_lora_files),
                LoRARequest("12", 12, dummy_lora_files),
                LoRARequest("13", 13, dummy_lora_files),
                LoRARequest("14", 14, dummy_lora_files),
            ],
            mapping,
        )
658

659
    assert worker_adapter_manager.device == device
660
661
662
663
    punica_wrapper = worker_adapter_manager._adapter_manager.punica_wrapper_mapping.get(
        DEFAULT_LANGUAGE_WRAPPER_KEY
    )
    assert punica_wrapper.device == device
664

665

666
@pytest.mark.parametrize("device", DEVICES)
667
668
669
def test_worker_adapter_manager(
    default_vllm_config, dist_init, dummy_model_gate_up, device, tmp_path
):
670
    # Should remove every LoRA not specified in the request.
671
672
673
    lora_config = LoRAConfig(
        max_lora_rank=8, max_cpu_loras=4, max_loras=4, lora_dtype=DEFAULT_DTYPE
    )
674
675

    model_config = ModelConfig(max_model_len=16)
676
    vllm_config = VllmConfig(model_config=model_config, lora_config=lora_config)
677
678
679
680

    vllm_config.scheduler_config.max_num_seqs = 4
    vllm_config.scheduler_config.max_num_batched_tokens = 2

681
    worker_adapter_manager = WorkerLoRAManager(vllm_config, device, EMBEDDING_MODULES)
682
    worker_adapter_manager.vocab_size = dummy_model_gate_up.unpadded_vocab_size
683
684
685
686
687
688
689
690
691
692
    worker_adapter_manager.create_lora_manager(dummy_model_gate_up)

    dummy_lora_files = f"{tmp_path}/lora_adapter"
    os.makedirs(dummy_lora_files, exist_ok=True)
    create_peft_lora(
        dummy_model_gate_up,
        save_dir=dummy_lora_files,
        target_modules=["layer1.dense1", "dense2"],
        lora_dtype=DEFAULT_DTYPE,
    )
693
694

    mapping = LoRAMapping([], [])
695
696
697
698
    worker_adapter_manager.set_active_adapters(
        [LoRARequest("1", 1, dummy_lora_files), LoRARequest("2", 2, dummy_lora_files)],
        mapping,
    )
699
700
701
    assert worker_adapter_manager.list_adapters() == {1, 2}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
702

703
704
705
706
707
708
709
710
    worker_adapter_manager.set_active_adapters(
        [
            LoRARequest("1", 1, dummy_lora_files),
            LoRARequest("3", 3, dummy_lora_files),
            LoRARequest("4", 4, dummy_lora_files),
        ],
        mapping,
    )
711
712
713
714
    assert worker_adapter_manager.list_adapters() == {1, 3, 4}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 3
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 4
715

716
717
718
719
720
721
722
723
    worker_adapter_manager.set_active_adapters(
        [
            LoRARequest("1", 1, dummy_lora_files),
            LoRARequest("2", 2, dummy_lora_files),
            LoRARequest("5", 5, dummy_lora_files),
        ],
        mapping,
    )
724
725
726
727
    assert worker_adapter_manager.list_adapters() == {1, 2, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
728

729
730
731
732
733
734
735
736
    worker_adapter_manager.set_active_adapters(
        [
            LoRARequest("1", 1, dummy_lora_files),
            LoRARequest("1", 1, dummy_lora_files),
            LoRARequest("1", 1, dummy_lora_files),
        ],
        mapping,
    )
737
738
739
740
    assert worker_adapter_manager.list_adapters() == {1}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] is None
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] is None
741

742
743
744
745
746
747
748
749
    worker_adapter_manager.set_active_adapters(
        [
            LoRARequest("6", 6, dummy_lora_files),
            LoRARequest("7", 7, dummy_lora_files),
            LoRARequest("8", 8, dummy_lora_files),
        ],
        mapping,
    )
750
751
752
753
    assert worker_adapter_manager.list_adapters() == {6, 7, 8}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 8
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 6
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 7
754
755
756

    # Over capacity
    with pytest.raises(RuntimeError):
757
758
759
760
761
762
763
764
765
766
        worker_adapter_manager.set_active_adapters(
            [
                LoRARequest("10", 10, dummy_lora_files),
                LoRARequest("11", 11, dummy_lora_files),
                LoRARequest("12", 12, dummy_lora_files),
                LoRARequest("13", 13, dummy_lora_files),
                LoRARequest("14", 14, dummy_lora_files),
            ],
            mapping,
        )
767

768
    assert worker_adapter_manager.device == device
769
770
771
772
    punica_wrapper = worker_adapter_manager._adapter_manager.punica_wrapper_mapping.get(
        DEFAULT_LANGUAGE_WRAPPER_KEY
    )
    assert punica_wrapper.device == device
773

774

775
@pytest.mark.parametrize("device", DEVICES)
776
def test_packed_loras(default_vllm_config, dist_init, dummy_model_gate_up, device):
777
778
779
780
781
    model = dummy_model_gate_up
    model_lora = create_packed_lora(
        1,
        model,
        module_name="gate_up_proj",
782
        replaced_module_names=["gate_proj", "up_proj"],
783
784
        device=device,
    )
785
786
787
788
789
    model_lora1 = create_packed_lora(
        2,
        model,
        module_name="gate_up_proj",
        replaced_module_names=["gate_proj", "up_proj"],
790
        device=device,
791
792
793
        empty_replaced_module_name="gate_proj",
    )

794
795
796
797
798
799
800
801
802
803
    manager = LoRAModelManager(
        model,
        2,
        2,
        2,
        LoRAConfig(
            max_lora_rank=8, max_cpu_loras=2, max_loras=2, lora_dtype=DEFAULT_DTYPE
        ),
        device=device,
    )
804
805
    model = manager.model

806
807
808
    assert isinstance(
        model.get_submodule("gate_up_proj"), MergedColumnParallelLinearWithLoRA
    )
809
810
811
    # Verify packed lora is correct
    model_lora_clone = model_lora.clone(1)
    model_lora_clone1 = model_lora1.clone(1)
812
813
    assert manager.add_adapter(model_lora)
    assert manager.add_adapter(model_lora1)
814

815
816
817
    assert model_lora.get_lora("gate_proj") is None
    assert model_lora.get_lora("up_proj") is None
    assert model_lora1.get_lora("up_proj") is None
818
819
820
    packed_lora = model_lora.get_lora("gate_up_proj")
    assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights)

821
822
823
824
825
826
827
828
829
830
831
832
    torch.testing.assert_close(
        packed_lora.lora_a[0], model_lora_clone.get_lora("gate_proj").lora_a
    )
    torch.testing.assert_close(
        packed_lora.lora_b[0], model_lora_clone.get_lora("gate_proj").lora_b
    )
    torch.testing.assert_close(
        packed_lora.lora_a[1], model_lora_clone.get_lora("up_proj").lora_a
    )
    torch.testing.assert_close(
        packed_lora.lora_b[1], model_lora_clone.get_lora("up_proj").lora_b
    )
833
834
835
836
837
838

    packed_lora1 = model_lora1.get_lora("gate_up_proj")
    assert packed_lora1 and isinstance(packed_lora1, PackedLoRALayerWeights)

    assert packed_lora1.lora_a[0] is None
    assert packed_lora1.lora_b[0] is None
839
840
841
842
843
844
    torch.testing.assert_close(
        packed_lora1.lora_a[1], model_lora_clone1.get_lora("up_proj").lora_a
    )
    torch.testing.assert_close(
        packed_lora1.lora_b[1], model_lora_clone1.get_lora("up_proj").lora_b
    )
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928


def _test_target_modules(
    model,
    target_modules: list[str] | None,
    device: str,
    expected_lora: list[tuple[str, type]],
    expected_no_lora: list[tuple[str, type]],
):
    """Create a LoRAModelManager and assert which modules have LoRA applied."""
    LoRAModelManager(
        model,
        2,
        2,
        2,
        LoRAConfig(
            max_lora_rank=8,
            max_cpu_loras=2,
            max_loras=2,
            lora_dtype=DEFAULT_DTYPE,
            target_modules=target_modules,
        ),
        device=device,
    )
    for module_path, lora_cls in expected_lora:
        assert isinstance(model.get_submodule(module_path), lora_cls)
    for module_path, lora_cls in expected_no_lora:
        assert not isinstance(model.get_submodule(module_path), lora_cls)


@pytest.mark.parametrize("device", DEVICES)
def test_target_modules_config(default_vllm_config, dist_init, dummy_model, device):
    """Test that target_modules config restricts which modules get LoRA applied."""
    _test_target_modules(
        dummy_model,
        ["dense1"],
        device,
        expected_lora=[
            ("dense1", ColumnParallelLinearWithLoRA),
            ("layer1.dense1", ColumnParallelLinearWithLoRA),
        ],
        expected_no_lora=[
            ("dense2", RowParallelLinearWithLoRA),
            ("layer1.dense2", RowParallelLinearWithLoRA),
        ],
    )


@pytest.mark.parametrize("device", DEVICES)
def test_target_modules_multiple(default_vllm_config, dist_init, dummy_model, device):
    """Test that multiple target_modules work correctly."""
    _test_target_modules(
        dummy_model,
        ["dense1", "dense2"],
        device,
        expected_lora=[
            ("dense1", ColumnParallelLinearWithLoRA),
            ("layer1.dense1", ColumnParallelLinearWithLoRA),
            ("dense2", RowParallelLinearWithLoRA),
            ("layer1.dense2", RowParallelLinearWithLoRA),
        ],
        expected_no_lora=[],
    )


@pytest.mark.parametrize("device", DEVICES)
def test_target_modules_none_uses_all(
    default_vllm_config, dist_init, dummy_model, device
):
    """Test that target_modules=None uses all supported modules."""
    _test_target_modules(
        dummy_model,
        None,
        device,
        expected_lora=[
            ("dense1", ColumnParallelLinearWithLoRA),
            ("layer1.dense1", ColumnParallelLinearWithLoRA),
            ("dense2", RowParallelLinearWithLoRA),
            ("layer1.dense2", RowParallelLinearWithLoRA),
        ],
        expected_no_lora=[],
    )


929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
@pytest.mark.parametrize("device", DEVICES)
def test_target_modules_match_packed_runtime_modules(
    default_vllm_config, dist_init, dummy_model_gate_up, device
):
    """Packed runtime modules should be selected by their adapter-visible names."""
    _test_target_modules(
        dummy_model_gate_up,
        ["gate_proj"],
        device,
        expected_lora=[("gate_up_proj", MergedColumnParallelLinearWithLoRA)],
        expected_no_lora=[
            ("dense1", ColumnParallelLinearWithLoRA),
            ("dense2", RowParallelLinearWithLoRA),
            ("layer1.dense1", ColumnParallelLinearWithLoRA),
            ("layer1.dense2", RowParallelLinearWithLoRA),
        ],
    )


948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
@pytest.mark.parametrize("device", DEVICES)
def test_load_adapter_warns_on_unsupported_modules(
    default_vllm_config, dist_init, dummy_model_gate_up, device, tmp_path
):
    """Test that _load_adapter warns when a LoRA adapter contains modules
    not in the model's supported LoRA target modules."""
    from unittest.mock import patch

    import vllm.lora.worker_manager as wm_module

    lora_config = LoRAConfig(
        max_lora_rank=8, max_cpu_loras=4, max_loras=4, lora_dtype=DEFAULT_DTYPE
    )

    dummy_lora_files = f"{tmp_path}/lora_adapter"
    os.makedirs(dummy_lora_files, exist_ok=True)
    create_peft_lora(
        dummy_model_gate_up,
        save_dir=dummy_lora_files,
        target_modules=["layer1.dense1", "dense2"],
        lora_dtype=DEFAULT_DTYPE,
    )

    model_config = ModelConfig(max_model_len=16)
    vllm_config = VllmConfig(model_config=model_config, lora_config=lora_config)
    vllm_config.scheduler_config.max_num_seqs = 4
    vllm_config.scheduler_config.max_num_batched_tokens = 2

    worker_manager = WorkerLoRAManager(vllm_config, device, EMBEDDING_MODULES)
    worker_manager.vocab_size = dummy_model_gate_up.unpadded_vocab_size
    worker_manager.create_lora_manager(dummy_model_gate_up)

    # Patch from_local_checkpoint to inject an unsupported module
    original_from_checkpoint = LoRAModel.from_local_checkpoint

    def patched_from_checkpoint(*args, **kwargs):
        lora = original_from_checkpoint(*args, **kwargs)
        lora.loras["unsupported_module"] = LoRALayerWeights(
            module_name="unsupported_module",
            rank=8,
            lora_alpha=16,
            lora_a=torch.randn(8, 10),
            lora_b=torch.randn(10, 8),
        )
        return lora

    lora_request = LoRARequest("test", 1, dummy_lora_files)
    with (
        patch.object(LoRAModel, "from_local_checkpoint", patched_from_checkpoint),
        patch.object(wm_module.logger, "warning_once") as mock_warning,
    ):
        worker_manager._load_adapter(lora_request)
        warning_args = mock_warning.call_args_list
        found = any("unsupported_module" in str(call) for call in warning_args)
        assert found, (
            f"Expected warning about 'unsupported_module', got: {warning_args}"
        )


@pytest.mark.parametrize("device", DEVICES)
def test_load_adapter_warns_on_target_modules_restriction(
    default_vllm_config, dist_init, dummy_model_gate_up, device, tmp_path
):
    """Test that _load_adapter warns when a LoRA adapter contains modules
    excluded by the deployment-time target_modules restriction."""
    from unittest.mock import patch

    import vllm.lora.worker_manager as wm_module

    # Restrict to only dense2 — adapter has dense1 which will be excluded
    lora_config = LoRAConfig(
        max_lora_rank=8,
        max_cpu_loras=4,
        max_loras=4,
        lora_dtype=DEFAULT_DTYPE,
        target_modules=["dense2"],
    )

    dummy_lora_files = f"{tmp_path}/lora_adapter"
    os.makedirs(dummy_lora_files, exist_ok=True)
    create_peft_lora(
        dummy_model_gate_up,
        save_dir=dummy_lora_files,
        target_modules=["layer1.dense1", "dense2"],
        lora_dtype=DEFAULT_DTYPE,
    )

    model_config = ModelConfig(max_model_len=16)
    vllm_config = VllmConfig(model_config=model_config, lora_config=lora_config)
    vllm_config.scheduler_config.max_num_seqs = 4
    vllm_config.scheduler_config.max_num_batched_tokens = 2

    worker_manager = WorkerLoRAManager(vllm_config, device, EMBEDDING_MODULES)
    worker_manager.vocab_size = dummy_model_gate_up.unpadded_vocab_size
    worker_manager.create_lora_manager(dummy_model_gate_up)

    lora_request = LoRARequest("test", 1, dummy_lora_files)
    with patch.object(wm_module.logger, "warning_once") as mock_warning:
        worker_manager._load_adapter(lora_request)
        warning_args = mock_warning.call_args_list
        # dense1 is supported by the model but excluded by target_modules
        found = any("target_modules" in str(call) for call in warning_args)
        assert found, (
            f"Expected warning about target_modules restriction, got: {warning_args}"
        )