test_lora_manager.py 27.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
import os

import pytest
import torch
from safetensors.torch import load_file
from torch import nn

11
from vllm.config import ModelConfig, VllmConfig
12
from vllm.config.lora import LoRAConfig
13
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
14
15
                              MergedColumnParallelLinearWithLoRA,
                              RowParallelLinearWithLoRA)
16
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
17
18
from vllm.lora.models import (LoRAMapping, LoRAModel, LoRAModelManager,
                              LRUCacheLoRAModelManager)
19
from vllm.lora.peft_helper import PEFTHelper
20
21
22
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager,
                                      WorkerLoRAManager)
23
from vllm.platforms import current_platform
24

25
26
from .utils import create_peft_lora

Terry's avatar
Terry committed
27
28
29
30
31
32
33
EMBEDDING_MODULES = {
    "embed_tokens": "input_embeddings",
    "lm_head": "output_embeddings",
}

EMBEDDING_PADDING_MODULES = ["lm_head"]

34
DEVICES = ([
35
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
36
] if current_platform.is_cuda_alike() else ["cpu"])
37

38
39
DEFAULT_DTYPE = torch.get_default_dtype()

40

41
@pytest.mark.parametrize("device", DEVICES)
42
def test_from_lora_tensors(sql_lora_files, device):
43
44
45
46
    tensors = load_file(
        os.path.join(sql_lora_files, "adapter_model.safetensors"))
    new_embeddings = load_file(
        os.path.join(sql_lora_files, "new_embeddings.safetensors"))
47

48
49
    peft_helper = PEFTHelper.from_local_dir(sql_lora_files,
                                            max_position_embeddings=4096)
Terry's avatar
Terry committed
50
51
52
    lora_model = LoRAModel.from_lora_tensors(
        1,
        tensors,
53
54
        peft_helper=peft_helper,
        device=device,
Terry's avatar
Terry committed
55
56
57
        embeddings=new_embeddings,
        embedding_modules=EMBEDDING_MODULES,
        embedding_padding_modules=EMBEDDING_PADDING_MODULES)
58
59
60
61
62
63
    for module_name, lora in lora_model.loras.items():
        assert lora.module_name == module_name
        assert lora.rank == 8
        assert lora.lora_alpha == 16
        assert lora.lora_a is not None
        assert lora.lora_b is not None
64
65
        assert lora.lora_a.device == torch.device(device)
        assert lora.lora_b.device == torch.device(device)
66
67
68
69
70
71
72
73
74
75
76
77
78
79
        assert (lora.lora_a.shape[1] == lora.lora_b.shape[0]
                ), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}"
        assert lora.lora_a.shape[1] == 8
        embeddings_module = next(
            (k for k in EMBEDDING_MODULES if k in module_name), None)
        if embeddings_module:
            assert torch.equal(
                lora.embeddings_tensor,
                new_embeddings[EMBEDDING_MODULES[embeddings_module]].to(
                    device=lora.embeddings_tensor.device))
        else:
            assert lora.embeddings_tensor is None


80
def create_lora(lora_id: int, model: nn.Module, sub_modules: list[str],
81
                device: torch.device) -> LoRAModel:
82
    loras: dict[str, LoRALayerWeights] = {}
83
84
85
86
87
88
    for name in sub_modules:
        w = model.get_submodule(name).weight
        loras[name] = LoRALayerWeights(
            name,
            8,
            16,
89
90
            torch.rand([w.shape[1], 8], device=device),
            torch.rand([8, w.shape[0]], device=device),
91
92
93
94
95
96
97
98
99
        )
    return LoRAModel(lora_id, 8, loras)


def create_packed_lora(
    lora_id: int,
    model: nn.Module,
    module_name,
    replaced_module_names,
100
    device: torch.device,
101
102
103
    empty_replaced_module_name=None,
) -> LoRAModel:
    w = model.get_submodule(module_name).weight
104
    loras: dict[str, LoRALayerWeights] = {}
105
106
107
108
109
110
111
    for replaced_module_name in replaced_module_names:
        if replaced_module_name == empty_replaced_module_name:
            continue
        loras[replaced_module_name] = LoRALayerWeights(
            replaced_module_name,
            8,
            16,
112
            torch.rand([w.shape[1], 8], device=device),
113
            torch.rand([8, w.shape[0] // len(replaced_module_names)],
114
                       device=device),
115
116
117
118
119
120
        )
    return LoRAModel(lora_id, 8, loras)


def test_replace_submodules(dist_init, dummy_model):
    model = dummy_model
Terry's avatar
Terry committed
121
122
    manager = LoRAModelManager(
        model, 1, 1, 1,
123
124
125
126
        LoRAConfig(max_lora_rank=8,
                   max_cpu_loras=8,
                   max_loras=8,
                   lora_dtype=DEFAULT_DTYPE), torch.device(DEVICES[0]))
127
128
129
130
131
    model = manager.model
    assert isinstance(model.get_submodule("dense1"),
                      ColumnParallelLinearWithLoRA)
    assert isinstance(model.get_submodule("layer1.dense1"),
                      ColumnParallelLinearWithLoRA)
132
    assert isinstance(model.get_submodule("dense2"), RowParallelLinearWithLoRA)
133
134
135
136
    assert isinstance(model.get_submodule("layer1.dense2"),
                      RowParallelLinearWithLoRA)


137
@pytest.mark.parametrize("device", DEVICES)
138
def test_lora_model_manager(dist_init, dummy_model, device):
139
    model = dummy_model
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
    model_lora1 = create_lora(1,
                              model, ["layer1.dense1", "dense2", "lm_head"],
                              device=device)
    model_lora2 = create_lora(2,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora3 = create_lora(3,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    manager = LoRAModelManager(model,
                               2,
                               2,
                               2,
                               LoRAConfig(max_lora_rank=8,
                                          max_cpu_loras=3,
155
156
                                          max_loras=2,
                                          lora_dtype=DEFAULT_DTYPE),
157
                               device=device)
158
    assert all(x is None for x in manager.lora_index_to_id)
159
160
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
161
    assert manager.lora_index_to_id[0] == 1
162
163
164
165
    assert not manager.add_adapter(model_lora1)
    assert not manager.activate_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
166
167
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
168
169
170
    assert not manager.add_adapter(model_lora2)
    assert not manager.activate_adapter(2)
    assert manager.add_adapter(model_lora3)
171
172
173
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
    with pytest.raises(ValueError):
174
        assert manager.activate_adapter(3)
175
176
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
177
    assert manager.remove_adapter(model_lora2.id)
178
    assert manager.lora_index_to_id[1] is None
179
180
181
182
    assert not manager.remove_adapter(model_lora2.id)
    assert manager.remove_adapter(model_lora1.id)
    assert not manager.remove_adapter(model_lora1.id)
    assert manager.add_adapter(model_lora1)
183
184
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] is None
185
186
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(3)
187
188
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] is None
189
    assert manager.activate_adapter(2)
190
191
192
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2

193
194
    assert manager.device == device
    assert manager.punica_wrapper.device == device
195
196
197
198
199
200
201
    assert hasattr(manager, "supported_lora_modules")
    assert sorted(manager.supported_lora_modules) == [
        "dense1",
        "dense2",
        "lm_head",
        "output",
    ]
202

203

204
@pytest.mark.parametrize("device", DEVICES)
205
def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
206
    model = dummy_model
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
    model_lora1 = create_lora(1,
                              model, ["layer1.dense1", "dense2", "lm_head"],
                              device=device)
    model_lora2 = create_lora(2,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora3 = create_lora(3,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    manager = LRUCacheLoRAModelManager(model,
                                       2,
                                       2,
                                       2,
                                       LoRAConfig(max_lora_rank=8,
                                                  max_cpu_loras=3,
222
223
                                                  max_loras=2,
                                                  lora_dtype=DEFAULT_DTYPE),
224
                                       device=device)
225
    assert all(x is None for x in manager.lora_index_to_id)
226
227
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
228
    assert manager.lora_index_to_id[0] == 1
229
230
231
232
    assert not manager.add_adapter(model_lora1)
    assert not manager.activate_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
233
234
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
235
236
237
    assert not manager.add_adapter(model_lora2)
    assert not manager.activate_adapter(2)
    assert manager.add_adapter(model_lora3)
238
239
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
240
    assert manager.activate_adapter(3)
241
242
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2
243
    assert manager.remove_adapter(model_lora2.id)
244
    assert manager.lora_index_to_id[1] is None
245
246
247
248
249
    assert not manager.remove_adapter(model_lora2.id)
    assert manager.remove_adapter(model_lora1.id)
    assert not manager.remove_adapter(model_lora1.id)
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
250
251
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
252
253
    assert manager.add_adapter(model_lora2)
    assert manager.deactivate_adapter(3)
254
255
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 1
256
    assert manager.activate_adapter(2)
257
258
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
259
    assert manager.activate_adapter(3)
260
261
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 3
262
    assert manager.pin_adapter(2)
263
264
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 3
265
    assert manager.activate_adapter(1)
266
267
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
268
    assert manager.deactivate_adapter(2)
269
270
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 1
271
    assert manager.activate_adapter(3)
272
273
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
274
275
    assert manager.pin_adapter(3)
    assert manager.pin_adapter(1)
276
    with pytest.raises(RuntimeError):
277
        assert manager.pin_adapter(2)
278
279
280
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
    with pytest.raises(RuntimeError):
281
        assert manager.activate_adapter(2)
282

283
284
    assert manager.deactivate_adapter(3)
    assert manager.pin_adapter(2)
285
286
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
287
    assert manager.remove_adapter(3)
288
    with pytest.raises(ValueError):
289
        assert manager.pin_adapter(3)
290

291
292
293
    assert manager.punica_wrapper.device == device
    assert manager.device == device

294

295
@pytest.mark.parametrize("device", DEVICES)
296
def test_lru_lora_model_manager(dist_init, dummy_model, device):
297
298
299
    # This tests just the LRU cache functionality, everything else is
    # tested in test_lora_model_manager
    model = dummy_model
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
    model_lora1 = create_lora(1,
                              model, ["layer1.dense1", "dense2", "lm_head"],
                              device=device)
    model_lora2 = create_lora(2,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora3 = create_lora(3,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora4 = create_lora(4,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    manager = LRUCacheLoRAModelManager(model,
                                       2,
                                       2,
                                       2,
                                       LoRAConfig(max_lora_rank=8,
                                                  max_cpu_loras=2,
318
319
                                                  max_loras=2,
                                                  lora_dtype=DEFAULT_DTYPE),
320
                                       device=device)
321
322
323
    assert all(x is None for x in manager.lora_index_to_id)

    # Add up to capacity
324
325
326
327
    assert manager.add_adapter(model_lora1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(1)
    assert manager.activate_adapter(2)
328

329
    assert set(manager.list_adapters()) == {1, 2}
330
331
332
333
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2

    # Add over capacity
334
335
336
337
    assert manager.add_adapter(model_lora3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(3)
    assert manager.activate_adapter(4)
338

339
    assert set(manager.list_adapters()) == {3, 4}
340
341
342
343
344
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 4

    # Add 3 again to move it to the top and then add 2
    # should return false since it's in already
345
346
347
348
    assert not manager.add_adapter(model_lora3)
    assert not manager.activate_adapter(3)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
349

350
    assert set(manager.list_adapters()) == {3, 2}
351
352
353
354
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2

    # Remove manually
355
356
    assert manager.remove_adapter(3)
    assert not manager.remove_adapter(3)
357

358
    assert set(manager.list_adapters()) == {2}
359
360
361
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 2

362
363
364
365
    assert manager.add_adapter(model_lora3)
    assert manager.activate_adapter(3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(4)
366

367
    assert set(manager.list_adapters()) == {3, 4}
368
369
370
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 4

371
372
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == {4}
373
374
375
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 4

376
377
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == set()
378
379
    assert all(x is None for x in manager.lora_index_to_id)

380
381
    assert not manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == set()
382
383
    assert all(x is None for x in manager.lora_index_to_id)

384
    # pinning
385
386
387
388
389
    assert manager.add_adapter(model_lora3)
    assert manager.activate_adapter(3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(4)
    assert set(manager.list_adapters()) == {3, 4}
390
    with pytest.raises(ValueError):
391
392
        assert manager.pin_adapter(1)
    assert manager.pin_adapter(3)
393
    # Remove manually
394
395
    assert manager.remove_adapter(3)
    assert not manager.remove_adapter(3)
396

397
    assert set(manager.list_adapters()) == {4}
398
399
400
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 4

401
402
403
404
    assert manager.add_adapter(model_lora1)
    assert manager.pin_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
405

406
    assert set(manager.list_adapters()) == {1, 2}
407
408
409
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2

410
411
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == {1}
412
413
414
415
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] is None

    with pytest.raises(RuntimeError):
416
        assert manager.remove_oldest_adapter()
417

418
    assert set(manager.list_adapters()) == {1}
419
420
    assert manager.punica_wrapper.device == device
    assert manager.device == device
421

422

423
@pytest.mark.parametrize("device", DEVICES)
424
425
def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device,
                                          tmp_path):
426
427
428
429
    lora_config = LoRAConfig(max_lora_rank=8,
                             max_cpu_loras=4,
                             max_loras=4,
                             lora_dtype=DEFAULT_DTYPE)
430
431
432
433
434
435
436
437
438

    dummy_lora_files = f"{tmp_path}/lora_adapter"
    os.makedirs(dummy_lora_files, exist_ok=True)
    create_peft_lora(
        dummy_model,
        save_dir=dummy_lora_files,
        target_modules=["layer1.dense1", "dense2"],
        lora_dtype=DEFAULT_DTYPE,
    )
439
440
441
442
443
444
445

    model_config = ModelConfig(max_model_len=16)
    vllm_config = VllmConfig(model_config=model_config,
                             lora_config=lora_config)

    vllm_config.scheduler_config.max_num_seqs = 4
    vllm_config.scheduler_config.max_num_batched_tokens = 2
446
    worker_adapter_manager = LRUCacheWorkerLoRAManager(
447
448
449
450
451
        vllm_config, device, EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)

    worker_adapter_manager.max_num_seqs = 4
    worker_adapter_manager.max_num_batched_tokens = 2

452
    worker_adapter_manager.create_lora_manager(dummy_model)
453
454

    mapping = LoRAMapping([], [])
455
    worker_adapter_manager.set_active_adapters([
456
457
        LoRARequest("1", 1, dummy_lora_files),
        LoRARequest("2", 2, dummy_lora_files)
458
    ], mapping)
459
460
461
    assert worker_adapter_manager.list_adapters() == {1, 2}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
462

463
    worker_adapter_manager.set_active_adapters([
464
465
466
        LoRARequest("1", 1, dummy_lora_files),
        LoRARequest("3", 3, dummy_lora_files),
        LoRARequest("4", 4, dummy_lora_files)
467
    ], mapping)
468
469
470
471
472
    assert worker_adapter_manager.list_adapters() == {1, 2, 3, 4}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 3
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
473

474
    worker_adapter_manager.set_active_adapters([
475
476
477
        LoRARequest("1", 1, dummy_lora_files),
        LoRARequest("2", 2, dummy_lora_files),
        LoRARequest("5", 5, dummy_lora_files)
478
    ], mapping)
479
480
481
482
483
    assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
484

485
    worker_adapter_manager.set_active_adapters([
486
487
488
        LoRARequest("1", 1, dummy_lora_files),
        LoRARequest("1", 1, dummy_lora_files),
        LoRARequest("1", 1, dummy_lora_files)
489
    ], mapping)
490
491
492
493
494
    assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
495

496
    worker_adapter_manager.set_active_adapters([
497
498
499
        LoRARequest("6", 6, dummy_lora_files),
        LoRARequest("7", 7, dummy_lora_files),
        LoRARequest("8", 8, dummy_lora_files)
500
    ], mapping)
501
502
503
504
505
    assert worker_adapter_manager.list_adapters() == {1, 6, 7, 8}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 7
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 8
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 6
506
507
508

    # Over capacity
    with pytest.raises(RuntimeError):
509
        worker_adapter_manager.set_active_adapters([
510
511
512
513
514
            LoRARequest("10", 10, dummy_lora_files),
            LoRARequest("11", 11, dummy_lora_files),
            LoRARequest("12", 12, dummy_lora_files),
            LoRARequest("13", 13, dummy_lora_files),
            LoRARequest("14", 14, dummy_lora_files)
515
516
        ], mapping)

517
518
519
    assert worker_adapter_manager.device == device
    assert (worker_adapter_manager._adapter_manager.punica_wrapper.device ==
            device)
520

521

522
@pytest.mark.parametrize("device", DEVICES)
523
524
def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device,
                                tmp_path):
525
    # Should remove every LoRA not specified in the request.
526
527
528
529
    lora_config = LoRAConfig(max_lora_rank=8,
                             max_cpu_loras=4,
                             max_loras=4,
                             lora_dtype=DEFAULT_DTYPE)
530
531
532
533
534
535
536
537
538
539
540
541
542
543

    model_config = ModelConfig(max_model_len=16)
    vllm_config = VllmConfig(model_config=model_config,
                             lora_config=lora_config)

    vllm_config.scheduler_config.max_num_seqs = 4
    vllm_config.scheduler_config.max_num_batched_tokens = 2

    worker_adapter_manager = WorkerLoRAManager(vllm_config, device,
                                               EMBEDDING_MODULES,
                                               EMBEDDING_PADDING_MODULES)
    worker_adapter_manager.vocab_size = (
        dummy_model_gate_up.unpadded_vocab_size -
        lora_config.lora_extra_vocab_size)
544
545
546
547
548
549
550
551
552
553
    worker_adapter_manager.create_lora_manager(dummy_model_gate_up)

    dummy_lora_files = f"{tmp_path}/lora_adapter"
    os.makedirs(dummy_lora_files, exist_ok=True)
    create_peft_lora(
        dummy_model_gate_up,
        save_dir=dummy_lora_files,
        target_modules=["layer1.dense1", "dense2"],
        lora_dtype=DEFAULT_DTYPE,
    )
554
555

    mapping = LoRAMapping([], [])
556
    worker_adapter_manager.set_active_adapters([
557
558
        LoRARequest("1", 1, dummy_lora_files),
        LoRARequest("2", 2, dummy_lora_files)
559
    ], mapping)
560
561
562
    assert worker_adapter_manager.list_adapters() == {1, 2}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
563

564
    worker_adapter_manager.set_active_adapters([
565
566
567
        LoRARequest("1", 1, dummy_lora_files),
        LoRARequest("3", 3, dummy_lora_files),
        LoRARequest("4", 4, dummy_lora_files)
568
    ], mapping)
569
570
571
572
    assert worker_adapter_manager.list_adapters() == {1, 3, 4}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 3
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 4
573

574
    worker_adapter_manager.set_active_adapters([
575
576
577
        LoRARequest("1", 1, dummy_lora_files),
        LoRARequest("2", 2, dummy_lora_files),
        LoRARequest("5", 5, dummy_lora_files)
578
    ], mapping)
579
580
581
582
    assert worker_adapter_manager.list_adapters() == {1, 2, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
583

584
    worker_adapter_manager.set_active_adapters([
585
586
587
        LoRARequest("1", 1, dummy_lora_files),
        LoRARequest("1", 1, dummy_lora_files),
        LoRARequest("1", 1, dummy_lora_files)
588
    ], mapping)
589
590
591
592
    assert worker_adapter_manager.list_adapters() == {1}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] is None
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] is None
593

594
    worker_adapter_manager.set_active_adapters([
595
596
597
        LoRARequest("6", 6, dummy_lora_files),
        LoRARequest("7", 7, dummy_lora_files),
        LoRARequest("8", 8, dummy_lora_files)
598
    ], mapping)
599
600
601
602
    assert worker_adapter_manager.list_adapters() == {6, 7, 8}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 8
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 6
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 7
603
604
605

    # Over capacity
    with pytest.raises(RuntimeError):
606
        worker_adapter_manager.set_active_adapters([
607
608
609
610
611
            LoRARequest("10", 10, dummy_lora_files),
            LoRARequest("11", 11, dummy_lora_files),
            LoRARequest("12", 12, dummy_lora_files),
            LoRARequest("13", 13, dummy_lora_files),
            LoRARequest("14", 14, dummy_lora_files)
612
613
        ], mapping)

614
615
616
617
    assert worker_adapter_manager.device == device
    assert (worker_adapter_manager._adapter_manager.punica_wrapper.device ==
            device)

618

619
@pytest.mark.parametrize("device", DEVICES)
620
def test_packed_loras(dist_init, dummy_model_gate_up, device):
621
622
623
624
625
    model = dummy_model_gate_up
    model_lora = create_packed_lora(
        1,
        model,
        module_name="gate_up_proj",
626
627
        replaced_module_names=["gate_proj", "up_proj"],
        device=device)
628
629
630
631
632
    model_lora1 = create_packed_lora(
        2,
        model,
        module_name="gate_up_proj",
        replaced_module_names=["gate_proj", "up_proj"],
633
        device=device,
634
635
636
        empty_replaced_module_name="gate_proj",
    )

637
638
639
640
641
642
    manager = LoRAModelManager(model,
                               2,
                               2,
                               2,
                               LoRAConfig(max_lora_rank=8,
                                          max_cpu_loras=2,
643
644
                                          max_loras=2,
                                          lora_dtype=DEFAULT_DTYPE),
645
                               device=device)
646
647
648
649
    model = manager.model

    assert isinstance(model.get_submodule("gate_up_proj"),
                      MergedColumnParallelLinearWithLoRA)
650
651
652
    # Verify packed lora is correct
    model_lora_clone = model_lora.clone(1)
    model_lora_clone1 = model_lora1.clone(1)
653
654
    assert manager.add_adapter(model_lora)
    assert manager.add_adapter(model_lora1)
655

656
657
658
    assert model_lora.get_lora("gate_proj") is None
    assert model_lora.get_lora("up_proj") is None
    assert model_lora1.get_lora("up_proj") is None
659
660
661
    packed_lora = model_lora.get_lora("gate_up_proj")
    assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights)

662
    torch.testing.assert_close(packed_lora.lora_a[0],
663
                               model_lora_clone.get_lora("gate_proj").lora_a)
664
    torch.testing.assert_close(packed_lora.lora_b[0],
665
                               model_lora_clone.get_lora("gate_proj").lora_b)
666
    torch.testing.assert_close(packed_lora.lora_a[1],
667
                               model_lora_clone.get_lora("up_proj").lora_a)
668
    torch.testing.assert_close(packed_lora.lora_b[1],
669
                               model_lora_clone.get_lora("up_proj").lora_b)
670
671
672
673
674
675

    packed_lora1 = model_lora1.get_lora("gate_up_proj")
    assert packed_lora1 and isinstance(packed_lora1, PackedLoRALayerWeights)

    assert packed_lora1.lora_a[0] is None
    assert packed_lora1.lora_b[0] is None
676
    torch.testing.assert_close(packed_lora1.lora_a[1],
677
                               model_lora_clone1.get_lora("up_proj").lora_a)
678
    torch.testing.assert_close(packed_lora1.lora_b[1],
679
                               model_lora_clone1.get_lora("up_proj").lora_b)