test_lora_manager.py 25.3 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import os
4
from typing import Dict, List
5
6
7
8
9
10
11
12

import pytest
import torch
from safetensors.torch import load_file
from torch import nn

from vllm.config import LoRAConfig
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
13
14
                              MergedColumnParallelLinearWithLoRA,
                              RowParallelLinearWithLoRA)
15
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
16
17
from vllm.lora.models import (LoRAMapping, LoRAModel, LoRAModelManager,
                              LRUCacheLoRAModelManager)
18
from vllm.lora.peft_helper import PEFTHelper
19
20
21
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager,
                                      WorkerLoRAManager)
22
from vllm.platforms import current_platform
23

Terry's avatar
Terry committed
24
25
26
27
28
29
30
EMBEDDING_MODULES = {
    "embed_tokens": "input_embeddings",
    "lm_head": "output_embeddings",
}

EMBEDDING_PADDING_MODULES = ["lm_head"]

31
DEVICES = ([
32
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
33
] if current_platform.is_cuda_alike() else ["cpu"])
34

35

36
@pytest.mark.parametrize("device", DEVICES)
37
def test_from_lora_tensors(sql_lora_files, device):
38
39
40
41
    tensors = load_file(
        os.path.join(sql_lora_files, "adapter_model.safetensors"))
    new_embeddings = load_file(
        os.path.join(sql_lora_files, "new_embeddings.safetensors"))
42

43
44
    peft_helper = PEFTHelper.from_local_dir(sql_lora_files,
                                            max_position_embeddings=4096)
Terry's avatar
Terry committed
45
46
47
    lora_model = LoRAModel.from_lora_tensors(
        1,
        tensors,
48
49
        peft_helper=peft_helper,
        device=device,
Terry's avatar
Terry committed
50
51
52
        embeddings=new_embeddings,
        embedding_modules=EMBEDDING_MODULES,
        embedding_padding_modules=EMBEDDING_PADDING_MODULES)
53
54
55
56
57
58
    for module_name, lora in lora_model.loras.items():
        assert lora.module_name == module_name
        assert lora.rank == 8
        assert lora.lora_alpha == 16
        assert lora.lora_a is not None
        assert lora.lora_b is not None
59
60
        assert lora.lora_a.device == torch.device(device)
        assert lora.lora_b.device == torch.device(device)
61
62
63
64
65
66
67
68
69
70
71
72
73
74
        assert (lora.lora_a.shape[1] == lora.lora_b.shape[0]
                ), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}"
        assert lora.lora_a.shape[1] == 8
        embeddings_module = next(
            (k for k in EMBEDDING_MODULES if k in module_name), None)
        if embeddings_module:
            assert torch.equal(
                lora.embeddings_tensor,
                new_embeddings[EMBEDDING_MODULES[embeddings_module]].to(
                    device=lora.embeddings_tensor.device))
        else:
            assert lora.embeddings_tensor is None


75
76
def create_lora(lora_id: int, model: nn.Module, sub_modules: List[str],
                device: torch.device) -> LoRAModel:
77
    loras: Dict[str, LoRALayerWeights] = {}
78
79
80
81
82
83
    for name in sub_modules:
        w = model.get_submodule(name).weight
        loras[name] = LoRALayerWeights(
            name,
            8,
            16,
84
85
            torch.rand([w.shape[1], 8], device=device),
            torch.rand([8, w.shape[0]], device=device),
86
87
88
89
90
91
92
93
94
        )
    return LoRAModel(lora_id, 8, loras)


def create_packed_lora(
    lora_id: int,
    model: nn.Module,
    module_name,
    replaced_module_names,
95
    device: torch.device,
96
97
98
    empty_replaced_module_name=None,
) -> LoRAModel:
    w = model.get_submodule(module_name).weight
99
    loras: Dict[str, LoRALayerWeights] = {}
100
101
102
103
104
105
106
    for replaced_module_name in replaced_module_names:
        if replaced_module_name == empty_replaced_module_name:
            continue
        loras[replaced_module_name] = LoRALayerWeights(
            replaced_module_name,
            8,
            16,
107
            torch.rand([w.shape[1], 8], device=device),
108
            torch.rand([8, w.shape[0] // len(replaced_module_names)],
109
                       device=device),
110
111
112
113
114
115
        )
    return LoRAModel(lora_id, 8, loras)


def test_replace_submodules(dist_init, dummy_model):
    model = dummy_model
Terry's avatar
Terry committed
116
117
    manager = LoRAModelManager(
        model, 1, 1, 1,
118
        LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8),
119
        torch.device(DEVICES[0]))
120
121
122
123
124
    model = manager.model
    assert isinstance(model.get_submodule("dense1"),
                      ColumnParallelLinearWithLoRA)
    assert isinstance(model.get_submodule("layer1.dense1"),
                      ColumnParallelLinearWithLoRA)
125
    assert isinstance(model.get_submodule("dense2"), RowParallelLinearWithLoRA)
126
127
128
129
    assert isinstance(model.get_submodule("layer1.dense2"),
                      RowParallelLinearWithLoRA)


130
@pytest.mark.parametrize("device", DEVICES)
131
def test_lora_model_manager(dist_init, dummy_model, device):
132
    model = dummy_model
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
    model_lora1 = create_lora(1,
                              model, ["layer1.dense1", "dense2", "lm_head"],
                              device=device)
    model_lora2 = create_lora(2,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora3 = create_lora(3,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    manager = LoRAModelManager(model,
                               2,
                               2,
                               2,
                               LoRAConfig(max_lora_rank=8,
                                          max_cpu_loras=3,
                                          max_loras=2),
                               device=device)
150
    assert all(x is None for x in manager.lora_index_to_id)
151
152
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
153
    assert manager.lora_index_to_id[0] == 1
154
155
156
157
    assert not manager.add_adapter(model_lora1)
    assert not manager.activate_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
158
159
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
160
161
162
    assert not manager.add_adapter(model_lora2)
    assert not manager.activate_adapter(2)
    assert manager.add_adapter(model_lora3)
163
164
165
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
    with pytest.raises(ValueError):
166
        assert manager.activate_adapter(3)
167
168
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
169
    assert manager.remove_adapter(model_lora2.id)
170
    assert manager.lora_index_to_id[1] is None
171
172
173
174
    assert not manager.remove_adapter(model_lora2.id)
    assert manager.remove_adapter(model_lora1.id)
    assert not manager.remove_adapter(model_lora1.id)
    assert manager.add_adapter(model_lora1)
175
176
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] is None
177
178
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(3)
179
180
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] is None
181
    assert manager.activate_adapter(2)
182
183
184
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2

185
186
    assert manager.device == device
    assert manager.punica_wrapper.device == device
187
188
189
190
191
192
193
    assert hasattr(manager, "supported_lora_modules")
    assert sorted(manager.supported_lora_modules) == [
        "dense1",
        "dense2",
        "lm_head",
        "output",
    ]
194

195

196
@pytest.mark.parametrize("device", DEVICES)
197
def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
198
    model = dummy_model
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
    model_lora1 = create_lora(1,
                              model, ["layer1.dense1", "dense2", "lm_head"],
                              device=device)
    model_lora2 = create_lora(2,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora3 = create_lora(3,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    manager = LRUCacheLoRAModelManager(model,
                                       2,
                                       2,
                                       2,
                                       LoRAConfig(max_lora_rank=8,
                                                  max_cpu_loras=3,
                                                  max_loras=2),
                                       device=device)
216
    assert all(x is None for x in manager.lora_index_to_id)
217
218
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
219
    assert manager.lora_index_to_id[0] == 1
220
221
222
223
    assert not manager.add_adapter(model_lora1)
    assert not manager.activate_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
224
225
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
226
227
228
    assert not manager.add_adapter(model_lora2)
    assert not manager.activate_adapter(2)
    assert manager.add_adapter(model_lora3)
229
230
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
231
    assert manager.activate_adapter(3)
232
233
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2
234
    assert manager.remove_adapter(model_lora2.id)
235
    assert manager.lora_index_to_id[1] is None
236
237
238
239
240
    assert not manager.remove_adapter(model_lora2.id)
    assert manager.remove_adapter(model_lora1.id)
    assert not manager.remove_adapter(model_lora1.id)
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
241
242
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
243
244
    assert manager.add_adapter(model_lora2)
    assert manager.deactivate_adapter(3)
245
246
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 1
247
    assert manager.activate_adapter(2)
248
249
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
250
    assert manager.activate_adapter(3)
251
252
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 3
253
    assert manager.pin_adapter(2)
254
255
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 3
256
    assert manager.activate_adapter(1)
257
258
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
259
    assert manager.deactivate_adapter(2)
260
261
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 1
262
    assert manager.activate_adapter(3)
263
264
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
265
266
    assert manager.pin_adapter(3)
    assert manager.pin_adapter(1)
267
    with pytest.raises(RuntimeError):
268
        assert manager.pin_adapter(2)
269
270
271
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
    with pytest.raises(RuntimeError):
272
        assert manager.activate_adapter(2)
273

274
275
    assert manager.deactivate_adapter(3)
    assert manager.pin_adapter(2)
276
277
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
278
    assert manager.remove_adapter(3)
279
    with pytest.raises(ValueError):
280
        assert manager.pin_adapter(3)
281

282
283
284
    assert manager.punica_wrapper.device == device
    assert manager.device == device

285

286
@pytest.mark.parametrize("device", DEVICES)
287
def test_lru_lora_model_manager(dist_init, dummy_model, device):
288
289
290
    # This tests just the LRU cache functionality, everything else is
    # tested in test_lora_model_manager
    model = dummy_model
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
    model_lora1 = create_lora(1,
                              model, ["layer1.dense1", "dense2", "lm_head"],
                              device=device)
    model_lora2 = create_lora(2,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora3 = create_lora(3,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora4 = create_lora(4,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    manager = LRUCacheLoRAModelManager(model,
                                       2,
                                       2,
                                       2,
                                       LoRAConfig(max_lora_rank=8,
                                                  max_cpu_loras=2,
                                                  max_loras=2),
                                       device=device)
311
312
313
314

    assert all(x is None for x in manager.lora_index_to_id)

    # Add up to capacity
315
316
317
318
    assert manager.add_adapter(model_lora1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(1)
    assert manager.activate_adapter(2)
319

320
    assert set(manager.list_adapters()) == {1, 2}
321
322
323
324
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2

    # Add over capacity
325
326
327
328
    assert manager.add_adapter(model_lora3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(3)
    assert manager.activate_adapter(4)
329

330
    assert set(manager.list_adapters()) == {3, 4}
331
332
333
334
335
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 4

    # Add 3 again to move it to the top and then add 2
    # should return false since it's in already
336
337
338
339
    assert not manager.add_adapter(model_lora3)
    assert not manager.activate_adapter(3)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
340

341
    assert set(manager.list_adapters()) == {3, 2}
342
343
344
345
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2

    # Remove manually
346
347
    assert manager.remove_adapter(3)
    assert not manager.remove_adapter(3)
348

349
    assert set(manager.list_adapters()) == {2}
350
351
352
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 2

353
354
355
356
    assert manager.add_adapter(model_lora3)
    assert manager.activate_adapter(3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(4)
357

358
    assert set(manager.list_adapters()) == {3, 4}
359
360
361
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 4

362
363
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == {4}
364
365
366
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 4

367
368
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == set()
369
370
    assert all(x is None for x in manager.lora_index_to_id)

371
372
    assert not manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == set()
373
374
    assert all(x is None for x in manager.lora_index_to_id)

375
    # pinning
376
377
378
379
380
    assert manager.add_adapter(model_lora3)
    assert manager.activate_adapter(3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(4)
    assert set(manager.list_adapters()) == {3, 4}
381
    with pytest.raises(ValueError):
382
383
        assert manager.pin_adapter(1)
    assert manager.pin_adapter(3)
384
    # Remove manually
385
386
    assert manager.remove_adapter(3)
    assert not manager.remove_adapter(3)
387

388
    assert set(manager.list_adapters()) == {4}
389
390
391
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 4

392
393
394
395
    assert manager.add_adapter(model_lora1)
    assert manager.pin_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
396

397
    assert set(manager.list_adapters()) == {1, 2}
398
399
400
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2

401
402
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == {1}
403
404
405
406
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] is None

    with pytest.raises(RuntimeError):
407
        assert manager.remove_oldest_adapter()
408

409
    assert set(manager.list_adapters()) == {1}
410
411
    assert manager.punica_wrapper.device == device
    assert manager.device == device
412

413

414
@pytest.mark.parametrize("device", DEVICES)
415
def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
416
                                          sql_lora_files, device):
417
    lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
418
    worker_adapter_manager = LRUCacheWorkerLoRAManager(
Terry's avatar
Terry committed
419
        4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
420
        lora_config.lora_extra_vocab_size, lora_config, device,
Terry's avatar
Terry committed
421
        EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
422
423
    worker_adapter_manager.create_lora_manager(
        llama_2_7b_model_extra_embeddings)
424
425

    mapping = LoRAMapping([], [])
426
    worker_adapter_manager.set_active_adapters([
427
428
429
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files)
    ], mapping)
430
431
432
    assert worker_adapter_manager.list_adapters() == {1, 2}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
433

434
    worker_adapter_manager.set_active_adapters([
435
436
437
438
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("3", 3, sql_lora_files),
        LoRARequest("4", 4, sql_lora_files)
    ], mapping)
439
440
441
442
443
    assert worker_adapter_manager.list_adapters() == {1, 2, 3, 4}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 3
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
444

445
    worker_adapter_manager.set_active_adapters([
446
447
448
449
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files),
        LoRARequest("5", 5, sql_lora_files)
    ], mapping)
450
451
452
453
454
    assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
455

456
    worker_adapter_manager.set_active_adapters([
457
458
459
460
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files)
    ], mapping)
461
462
463
464
465
    assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
466

467
    worker_adapter_manager.set_active_adapters([
468
469
470
471
        LoRARequest("6", 6, sql_lora_files),
        LoRARequest("7", 7, sql_lora_files),
        LoRARequest("8", 8, sql_lora_files)
    ], mapping)
472
473
474
475
476
    assert worker_adapter_manager.list_adapters() == {1, 6, 7, 8}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 7
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 8
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 6
477
478
479

    # Over capacity
    with pytest.raises(RuntimeError):
480
        worker_adapter_manager.set_active_adapters([
481
482
483
484
485
486
487
            LoRARequest("10", 10, sql_lora_files),
            LoRARequest("11", 11, sql_lora_files),
            LoRARequest("12", 12, sql_lora_files),
            LoRARequest("13", 13, sql_lora_files),
            LoRARequest("14", 14, sql_lora_files)
        ], mapping)

488
489
490
    assert worker_adapter_manager.device == device
    assert (worker_adapter_manager._adapter_manager.punica_wrapper.device ==
            device)
491

492

493
@pytest.mark.parametrize("device", DEVICES)
494
def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
495
                                sql_lora_files, device):
496
497
    # Should remove every LoRA not specified in the request.
    lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
498
    worker_adapter_manager = WorkerLoRAManager(
Terry's avatar
Terry committed
499
        4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
500
        lora_config.lora_extra_vocab_size, lora_config, device,
Terry's avatar
Terry committed
501
        EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
502
503
    worker_adapter_manager.create_lora_manager(
        llama_2_7b_model_extra_embeddings)
504
505

    mapping = LoRAMapping([], [])
506
    worker_adapter_manager.set_active_adapters([
507
508
509
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files)
    ], mapping)
510
511
512
    assert worker_adapter_manager.list_adapters() == {1, 2}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
513

514
    worker_adapter_manager.set_active_adapters([
515
516
517
518
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("3", 3, sql_lora_files),
        LoRARequest("4", 4, sql_lora_files)
    ], mapping)
519
520
521
522
    assert worker_adapter_manager.list_adapters() == {1, 3, 4}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 3
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 4
523

524
    worker_adapter_manager.set_active_adapters([
525
526
527
528
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files),
        LoRARequest("5", 5, sql_lora_files)
    ], mapping)
529
530
531
532
    assert worker_adapter_manager.list_adapters() == {1, 2, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
533

534
    worker_adapter_manager.set_active_adapters([
535
536
537
538
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files)
    ], mapping)
539
540
541
542
    assert worker_adapter_manager.list_adapters() == {1}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] is None
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] is None
543

544
    worker_adapter_manager.set_active_adapters([
545
546
547
548
        LoRARequest("6", 6, sql_lora_files),
        LoRARequest("7", 7, sql_lora_files),
        LoRARequest("8", 8, sql_lora_files)
    ], mapping)
549
550
551
552
    assert worker_adapter_manager.list_adapters() == {6, 7, 8}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 8
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 6
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 7
553
554
555

    # Over capacity
    with pytest.raises(RuntimeError):
556
        worker_adapter_manager.set_active_adapters([
557
558
559
560
561
562
563
            LoRARequest("10", 10, sql_lora_files),
            LoRARequest("11", 11, sql_lora_files),
            LoRARequest("12", 12, sql_lora_files),
            LoRARequest("13", 13, sql_lora_files),
            LoRARequest("14", 14, sql_lora_files)
        ], mapping)

564
565
566
567
    assert worker_adapter_manager.device == device
    assert (worker_adapter_manager._adapter_manager.punica_wrapper.device ==
            device)

568

569
@pytest.mark.parametrize("device", DEVICES)
570
def test_packed_loras(dist_init, dummy_model_gate_up, device):
571
572
573
574
575
    model = dummy_model_gate_up
    model_lora = create_packed_lora(
        1,
        model,
        module_name="gate_up_proj",
576
577
        replaced_module_names=["gate_proj", "up_proj"],
        device=device)
578
579
580
581
582
    model_lora1 = create_packed_lora(
        2,
        model,
        module_name="gate_up_proj",
        replaced_module_names=["gate_proj", "up_proj"],
583
        device=device,
584
585
586
        empty_replaced_module_name="gate_proj",
    )

587
588
589
590
591
592
593
594
    manager = LoRAModelManager(model,
                               2,
                               2,
                               2,
                               LoRAConfig(max_lora_rank=8,
                                          max_cpu_loras=2,
                                          max_loras=2),
                               device=device)
595
596
597
598
    model = manager.model

    assert isinstance(model.get_submodule("gate_up_proj"),
                      MergedColumnParallelLinearWithLoRA)
599
600
601
    # Verify packed lora is correct
    model_lora_clone = model_lora.clone(1)
    model_lora_clone1 = model_lora1.clone(1)
602
603
    assert manager.add_adapter(model_lora)
    assert manager.add_adapter(model_lora1)
604

605
606
607
    assert model_lora.get_lora("gate_proj") is None
    assert model_lora.get_lora("up_proj") is None
    assert model_lora1.get_lora("up_proj") is None
608
609
610
    packed_lora = model_lora.get_lora("gate_up_proj")
    assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights)

611
    torch.testing.assert_close(packed_lora.lora_a[0],
612
                               model_lora_clone.get_lora("gate_proj").lora_a)
613
    torch.testing.assert_close(packed_lora.lora_b[0],
614
                               model_lora_clone.get_lora("gate_proj").lora_b)
615
    torch.testing.assert_close(packed_lora.lora_a[1],
616
                               model_lora_clone.get_lora("up_proj").lora_a)
617
    torch.testing.assert_close(packed_lora.lora_b[1],
618
                               model_lora_clone.get_lora("up_proj").lora_b)
619
620
621
622
623
624

    packed_lora1 = model_lora1.get_lora("gate_up_proj")
    assert packed_lora1 and isinstance(packed_lora1, PackedLoRALayerWeights)

    assert packed_lora1.lora_a[0] is None
    assert packed_lora1.lora_b[0] is None
625
    torch.testing.assert_close(packed_lora1.lora_a[1],
626
                               model_lora_clone1.get_lora("up_proj").lora_a)
627
    torch.testing.assert_close(packed_lora1.lora_b[1],
628
                               model_lora_clone1.get_lora("up_proj").lora_b)