"vllm/vscode:/vscode.git/clone" did not exist on "d3cf61b89bc53aa7709932ab43e7630b9a71f2b3"
test_lora_manager.py 26.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
import os

import pytest
import torch
from safetensors.torch import load_file
from torch import nn

11
from vllm.config.lora import LoRAConfig
12
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
13
14
                              MergedColumnParallelLinearWithLoRA,
                              RowParallelLinearWithLoRA)
15
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
16
17
from vllm.lora.models import (LoRAMapping, LoRAModel, LoRAModelManager,
                              LRUCacheLoRAModelManager)
18
from vllm.lora.peft_helper import PEFTHelper
19
20
21
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager,
                                      WorkerLoRAManager)
22
from vllm.platforms import current_platform
23

24
25
from .utils import create_peft_lora

Terry's avatar
Terry committed
26
27
28
29
30
31
32
EMBEDDING_MODULES = {
    "embed_tokens": "input_embeddings",
    "lm_head": "output_embeddings",
}

EMBEDDING_PADDING_MODULES = ["lm_head"]

33
DEVICES = ([
34
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
35
] if current_platform.is_cuda_alike() else ["cpu"])
36

37
38
DEFAULT_DTYPE = torch.get_default_dtype()

39

40
@pytest.mark.parametrize("device", DEVICES)
41
def test_from_lora_tensors(sql_lora_files, device):
42
43
44
45
    tensors = load_file(
        os.path.join(sql_lora_files, "adapter_model.safetensors"))
    new_embeddings = load_file(
        os.path.join(sql_lora_files, "new_embeddings.safetensors"))
46

47
48
    peft_helper = PEFTHelper.from_local_dir(sql_lora_files,
                                            max_position_embeddings=4096)
Terry's avatar
Terry committed
49
50
51
    lora_model = LoRAModel.from_lora_tensors(
        1,
        tensors,
52
53
        peft_helper=peft_helper,
        device=device,
Terry's avatar
Terry committed
54
55
56
        embeddings=new_embeddings,
        embedding_modules=EMBEDDING_MODULES,
        embedding_padding_modules=EMBEDDING_PADDING_MODULES)
57
58
59
60
61
62
    for module_name, lora in lora_model.loras.items():
        assert lora.module_name == module_name
        assert lora.rank == 8
        assert lora.lora_alpha == 16
        assert lora.lora_a is not None
        assert lora.lora_b is not None
63
64
        assert lora.lora_a.device == torch.device(device)
        assert lora.lora_b.device == torch.device(device)
65
66
67
68
69
70
71
72
73
74
75
76
77
78
        assert (lora.lora_a.shape[1] == lora.lora_b.shape[0]
                ), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}"
        assert lora.lora_a.shape[1] == 8
        embeddings_module = next(
            (k for k in EMBEDDING_MODULES if k in module_name), None)
        if embeddings_module:
            assert torch.equal(
                lora.embeddings_tensor,
                new_embeddings[EMBEDDING_MODULES[embeddings_module]].to(
                    device=lora.embeddings_tensor.device))
        else:
            assert lora.embeddings_tensor is None


79
def create_lora(lora_id: int, model: nn.Module, sub_modules: list[str],
80
                device: torch.device) -> LoRAModel:
81
    loras: dict[str, LoRALayerWeights] = {}
82
83
84
85
86
87
    for name in sub_modules:
        w = model.get_submodule(name).weight
        loras[name] = LoRALayerWeights(
            name,
            8,
            16,
88
89
            torch.rand([w.shape[1], 8], device=device),
            torch.rand([8, w.shape[0]], device=device),
90
91
92
93
94
95
96
97
98
        )
    return LoRAModel(lora_id, 8, loras)


def create_packed_lora(
    lora_id: int,
    model: nn.Module,
    module_name,
    replaced_module_names,
99
    device: torch.device,
100
101
102
    empty_replaced_module_name=None,
) -> LoRAModel:
    w = model.get_submodule(module_name).weight
103
    loras: dict[str, LoRALayerWeights] = {}
104
105
106
107
108
109
110
    for replaced_module_name in replaced_module_names:
        if replaced_module_name == empty_replaced_module_name:
            continue
        loras[replaced_module_name] = LoRALayerWeights(
            replaced_module_name,
            8,
            16,
111
            torch.rand([w.shape[1], 8], device=device),
112
            torch.rand([8, w.shape[0] // len(replaced_module_names)],
113
                       device=device),
114
115
116
117
118
119
        )
    return LoRAModel(lora_id, 8, loras)


def test_replace_submodules(dist_init, dummy_model):
    model = dummy_model
Terry's avatar
Terry committed
120
121
    manager = LoRAModelManager(
        model, 1, 1, 1,
122
123
124
125
        LoRAConfig(max_lora_rank=8,
                   max_cpu_loras=8,
                   max_loras=8,
                   lora_dtype=DEFAULT_DTYPE), torch.device(DEVICES[0]))
126
127
128
129
130
    model = manager.model
    assert isinstance(model.get_submodule("dense1"),
                      ColumnParallelLinearWithLoRA)
    assert isinstance(model.get_submodule("layer1.dense1"),
                      ColumnParallelLinearWithLoRA)
131
    assert isinstance(model.get_submodule("dense2"), RowParallelLinearWithLoRA)
132
133
134
135
    assert isinstance(model.get_submodule("layer1.dense2"),
                      RowParallelLinearWithLoRA)


136
@pytest.mark.parametrize("device", DEVICES)
137
def test_lora_model_manager(dist_init, dummy_model, device):
138
    model = dummy_model
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
    model_lora1 = create_lora(1,
                              model, ["layer1.dense1", "dense2", "lm_head"],
                              device=device)
    model_lora2 = create_lora(2,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora3 = create_lora(3,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    manager = LoRAModelManager(model,
                               2,
                               2,
                               2,
                               LoRAConfig(max_lora_rank=8,
                                          max_cpu_loras=3,
154
155
                                          max_loras=2,
                                          lora_dtype=DEFAULT_DTYPE),
156
                               device=device)
157
    assert all(x is None for x in manager.lora_index_to_id)
158
159
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
160
    assert manager.lora_index_to_id[0] == 1
161
162
163
164
    assert not manager.add_adapter(model_lora1)
    assert not manager.activate_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
165
166
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
167
168
169
    assert not manager.add_adapter(model_lora2)
    assert not manager.activate_adapter(2)
    assert manager.add_adapter(model_lora3)
170
171
172
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
    with pytest.raises(ValueError):
173
        assert manager.activate_adapter(3)
174
175
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
176
    assert manager.remove_adapter(model_lora2.id)
177
    assert manager.lora_index_to_id[1] is None
178
179
180
181
    assert not manager.remove_adapter(model_lora2.id)
    assert manager.remove_adapter(model_lora1.id)
    assert not manager.remove_adapter(model_lora1.id)
    assert manager.add_adapter(model_lora1)
182
183
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] is None
184
185
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(3)
186
187
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] is None
188
    assert manager.activate_adapter(2)
189
190
191
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2

192
193
    assert manager.device == device
    assert manager.punica_wrapper.device == device
194
195
196
197
198
199
200
    assert hasattr(manager, "supported_lora_modules")
    assert sorted(manager.supported_lora_modules) == [
        "dense1",
        "dense2",
        "lm_head",
        "output",
    ]
201

202

203
@pytest.mark.parametrize("device", DEVICES)
204
def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
205
    model = dummy_model
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
    model_lora1 = create_lora(1,
                              model, ["layer1.dense1", "dense2", "lm_head"],
                              device=device)
    model_lora2 = create_lora(2,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora3 = create_lora(3,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    manager = LRUCacheLoRAModelManager(model,
                                       2,
                                       2,
                                       2,
                                       LoRAConfig(max_lora_rank=8,
                                                  max_cpu_loras=3,
221
222
                                                  max_loras=2,
                                                  lora_dtype=DEFAULT_DTYPE),
223
                                       device=device)
224
    assert all(x is None for x in manager.lora_index_to_id)
225
226
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
227
    assert manager.lora_index_to_id[0] == 1
228
229
230
231
    assert not manager.add_adapter(model_lora1)
    assert not manager.activate_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
232
233
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
234
235
236
    assert not manager.add_adapter(model_lora2)
    assert not manager.activate_adapter(2)
    assert manager.add_adapter(model_lora3)
237
238
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
239
    assert manager.activate_adapter(3)
240
241
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2
242
    assert manager.remove_adapter(model_lora2.id)
243
    assert manager.lora_index_to_id[1] is None
244
245
246
247
248
    assert not manager.remove_adapter(model_lora2.id)
    assert manager.remove_adapter(model_lora1.id)
    assert not manager.remove_adapter(model_lora1.id)
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
249
250
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
251
252
    assert manager.add_adapter(model_lora2)
    assert manager.deactivate_adapter(3)
253
254
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 1
255
    assert manager.activate_adapter(2)
256
257
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
258
    assert manager.activate_adapter(3)
259
260
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 3
261
    assert manager.pin_adapter(2)
262
263
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 3
264
    assert manager.activate_adapter(1)
265
266
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
267
    assert manager.deactivate_adapter(2)
268
269
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 1
270
    assert manager.activate_adapter(3)
271
272
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
273
274
    assert manager.pin_adapter(3)
    assert manager.pin_adapter(1)
275
    with pytest.raises(RuntimeError):
276
        assert manager.pin_adapter(2)
277
278
279
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
    with pytest.raises(RuntimeError):
280
        assert manager.activate_adapter(2)
281

282
283
    assert manager.deactivate_adapter(3)
    assert manager.pin_adapter(2)
284
285
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
286
    assert manager.remove_adapter(3)
287
    with pytest.raises(ValueError):
288
        assert manager.pin_adapter(3)
289

290
291
292
    assert manager.punica_wrapper.device == device
    assert manager.device == device

293

294
@pytest.mark.parametrize("device", DEVICES)
295
def test_lru_lora_model_manager(dist_init, dummy_model, device):
296
297
298
    # This tests just the LRU cache functionality, everything else is
    # tested in test_lora_model_manager
    model = dummy_model
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
    model_lora1 = create_lora(1,
                              model, ["layer1.dense1", "dense2", "lm_head"],
                              device=device)
    model_lora2 = create_lora(2,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora3 = create_lora(3,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora4 = create_lora(4,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    manager = LRUCacheLoRAModelManager(model,
                                       2,
                                       2,
                                       2,
                                       LoRAConfig(max_lora_rank=8,
                                                  max_cpu_loras=2,
317
318
                                                  max_loras=2,
                                                  lora_dtype=DEFAULT_DTYPE),
319
                                       device=device)
320
321
322
    assert all(x is None for x in manager.lora_index_to_id)

    # Add up to capacity
323
324
325
326
    assert manager.add_adapter(model_lora1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(1)
    assert manager.activate_adapter(2)
327

328
    assert set(manager.list_adapters()) == {1, 2}
329
330
331
332
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2

    # Add over capacity
333
334
335
336
    assert manager.add_adapter(model_lora3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(3)
    assert manager.activate_adapter(4)
337

338
    assert set(manager.list_adapters()) == {3, 4}
339
340
341
342
343
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 4

    # Add 3 again to move it to the top and then add 2
    # should return false since it's in already
344
345
346
347
    assert not manager.add_adapter(model_lora3)
    assert not manager.activate_adapter(3)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
348

349
    assert set(manager.list_adapters()) == {3, 2}
350
351
352
353
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2

    # Remove manually
354
355
    assert manager.remove_adapter(3)
    assert not manager.remove_adapter(3)
356

357
    assert set(manager.list_adapters()) == {2}
358
359
360
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 2

361
362
363
364
    assert manager.add_adapter(model_lora3)
    assert manager.activate_adapter(3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(4)
365

366
    assert set(manager.list_adapters()) == {3, 4}
367
368
369
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 4

370
371
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == {4}
372
373
374
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 4

375
376
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == set()
377
378
    assert all(x is None for x in manager.lora_index_to_id)

379
380
    assert not manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == set()
381
382
    assert all(x is None for x in manager.lora_index_to_id)

383
    # pinning
384
385
386
387
388
    assert manager.add_adapter(model_lora3)
    assert manager.activate_adapter(3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(4)
    assert set(manager.list_adapters()) == {3, 4}
389
    with pytest.raises(ValueError):
390
391
        assert manager.pin_adapter(1)
    assert manager.pin_adapter(3)
392
    # Remove manually
393
394
    assert manager.remove_adapter(3)
    assert not manager.remove_adapter(3)
395

396
    assert set(manager.list_adapters()) == {4}
397
398
399
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 4

400
401
402
403
    assert manager.add_adapter(model_lora1)
    assert manager.pin_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
404

405
    assert set(manager.list_adapters()) == {1, 2}
406
407
408
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2

409
410
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == {1}
411
412
413
414
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] is None

    with pytest.raises(RuntimeError):
415
        assert manager.remove_oldest_adapter()
416

417
    assert set(manager.list_adapters()) == {1}
418
419
    assert manager.punica_wrapper.device == device
    assert manager.device == device
420

421

422
@pytest.mark.parametrize("device", DEVICES)
423
424
def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device,
                                          tmp_path):
425
426
427
428
    lora_config = LoRAConfig(max_lora_rank=8,
                             max_cpu_loras=4,
                             max_loras=4,
                             lora_dtype=DEFAULT_DTYPE)
429
430
431
432
433
434
435
436
437

    dummy_lora_files = f"{tmp_path}/lora_adapter"
    os.makedirs(dummy_lora_files, exist_ok=True)
    create_peft_lora(
        dummy_model,
        save_dir=dummy_lora_files,
        target_modules=["layer1.dense1", "dense2"],
        lora_dtype=DEFAULT_DTYPE,
    )
438
    worker_adapter_manager = LRUCacheWorkerLoRAManager(
439
440
441
442
        4, 2,
        dummy_model.unpadded_vocab_size - lora_config.lora_extra_vocab_size,
        lora_config, device, EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
    worker_adapter_manager.create_lora_manager(dummy_model)
443
444

    mapping = LoRAMapping([], [])
445
    worker_adapter_manager.set_active_adapters([
446
447
        LoRARequest("1", 1, dummy_lora_files),
        LoRARequest("2", 2, dummy_lora_files)
448
    ], mapping)
449
450
451
    assert worker_adapter_manager.list_adapters() == {1, 2}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
452

453
    worker_adapter_manager.set_active_adapters([
454
455
456
        LoRARequest("1", 1, dummy_lora_files),
        LoRARequest("3", 3, dummy_lora_files),
        LoRARequest("4", 4, dummy_lora_files)
457
    ], mapping)
458
459
460
461
462
    assert worker_adapter_manager.list_adapters() == {1, 2, 3, 4}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 3
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
463

464
    worker_adapter_manager.set_active_adapters([
465
466
467
        LoRARequest("1", 1, dummy_lora_files),
        LoRARequest("2", 2, dummy_lora_files),
        LoRARequest("5", 5, dummy_lora_files)
468
    ], mapping)
469
470
471
472
473
    assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
474

475
    worker_adapter_manager.set_active_adapters([
476
477
478
        LoRARequest("1", 1, dummy_lora_files),
        LoRARequest("1", 1, dummy_lora_files),
        LoRARequest("1", 1, dummy_lora_files)
479
    ], mapping)
480
481
482
483
484
    assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
485

486
    worker_adapter_manager.set_active_adapters([
487
488
489
        LoRARequest("6", 6, dummy_lora_files),
        LoRARequest("7", 7, dummy_lora_files),
        LoRARequest("8", 8, dummy_lora_files)
490
    ], mapping)
491
492
493
494
495
    assert worker_adapter_manager.list_adapters() == {1, 6, 7, 8}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 7
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 8
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 6
496
497
498

    # Over capacity
    with pytest.raises(RuntimeError):
499
        worker_adapter_manager.set_active_adapters([
500
501
502
503
504
            LoRARequest("10", 10, dummy_lora_files),
            LoRARequest("11", 11, dummy_lora_files),
            LoRARequest("12", 12, dummy_lora_files),
            LoRARequest("13", 13, dummy_lora_files),
            LoRARequest("14", 14, dummy_lora_files)
505
506
        ], mapping)

507
508
509
    assert worker_adapter_manager.device == device
    assert (worker_adapter_manager._adapter_manager.punica_wrapper.device ==
            device)
510

511

512
@pytest.mark.parametrize("device", DEVICES)
513
514
def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device,
                                tmp_path):
515
    # Should remove every LoRA not specified in the request.
516
517
518
519
    lora_config = LoRAConfig(max_lora_rank=8,
                             max_cpu_loras=4,
                             max_loras=4,
                             lora_dtype=DEFAULT_DTYPE)
520
    worker_adapter_manager = WorkerLoRAManager(
521
        4, 2, dummy_model_gate_up.unpadded_vocab_size -
522
        lora_config.lora_extra_vocab_size, lora_config, device,
Terry's avatar
Terry committed
523
        EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
524
525
526
527
528
529
530
531
532
533
    worker_adapter_manager.create_lora_manager(dummy_model_gate_up)

    dummy_lora_files = f"{tmp_path}/lora_adapter"
    os.makedirs(dummy_lora_files, exist_ok=True)
    create_peft_lora(
        dummy_model_gate_up,
        save_dir=dummy_lora_files,
        target_modules=["layer1.dense1", "dense2"],
        lora_dtype=DEFAULT_DTYPE,
    )
534
535

    mapping = LoRAMapping([], [])
536
    worker_adapter_manager.set_active_adapters([
537
538
        LoRARequest("1", 1, dummy_lora_files),
        LoRARequest("2", 2, dummy_lora_files)
539
    ], mapping)
540
541
542
    assert worker_adapter_manager.list_adapters() == {1, 2}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
543

544
    worker_adapter_manager.set_active_adapters([
545
546
547
        LoRARequest("1", 1, dummy_lora_files),
        LoRARequest("3", 3, dummy_lora_files),
        LoRARequest("4", 4, dummy_lora_files)
548
    ], mapping)
549
550
551
552
    assert worker_adapter_manager.list_adapters() == {1, 3, 4}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 3
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 4
553

554
    worker_adapter_manager.set_active_adapters([
555
556
557
        LoRARequest("1", 1, dummy_lora_files),
        LoRARequest("2", 2, dummy_lora_files),
        LoRARequest("5", 5, dummy_lora_files)
558
    ], mapping)
559
560
561
562
    assert worker_adapter_manager.list_adapters() == {1, 2, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
563

564
    worker_adapter_manager.set_active_adapters([
565
566
567
        LoRARequest("1", 1, dummy_lora_files),
        LoRARequest("1", 1, dummy_lora_files),
        LoRARequest("1", 1, dummy_lora_files)
568
    ], mapping)
569
570
571
572
    assert worker_adapter_manager.list_adapters() == {1}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] is None
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] is None
573

574
    worker_adapter_manager.set_active_adapters([
575
576
577
        LoRARequest("6", 6, dummy_lora_files),
        LoRARequest("7", 7, dummy_lora_files),
        LoRARequest("8", 8, dummy_lora_files)
578
    ], mapping)
579
580
581
582
    assert worker_adapter_manager.list_adapters() == {6, 7, 8}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 8
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 6
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 7
583
584
585

    # Over capacity
    with pytest.raises(RuntimeError):
586
        worker_adapter_manager.set_active_adapters([
587
588
589
590
591
            LoRARequest("10", 10, dummy_lora_files),
            LoRARequest("11", 11, dummy_lora_files),
            LoRARequest("12", 12, dummy_lora_files),
            LoRARequest("13", 13, dummy_lora_files),
            LoRARequest("14", 14, dummy_lora_files)
592
593
        ], mapping)

594
595
596
597
    assert worker_adapter_manager.device == device
    assert (worker_adapter_manager._adapter_manager.punica_wrapper.device ==
            device)

598

599
@pytest.mark.parametrize("device", DEVICES)
600
def test_packed_loras(dist_init, dummy_model_gate_up, device):
601
602
603
604
605
    model = dummy_model_gate_up
    model_lora = create_packed_lora(
        1,
        model,
        module_name="gate_up_proj",
606
607
        replaced_module_names=["gate_proj", "up_proj"],
        device=device)
608
609
610
611
612
    model_lora1 = create_packed_lora(
        2,
        model,
        module_name="gate_up_proj",
        replaced_module_names=["gate_proj", "up_proj"],
613
        device=device,
614
615
616
        empty_replaced_module_name="gate_proj",
    )

617
618
619
620
621
622
    manager = LoRAModelManager(model,
                               2,
                               2,
                               2,
                               LoRAConfig(max_lora_rank=8,
                                          max_cpu_loras=2,
623
624
                                          max_loras=2,
                                          lora_dtype=DEFAULT_DTYPE),
625
                               device=device)
626
627
628
629
    model = manager.model

    assert isinstance(model.get_submodule("gate_up_proj"),
                      MergedColumnParallelLinearWithLoRA)
630
631
632
    # Verify packed lora is correct
    model_lora_clone = model_lora.clone(1)
    model_lora_clone1 = model_lora1.clone(1)
633
634
    assert manager.add_adapter(model_lora)
    assert manager.add_adapter(model_lora1)
635

636
637
638
    assert model_lora.get_lora("gate_proj") is None
    assert model_lora.get_lora("up_proj") is None
    assert model_lora1.get_lora("up_proj") is None
639
640
641
    packed_lora = model_lora.get_lora("gate_up_proj")
    assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights)

642
    torch.testing.assert_close(packed_lora.lora_a[0],
643
                               model_lora_clone.get_lora("gate_proj").lora_a)
644
    torch.testing.assert_close(packed_lora.lora_b[0],
645
                               model_lora_clone.get_lora("gate_proj").lora_b)
646
    torch.testing.assert_close(packed_lora.lora_a[1],
647
                               model_lora_clone.get_lora("up_proj").lora_a)
648
    torch.testing.assert_close(packed_lora.lora_b[1],
649
                               model_lora_clone.get_lora("up_proj").lora_b)
650
651
652
653
654
655

    packed_lora1 = model_lora1.get_lora("gate_up_proj")
    assert packed_lora1 and isinstance(packed_lora1, PackedLoRALayerWeights)

    assert packed_lora1.lora_a[0] is None
    assert packed_lora1.lora_b[0] is None
656
    torch.testing.assert_close(packed_lora1.lora_a[1],
657
                               model_lora_clone1.get_lora("up_proj").lora_a)
658
    torch.testing.assert_close(packed_lora1.lora_b[1],
659
                               model_lora_clone1.get_lora("up_proj").lora_b)