"tools/vscode:/vscode.git/clone" did not exist on "379689d533642cfc1d3ab2cf4dc02f09a8318a5f"
test_lora_manager.py 25.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
import os

import pytest
import torch
from safetensors.torch import load_file
from torch import nn

11
from vllm.config import ModelConfig, VllmConfig
12
from vllm.config.lora import LoRAConfig
13
14
15
16
17
from vllm.lora.layers import (
    ColumnParallelLinearWithLoRA,
    MergedColumnParallelLinearWithLoRA,
    RowParallelLinearWithLoRA,
)
18
from vllm.lora.lora_model import LoRAModel
19
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
20
from vllm.lora.model_manager import (
21
    DEFAULT_LANGUAGE_WRAPPER_KEY,
22
23
24
25
    LoRAMapping,
    LoRAModelManager,
    LRUCacheLoRAModelManager,
)
26
from vllm.lora.peft_helper import PEFTHelper
27
from vllm.lora.request import LoRARequest
28
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager, WorkerLoRAManager
29
from vllm.platforms import current_platform
30

31
32
from .utils import create_peft_lora

Terry's avatar
Terry committed
33
34
35
36
37
38
EMBEDDING_MODULES = {
    "embed_tokens": "input_embeddings",
    "lm_head": "output_embeddings",
}


39
40
41
42
43
DEVICES = (
    [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
    if current_platform.is_cuda_alike()
    else ["cpu"]
)
44

45
46
DEFAULT_DTYPE = torch.get_default_dtype()

47

48
@pytest.mark.parametrize("device", DEVICES)
49
50
def test_from_lora_tensors(qwen3_lora_files, device):
    tensors = load_file(os.path.join(qwen3_lora_files, "adapter_model.safetensors"))
51

52
    peft_helper = PEFTHelper.from_local_dir(
53
        qwen3_lora_files, max_position_embeddings=4096
54
    )
Terry's avatar
Terry committed
55
56
57
    lora_model = LoRAModel.from_lora_tensors(
        1,
        tensors,
58
59
        peft_helper=peft_helper,
        device=device,
60
    )
61
62
63
    for module_name, lora in lora_model.loras.items():
        assert lora.module_name == module_name
        assert lora.rank == 8
64
        assert lora.lora_alpha == 32
65
66
        assert lora.lora_a is not None
        assert lora.lora_b is not None
67
68
        assert lora.lora_a.device == torch.device(device)
        assert lora.lora_b.device == torch.device(device)
69
70
71
        assert lora.lora_a.shape[0] == lora.lora_b.shape[1], (
            f"{lora.lora_a.shape=}, {lora.lora_b.shape=}"
        )
72
        assert lora.lora_a.shape[0] == 8
73
74


75
76
77
def create_lora(
    lora_id: int, model: nn.Module, sub_modules: list[str], device: torch.device
) -> LoRAModel:
78
    loras: dict[str, LoRALayerWeights] = {}
79
80
81
82
83
84
    for name in sub_modules:
        w = model.get_submodule(name).weight
        loras[name] = LoRALayerWeights(
            name,
            8,
            16,
85
86
            torch.rand([8, w.shape[1]], device=device),
            torch.rand([w.shape[0], 8], device=device),
87
88
89
90
91
92
93
94
95
        )
    return LoRAModel(lora_id, 8, loras)


def create_packed_lora(
    lora_id: int,
    model: nn.Module,
    module_name,
    replaced_module_names,
96
    device: torch.device,
97
98
99
    empty_replaced_module_name=None,
) -> LoRAModel:
    w = model.get_submodule(module_name).weight
100
    loras: dict[str, LoRALayerWeights] = {}
101
102
103
104
105
106
107
    for replaced_module_name in replaced_module_names:
        if replaced_module_name == empty_replaced_module_name:
            continue
        loras[replaced_module_name] = LoRALayerWeights(
            replaced_module_name,
            8,
            16,
108
            torch.rand([8, w.shape[1]], device=device),
109
            torch.rand([w.shape[0] // len(replaced_module_names), 8], device=device),
110
111
112
113
        )
    return LoRAModel(lora_id, 8, loras)


114
def test_replace_submodules(default_vllm_config, dist_init, dummy_model):
115
    model = dummy_model
Terry's avatar
Terry committed
116
    manager = LoRAModelManager(
117
118
119
120
121
122
123
124
125
        model,
        1,
        1,
        1,
        LoRAConfig(
            max_lora_rank=8, max_cpu_loras=8, max_loras=8, lora_dtype=DEFAULT_DTYPE
        ),
        torch.device(DEVICES[0]),
    )
126
    model = manager.model
127
128
129
130
    assert isinstance(model.get_submodule("dense1"), ColumnParallelLinearWithLoRA)
    assert isinstance(
        model.get_submodule("layer1.dense1"), ColumnParallelLinearWithLoRA
    )
131
    assert isinstance(model.get_submodule("dense2"), RowParallelLinearWithLoRA)
132
    assert isinstance(model.get_submodule("layer1.dense2"), RowParallelLinearWithLoRA)
133
134


135
@pytest.mark.parametrize("device", DEVICES)
136
def test_lora_model_manager(default_vllm_config, dist_init, dummy_model, device):
137
    model = dummy_model
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
    model_lora1 = create_lora(
        1, model, ["layer1.dense1", "dense2", "lm_head"], device=device
    )
    model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"], device=device)
    model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"], device=device)
    manager = LoRAModelManager(
        model,
        2,
        2,
        2,
        LoRAConfig(
            max_lora_rank=8, max_cpu_loras=3, max_loras=2, lora_dtype=DEFAULT_DTYPE
        ),
        device=device,
    )
153
    assert all(x is None for x in manager.lora_index_to_id)
154
155
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
156
    assert manager.lora_index_to_id[0] == 1
157
158
159
160
    assert not manager.add_adapter(model_lora1)
    assert not manager.activate_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
161
162
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
163
164
165
    assert not manager.add_adapter(model_lora2)
    assert not manager.activate_adapter(2)
    assert manager.add_adapter(model_lora3)
166
167
168
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
    with pytest.raises(ValueError):
169
        assert manager.activate_adapter(3)
170
171
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
172
    assert manager.remove_adapter(model_lora2.id)
173
    assert manager.lora_index_to_id[1] is None
174
175
176
177
    assert not manager.remove_adapter(model_lora2.id)
    assert manager.remove_adapter(model_lora1.id)
    assert not manager.remove_adapter(model_lora1.id)
    assert manager.add_adapter(model_lora1)
178
179
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] is None
180
181
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(3)
182
183
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] is None
184
    assert manager.activate_adapter(2)
185
186
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2
187
    assert manager.device == device
188
189
190
191
    assert (
        manager.punica_wrapper_mapping.get(DEFAULT_LANGUAGE_WRAPPER_KEY).device
        == device
    )
192
193
194
195
196
197
198
    assert hasattr(manager, "supported_lora_modules")
    assert sorted(manager.supported_lora_modules) == [
        "dense1",
        "dense2",
        "lm_head",
        "output",
    ]
199

200

201
@pytest.mark.parametrize("device", DEVICES)
202
203
204
def test_lora_lru_cache_model_manager(
    default_vllm_config, dist_init, dummy_model, device
):
205
    model = dummy_model
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
    model_lora1 = create_lora(
        1, model, ["layer1.dense1", "dense2", "lm_head"], device=device
    )
    model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"], device=device)
    model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"], device=device)
    manager = LRUCacheLoRAModelManager(
        model,
        2,
        2,
        2,
        LoRAConfig(
            max_lora_rank=8, max_cpu_loras=3, max_loras=2, lora_dtype=DEFAULT_DTYPE
        ),
        device=device,
    )
221
    assert all(x is None for x in manager.lora_index_to_id)
222
223
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
224
    assert manager.lora_index_to_id[0] == 1
225
226
227
228
    assert not manager.add_adapter(model_lora1)
    assert not manager.activate_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
229
230
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
231
232
233
    assert not manager.add_adapter(model_lora2)
    assert not manager.activate_adapter(2)
    assert manager.add_adapter(model_lora3)
234
235
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
236
    assert manager.activate_adapter(3)
237
238
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2
239
    assert manager.remove_adapter(model_lora2.id)
240
    assert manager.lora_index_to_id[1] is None
241
242
243
244
245
    assert not manager.remove_adapter(model_lora2.id)
    assert manager.remove_adapter(model_lora1.id)
    assert not manager.remove_adapter(model_lora1.id)
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
246
247
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
248
249
    assert manager.add_adapter(model_lora2)
    assert manager.deactivate_adapter(3)
250
251
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 1
252
    assert manager.activate_adapter(2)
253
254
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
255
    assert manager.activate_adapter(3)
256
257
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 3
258
    assert manager.pin_adapter(2)
259
260
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 3
261
    assert manager.activate_adapter(1)
262
263
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
264
    assert manager.deactivate_adapter(2)
265
266
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 1
267
    assert manager.activate_adapter(3)
268
269
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
270
271
    assert manager.pin_adapter(3)
    assert manager.pin_adapter(1)
272
    with pytest.raises(RuntimeError):
273
        assert manager.pin_adapter(2)
274
275
276
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
    with pytest.raises(RuntimeError):
277
        assert manager.activate_adapter(2)
278

279
280
    assert manager.deactivate_adapter(3)
    assert manager.pin_adapter(2)
281
282
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
283
    assert manager.remove_adapter(3)
284
    with pytest.raises(ValueError):
285
        assert manager.pin_adapter(3)
286
287
288
289
    assert (
        manager.punica_wrapper_mapping.get(DEFAULT_LANGUAGE_WRAPPER_KEY).device
        == device
    )
290
291
    assert manager.device == device

292

293
@pytest.mark.parametrize("device", DEVICES)
294
def test_lru_lora_model_manager(default_vllm_config, dist_init, dummy_model, device):
295
296
297
    # This tests just the LRU cache functionality, everything else is
    # tested in test_lora_model_manager
    model = dummy_model
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
    model_lora1 = create_lora(
        1, model, ["layer1.dense1", "dense2", "lm_head"], device=device
    )
    model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"], device=device)
    model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"], device=device)
    model_lora4 = create_lora(4, model, ["dense1", "dense2", "lm_head"], device=device)
    manager = LRUCacheLoRAModelManager(
        model,
        2,
        2,
        2,
        LoRAConfig(
            max_lora_rank=8, max_cpu_loras=2, max_loras=2, lora_dtype=DEFAULT_DTYPE
        ),
        device=device,
    )
314
315
316
    assert all(x is None for x in manager.lora_index_to_id)

    # Add up to capacity
317
318
319
320
    assert manager.add_adapter(model_lora1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(1)
    assert manager.activate_adapter(2)
321

322
    assert set(manager.list_adapters()) == {1, 2}
323
324
325
326
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2

    # Add over capacity
327
328
329
330
    assert manager.add_adapter(model_lora3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(3)
    assert manager.activate_adapter(4)
331

332
    assert set(manager.list_adapters()) == {3, 4}
333
334
335
336
337
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 4

    # Add 3 again to move it to the top and then add 2
    # should return false since it's in already
338
339
340
341
    assert not manager.add_adapter(model_lora3)
    assert not manager.activate_adapter(3)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
342

343
    assert set(manager.list_adapters()) == {3, 2}
344
345
346
347
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2

    # Remove manually
348
349
    assert manager.remove_adapter(3)
    assert not manager.remove_adapter(3)
350

351
    assert set(manager.list_adapters()) == {2}
352
353
354
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 2

355
356
357
358
    assert manager.add_adapter(model_lora3)
    assert manager.activate_adapter(3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(4)
359

360
    assert set(manager.list_adapters()) == {3, 4}
361
362
363
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 4

364
365
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == {4}
366
367
368
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 4

369
370
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == set()
371
372
    assert all(x is None for x in manager.lora_index_to_id)

373
374
    assert not manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == set()
375
376
    assert all(x is None for x in manager.lora_index_to_id)

377
    # pinning
378
379
380
381
382
    assert manager.add_adapter(model_lora3)
    assert manager.activate_adapter(3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(4)
    assert set(manager.list_adapters()) == {3, 4}
383
    with pytest.raises(ValueError):
384
385
        assert manager.pin_adapter(1)
    assert manager.pin_adapter(3)
386
    # Remove manually
387
388
    assert manager.remove_adapter(3)
    assert not manager.remove_adapter(3)
389

390
    assert set(manager.list_adapters()) == {4}
391
392
393
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 4

394
395
396
397
    assert manager.add_adapter(model_lora1)
    assert manager.pin_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
398

399
    assert set(manager.list_adapters()) == {1, 2}
400
401
402
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2

403
404
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == {1}
405
406
407
408
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] is None

    with pytest.raises(RuntimeError):
409
        assert manager.remove_oldest_adapter()
410

411
    assert set(manager.list_adapters()) == {1}
412
413
414
415
    assert (
        manager.punica_wrapper_mapping.get(DEFAULT_LANGUAGE_WRAPPER_KEY).device
        == device
    )
416
    assert manager.device == device
417

418

419
@pytest.mark.parametrize("device", DEVICES)
420
421
422
def test_lru_cache_worker_adapter_manager(
    default_vllm_config, dist_init, dummy_model, device, tmp_path
):
423
424
425
    lora_config = LoRAConfig(
        max_lora_rank=8, max_cpu_loras=4, max_loras=4, lora_dtype=DEFAULT_DTYPE
    )
426
427
428
429
430
431
432
433
434

    dummy_lora_files = f"{tmp_path}/lora_adapter"
    os.makedirs(dummy_lora_files, exist_ok=True)
    create_peft_lora(
        dummy_model,
        save_dir=dummy_lora_files,
        target_modules=["layer1.dense1", "dense2"],
        lora_dtype=DEFAULT_DTYPE,
    )
435
436

    model_config = ModelConfig(max_model_len=16)
437
    vllm_config = VllmConfig(model_config=model_config, lora_config=lora_config)
438
439
440

    vllm_config.scheduler_config.max_num_seqs = 4
    vllm_config.scheduler_config.max_num_batched_tokens = 2
441
    worker_adapter_manager = LRUCacheWorkerLoRAManager(
442
        vllm_config, device, EMBEDDING_MODULES
443
    )
444
445
446
447

    worker_adapter_manager.max_num_seqs = 4
    worker_adapter_manager.max_num_batched_tokens = 2

448
    worker_adapter_manager.create_lora_manager(dummy_model)
449
450

    mapping = LoRAMapping([], [])
451
452
453
454
    worker_adapter_manager.set_active_adapters(
        [LoRARequest("1", 1, dummy_lora_files), LoRARequest("2", 2, dummy_lora_files)],
        mapping,
    )
455
456
457
    assert worker_adapter_manager.list_adapters() == {1, 2}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
458

459
460
461
462
463
464
465
466
    worker_adapter_manager.set_active_adapters(
        [
            LoRARequest("1", 1, dummy_lora_files),
            LoRARequest("3", 3, dummy_lora_files),
            LoRARequest("4", 4, dummy_lora_files),
        ],
        mapping,
    )
467
468
469
470
471
    assert worker_adapter_manager.list_adapters() == {1, 2, 3, 4}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 3
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
472

473
474
475
476
477
478
479
480
    worker_adapter_manager.set_active_adapters(
        [
            LoRARequest("1", 1, dummy_lora_files),
            LoRARequest("2", 2, dummy_lora_files),
            LoRARequest("5", 5, dummy_lora_files),
        ],
        mapping,
    )
481
482
483
484
485
    assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
486

487
488
489
490
491
492
493
494
    worker_adapter_manager.set_active_adapters(
        [
            LoRARequest("1", 1, dummy_lora_files),
            LoRARequest("1", 1, dummy_lora_files),
            LoRARequest("1", 1, dummy_lora_files),
        ],
        mapping,
    )
495
496
497
498
499
    assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
500

501
502
503
504
505
506
507
508
    worker_adapter_manager.set_active_adapters(
        [
            LoRARequest("6", 6, dummy_lora_files),
            LoRARequest("7", 7, dummy_lora_files),
            LoRARequest("8", 8, dummy_lora_files),
        ],
        mapping,
    )
509
510
511
512
513
    assert worker_adapter_manager.list_adapters() == {1, 6, 7, 8}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 7
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 8
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 6
514
515
516

    # Over capacity
    with pytest.raises(RuntimeError):
517
518
519
520
521
522
523
524
525
526
        worker_adapter_manager.set_active_adapters(
            [
                LoRARequest("10", 10, dummy_lora_files),
                LoRARequest("11", 11, dummy_lora_files),
                LoRARequest("12", 12, dummy_lora_files),
                LoRARequest("13", 13, dummy_lora_files),
                LoRARequest("14", 14, dummy_lora_files),
            ],
            mapping,
        )
527

528
    assert worker_adapter_manager.device == device
529
530
531
532
    punica_wrapper = worker_adapter_manager._adapter_manager.punica_wrapper_mapping.get(
        DEFAULT_LANGUAGE_WRAPPER_KEY
    )
    assert punica_wrapper.device == device
533

534

535
@pytest.mark.parametrize("device", DEVICES)
536
537
538
def test_worker_adapter_manager(
    default_vllm_config, dist_init, dummy_model_gate_up, device, tmp_path
):
539
    # Should remove every LoRA not specified in the request.
540
541
542
    lora_config = LoRAConfig(
        max_lora_rank=8, max_cpu_loras=4, max_loras=4, lora_dtype=DEFAULT_DTYPE
    )
543
544

    model_config = ModelConfig(max_model_len=16)
545
    vllm_config = VllmConfig(model_config=model_config, lora_config=lora_config)
546
547
548
549

    vllm_config.scheduler_config.max_num_seqs = 4
    vllm_config.scheduler_config.max_num_batched_tokens = 2

550
    worker_adapter_manager = WorkerLoRAManager(vllm_config, device, EMBEDDING_MODULES)
551
    worker_adapter_manager.vocab_size = dummy_model_gate_up.unpadded_vocab_size
552
553
554
555
556
557
558
559
560
561
    worker_adapter_manager.create_lora_manager(dummy_model_gate_up)

    dummy_lora_files = f"{tmp_path}/lora_adapter"
    os.makedirs(dummy_lora_files, exist_ok=True)
    create_peft_lora(
        dummy_model_gate_up,
        save_dir=dummy_lora_files,
        target_modules=["layer1.dense1", "dense2"],
        lora_dtype=DEFAULT_DTYPE,
    )
562
563

    mapping = LoRAMapping([], [])
564
565
566
567
    worker_adapter_manager.set_active_adapters(
        [LoRARequest("1", 1, dummy_lora_files), LoRARequest("2", 2, dummy_lora_files)],
        mapping,
    )
568
569
570
    assert worker_adapter_manager.list_adapters() == {1, 2}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
571

572
573
574
575
576
577
578
579
    worker_adapter_manager.set_active_adapters(
        [
            LoRARequest("1", 1, dummy_lora_files),
            LoRARequest("3", 3, dummy_lora_files),
            LoRARequest("4", 4, dummy_lora_files),
        ],
        mapping,
    )
580
581
582
583
    assert worker_adapter_manager.list_adapters() == {1, 3, 4}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 3
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 4
584

585
586
587
588
589
590
591
592
    worker_adapter_manager.set_active_adapters(
        [
            LoRARequest("1", 1, dummy_lora_files),
            LoRARequest("2", 2, dummy_lora_files),
            LoRARequest("5", 5, dummy_lora_files),
        ],
        mapping,
    )
593
594
595
596
    assert worker_adapter_manager.list_adapters() == {1, 2, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
597

598
599
600
601
602
603
604
605
    worker_adapter_manager.set_active_adapters(
        [
            LoRARequest("1", 1, dummy_lora_files),
            LoRARequest("1", 1, dummy_lora_files),
            LoRARequest("1", 1, dummy_lora_files),
        ],
        mapping,
    )
606
607
608
609
    assert worker_adapter_manager.list_adapters() == {1}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] is None
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] is None
610

611
612
613
614
615
616
617
618
    worker_adapter_manager.set_active_adapters(
        [
            LoRARequest("6", 6, dummy_lora_files),
            LoRARequest("7", 7, dummy_lora_files),
            LoRARequest("8", 8, dummy_lora_files),
        ],
        mapping,
    )
619
620
621
622
    assert worker_adapter_manager.list_adapters() == {6, 7, 8}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 8
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 6
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 7
623
624
625

    # Over capacity
    with pytest.raises(RuntimeError):
626
627
628
629
630
631
632
633
634
635
        worker_adapter_manager.set_active_adapters(
            [
                LoRARequest("10", 10, dummy_lora_files),
                LoRARequest("11", 11, dummy_lora_files),
                LoRARequest("12", 12, dummy_lora_files),
                LoRARequest("13", 13, dummy_lora_files),
                LoRARequest("14", 14, dummy_lora_files),
            ],
            mapping,
        )
636

637
    assert worker_adapter_manager.device == device
638
639
640
641
    punica_wrapper = worker_adapter_manager._adapter_manager.punica_wrapper_mapping.get(
        DEFAULT_LANGUAGE_WRAPPER_KEY
    )
    assert punica_wrapper.device == device
642

643

644
@pytest.mark.parametrize("device", DEVICES)
645
def test_packed_loras(default_vllm_config, dist_init, dummy_model_gate_up, device):
646
647
648
649
650
    model = dummy_model_gate_up
    model_lora = create_packed_lora(
        1,
        model,
        module_name="gate_up_proj",
651
        replaced_module_names=["gate_proj", "up_proj"],
652
653
        device=device,
    )
654
655
656
657
658
    model_lora1 = create_packed_lora(
        2,
        model,
        module_name="gate_up_proj",
        replaced_module_names=["gate_proj", "up_proj"],
659
        device=device,
660
661
662
        empty_replaced_module_name="gate_proj",
    )

663
664
665
666
667
668
669
670
671
672
    manager = LoRAModelManager(
        model,
        2,
        2,
        2,
        LoRAConfig(
            max_lora_rank=8, max_cpu_loras=2, max_loras=2, lora_dtype=DEFAULT_DTYPE
        ),
        device=device,
    )
673
674
    model = manager.model

675
676
677
    assert isinstance(
        model.get_submodule("gate_up_proj"), MergedColumnParallelLinearWithLoRA
    )
678
679
680
    # Verify packed lora is correct
    model_lora_clone = model_lora.clone(1)
    model_lora_clone1 = model_lora1.clone(1)
681
682
    assert manager.add_adapter(model_lora)
    assert manager.add_adapter(model_lora1)
683

684
685
686
    assert model_lora.get_lora("gate_proj") is None
    assert model_lora.get_lora("up_proj") is None
    assert model_lora1.get_lora("up_proj") is None
687
688
689
    packed_lora = model_lora.get_lora("gate_up_proj")
    assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights)

690
691
692
693
694
695
696
697
698
699
700
701
    torch.testing.assert_close(
        packed_lora.lora_a[0], model_lora_clone.get_lora("gate_proj").lora_a
    )
    torch.testing.assert_close(
        packed_lora.lora_b[0], model_lora_clone.get_lora("gate_proj").lora_b
    )
    torch.testing.assert_close(
        packed_lora.lora_a[1], model_lora_clone.get_lora("up_proj").lora_a
    )
    torch.testing.assert_close(
        packed_lora.lora_b[1], model_lora_clone.get_lora("up_proj").lora_b
    )
702
703
704
705
706
707

    packed_lora1 = model_lora1.get_lora("gate_up_proj")
    assert packed_lora1 and isinstance(packed_lora1, PackedLoRALayerWeights)

    assert packed_lora1.lora_a[0] is None
    assert packed_lora1.lora_b[0] is None
708
709
710
711
712
713
    torch.testing.assert_close(
        packed_lora1.lora_a[1], model_lora_clone1.get_lora("up_proj").lora_a
    )
    torch.testing.assert_close(
        packed_lora1.lora_b[1], model_lora_clone1.get_lora("up_proj").lora_b
    )