test_lora_manager.py 26.2 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
9
10
11
import os

import pytest
import torch
from safetensors.torch import load_file
from torch import nn

from vllm.config import LoRAConfig
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
12
13
                              MergedColumnParallelLinearWithLoRA,
                              RowParallelLinearWithLoRA)
14
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
15
16
from vllm.lora.models import (LoRAMapping, LoRAModel, LoRAModelManager,
                              LRUCacheLoRAModelManager)
17
from vllm.lora.peft_helper import PEFTHelper
18
19
20
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager,
                                      WorkerLoRAManager)
21
from vllm.platforms import current_platform
22

Terry's avatar
Terry committed
23
24
25
26
27
28
29
EMBEDDING_MODULES = {
    "embed_tokens": "input_embeddings",
    "lm_head": "output_embeddings",
}

EMBEDDING_PADDING_MODULES = ["lm_head"]

30
DEVICES = ([
31
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
32
] if current_platform.is_cuda_alike() else ["cpu"])
33

34
35
DEFAULT_DTYPE = torch.get_default_dtype()

36

37
38
39
40
41
42
43
44
45
46
47
@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch: pytest.MonkeyPatch):
    """
    Some tests depend on V0 internals. Since both V0 and V1 use the same
    LoRAModelManager it is okay to just test V0.
    """
    with monkeypatch.context() as m:
        m.setenv('VLLM_USE_V1', '0')
        yield


48
@pytest.mark.parametrize("device", DEVICES)
49
def test_from_lora_tensors(sql_lora_files, device):
50
51
52
53
    tensors = load_file(
        os.path.join(sql_lora_files, "adapter_model.safetensors"))
    new_embeddings = load_file(
        os.path.join(sql_lora_files, "new_embeddings.safetensors"))
54

55
56
    peft_helper = PEFTHelper.from_local_dir(sql_lora_files,
                                            max_position_embeddings=4096)
Terry's avatar
Terry committed
57
58
59
    lora_model = LoRAModel.from_lora_tensors(
        1,
        tensors,
60
61
        peft_helper=peft_helper,
        device=device,
Terry's avatar
Terry committed
62
63
64
        embeddings=new_embeddings,
        embedding_modules=EMBEDDING_MODULES,
        embedding_padding_modules=EMBEDDING_PADDING_MODULES)
65
66
67
68
69
70
    for module_name, lora in lora_model.loras.items():
        assert lora.module_name == module_name
        assert lora.rank == 8
        assert lora.lora_alpha == 16
        assert lora.lora_a is not None
        assert lora.lora_b is not None
71
72
        assert lora.lora_a.device == torch.device(device)
        assert lora.lora_b.device == torch.device(device)
73
74
75
76
77
78
79
80
81
82
83
84
85
86
        assert (lora.lora_a.shape[1] == lora.lora_b.shape[0]
                ), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}"
        assert lora.lora_a.shape[1] == 8
        embeddings_module = next(
            (k for k in EMBEDDING_MODULES if k in module_name), None)
        if embeddings_module:
            assert torch.equal(
                lora.embeddings_tensor,
                new_embeddings[EMBEDDING_MODULES[embeddings_module]].to(
                    device=lora.embeddings_tensor.device))
        else:
            assert lora.embeddings_tensor is None


87
def create_lora(lora_id: int, model: nn.Module, sub_modules: list[str],
88
                device: torch.device) -> LoRAModel:
89
    loras: dict[str, LoRALayerWeights] = {}
90
91
92
93
94
95
    for name in sub_modules:
        w = model.get_submodule(name).weight
        loras[name] = LoRALayerWeights(
            name,
            8,
            16,
96
97
            torch.rand([w.shape[1], 8], device=device),
            torch.rand([8, w.shape[0]], device=device),
98
99
100
101
102
103
104
105
106
        )
    return LoRAModel(lora_id, 8, loras)


def create_packed_lora(
    lora_id: int,
    model: nn.Module,
    module_name,
    replaced_module_names,
107
    device: torch.device,
108
109
110
    empty_replaced_module_name=None,
) -> LoRAModel:
    w = model.get_submodule(module_name).weight
111
    loras: dict[str, LoRALayerWeights] = {}
112
113
114
115
116
117
118
    for replaced_module_name in replaced_module_names:
        if replaced_module_name == empty_replaced_module_name:
            continue
        loras[replaced_module_name] = LoRALayerWeights(
            replaced_module_name,
            8,
            16,
119
            torch.rand([w.shape[1], 8], device=device),
120
            torch.rand([8, w.shape[0] // len(replaced_module_names)],
121
                       device=device),
122
123
124
125
126
127
        )
    return LoRAModel(lora_id, 8, loras)


def test_replace_submodules(dist_init, dummy_model):
    model = dummy_model
Terry's avatar
Terry committed
128
129
    manager = LoRAModelManager(
        model, 1, 1, 1,
130
131
132
133
        LoRAConfig(max_lora_rank=8,
                   max_cpu_loras=8,
                   max_loras=8,
                   lora_dtype=DEFAULT_DTYPE), torch.device(DEVICES[0]))
134
135
136
137
138
    model = manager.model
    assert isinstance(model.get_submodule("dense1"),
                      ColumnParallelLinearWithLoRA)
    assert isinstance(model.get_submodule("layer1.dense1"),
                      ColumnParallelLinearWithLoRA)
139
    assert isinstance(model.get_submodule("dense2"), RowParallelLinearWithLoRA)
140
141
142
143
    assert isinstance(model.get_submodule("layer1.dense2"),
                      RowParallelLinearWithLoRA)


144
@pytest.mark.parametrize("device", DEVICES)
145
def test_lora_model_manager(dist_init, dummy_model, device):
146
    model = dummy_model
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
    model_lora1 = create_lora(1,
                              model, ["layer1.dense1", "dense2", "lm_head"],
                              device=device)
    model_lora2 = create_lora(2,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora3 = create_lora(3,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    manager = LoRAModelManager(model,
                               2,
                               2,
                               2,
                               LoRAConfig(max_lora_rank=8,
                                          max_cpu_loras=3,
162
163
                                          max_loras=2,
                                          lora_dtype=DEFAULT_DTYPE),
164
                               device=device)
165
    assert all(x is None for x in manager.lora_index_to_id)
166
167
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
168
    assert manager.lora_index_to_id[0] == 1
169
170
171
172
    assert not manager.add_adapter(model_lora1)
    assert not manager.activate_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
173
174
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
175
176
177
    assert not manager.add_adapter(model_lora2)
    assert not manager.activate_adapter(2)
    assert manager.add_adapter(model_lora3)
178
179
180
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
    with pytest.raises(ValueError):
181
        assert manager.activate_adapter(3)
182
183
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
184
    assert manager.remove_adapter(model_lora2.id)
185
    assert manager.lora_index_to_id[1] is None
186
187
188
189
    assert not manager.remove_adapter(model_lora2.id)
    assert manager.remove_adapter(model_lora1.id)
    assert not manager.remove_adapter(model_lora1.id)
    assert manager.add_adapter(model_lora1)
190
191
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] is None
192
193
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(3)
194
195
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] is None
196
    assert manager.activate_adapter(2)
197
198
199
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2

200
201
    assert manager.device == device
    assert manager.punica_wrapper.device == device
202
203
204
205
206
207
208
    assert hasattr(manager, "supported_lora_modules")
    assert sorted(manager.supported_lora_modules) == [
        "dense1",
        "dense2",
        "lm_head",
        "output",
    ]
209

210

211
@pytest.mark.parametrize("device", DEVICES)
212
def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
213
    model = dummy_model
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
    model_lora1 = create_lora(1,
                              model, ["layer1.dense1", "dense2", "lm_head"],
                              device=device)
    model_lora2 = create_lora(2,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora3 = create_lora(3,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    manager = LRUCacheLoRAModelManager(model,
                                       2,
                                       2,
                                       2,
                                       LoRAConfig(max_lora_rank=8,
                                                  max_cpu_loras=3,
229
230
                                                  max_loras=2,
                                                  lora_dtype=DEFAULT_DTYPE),
231
                                       device=device)
232
    assert all(x is None for x in manager.lora_index_to_id)
233
234
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
235
    assert manager.lora_index_to_id[0] == 1
236
237
238
239
    assert not manager.add_adapter(model_lora1)
    assert not manager.activate_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
240
241
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
242
243
244
    assert not manager.add_adapter(model_lora2)
    assert not manager.activate_adapter(2)
    assert manager.add_adapter(model_lora3)
245
246
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
247
    assert manager.activate_adapter(3)
248
249
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2
250
    assert manager.remove_adapter(model_lora2.id)
251
    assert manager.lora_index_to_id[1] is None
252
253
254
255
256
    assert not manager.remove_adapter(model_lora2.id)
    assert manager.remove_adapter(model_lora1.id)
    assert not manager.remove_adapter(model_lora1.id)
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
257
258
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
259
260
    assert manager.add_adapter(model_lora2)
    assert manager.deactivate_adapter(3)
261
262
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 1
263
    assert manager.activate_adapter(2)
264
265
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
266
    assert manager.activate_adapter(3)
267
268
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 3
269
    assert manager.pin_adapter(2)
270
271
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 3
272
    assert manager.activate_adapter(1)
273
274
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
275
    assert manager.deactivate_adapter(2)
276
277
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 1
278
    assert manager.activate_adapter(3)
279
280
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
281
282
    assert manager.pin_adapter(3)
    assert manager.pin_adapter(1)
283
    with pytest.raises(RuntimeError):
284
        assert manager.pin_adapter(2)
285
286
287
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
    with pytest.raises(RuntimeError):
288
        assert manager.activate_adapter(2)
289

290
291
    assert manager.deactivate_adapter(3)
    assert manager.pin_adapter(2)
292
293
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
294
    assert manager.remove_adapter(3)
295
    with pytest.raises(ValueError):
296
        assert manager.pin_adapter(3)
297

298
299
300
    assert manager.punica_wrapper.device == device
    assert manager.device == device

301

302
@pytest.mark.parametrize("device", DEVICES)
303
def test_lru_lora_model_manager(dist_init, dummy_model, device):
304
305
306
    # This tests just the LRU cache functionality, everything else is
    # tested in test_lora_model_manager
    model = dummy_model
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
    model_lora1 = create_lora(1,
                              model, ["layer1.dense1", "dense2", "lm_head"],
                              device=device)
    model_lora2 = create_lora(2,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora3 = create_lora(3,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora4 = create_lora(4,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    manager = LRUCacheLoRAModelManager(model,
                                       2,
                                       2,
                                       2,
                                       LoRAConfig(max_lora_rank=8,
                                                  max_cpu_loras=2,
325
326
                                                  max_loras=2,
                                                  lora_dtype=DEFAULT_DTYPE),
327
                                       device=device)
328
329
330
331

    assert all(x is None for x in manager.lora_index_to_id)

    # Add up to capacity
332
333
334
335
    assert manager.add_adapter(model_lora1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(1)
    assert manager.activate_adapter(2)
336

337
    assert set(manager.list_adapters()) == {1, 2}
338
339
340
341
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2

    # Add over capacity
342
343
344
345
    assert manager.add_adapter(model_lora3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(3)
    assert manager.activate_adapter(4)
346

347
    assert set(manager.list_adapters()) == {3, 4}
348
349
350
351
352
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 4

    # Add 3 again to move it to the top and then add 2
    # should return false since it's in already
353
354
355
356
    assert not manager.add_adapter(model_lora3)
    assert not manager.activate_adapter(3)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
357

358
    assert set(manager.list_adapters()) == {3, 2}
359
360
361
362
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2

    # Remove manually
363
364
    assert manager.remove_adapter(3)
    assert not manager.remove_adapter(3)
365

366
    assert set(manager.list_adapters()) == {2}
367
368
369
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 2

370
371
372
373
    assert manager.add_adapter(model_lora3)
    assert manager.activate_adapter(3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(4)
374

375
    assert set(manager.list_adapters()) == {3, 4}
376
377
378
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 4

379
380
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == {4}
381
382
383
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 4

384
385
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == set()
386
387
    assert all(x is None for x in manager.lora_index_to_id)

388
389
    assert not manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == set()
390
391
    assert all(x is None for x in manager.lora_index_to_id)

392
    # pinning
393
394
395
396
397
    assert manager.add_adapter(model_lora3)
    assert manager.activate_adapter(3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(4)
    assert set(manager.list_adapters()) == {3, 4}
398
    with pytest.raises(ValueError):
399
400
        assert manager.pin_adapter(1)
    assert manager.pin_adapter(3)
401
    # Remove manually
402
403
    assert manager.remove_adapter(3)
    assert not manager.remove_adapter(3)
404

405
    assert set(manager.list_adapters()) == {4}
406
407
408
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 4

409
410
411
412
    assert manager.add_adapter(model_lora1)
    assert manager.pin_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
413

414
    assert set(manager.list_adapters()) == {1, 2}
415
416
417
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2

418
419
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == {1}
420
421
422
423
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] is None

    with pytest.raises(RuntimeError):
424
        assert manager.remove_oldest_adapter()
425

426
    assert set(manager.list_adapters()) == {1}
427
428
    assert manager.punica_wrapper.device == device
    assert manager.device == device
429

430

431
@pytest.mark.parametrize("device", DEVICES)
432
def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
433
                                          sql_lora_files, device):
434
435
436
437
    lora_config = LoRAConfig(max_lora_rank=8,
                             max_cpu_loras=4,
                             max_loras=4,
                             lora_dtype=DEFAULT_DTYPE)
438
    worker_adapter_manager = LRUCacheWorkerLoRAManager(
Terry's avatar
Terry committed
439
        4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
440
        lora_config.lora_extra_vocab_size, lora_config, device,
Terry's avatar
Terry committed
441
        EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
442
443
    worker_adapter_manager.create_lora_manager(
        llama_2_7b_model_extra_embeddings)
444
445

    mapping = LoRAMapping([], [])
446
    worker_adapter_manager.set_active_adapters([
447
448
449
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files)
    ], mapping)
450
451
452
    assert worker_adapter_manager.list_adapters() == {1, 2}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
453

454
    worker_adapter_manager.set_active_adapters([
455
456
457
458
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("3", 3, sql_lora_files),
        LoRARequest("4", 4, sql_lora_files)
    ], mapping)
459
460
461
462
463
    assert worker_adapter_manager.list_adapters() == {1, 2, 3, 4}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 3
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
464

465
    worker_adapter_manager.set_active_adapters([
466
467
468
469
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files),
        LoRARequest("5", 5, sql_lora_files)
    ], mapping)
470
471
472
473
474
    assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
475

476
    worker_adapter_manager.set_active_adapters([
477
478
479
480
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files)
    ], mapping)
481
482
483
484
485
    assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
486

487
    worker_adapter_manager.set_active_adapters([
488
489
490
491
        LoRARequest("6", 6, sql_lora_files),
        LoRARequest("7", 7, sql_lora_files),
        LoRARequest("8", 8, sql_lora_files)
    ], mapping)
492
493
494
495
496
    assert worker_adapter_manager.list_adapters() == {1, 6, 7, 8}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 7
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 8
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 6
497
498
499

    # Over capacity
    with pytest.raises(RuntimeError):
500
        worker_adapter_manager.set_active_adapters([
501
502
503
504
505
506
507
            LoRARequest("10", 10, sql_lora_files),
            LoRARequest("11", 11, sql_lora_files),
            LoRARequest("12", 12, sql_lora_files),
            LoRARequest("13", 13, sql_lora_files),
            LoRARequest("14", 14, sql_lora_files)
        ], mapping)

508
509
510
    assert worker_adapter_manager.device == device
    assert (worker_adapter_manager._adapter_manager.punica_wrapper.device ==
            device)
511

512

513
@pytest.mark.parametrize("device", DEVICES)
514
def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
515
                                sql_lora_files, device):
516
    # Should remove every LoRA not specified in the request.
517
518
519
520
    lora_config = LoRAConfig(max_lora_rank=8,
                             max_cpu_loras=4,
                             max_loras=4,
                             lora_dtype=DEFAULT_DTYPE)
521
    worker_adapter_manager = WorkerLoRAManager(
Terry's avatar
Terry committed
522
        4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
523
        lora_config.lora_extra_vocab_size, lora_config, device,
Terry's avatar
Terry committed
524
        EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
525
526
    worker_adapter_manager.create_lora_manager(
        llama_2_7b_model_extra_embeddings)
527
528

    mapping = LoRAMapping([], [])
529
    worker_adapter_manager.set_active_adapters([
530
531
532
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files)
    ], mapping)
533
534
535
    assert worker_adapter_manager.list_adapters() == {1, 2}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
536

537
    worker_adapter_manager.set_active_adapters([
538
539
540
541
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("3", 3, sql_lora_files),
        LoRARequest("4", 4, sql_lora_files)
    ], mapping)
542
543
544
545
    assert worker_adapter_manager.list_adapters() == {1, 3, 4}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 3
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 4
546

547
    worker_adapter_manager.set_active_adapters([
548
549
550
551
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files),
        LoRARequest("5", 5, sql_lora_files)
    ], mapping)
552
553
554
555
    assert worker_adapter_manager.list_adapters() == {1, 2, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
556

557
    worker_adapter_manager.set_active_adapters([
558
559
560
561
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files)
    ], mapping)
562
563
564
565
    assert worker_adapter_manager.list_adapters() == {1}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] is None
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] is None
566

567
    worker_adapter_manager.set_active_adapters([
568
569
570
571
        LoRARequest("6", 6, sql_lora_files),
        LoRARequest("7", 7, sql_lora_files),
        LoRARequest("8", 8, sql_lora_files)
    ], mapping)
572
573
574
575
    assert worker_adapter_manager.list_adapters() == {6, 7, 8}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 8
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 6
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 7
576
577
578

    # Over capacity
    with pytest.raises(RuntimeError):
579
        worker_adapter_manager.set_active_adapters([
580
581
582
583
584
585
586
            LoRARequest("10", 10, sql_lora_files),
            LoRARequest("11", 11, sql_lora_files),
            LoRARequest("12", 12, sql_lora_files),
            LoRARequest("13", 13, sql_lora_files),
            LoRARequest("14", 14, sql_lora_files)
        ], mapping)

587
588
589
590
    assert worker_adapter_manager.device == device
    assert (worker_adapter_manager._adapter_manager.punica_wrapper.device ==
            device)

591

592
@pytest.mark.parametrize("device", DEVICES)
593
def test_packed_loras(dist_init, dummy_model_gate_up, device):
594
595
596
597
598
    model = dummy_model_gate_up
    model_lora = create_packed_lora(
        1,
        model,
        module_name="gate_up_proj",
599
600
        replaced_module_names=["gate_proj", "up_proj"],
        device=device)
601
602
603
604
605
    model_lora1 = create_packed_lora(
        2,
        model,
        module_name="gate_up_proj",
        replaced_module_names=["gate_proj", "up_proj"],
606
        device=device,
607
608
609
        empty_replaced_module_name="gate_proj",
    )

610
611
612
613
614
615
    manager = LoRAModelManager(model,
                               2,
                               2,
                               2,
                               LoRAConfig(max_lora_rank=8,
                                          max_cpu_loras=2,
616
617
                                          max_loras=2,
                                          lora_dtype=DEFAULT_DTYPE),
618
                               device=device)
619
620
621
622
    model = manager.model

    assert isinstance(model.get_submodule("gate_up_proj"),
                      MergedColumnParallelLinearWithLoRA)
623
624
625
    # Verify packed lora is correct
    model_lora_clone = model_lora.clone(1)
    model_lora_clone1 = model_lora1.clone(1)
626
627
    assert manager.add_adapter(model_lora)
    assert manager.add_adapter(model_lora1)
628

629
630
631
    assert model_lora.get_lora("gate_proj") is None
    assert model_lora.get_lora("up_proj") is None
    assert model_lora1.get_lora("up_proj") is None
632
633
634
    packed_lora = model_lora.get_lora("gate_up_proj")
    assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights)

635
    torch.testing.assert_close(packed_lora.lora_a[0],
636
                               model_lora_clone.get_lora("gate_proj").lora_a)
637
    torch.testing.assert_close(packed_lora.lora_b[0],
638
                               model_lora_clone.get_lora("gate_proj").lora_b)
639
    torch.testing.assert_close(packed_lora.lora_a[1],
640
                               model_lora_clone.get_lora("up_proj").lora_a)
641
    torch.testing.assert_close(packed_lora.lora_b[1],
642
                               model_lora_clone.get_lora("up_proj").lora_b)
643
644
645
646
647
648

    packed_lora1 = model_lora1.get_lora("gate_up_proj")
    assert packed_lora1 and isinstance(packed_lora1, PackedLoRALayerWeights)

    assert packed_lora1.lora_a[0] is None
    assert packed_lora1.lora_b[0] is None
649
    torch.testing.assert_close(packed_lora1.lora_a[1],
650
                               model_lora_clone1.get_lora("up_proj").lora_a)
651
    torch.testing.assert_close(packed_lora1.lora_b[1],
652
                               model_lora_clone1.get_lora("up_proj").lora_b)