test_lora_manager.py 25.4 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
9
import os

import pytest
import torch
from safetensors.torch import load_file
from torch import nn

10
from vllm import envs
11
12
from vllm.config import LoRAConfig
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
13
14
                              MergedColumnParallelLinearWithLoRA,
                              RowParallelLinearWithLoRA)
15
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
16
17
from vllm.lora.models import (LoRAMapping, LoRAModel, LoRAModelManager,
                              LRUCacheLoRAModelManager)
18
from vllm.lora.peft_helper import PEFTHelper
19
20
21
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager,
                                      WorkerLoRAManager)
22
from vllm.platforms import current_platform
23

Terry's avatar
Terry committed
24
25
26
27
28
29
30
EMBEDDING_MODULES = {
    "embed_tokens": "input_embeddings",
    "lm_head": "output_embeddings",
}

EMBEDDING_PADDING_MODULES = ["lm_head"]

31
DEVICES = ([
32
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
33
] if current_platform.is_cuda_alike() else ["cpu"])
34

35

36
@pytest.mark.parametrize("device", DEVICES)
37
def test_from_lora_tensors(sql_lora_files, device):
38
39
40
41
    tensors = load_file(
        os.path.join(sql_lora_files, "adapter_model.safetensors"))
    new_embeddings = load_file(
        os.path.join(sql_lora_files, "new_embeddings.safetensors"))
42

43
44
    peft_helper = PEFTHelper.from_local_dir(sql_lora_files,
                                            max_position_embeddings=4096)
Terry's avatar
Terry committed
45
46
47
    lora_model = LoRAModel.from_lora_tensors(
        1,
        tensors,
48
49
        peft_helper=peft_helper,
        device=device,
Terry's avatar
Terry committed
50
51
52
        embeddings=new_embeddings,
        embedding_modules=EMBEDDING_MODULES,
        embedding_padding_modules=EMBEDDING_PADDING_MODULES)
53
54
55
56
57
58
    for module_name, lora in lora_model.loras.items():
        assert lora.module_name == module_name
        assert lora.rank == 8
        assert lora.lora_alpha == 16
        assert lora.lora_a is not None
        assert lora.lora_b is not None
59
60
        assert lora.lora_a.device == torch.device(device)
        assert lora.lora_b.device == torch.device(device)
61
62
63
64
65
66
67
68
69
70
71
72
73
74
        assert (lora.lora_a.shape[1] == lora.lora_b.shape[0]
                ), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}"
        assert lora.lora_a.shape[1] == 8
        embeddings_module = next(
            (k for k in EMBEDDING_MODULES if k in module_name), None)
        if embeddings_module:
            assert torch.equal(
                lora.embeddings_tensor,
                new_embeddings[EMBEDDING_MODULES[embeddings_module]].to(
                    device=lora.embeddings_tensor.device))
        else:
            assert lora.embeddings_tensor is None


75
def create_lora(lora_id: int, model: nn.Module, sub_modules: list[str],
76
                device: torch.device) -> LoRAModel:
77
    loras: dict[str, LoRALayerWeights] = {}
78
79
80
81
82
83
    for name in sub_modules:
        w = model.get_submodule(name).weight
        loras[name] = LoRALayerWeights(
            name,
            8,
            16,
84
85
            torch.rand([w.shape[1], 8], device=device),
            torch.rand([8, w.shape[0]], device=device),
86
87
88
89
90
91
92
93
94
        )
    return LoRAModel(lora_id, 8, loras)


def create_packed_lora(
    lora_id: int,
    model: nn.Module,
    module_name,
    replaced_module_names,
95
    device: torch.device,
96
97
98
    empty_replaced_module_name=None,
) -> LoRAModel:
    w = model.get_submodule(module_name).weight
99
    loras: dict[str, LoRALayerWeights] = {}
100
101
102
103
104
105
106
    for replaced_module_name in replaced_module_names:
        if replaced_module_name == empty_replaced_module_name:
            continue
        loras[replaced_module_name] = LoRALayerWeights(
            replaced_module_name,
            8,
            16,
107
            torch.rand([w.shape[1], 8], device=device),
108
            torch.rand([8, w.shape[0] // len(replaced_module_names)],
109
                       device=device),
110
111
112
113
114
115
        )
    return LoRAModel(lora_id, 8, loras)


def test_replace_submodules(dist_init, dummy_model):
    model = dummy_model
Terry's avatar
Terry committed
116
117
    manager = LoRAModelManager(
        model, 1, 1, 1,
118
        LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8),
119
        torch.device(DEVICES[0]))
120
121
122
123
124
    model = manager.model
    assert isinstance(model.get_submodule("dense1"),
                      ColumnParallelLinearWithLoRA)
    assert isinstance(model.get_submodule("layer1.dense1"),
                      ColumnParallelLinearWithLoRA)
125
    assert isinstance(model.get_submodule("dense2"), RowParallelLinearWithLoRA)
126
127
128
129
    assert isinstance(model.get_submodule("layer1.dense2"),
                      RowParallelLinearWithLoRA)


130
@pytest.mark.parametrize("device", DEVICES)
131
def test_lora_model_manager(dist_init, dummy_model, device):
132
    model = dummy_model
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
    model_lora1 = create_lora(1,
                              model, ["layer1.dense1", "dense2", "lm_head"],
                              device=device)
    model_lora2 = create_lora(2,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora3 = create_lora(3,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    manager = LoRAModelManager(model,
                               2,
                               2,
                               2,
                               LoRAConfig(max_lora_rank=8,
                                          max_cpu_loras=3,
                                          max_loras=2),
                               device=device)
150
    assert all(x is None for x in manager.lora_index_to_id)
151
152
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
153
    assert manager.lora_index_to_id[0] == 1
154
155
156
157
    assert not manager.add_adapter(model_lora1)
    assert not manager.activate_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
158
159
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
160
161
162
    assert not manager.add_adapter(model_lora2)
    assert not manager.activate_adapter(2)
    assert manager.add_adapter(model_lora3)
163
164
165
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
    with pytest.raises(ValueError):
166
        assert manager.activate_adapter(3)
167
168
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
169
    assert manager.remove_adapter(model_lora2.id)
170
    assert manager.lora_index_to_id[1] is None
171
172
173
174
    assert not manager.remove_adapter(model_lora2.id)
    assert manager.remove_adapter(model_lora1.id)
    assert not manager.remove_adapter(model_lora1.id)
    assert manager.add_adapter(model_lora1)
175
176
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] is None
177
178
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(3)
179
180
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] is None
181
    assert manager.activate_adapter(2)
182
183
184
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2

185
186
    assert manager.device == device
    assert manager.punica_wrapper.device == device
187
188
189
190
191
192
193
    assert hasattr(manager, "supported_lora_modules")
    assert sorted(manager.supported_lora_modules) == [
        "dense1",
        "dense2",
        "lm_head",
        "output",
    ]
194

195

196
@pytest.mark.parametrize("device", DEVICES)
197
def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
198
    model = dummy_model
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
    model_lora1 = create_lora(1,
                              model, ["layer1.dense1", "dense2", "lm_head"],
                              device=device)
    model_lora2 = create_lora(2,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora3 = create_lora(3,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    manager = LRUCacheLoRAModelManager(model,
                                       2,
                                       2,
                                       2,
                                       LoRAConfig(max_lora_rank=8,
                                                  max_cpu_loras=3,
                                                  max_loras=2),
                                       device=device)
216
    assert all(x is None for x in manager.lora_index_to_id)
217
218
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
219
    assert manager.lora_index_to_id[0] == 1
220
221
222
223
    assert not manager.add_adapter(model_lora1)
    assert not manager.activate_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
224
225
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
226
227
228
    assert not manager.add_adapter(model_lora2)
    assert not manager.activate_adapter(2)
    assert manager.add_adapter(model_lora3)
229
230
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
231
    assert manager.activate_adapter(3)
232
233
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2
234
    assert manager.remove_adapter(model_lora2.id)
235
    assert manager.lora_index_to_id[1] is None
236
237
238
239
240
    assert not manager.remove_adapter(model_lora2.id)
    assert manager.remove_adapter(model_lora1.id)
    assert not manager.remove_adapter(model_lora1.id)
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
241
242
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
243
244
    assert manager.add_adapter(model_lora2)
    assert manager.deactivate_adapter(3)
245
246
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 1
247
    assert manager.activate_adapter(2)
248
249
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
250
    assert manager.activate_adapter(3)
251
252
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 3
253
    assert manager.pin_adapter(2)
254
255
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 3
256
    assert manager.activate_adapter(1)
257
258
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
259
    assert manager.deactivate_adapter(2)
260
261
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 1
262
    assert manager.activate_adapter(3)
263
264
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
265
266
    assert manager.pin_adapter(3)
    assert manager.pin_adapter(1)
267
    with pytest.raises(RuntimeError):
268
        assert manager.pin_adapter(2)
269
270
271
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
    with pytest.raises(RuntimeError):
272
        assert manager.activate_adapter(2)
273

274
275
    assert manager.deactivate_adapter(3)
    assert manager.pin_adapter(2)
276
277
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
278
    assert manager.remove_adapter(3)
279
    with pytest.raises(ValueError):
280
        assert manager.pin_adapter(3)
281

282
283
284
    assert manager.punica_wrapper.device == device
    assert manager.device == device

285

286
@pytest.mark.parametrize("device", DEVICES)
287
def test_lru_lora_model_manager(dist_init, dummy_model, device):
288
289
290
    # This tests just the LRU cache functionality, everything else is
    # tested in test_lora_model_manager
    model = dummy_model
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
    model_lora1 = create_lora(1,
                              model, ["layer1.dense1", "dense2", "lm_head"],
                              device=device)
    model_lora2 = create_lora(2,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora3 = create_lora(3,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora4 = create_lora(4,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    manager = LRUCacheLoRAModelManager(model,
                                       2,
                                       2,
                                       2,
                                       LoRAConfig(max_lora_rank=8,
                                                  max_cpu_loras=2,
                                                  max_loras=2),
                                       device=device)
311
312
313
314

    assert all(x is None for x in manager.lora_index_to_id)

    # Add up to capacity
315
316
317
318
    assert manager.add_adapter(model_lora1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(1)
    assert manager.activate_adapter(2)
319

320
    assert set(manager.list_adapters()) == {1, 2}
321
322
323
324
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2

    # Add over capacity
325
326
327
328
    assert manager.add_adapter(model_lora3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(3)
    assert manager.activate_adapter(4)
329

330
    assert set(manager.list_adapters()) == {3, 4}
331
332
333
334
335
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 4

    # Add 3 again to move it to the top and then add 2
    # should return false since it's in already
336
337
338
339
    assert not manager.add_adapter(model_lora3)
    assert not manager.activate_adapter(3)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
340

341
    assert set(manager.list_adapters()) == {3, 2}
342
343
344
345
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2

    # Remove manually
346
347
    assert manager.remove_adapter(3)
    assert not manager.remove_adapter(3)
348

349
    assert set(manager.list_adapters()) == {2}
350
351
352
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 2

353
354
355
356
    assert manager.add_adapter(model_lora3)
    assert manager.activate_adapter(3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(4)
357

358
    assert set(manager.list_adapters()) == {3, 4}
359
360
361
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 4

362
363
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == {4}
364
365
366
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 4

367
368
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == set()
369
370
    assert all(x is None for x in manager.lora_index_to_id)

371
372
    assert not manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == set()
373
374
    assert all(x is None for x in manager.lora_index_to_id)

375
    # pinning
376
377
378
379
380
    assert manager.add_adapter(model_lora3)
    assert manager.activate_adapter(3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(4)
    assert set(manager.list_adapters()) == {3, 4}
381
    with pytest.raises(ValueError):
382
383
        assert manager.pin_adapter(1)
    assert manager.pin_adapter(3)
384
    # Remove manually
385
386
    assert manager.remove_adapter(3)
    assert not manager.remove_adapter(3)
387

388
    assert set(manager.list_adapters()) == {4}
389
390
391
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 4

392
393
394
395
    assert manager.add_adapter(model_lora1)
    assert manager.pin_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
396

397
    assert set(manager.list_adapters()) == {1, 2}
398
399
400
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2

401
402
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == {1}
403
404
405
406
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] is None

    with pytest.raises(RuntimeError):
407
        assert manager.remove_oldest_adapter()
408

409
    assert set(manager.list_adapters()) == {1}
410
411
    assert manager.punica_wrapper.device == device
    assert manager.device == device
412

413

414
@pytest.mark.skipif(envs.VLLM_USE_V1, reason="Test leverages V0 internals.")
415
@pytest.mark.parametrize("device", DEVICES)
416
def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
417
                                          sql_lora_files, device):
418
    lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
419
    worker_adapter_manager = LRUCacheWorkerLoRAManager(
Terry's avatar
Terry committed
420
        4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
421
        lora_config.lora_extra_vocab_size, lora_config, device,
Terry's avatar
Terry committed
422
        EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
423
424
    worker_adapter_manager.create_lora_manager(
        llama_2_7b_model_extra_embeddings)
425
426

    mapping = LoRAMapping([], [])
427
    worker_adapter_manager.set_active_adapters([
428
429
430
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files)
    ], mapping)
431
432
433
    assert worker_adapter_manager.list_adapters() == {1, 2}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
434

435
    worker_adapter_manager.set_active_adapters([
436
437
438
439
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("3", 3, sql_lora_files),
        LoRARequest("4", 4, sql_lora_files)
    ], mapping)
440
441
442
443
444
    assert worker_adapter_manager.list_adapters() == {1, 2, 3, 4}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 3
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
445

446
    worker_adapter_manager.set_active_adapters([
447
448
449
450
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files),
        LoRARequest("5", 5, sql_lora_files)
    ], mapping)
451
452
453
454
455
    assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
456

457
    worker_adapter_manager.set_active_adapters([
458
459
460
461
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files)
    ], mapping)
462
463
464
465
466
    assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
467

468
    worker_adapter_manager.set_active_adapters([
469
470
471
472
        LoRARequest("6", 6, sql_lora_files),
        LoRARequest("7", 7, sql_lora_files),
        LoRARequest("8", 8, sql_lora_files)
    ], mapping)
473
474
475
476
477
    assert worker_adapter_manager.list_adapters() == {1, 6, 7, 8}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 7
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 8
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 6
478
479
480

    # Over capacity
    with pytest.raises(RuntimeError):
481
        worker_adapter_manager.set_active_adapters([
482
483
484
485
486
487
488
            LoRARequest("10", 10, sql_lora_files),
            LoRARequest("11", 11, sql_lora_files),
            LoRARequest("12", 12, sql_lora_files),
            LoRARequest("13", 13, sql_lora_files),
            LoRARequest("14", 14, sql_lora_files)
        ], mapping)

489
490
491
    assert worker_adapter_manager.device == device
    assert (worker_adapter_manager._adapter_manager.punica_wrapper.device ==
            device)
492

493

494
@pytest.mark.skipif(envs.VLLM_USE_V1, reason="Test leverages V0 internals.")
495
@pytest.mark.parametrize("device", DEVICES)
496
def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
497
                                sql_lora_files, device):
498
499
    # Should remove every LoRA not specified in the request.
    lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
500
    worker_adapter_manager = WorkerLoRAManager(
Terry's avatar
Terry committed
501
        4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
502
        lora_config.lora_extra_vocab_size, lora_config, device,
Terry's avatar
Terry committed
503
        EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
504
505
    worker_adapter_manager.create_lora_manager(
        llama_2_7b_model_extra_embeddings)
506
507

    mapping = LoRAMapping([], [])
508
    worker_adapter_manager.set_active_adapters([
509
510
511
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files)
    ], mapping)
512
513
514
    assert worker_adapter_manager.list_adapters() == {1, 2}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
515

516
    worker_adapter_manager.set_active_adapters([
517
518
519
520
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("3", 3, sql_lora_files),
        LoRARequest("4", 4, sql_lora_files)
    ], mapping)
521
522
523
524
    assert worker_adapter_manager.list_adapters() == {1, 3, 4}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 3
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 4
525

526
    worker_adapter_manager.set_active_adapters([
527
528
529
530
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files),
        LoRARequest("5", 5, sql_lora_files)
    ], mapping)
531
532
533
534
    assert worker_adapter_manager.list_adapters() == {1, 2, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
535

536
    worker_adapter_manager.set_active_adapters([
537
538
539
540
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files)
    ], mapping)
541
542
543
544
    assert worker_adapter_manager.list_adapters() == {1}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] is None
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] is None
545

546
    worker_adapter_manager.set_active_adapters([
547
548
549
550
        LoRARequest("6", 6, sql_lora_files),
        LoRARequest("7", 7, sql_lora_files),
        LoRARequest("8", 8, sql_lora_files)
    ], mapping)
551
552
553
554
    assert worker_adapter_manager.list_adapters() == {6, 7, 8}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 8
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 6
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 7
555
556
557

    # Over capacity
    with pytest.raises(RuntimeError):
558
        worker_adapter_manager.set_active_adapters([
559
560
561
562
563
564
565
            LoRARequest("10", 10, sql_lora_files),
            LoRARequest("11", 11, sql_lora_files),
            LoRARequest("12", 12, sql_lora_files),
            LoRARequest("13", 13, sql_lora_files),
            LoRARequest("14", 14, sql_lora_files)
        ], mapping)

566
567
568
569
    assert worker_adapter_manager.device == device
    assert (worker_adapter_manager._adapter_manager.punica_wrapper.device ==
            device)

570

571
@pytest.mark.parametrize("device", DEVICES)
572
def test_packed_loras(dist_init, dummy_model_gate_up, device):
573
574
575
576
577
    model = dummy_model_gate_up
    model_lora = create_packed_lora(
        1,
        model,
        module_name="gate_up_proj",
578
579
        replaced_module_names=["gate_proj", "up_proj"],
        device=device)
580
581
582
583
584
    model_lora1 = create_packed_lora(
        2,
        model,
        module_name="gate_up_proj",
        replaced_module_names=["gate_proj", "up_proj"],
585
        device=device,
586
587
588
        empty_replaced_module_name="gate_proj",
    )

589
590
591
592
593
594
595
596
    manager = LoRAModelManager(model,
                               2,
                               2,
                               2,
                               LoRAConfig(max_lora_rank=8,
                                          max_cpu_loras=2,
                                          max_loras=2),
                               device=device)
597
598
599
600
    model = manager.model

    assert isinstance(model.get_submodule("gate_up_proj"),
                      MergedColumnParallelLinearWithLoRA)
601
602
603
    # Verify packed lora is correct
    model_lora_clone = model_lora.clone(1)
    model_lora_clone1 = model_lora1.clone(1)
604
605
    assert manager.add_adapter(model_lora)
    assert manager.add_adapter(model_lora1)
606

607
608
609
    assert model_lora.get_lora("gate_proj") is None
    assert model_lora.get_lora("up_proj") is None
    assert model_lora1.get_lora("up_proj") is None
610
611
612
    packed_lora = model_lora.get_lora("gate_up_proj")
    assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights)

613
    torch.testing.assert_close(packed_lora.lora_a[0],
614
                               model_lora_clone.get_lora("gate_proj").lora_a)
615
    torch.testing.assert_close(packed_lora.lora_b[0],
616
                               model_lora_clone.get_lora("gate_proj").lora_b)
617
    torch.testing.assert_close(packed_lora.lora_a[1],
618
                               model_lora_clone.get_lora("up_proj").lora_a)
619
    torch.testing.assert_close(packed_lora.lora_b[1],
620
                               model_lora_clone.get_lora("up_proj").lora_b)
621
622
623
624
625
626

    packed_lora1 = model_lora1.get_lora("gate_up_proj")
    assert packed_lora1 and isinstance(packed_lora1, PackedLoRALayerWeights)

    assert packed_lora1.lora_a[0] is None
    assert packed_lora1.lora_b[0] is None
627
    torch.testing.assert_close(packed_lora1.lora_a[1],
628
                               model_lora_clone1.get_lora("up_proj").lora_a)
629
    torch.testing.assert_close(packed_lora1.lora_b[1],
630
                               model_lora_clone1.get_lora("up_proj").lora_b)