test_lora_manager.py 25.6 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
9
10
11
import os

import pytest
import torch
from safetensors.torch import load_file
from torch import nn

from vllm.config import LoRAConfig
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
12
13
                              MergedColumnParallelLinearWithLoRA,
                              RowParallelLinearWithLoRA)
14
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
15
16
from vllm.lora.models import (LoRAMapping, LoRAModel, LoRAModelManager,
                              LRUCacheLoRAModelManager)
17
from vllm.lora.peft_helper import PEFTHelper
18
19
20
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager,
                                      WorkerLoRAManager)
21
from vllm.platforms import current_platform
22

Terry's avatar
Terry committed
23
24
25
26
27
28
29
EMBEDDING_MODULES = {
    "embed_tokens": "input_embeddings",
    "lm_head": "output_embeddings",
}

EMBEDDING_PADDING_MODULES = ["lm_head"]

30
DEVICES = ([
31
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
32
] if current_platform.is_cuda_alike() else ["cpu"])
33

34

35
36
37
38
39
40
41
42
43
44
45
@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch: pytest.MonkeyPatch):
    """
    Some tests depend on V0 internals. Since both V0 and V1 use the same
    LoRAModelManager it is okay to just test V0.
    """
    with monkeypatch.context() as m:
        m.setenv('VLLM_USE_V1', '0')
        yield


46
@pytest.mark.parametrize("device", DEVICES)
47
def test_from_lora_tensors(sql_lora_files, device):
48
49
50
51
    tensors = load_file(
        os.path.join(sql_lora_files, "adapter_model.safetensors"))
    new_embeddings = load_file(
        os.path.join(sql_lora_files, "new_embeddings.safetensors"))
52

53
54
    peft_helper = PEFTHelper.from_local_dir(sql_lora_files,
                                            max_position_embeddings=4096)
Terry's avatar
Terry committed
55
56
57
    lora_model = LoRAModel.from_lora_tensors(
        1,
        tensors,
58
59
        peft_helper=peft_helper,
        device=device,
Terry's avatar
Terry committed
60
61
62
        embeddings=new_embeddings,
        embedding_modules=EMBEDDING_MODULES,
        embedding_padding_modules=EMBEDDING_PADDING_MODULES)
63
64
65
66
67
68
    for module_name, lora in lora_model.loras.items():
        assert lora.module_name == module_name
        assert lora.rank == 8
        assert lora.lora_alpha == 16
        assert lora.lora_a is not None
        assert lora.lora_b is not None
69
70
        assert lora.lora_a.device == torch.device(device)
        assert lora.lora_b.device == torch.device(device)
71
72
73
74
75
76
77
78
79
80
81
82
83
84
        assert (lora.lora_a.shape[1] == lora.lora_b.shape[0]
                ), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}"
        assert lora.lora_a.shape[1] == 8
        embeddings_module = next(
            (k for k in EMBEDDING_MODULES if k in module_name), None)
        if embeddings_module:
            assert torch.equal(
                lora.embeddings_tensor,
                new_embeddings[EMBEDDING_MODULES[embeddings_module]].to(
                    device=lora.embeddings_tensor.device))
        else:
            assert lora.embeddings_tensor is None


85
def create_lora(lora_id: int, model: nn.Module, sub_modules: list[str],
86
                device: torch.device) -> LoRAModel:
87
    loras: dict[str, LoRALayerWeights] = {}
88
89
90
91
92
93
    for name in sub_modules:
        w = model.get_submodule(name).weight
        loras[name] = LoRALayerWeights(
            name,
            8,
            16,
94
95
            torch.rand([w.shape[1], 8], device=device),
            torch.rand([8, w.shape[0]], device=device),
96
97
98
99
100
101
102
103
104
        )
    return LoRAModel(lora_id, 8, loras)


def create_packed_lora(
    lora_id: int,
    model: nn.Module,
    module_name,
    replaced_module_names,
105
    device: torch.device,
106
107
108
    empty_replaced_module_name=None,
) -> LoRAModel:
    w = model.get_submodule(module_name).weight
109
    loras: dict[str, LoRALayerWeights] = {}
110
111
112
113
114
115
116
    for replaced_module_name in replaced_module_names:
        if replaced_module_name == empty_replaced_module_name:
            continue
        loras[replaced_module_name] = LoRALayerWeights(
            replaced_module_name,
            8,
            16,
117
            torch.rand([w.shape[1], 8], device=device),
118
            torch.rand([8, w.shape[0] // len(replaced_module_names)],
119
                       device=device),
120
121
122
123
124
125
        )
    return LoRAModel(lora_id, 8, loras)


def test_replace_submodules(dist_init, dummy_model):
    model = dummy_model
Terry's avatar
Terry committed
126
127
    manager = LoRAModelManager(
        model, 1, 1, 1,
128
        LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8),
129
        torch.device(DEVICES[0]))
130
131
132
133
134
    model = manager.model
    assert isinstance(model.get_submodule("dense1"),
                      ColumnParallelLinearWithLoRA)
    assert isinstance(model.get_submodule("layer1.dense1"),
                      ColumnParallelLinearWithLoRA)
135
    assert isinstance(model.get_submodule("dense2"), RowParallelLinearWithLoRA)
136
137
138
139
    assert isinstance(model.get_submodule("layer1.dense2"),
                      RowParallelLinearWithLoRA)


140
@pytest.mark.parametrize("device", DEVICES)
141
def test_lora_model_manager(dist_init, dummy_model, device):
142
    model = dummy_model
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
    model_lora1 = create_lora(1,
                              model, ["layer1.dense1", "dense2", "lm_head"],
                              device=device)
    model_lora2 = create_lora(2,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora3 = create_lora(3,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    manager = LoRAModelManager(model,
                               2,
                               2,
                               2,
                               LoRAConfig(max_lora_rank=8,
                                          max_cpu_loras=3,
                                          max_loras=2),
                               device=device)
160
    assert all(x is None for x in manager.lora_index_to_id)
161
162
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
163
    assert manager.lora_index_to_id[0] == 1
164
165
166
167
    assert not manager.add_adapter(model_lora1)
    assert not manager.activate_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
168
169
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
170
171
172
    assert not manager.add_adapter(model_lora2)
    assert not manager.activate_adapter(2)
    assert manager.add_adapter(model_lora3)
173
174
175
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
    with pytest.raises(ValueError):
176
        assert manager.activate_adapter(3)
177
178
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
179
    assert manager.remove_adapter(model_lora2.id)
180
    assert manager.lora_index_to_id[1] is None
181
182
183
184
    assert not manager.remove_adapter(model_lora2.id)
    assert manager.remove_adapter(model_lora1.id)
    assert not manager.remove_adapter(model_lora1.id)
    assert manager.add_adapter(model_lora1)
185
186
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] is None
187
188
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(3)
189
190
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] is None
191
    assert manager.activate_adapter(2)
192
193
194
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2

195
196
    assert manager.device == device
    assert manager.punica_wrapper.device == device
197
198
199
200
201
202
203
    assert hasattr(manager, "supported_lora_modules")
    assert sorted(manager.supported_lora_modules) == [
        "dense1",
        "dense2",
        "lm_head",
        "output",
    ]
204

205

206
@pytest.mark.parametrize("device", DEVICES)
207
def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
208
    model = dummy_model
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
    model_lora1 = create_lora(1,
                              model, ["layer1.dense1", "dense2", "lm_head"],
                              device=device)
    model_lora2 = create_lora(2,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora3 = create_lora(3,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    manager = LRUCacheLoRAModelManager(model,
                                       2,
                                       2,
                                       2,
                                       LoRAConfig(max_lora_rank=8,
                                                  max_cpu_loras=3,
                                                  max_loras=2),
                                       device=device)
226
    assert all(x is None for x in manager.lora_index_to_id)
227
228
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
229
    assert manager.lora_index_to_id[0] == 1
230
231
232
233
    assert not manager.add_adapter(model_lora1)
    assert not manager.activate_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
234
235
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
236
237
238
    assert not manager.add_adapter(model_lora2)
    assert not manager.activate_adapter(2)
    assert manager.add_adapter(model_lora3)
239
240
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
241
    assert manager.activate_adapter(3)
242
243
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2
244
    assert manager.remove_adapter(model_lora2.id)
245
    assert manager.lora_index_to_id[1] is None
246
247
248
249
250
    assert not manager.remove_adapter(model_lora2.id)
    assert manager.remove_adapter(model_lora1.id)
    assert not manager.remove_adapter(model_lora1.id)
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
251
252
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
253
254
    assert manager.add_adapter(model_lora2)
    assert manager.deactivate_adapter(3)
255
256
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 1
257
    assert manager.activate_adapter(2)
258
259
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
260
    assert manager.activate_adapter(3)
261
262
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 3
263
    assert manager.pin_adapter(2)
264
265
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 3
266
    assert manager.activate_adapter(1)
267
268
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
269
    assert manager.deactivate_adapter(2)
270
271
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 1
272
    assert manager.activate_adapter(3)
273
274
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
275
276
    assert manager.pin_adapter(3)
    assert manager.pin_adapter(1)
277
    with pytest.raises(RuntimeError):
278
        assert manager.pin_adapter(2)
279
280
281
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
    with pytest.raises(RuntimeError):
282
        assert manager.activate_adapter(2)
283

284
285
    assert manager.deactivate_adapter(3)
    assert manager.pin_adapter(2)
286
287
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
288
    assert manager.remove_adapter(3)
289
    with pytest.raises(ValueError):
290
        assert manager.pin_adapter(3)
291

292
293
294
    assert manager.punica_wrapper.device == device
    assert manager.device == device

295

296
@pytest.mark.parametrize("device", DEVICES)
297
def test_lru_lora_model_manager(dist_init, dummy_model, device):
298
299
300
    # This tests just the LRU cache functionality, everything else is
    # tested in test_lora_model_manager
    model = dummy_model
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
    model_lora1 = create_lora(1,
                              model, ["layer1.dense1", "dense2", "lm_head"],
                              device=device)
    model_lora2 = create_lora(2,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora3 = create_lora(3,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora4 = create_lora(4,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    manager = LRUCacheLoRAModelManager(model,
                                       2,
                                       2,
                                       2,
                                       LoRAConfig(max_lora_rank=8,
                                                  max_cpu_loras=2,
                                                  max_loras=2),
                                       device=device)
321
322
323
324

    assert all(x is None for x in manager.lora_index_to_id)

    # Add up to capacity
325
326
327
328
    assert manager.add_adapter(model_lora1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(1)
    assert manager.activate_adapter(2)
329

330
    assert set(manager.list_adapters()) == {1, 2}
331
332
333
334
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2

    # Add over capacity
335
336
337
338
    assert manager.add_adapter(model_lora3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(3)
    assert manager.activate_adapter(4)
339

340
    assert set(manager.list_adapters()) == {3, 4}
341
342
343
344
345
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 4

    # Add 3 again to move it to the top and then add 2
    # should return false since it's in already
346
347
348
349
    assert not manager.add_adapter(model_lora3)
    assert not manager.activate_adapter(3)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
350

351
    assert set(manager.list_adapters()) == {3, 2}
352
353
354
355
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2

    # Remove manually
356
357
    assert manager.remove_adapter(3)
    assert not manager.remove_adapter(3)
358

359
    assert set(manager.list_adapters()) == {2}
360
361
362
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 2

363
364
365
366
    assert manager.add_adapter(model_lora3)
    assert manager.activate_adapter(3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(4)
367

368
    assert set(manager.list_adapters()) == {3, 4}
369
370
371
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 4

372
373
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == {4}
374
375
376
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 4

377
378
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == set()
379
380
    assert all(x is None for x in manager.lora_index_to_id)

381
382
    assert not manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == set()
383
384
    assert all(x is None for x in manager.lora_index_to_id)

385
    # pinning
386
387
388
389
390
    assert manager.add_adapter(model_lora3)
    assert manager.activate_adapter(3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(4)
    assert set(manager.list_adapters()) == {3, 4}
391
    with pytest.raises(ValueError):
392
393
        assert manager.pin_adapter(1)
    assert manager.pin_adapter(3)
394
    # Remove manually
395
396
    assert manager.remove_adapter(3)
    assert not manager.remove_adapter(3)
397

398
    assert set(manager.list_adapters()) == {4}
399
400
401
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 4

402
403
404
405
    assert manager.add_adapter(model_lora1)
    assert manager.pin_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
406

407
    assert set(manager.list_adapters()) == {1, 2}
408
409
410
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2

411
412
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == {1}
413
414
415
416
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] is None

    with pytest.raises(RuntimeError):
417
        assert manager.remove_oldest_adapter()
418

419
    assert set(manager.list_adapters()) == {1}
420
421
    assert manager.punica_wrapper.device == device
    assert manager.device == device
422

423

424
@pytest.mark.parametrize("device", DEVICES)
425
def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
426
                                          sql_lora_files, device):
427
    lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
428
    worker_adapter_manager = LRUCacheWorkerLoRAManager(
Terry's avatar
Terry committed
429
        4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
430
        lora_config.lora_extra_vocab_size, lora_config, device,
Terry's avatar
Terry committed
431
        EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
432
433
    worker_adapter_manager.create_lora_manager(
        llama_2_7b_model_extra_embeddings)
434
435

    mapping = LoRAMapping([], [])
436
    worker_adapter_manager.set_active_adapters([
437
438
439
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files)
    ], mapping)
440
441
442
    assert worker_adapter_manager.list_adapters() == {1, 2}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
443

444
    worker_adapter_manager.set_active_adapters([
445
446
447
448
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("3", 3, sql_lora_files),
        LoRARequest("4", 4, sql_lora_files)
    ], mapping)
449
450
451
452
453
    assert worker_adapter_manager.list_adapters() == {1, 2, 3, 4}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 3
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
454

455
    worker_adapter_manager.set_active_adapters([
456
457
458
459
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files),
        LoRARequest("5", 5, sql_lora_files)
    ], mapping)
460
461
462
463
464
    assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
465

466
    worker_adapter_manager.set_active_adapters([
467
468
469
470
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files)
    ], mapping)
471
472
473
474
475
    assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
476

477
    worker_adapter_manager.set_active_adapters([
478
479
480
481
        LoRARequest("6", 6, sql_lora_files),
        LoRARequest("7", 7, sql_lora_files),
        LoRARequest("8", 8, sql_lora_files)
    ], mapping)
482
483
484
485
486
    assert worker_adapter_manager.list_adapters() == {1, 6, 7, 8}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 7
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 8
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 6
487
488
489

    # Over capacity
    with pytest.raises(RuntimeError):
490
        worker_adapter_manager.set_active_adapters([
491
492
493
494
495
496
497
            LoRARequest("10", 10, sql_lora_files),
            LoRARequest("11", 11, sql_lora_files),
            LoRARequest("12", 12, sql_lora_files),
            LoRARequest("13", 13, sql_lora_files),
            LoRARequest("14", 14, sql_lora_files)
        ], mapping)

498
499
500
    assert worker_adapter_manager.device == device
    assert (worker_adapter_manager._adapter_manager.punica_wrapper.device ==
            device)
501

502

503
@pytest.mark.parametrize("device", DEVICES)
504
def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
505
                                sql_lora_files, device):
506
507
    # Should remove every LoRA not specified in the request.
    lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
508
    worker_adapter_manager = WorkerLoRAManager(
Terry's avatar
Terry committed
509
        4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
510
        lora_config.lora_extra_vocab_size, lora_config, device,
Terry's avatar
Terry committed
511
        EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
512
513
    worker_adapter_manager.create_lora_manager(
        llama_2_7b_model_extra_embeddings)
514
515

    mapping = LoRAMapping([], [])
516
    worker_adapter_manager.set_active_adapters([
517
518
519
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files)
    ], mapping)
520
521
522
    assert worker_adapter_manager.list_adapters() == {1, 2}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
523

524
    worker_adapter_manager.set_active_adapters([
525
526
527
528
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("3", 3, sql_lora_files),
        LoRARequest("4", 4, sql_lora_files)
    ], mapping)
529
530
531
532
    assert worker_adapter_manager.list_adapters() == {1, 3, 4}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 3
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 4
533

534
    worker_adapter_manager.set_active_adapters([
535
536
537
538
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files),
        LoRARequest("5", 5, sql_lora_files)
    ], mapping)
539
540
541
542
    assert worker_adapter_manager.list_adapters() == {1, 2, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
543

544
    worker_adapter_manager.set_active_adapters([
545
546
547
548
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files)
    ], mapping)
549
550
551
552
    assert worker_adapter_manager.list_adapters() == {1}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] is None
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] is None
553

554
    worker_adapter_manager.set_active_adapters([
555
556
557
558
        LoRARequest("6", 6, sql_lora_files),
        LoRARequest("7", 7, sql_lora_files),
        LoRARequest("8", 8, sql_lora_files)
    ], mapping)
559
560
561
562
    assert worker_adapter_manager.list_adapters() == {6, 7, 8}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 8
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 6
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 7
563
564
565

    # Over capacity
    with pytest.raises(RuntimeError):
566
        worker_adapter_manager.set_active_adapters([
567
568
569
570
571
572
573
            LoRARequest("10", 10, sql_lora_files),
            LoRARequest("11", 11, sql_lora_files),
            LoRARequest("12", 12, sql_lora_files),
            LoRARequest("13", 13, sql_lora_files),
            LoRARequest("14", 14, sql_lora_files)
        ], mapping)

574
575
576
577
    assert worker_adapter_manager.device == device
    assert (worker_adapter_manager._adapter_manager.punica_wrapper.device ==
            device)

578

579
@pytest.mark.parametrize("device", DEVICES)
580
def test_packed_loras(dist_init, dummy_model_gate_up, device):
581
582
583
584
585
    model = dummy_model_gate_up
    model_lora = create_packed_lora(
        1,
        model,
        module_name="gate_up_proj",
586
587
        replaced_module_names=["gate_proj", "up_proj"],
        device=device)
588
589
590
591
592
    model_lora1 = create_packed_lora(
        2,
        model,
        module_name="gate_up_proj",
        replaced_module_names=["gate_proj", "up_proj"],
593
        device=device,
594
595
596
        empty_replaced_module_name="gate_proj",
    )

597
598
599
600
601
602
603
604
    manager = LoRAModelManager(model,
                               2,
                               2,
                               2,
                               LoRAConfig(max_lora_rank=8,
                                          max_cpu_loras=2,
                                          max_loras=2),
                               device=device)
605
606
607
608
    model = manager.model

    assert isinstance(model.get_submodule("gate_up_proj"),
                      MergedColumnParallelLinearWithLoRA)
609
610
611
    # Verify packed lora is correct
    model_lora_clone = model_lora.clone(1)
    model_lora_clone1 = model_lora1.clone(1)
612
613
    assert manager.add_adapter(model_lora)
    assert manager.add_adapter(model_lora1)
614

615
616
617
    assert model_lora.get_lora("gate_proj") is None
    assert model_lora.get_lora("up_proj") is None
    assert model_lora1.get_lora("up_proj") is None
618
619
620
    packed_lora = model_lora.get_lora("gate_up_proj")
    assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights)

621
    torch.testing.assert_close(packed_lora.lora_a[0],
622
                               model_lora_clone.get_lora("gate_proj").lora_a)
623
    torch.testing.assert_close(packed_lora.lora_b[0],
624
                               model_lora_clone.get_lora("gate_proj").lora_b)
625
    torch.testing.assert_close(packed_lora.lora_a[1],
626
                               model_lora_clone.get_lora("up_proj").lora_a)
627
    torch.testing.assert_close(packed_lora.lora_b[1],
628
                               model_lora_clone.get_lora("up_proj").lora_b)
629
630
631
632
633
634

    packed_lora1 = model_lora1.get_lora("gate_up_proj")
    assert packed_lora1 and isinstance(packed_lora1, PackedLoRALayerWeights)

    assert packed_lora1.lora_a[0] is None
    assert packed_lora1.lora_b[0] is None
635
    torch.testing.assert_close(packed_lora1.lora_a[1],
636
                               model_lora_clone1.get_lora("up_proj").lora_a)
637
    torch.testing.assert_close(packed_lora1.lora_b[1],
638
                               model_lora_clone1.get_lora("up_proj").lora_b)