"vllm/benchmarks/sweep/serve_workload.py" did not exist on "d3a51da92a031f6c1758771a2b13976ace2eece2"
test_lora_manager.py 25.3 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
9
10
11
import os

import pytest
import torch
from safetensors.torch import load_file
from torch import nn

from vllm.config import LoRAConfig
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
12
13
                              MergedColumnParallelLinearWithLoRA,
                              RowParallelLinearWithLoRA)
14
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
15
16
from vllm.lora.models import (LoRAMapping, LoRAModel, LoRAModelManager,
                              LRUCacheLoRAModelManager)
17
from vllm.lora.peft_helper import PEFTHelper
18
19
20
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager,
                                      WorkerLoRAManager)
21
from vllm.platforms import current_platform
22

Terry's avatar
Terry committed
23
24
25
26
27
28
29
EMBEDDING_MODULES = {
    "embed_tokens": "input_embeddings",
    "lm_head": "output_embeddings",
}

EMBEDDING_PADDING_MODULES = ["lm_head"]

30
DEVICES = ([
31
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
32
] if current_platform.is_cuda_alike() else ["cpu"])
33

34

35
@pytest.mark.parametrize("device", DEVICES)
36
def test_from_lora_tensors(sql_lora_files, device):
37
38
39
40
    tensors = load_file(
        os.path.join(sql_lora_files, "adapter_model.safetensors"))
    new_embeddings = load_file(
        os.path.join(sql_lora_files, "new_embeddings.safetensors"))
41

42
43
    peft_helper = PEFTHelper.from_local_dir(sql_lora_files,
                                            max_position_embeddings=4096)
Terry's avatar
Terry committed
44
45
46
    lora_model = LoRAModel.from_lora_tensors(
        1,
        tensors,
47
48
        peft_helper=peft_helper,
        device=device,
Terry's avatar
Terry committed
49
50
51
        embeddings=new_embeddings,
        embedding_modules=EMBEDDING_MODULES,
        embedding_padding_modules=EMBEDDING_PADDING_MODULES)
52
53
54
55
56
57
    for module_name, lora in lora_model.loras.items():
        assert lora.module_name == module_name
        assert lora.rank == 8
        assert lora.lora_alpha == 16
        assert lora.lora_a is not None
        assert lora.lora_b is not None
58
59
        assert lora.lora_a.device == torch.device(device)
        assert lora.lora_b.device == torch.device(device)
60
61
62
63
64
65
66
67
68
69
70
71
72
73
        assert (lora.lora_a.shape[1] == lora.lora_b.shape[0]
                ), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}"
        assert lora.lora_a.shape[1] == 8
        embeddings_module = next(
            (k for k in EMBEDDING_MODULES if k in module_name), None)
        if embeddings_module:
            assert torch.equal(
                lora.embeddings_tensor,
                new_embeddings[EMBEDDING_MODULES[embeddings_module]].to(
                    device=lora.embeddings_tensor.device))
        else:
            assert lora.embeddings_tensor is None


74
def create_lora(lora_id: int, model: nn.Module, sub_modules: list[str],
75
                device: torch.device) -> LoRAModel:
76
    loras: dict[str, LoRALayerWeights] = {}
77
78
79
80
81
82
    for name in sub_modules:
        w = model.get_submodule(name).weight
        loras[name] = LoRALayerWeights(
            name,
            8,
            16,
83
84
            torch.rand([w.shape[1], 8], device=device),
            torch.rand([8, w.shape[0]], device=device),
85
86
87
88
89
90
91
92
93
        )
    return LoRAModel(lora_id, 8, loras)


def create_packed_lora(
    lora_id: int,
    model: nn.Module,
    module_name,
    replaced_module_names,
94
    device: torch.device,
95
96
97
    empty_replaced_module_name=None,
) -> LoRAModel:
    w = model.get_submodule(module_name).weight
98
    loras: dict[str, LoRALayerWeights] = {}
99
100
101
102
103
104
105
    for replaced_module_name in replaced_module_names:
        if replaced_module_name == empty_replaced_module_name:
            continue
        loras[replaced_module_name] = LoRALayerWeights(
            replaced_module_name,
            8,
            16,
106
            torch.rand([w.shape[1], 8], device=device),
107
            torch.rand([8, w.shape[0] // len(replaced_module_names)],
108
                       device=device),
109
110
111
112
113
114
        )
    return LoRAModel(lora_id, 8, loras)


def test_replace_submodules(dist_init, dummy_model):
    model = dummy_model
Terry's avatar
Terry committed
115
116
    manager = LoRAModelManager(
        model, 1, 1, 1,
117
        LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8),
118
        torch.device(DEVICES[0]))
119
120
121
122
123
    model = manager.model
    assert isinstance(model.get_submodule("dense1"),
                      ColumnParallelLinearWithLoRA)
    assert isinstance(model.get_submodule("layer1.dense1"),
                      ColumnParallelLinearWithLoRA)
124
    assert isinstance(model.get_submodule("dense2"), RowParallelLinearWithLoRA)
125
126
127
128
    assert isinstance(model.get_submodule("layer1.dense2"),
                      RowParallelLinearWithLoRA)


129
@pytest.mark.parametrize("device", DEVICES)
130
def test_lora_model_manager(dist_init, dummy_model, device):
131
    model = dummy_model
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
    model_lora1 = create_lora(1,
                              model, ["layer1.dense1", "dense2", "lm_head"],
                              device=device)
    model_lora2 = create_lora(2,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora3 = create_lora(3,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    manager = LoRAModelManager(model,
                               2,
                               2,
                               2,
                               LoRAConfig(max_lora_rank=8,
                                          max_cpu_loras=3,
                                          max_loras=2),
                               device=device)
149
    assert all(x is None for x in manager.lora_index_to_id)
150
151
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
152
    assert manager.lora_index_to_id[0] == 1
153
154
155
156
    assert not manager.add_adapter(model_lora1)
    assert not manager.activate_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
157
158
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
159
160
161
    assert not manager.add_adapter(model_lora2)
    assert not manager.activate_adapter(2)
    assert manager.add_adapter(model_lora3)
162
163
164
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
    with pytest.raises(ValueError):
165
        assert manager.activate_adapter(3)
166
167
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
168
    assert manager.remove_adapter(model_lora2.id)
169
    assert manager.lora_index_to_id[1] is None
170
171
172
173
    assert not manager.remove_adapter(model_lora2.id)
    assert manager.remove_adapter(model_lora1.id)
    assert not manager.remove_adapter(model_lora1.id)
    assert manager.add_adapter(model_lora1)
174
175
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] is None
176
177
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(3)
178
179
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] is None
180
    assert manager.activate_adapter(2)
181
182
183
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2

184
185
    assert manager.device == device
    assert manager.punica_wrapper.device == device
186
187
188
189
190
191
192
    assert hasattr(manager, "supported_lora_modules")
    assert sorted(manager.supported_lora_modules) == [
        "dense1",
        "dense2",
        "lm_head",
        "output",
    ]
193

194

195
@pytest.mark.parametrize("device", DEVICES)
196
def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
197
    model = dummy_model
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
    model_lora1 = create_lora(1,
                              model, ["layer1.dense1", "dense2", "lm_head"],
                              device=device)
    model_lora2 = create_lora(2,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora3 = create_lora(3,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    manager = LRUCacheLoRAModelManager(model,
                                       2,
                                       2,
                                       2,
                                       LoRAConfig(max_lora_rank=8,
                                                  max_cpu_loras=3,
                                                  max_loras=2),
                                       device=device)
215
    assert all(x is None for x in manager.lora_index_to_id)
216
217
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
218
    assert manager.lora_index_to_id[0] == 1
219
220
221
222
    assert not manager.add_adapter(model_lora1)
    assert not manager.activate_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
223
224
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
225
226
227
    assert not manager.add_adapter(model_lora2)
    assert not manager.activate_adapter(2)
    assert manager.add_adapter(model_lora3)
228
229
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
230
    assert manager.activate_adapter(3)
231
232
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2
233
    assert manager.remove_adapter(model_lora2.id)
234
    assert manager.lora_index_to_id[1] is None
235
236
237
238
239
    assert not manager.remove_adapter(model_lora2.id)
    assert manager.remove_adapter(model_lora1.id)
    assert not manager.remove_adapter(model_lora1.id)
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
240
241
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
242
243
    assert manager.add_adapter(model_lora2)
    assert manager.deactivate_adapter(3)
244
245
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 1
246
    assert manager.activate_adapter(2)
247
248
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
249
    assert manager.activate_adapter(3)
250
251
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 3
252
    assert manager.pin_adapter(2)
253
254
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 3
255
    assert manager.activate_adapter(1)
256
257
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
258
    assert manager.deactivate_adapter(2)
259
260
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 1
261
    assert manager.activate_adapter(3)
262
263
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
264
265
    assert manager.pin_adapter(3)
    assert manager.pin_adapter(1)
266
    with pytest.raises(RuntimeError):
267
        assert manager.pin_adapter(2)
268
269
270
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
    with pytest.raises(RuntimeError):
271
        assert manager.activate_adapter(2)
272

273
274
    assert manager.deactivate_adapter(3)
    assert manager.pin_adapter(2)
275
276
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
277
    assert manager.remove_adapter(3)
278
    with pytest.raises(ValueError):
279
        assert manager.pin_adapter(3)
280

281
282
283
    assert manager.punica_wrapper.device == device
    assert manager.device == device

284

285
@pytest.mark.parametrize("device", DEVICES)
286
def test_lru_lora_model_manager(dist_init, dummy_model, device):
287
288
289
    # This tests just the LRU cache functionality, everything else is
    # tested in test_lora_model_manager
    model = dummy_model
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
    model_lora1 = create_lora(1,
                              model, ["layer1.dense1", "dense2", "lm_head"],
                              device=device)
    model_lora2 = create_lora(2,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora3 = create_lora(3,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora4 = create_lora(4,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    manager = LRUCacheLoRAModelManager(model,
                                       2,
                                       2,
                                       2,
                                       LoRAConfig(max_lora_rank=8,
                                                  max_cpu_loras=2,
                                                  max_loras=2),
                                       device=device)
310
311
312
313

    assert all(x is None for x in manager.lora_index_to_id)

    # Add up to capacity
314
315
316
317
    assert manager.add_adapter(model_lora1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(1)
    assert manager.activate_adapter(2)
318

319
    assert set(manager.list_adapters()) == {1, 2}
320
321
322
323
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2

    # Add over capacity
324
325
326
327
    assert manager.add_adapter(model_lora3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(3)
    assert manager.activate_adapter(4)
328

329
    assert set(manager.list_adapters()) == {3, 4}
330
331
332
333
334
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 4

    # Add 3 again to move it to the top and then add 2
    # should return false since it's in already
335
336
337
338
    assert not manager.add_adapter(model_lora3)
    assert not manager.activate_adapter(3)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
339

340
    assert set(manager.list_adapters()) == {3, 2}
341
342
343
344
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2

    # Remove manually
345
346
    assert manager.remove_adapter(3)
    assert not manager.remove_adapter(3)
347

348
    assert set(manager.list_adapters()) == {2}
349
350
351
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 2

352
353
354
355
    assert manager.add_adapter(model_lora3)
    assert manager.activate_adapter(3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(4)
356

357
    assert set(manager.list_adapters()) == {3, 4}
358
359
360
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 4

361
362
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == {4}
363
364
365
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 4

366
367
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == set()
368
369
    assert all(x is None for x in manager.lora_index_to_id)

370
371
    assert not manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == set()
372
373
    assert all(x is None for x in manager.lora_index_to_id)

374
    # pinning
375
376
377
378
379
    assert manager.add_adapter(model_lora3)
    assert manager.activate_adapter(3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(4)
    assert set(manager.list_adapters()) == {3, 4}
380
    with pytest.raises(ValueError):
381
382
        assert manager.pin_adapter(1)
    assert manager.pin_adapter(3)
383
    # Remove manually
384
385
    assert manager.remove_adapter(3)
    assert not manager.remove_adapter(3)
386

387
    assert set(manager.list_adapters()) == {4}
388
389
390
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 4

391
392
393
394
    assert manager.add_adapter(model_lora1)
    assert manager.pin_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
395

396
    assert set(manager.list_adapters()) == {1, 2}
397
398
399
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2

400
401
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == {1}
402
403
404
405
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] is None

    with pytest.raises(RuntimeError):
406
        assert manager.remove_oldest_adapter()
407

408
    assert set(manager.list_adapters()) == {1}
409
410
    assert manager.punica_wrapper.device == device
    assert manager.device == device
411

412

413
@pytest.mark.parametrize("device", DEVICES)
414
def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
415
                                          sql_lora_files, device):
416
    lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
417
    worker_adapter_manager = LRUCacheWorkerLoRAManager(
Terry's avatar
Terry committed
418
        4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
419
        lora_config.lora_extra_vocab_size, lora_config, device,
Terry's avatar
Terry committed
420
        EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
421
422
    worker_adapter_manager.create_lora_manager(
        llama_2_7b_model_extra_embeddings)
423
424

    mapping = LoRAMapping([], [])
425
    worker_adapter_manager.set_active_adapters([
426
427
428
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files)
    ], mapping)
429
430
431
    assert worker_adapter_manager.list_adapters() == {1, 2}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
432

433
    worker_adapter_manager.set_active_adapters([
434
435
436
437
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("3", 3, sql_lora_files),
        LoRARequest("4", 4, sql_lora_files)
    ], mapping)
438
439
440
441
442
    assert worker_adapter_manager.list_adapters() == {1, 2, 3, 4}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 3
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
443

444
    worker_adapter_manager.set_active_adapters([
445
446
447
448
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files),
        LoRARequest("5", 5, sql_lora_files)
    ], mapping)
449
450
451
452
453
    assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
454

455
    worker_adapter_manager.set_active_adapters([
456
457
458
459
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files)
    ], mapping)
460
461
462
463
464
    assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
465

466
    worker_adapter_manager.set_active_adapters([
467
468
469
470
        LoRARequest("6", 6, sql_lora_files),
        LoRARequest("7", 7, sql_lora_files),
        LoRARequest("8", 8, sql_lora_files)
    ], mapping)
471
472
473
474
475
    assert worker_adapter_manager.list_adapters() == {1, 6, 7, 8}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 7
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 8
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 6
476
477
478

    # Over capacity
    with pytest.raises(RuntimeError):
479
        worker_adapter_manager.set_active_adapters([
480
481
482
483
484
485
486
            LoRARequest("10", 10, sql_lora_files),
            LoRARequest("11", 11, sql_lora_files),
            LoRARequest("12", 12, sql_lora_files),
            LoRARequest("13", 13, sql_lora_files),
            LoRARequest("14", 14, sql_lora_files)
        ], mapping)

487
488
489
    assert worker_adapter_manager.device == device
    assert (worker_adapter_manager._adapter_manager.punica_wrapper.device ==
            device)
490

491

492
@pytest.mark.parametrize("device", DEVICES)
493
def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
494
                                sql_lora_files, device):
495
496
    # Should remove every LoRA not specified in the request.
    lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
497
    worker_adapter_manager = WorkerLoRAManager(
Terry's avatar
Terry committed
498
        4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
499
        lora_config.lora_extra_vocab_size, lora_config, device,
Terry's avatar
Terry committed
500
        EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
501
502
    worker_adapter_manager.create_lora_manager(
        llama_2_7b_model_extra_embeddings)
503
504

    mapping = LoRAMapping([], [])
505
    worker_adapter_manager.set_active_adapters([
506
507
508
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files)
    ], mapping)
509
510
511
    assert worker_adapter_manager.list_adapters() == {1, 2}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
512

513
    worker_adapter_manager.set_active_adapters([
514
515
516
517
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("3", 3, sql_lora_files),
        LoRARequest("4", 4, sql_lora_files)
    ], mapping)
518
519
520
521
    assert worker_adapter_manager.list_adapters() == {1, 3, 4}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 3
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 4
522

523
    worker_adapter_manager.set_active_adapters([
524
525
526
527
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files),
        LoRARequest("5", 5, sql_lora_files)
    ], mapping)
528
529
530
531
    assert worker_adapter_manager.list_adapters() == {1, 2, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
532

533
    worker_adapter_manager.set_active_adapters([
534
535
536
537
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files)
    ], mapping)
538
539
540
541
    assert worker_adapter_manager.list_adapters() == {1}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] is None
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] is None
542

543
    worker_adapter_manager.set_active_adapters([
544
545
546
547
        LoRARequest("6", 6, sql_lora_files),
        LoRARequest("7", 7, sql_lora_files),
        LoRARequest("8", 8, sql_lora_files)
    ], mapping)
548
549
550
551
    assert worker_adapter_manager.list_adapters() == {6, 7, 8}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 8
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 6
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 7
552
553
554

    # Over capacity
    with pytest.raises(RuntimeError):
555
        worker_adapter_manager.set_active_adapters([
556
557
558
559
560
561
562
            LoRARequest("10", 10, sql_lora_files),
            LoRARequest("11", 11, sql_lora_files),
            LoRARequest("12", 12, sql_lora_files),
            LoRARequest("13", 13, sql_lora_files),
            LoRARequest("14", 14, sql_lora_files)
        ], mapping)

563
564
565
566
    assert worker_adapter_manager.device == device
    assert (worker_adapter_manager._adapter_manager.punica_wrapper.device ==
            device)

567

568
@pytest.mark.parametrize("device", DEVICES)
569
def test_packed_loras(dist_init, dummy_model_gate_up, device):
570
571
572
573
574
    model = dummy_model_gate_up
    model_lora = create_packed_lora(
        1,
        model,
        module_name="gate_up_proj",
575
576
        replaced_module_names=["gate_proj", "up_proj"],
        device=device)
577
578
579
580
581
    model_lora1 = create_packed_lora(
        2,
        model,
        module_name="gate_up_proj",
        replaced_module_names=["gate_proj", "up_proj"],
582
        device=device,
583
584
585
        empty_replaced_module_name="gate_proj",
    )

586
587
588
589
590
591
592
593
    manager = LoRAModelManager(model,
                               2,
                               2,
                               2,
                               LoRAConfig(max_lora_rank=8,
                                          max_cpu_loras=2,
                                          max_loras=2),
                               device=device)
594
595
596
597
    model = manager.model

    assert isinstance(model.get_submodule("gate_up_proj"),
                      MergedColumnParallelLinearWithLoRA)
598
599
600
    # Verify packed lora is correct
    model_lora_clone = model_lora.clone(1)
    model_lora_clone1 = model_lora1.clone(1)
601
602
    assert manager.add_adapter(model_lora)
    assert manager.add_adapter(model_lora1)
603

604
605
606
    assert model_lora.get_lora("gate_proj") is None
    assert model_lora.get_lora("up_proj") is None
    assert model_lora1.get_lora("up_proj") is None
607
608
609
    packed_lora = model_lora.get_lora("gate_up_proj")
    assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights)

610
    torch.testing.assert_close(packed_lora.lora_a[0],
611
                               model_lora_clone.get_lora("gate_proj").lora_a)
612
    torch.testing.assert_close(packed_lora.lora_b[0],
613
                               model_lora_clone.get_lora("gate_proj").lora_b)
614
    torch.testing.assert_close(packed_lora.lora_a[1],
615
                               model_lora_clone.get_lora("up_proj").lora_a)
616
    torch.testing.assert_close(packed_lora.lora_b[1],
617
                               model_lora_clone.get_lora("up_proj").lora_b)
618
619
620
621
622
623

    packed_lora1 = model_lora1.get_lora("gate_up_proj")
    assert packed_lora1 and isinstance(packed_lora1, PackedLoRALayerWeights)

    assert packed_lora1.lora_a[0] is None
    assert packed_lora1.lora_b[0] is None
624
    torch.testing.assert_close(packed_lora1.lora_a[1],
625
                               model_lora_clone1.get_lora("up_proj").lora_a)
626
    torch.testing.assert_close(packed_lora1.lora_b[1],
627
                               model_lora_clone1.get_lora("up_proj").lora_b)