test_lora_manager.py 26.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
import os

import pytest
import torch
from safetensors.torch import load_file
from torch import nn

from vllm.config import LoRAConfig
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
13
14
                              MergedColumnParallelLinearWithLoRA,
                              RowParallelLinearWithLoRA)
15
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
16
17
from vllm.lora.models import (LoRAMapping, LoRAModel, LoRAModelManager,
                              LRUCacheLoRAModelManager)
18
from vllm.lora.peft_helper import PEFTHelper
19
20
21
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager,
                                      WorkerLoRAManager)
22
from vllm.platforms import current_platform
23

Terry's avatar
Terry committed
24
25
26
27
28
29
30
EMBEDDING_MODULES = {
    "embed_tokens": "input_embeddings",
    "lm_head": "output_embeddings",
}

EMBEDDING_PADDING_MODULES = ["lm_head"]

31
DEVICES = ([
32
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
33
] if current_platform.is_cuda_alike() else ["cpu"])
34

35
36
DEFAULT_DTYPE = torch.get_default_dtype()

37

38
39
40
41
42
43
44
45
46
47
48
@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch: pytest.MonkeyPatch):
    """
    Some tests depend on V0 internals. Since both V0 and V1 use the same
    LoRAModelManager it is okay to just test V0.
    """
    with monkeypatch.context() as m:
        m.setenv('VLLM_USE_V1', '0')
        yield


49
@pytest.mark.parametrize("device", DEVICES)
50
def test_from_lora_tensors(sql_lora_files, device):
51
52
53
54
    tensors = load_file(
        os.path.join(sql_lora_files, "adapter_model.safetensors"))
    new_embeddings = load_file(
        os.path.join(sql_lora_files, "new_embeddings.safetensors"))
55

56
57
    peft_helper = PEFTHelper.from_local_dir(sql_lora_files,
                                            max_position_embeddings=4096)
Terry's avatar
Terry committed
58
59
60
    lora_model = LoRAModel.from_lora_tensors(
        1,
        tensors,
61
62
        peft_helper=peft_helper,
        device=device,
Terry's avatar
Terry committed
63
64
65
        embeddings=new_embeddings,
        embedding_modules=EMBEDDING_MODULES,
        embedding_padding_modules=EMBEDDING_PADDING_MODULES)
66
67
68
69
70
71
    for module_name, lora in lora_model.loras.items():
        assert lora.module_name == module_name
        assert lora.rank == 8
        assert lora.lora_alpha == 16
        assert lora.lora_a is not None
        assert lora.lora_b is not None
72
73
        assert lora.lora_a.device == torch.device(device)
        assert lora.lora_b.device == torch.device(device)
74
75
76
77
78
79
80
81
82
83
84
85
86
87
        assert (lora.lora_a.shape[1] == lora.lora_b.shape[0]
                ), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}"
        assert lora.lora_a.shape[1] == 8
        embeddings_module = next(
            (k for k in EMBEDDING_MODULES if k in module_name), None)
        if embeddings_module:
            assert torch.equal(
                lora.embeddings_tensor,
                new_embeddings[EMBEDDING_MODULES[embeddings_module]].to(
                    device=lora.embeddings_tensor.device))
        else:
            assert lora.embeddings_tensor is None


88
def create_lora(lora_id: int, model: nn.Module, sub_modules: list[str],
89
                device: torch.device) -> LoRAModel:
90
    loras: dict[str, LoRALayerWeights] = {}
91
92
93
94
95
96
    for name in sub_modules:
        w = model.get_submodule(name).weight
        loras[name] = LoRALayerWeights(
            name,
            8,
            16,
97
98
            torch.rand([w.shape[1], 8], device=device),
            torch.rand([8, w.shape[0]], device=device),
99
100
101
102
103
104
105
106
107
        )
    return LoRAModel(lora_id, 8, loras)


def create_packed_lora(
    lora_id: int,
    model: nn.Module,
    module_name,
    replaced_module_names,
108
    device: torch.device,
109
110
111
    empty_replaced_module_name=None,
) -> LoRAModel:
    w = model.get_submodule(module_name).weight
112
    loras: dict[str, LoRALayerWeights] = {}
113
114
115
116
117
118
119
    for replaced_module_name in replaced_module_names:
        if replaced_module_name == empty_replaced_module_name:
            continue
        loras[replaced_module_name] = LoRALayerWeights(
            replaced_module_name,
            8,
            16,
120
            torch.rand([w.shape[1], 8], device=device),
121
            torch.rand([8, w.shape[0] // len(replaced_module_names)],
122
                       device=device),
123
124
125
126
127
128
        )
    return LoRAModel(lora_id, 8, loras)


def test_replace_submodules(dist_init, dummy_model):
    model = dummy_model
Terry's avatar
Terry committed
129
130
    manager = LoRAModelManager(
        model, 1, 1, 1,
131
132
133
134
        LoRAConfig(max_lora_rank=8,
                   max_cpu_loras=8,
                   max_loras=8,
                   lora_dtype=DEFAULT_DTYPE), torch.device(DEVICES[0]))
135
136
137
138
139
    model = manager.model
    assert isinstance(model.get_submodule("dense1"),
                      ColumnParallelLinearWithLoRA)
    assert isinstance(model.get_submodule("layer1.dense1"),
                      ColumnParallelLinearWithLoRA)
140
    assert isinstance(model.get_submodule("dense2"), RowParallelLinearWithLoRA)
141
142
143
144
    assert isinstance(model.get_submodule("layer1.dense2"),
                      RowParallelLinearWithLoRA)


145
@pytest.mark.parametrize("device", DEVICES)
146
def test_lora_model_manager(dist_init, dummy_model, device):
147
    model = dummy_model
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
    model_lora1 = create_lora(1,
                              model, ["layer1.dense1", "dense2", "lm_head"],
                              device=device)
    model_lora2 = create_lora(2,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora3 = create_lora(3,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    manager = LoRAModelManager(model,
                               2,
                               2,
                               2,
                               LoRAConfig(max_lora_rank=8,
                                          max_cpu_loras=3,
163
164
                                          max_loras=2,
                                          lora_dtype=DEFAULT_DTYPE),
165
                               device=device)
166
    assert all(x is None for x in manager.lora_index_to_id)
167
168
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
169
    assert manager.lora_index_to_id[0] == 1
170
171
172
173
    assert not manager.add_adapter(model_lora1)
    assert not manager.activate_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
174
175
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
176
177
178
    assert not manager.add_adapter(model_lora2)
    assert not manager.activate_adapter(2)
    assert manager.add_adapter(model_lora3)
179
180
181
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
    with pytest.raises(ValueError):
182
        assert manager.activate_adapter(3)
183
184
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
185
    assert manager.remove_adapter(model_lora2.id)
186
    assert manager.lora_index_to_id[1] is None
187
188
189
190
    assert not manager.remove_adapter(model_lora2.id)
    assert manager.remove_adapter(model_lora1.id)
    assert not manager.remove_adapter(model_lora1.id)
    assert manager.add_adapter(model_lora1)
191
192
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] is None
193
194
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(3)
195
196
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] is None
197
    assert manager.activate_adapter(2)
198
199
200
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2

201
202
    assert manager.device == device
    assert manager.punica_wrapper.device == device
203
204
205
206
207
208
209
    assert hasattr(manager, "supported_lora_modules")
    assert sorted(manager.supported_lora_modules) == [
        "dense1",
        "dense2",
        "lm_head",
        "output",
    ]
210

211

212
@pytest.mark.parametrize("device", DEVICES)
213
def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
214
    model = dummy_model
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
    model_lora1 = create_lora(1,
                              model, ["layer1.dense1", "dense2", "lm_head"],
                              device=device)
    model_lora2 = create_lora(2,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora3 = create_lora(3,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    manager = LRUCacheLoRAModelManager(model,
                                       2,
                                       2,
                                       2,
                                       LoRAConfig(max_lora_rank=8,
                                                  max_cpu_loras=3,
230
231
                                                  max_loras=2,
                                                  lora_dtype=DEFAULT_DTYPE),
232
                                       device=device)
233
    assert all(x is None for x in manager.lora_index_to_id)
234
235
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
236
    assert manager.lora_index_to_id[0] == 1
237
238
239
240
    assert not manager.add_adapter(model_lora1)
    assert not manager.activate_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
241
242
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
243
244
245
    assert not manager.add_adapter(model_lora2)
    assert not manager.activate_adapter(2)
    assert manager.add_adapter(model_lora3)
246
247
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
248
    assert manager.activate_adapter(3)
249
250
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2
251
    assert manager.remove_adapter(model_lora2.id)
252
    assert manager.lora_index_to_id[1] is None
253
254
255
256
257
    assert not manager.remove_adapter(model_lora2.id)
    assert manager.remove_adapter(model_lora1.id)
    assert not manager.remove_adapter(model_lora1.id)
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
258
259
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
260
261
    assert manager.add_adapter(model_lora2)
    assert manager.deactivate_adapter(3)
262
263
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 1
264
    assert manager.activate_adapter(2)
265
266
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
267
    assert manager.activate_adapter(3)
268
269
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 3
270
    assert manager.pin_adapter(2)
271
272
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 3
273
    assert manager.activate_adapter(1)
274
275
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
276
    assert manager.deactivate_adapter(2)
277
278
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 1
279
    assert manager.activate_adapter(3)
280
281
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
282
283
    assert manager.pin_adapter(3)
    assert manager.pin_adapter(1)
284
    with pytest.raises(RuntimeError):
285
        assert manager.pin_adapter(2)
286
287
288
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
    with pytest.raises(RuntimeError):
289
        assert manager.activate_adapter(2)
290

291
292
    assert manager.deactivate_adapter(3)
    assert manager.pin_adapter(2)
293
294
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
295
    assert manager.remove_adapter(3)
296
    with pytest.raises(ValueError):
297
        assert manager.pin_adapter(3)
298

299
300
301
    assert manager.punica_wrapper.device == device
    assert manager.device == device

302

303
@pytest.mark.parametrize("device", DEVICES)
304
def test_lru_lora_model_manager(dist_init, dummy_model, device):
305
306
307
    # This tests just the LRU cache functionality, everything else is
    # tested in test_lora_model_manager
    model = dummy_model
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
    model_lora1 = create_lora(1,
                              model, ["layer1.dense1", "dense2", "lm_head"],
                              device=device)
    model_lora2 = create_lora(2,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora3 = create_lora(3,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora4 = create_lora(4,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    manager = LRUCacheLoRAModelManager(model,
                                       2,
                                       2,
                                       2,
                                       LoRAConfig(max_lora_rank=8,
                                                  max_cpu_loras=2,
326
327
                                                  max_loras=2,
                                                  lora_dtype=DEFAULT_DTYPE),
328
                                       device=device)
329
330
331
332

    assert all(x is None for x in manager.lora_index_to_id)

    # Add up to capacity
333
334
335
336
    assert manager.add_adapter(model_lora1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(1)
    assert manager.activate_adapter(2)
337

338
    assert set(manager.list_adapters()) == {1, 2}
339
340
341
342
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2

    # Add over capacity
343
344
345
346
    assert manager.add_adapter(model_lora3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(3)
    assert manager.activate_adapter(4)
347

348
    assert set(manager.list_adapters()) == {3, 4}
349
350
351
352
353
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 4

    # Add 3 again to move it to the top and then add 2
    # should return false since it's in already
354
355
356
357
    assert not manager.add_adapter(model_lora3)
    assert not manager.activate_adapter(3)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
358

359
    assert set(manager.list_adapters()) == {3, 2}
360
361
362
363
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2

    # Remove manually
364
365
    assert manager.remove_adapter(3)
    assert not manager.remove_adapter(3)
366

367
    assert set(manager.list_adapters()) == {2}
368
369
370
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 2

371
372
373
374
    assert manager.add_adapter(model_lora3)
    assert manager.activate_adapter(3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(4)
375

376
    assert set(manager.list_adapters()) == {3, 4}
377
378
379
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 4

380
381
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == {4}
382
383
384
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 4

385
386
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == set()
387
388
    assert all(x is None for x in manager.lora_index_to_id)

389
390
    assert not manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == set()
391
392
    assert all(x is None for x in manager.lora_index_to_id)

393
    # pinning
394
395
396
397
398
    assert manager.add_adapter(model_lora3)
    assert manager.activate_adapter(3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(4)
    assert set(manager.list_adapters()) == {3, 4}
399
    with pytest.raises(ValueError):
400
401
        assert manager.pin_adapter(1)
    assert manager.pin_adapter(3)
402
    # Remove manually
403
404
    assert manager.remove_adapter(3)
    assert not manager.remove_adapter(3)
405

406
    assert set(manager.list_adapters()) == {4}
407
408
409
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 4

410
411
412
413
    assert manager.add_adapter(model_lora1)
    assert manager.pin_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
414

415
    assert set(manager.list_adapters()) == {1, 2}
416
417
418
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2

419
420
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == {1}
421
422
423
424
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] is None

    with pytest.raises(RuntimeError):
425
        assert manager.remove_oldest_adapter()
426

427
    assert set(manager.list_adapters()) == {1}
428
429
    assert manager.punica_wrapper.device == device
    assert manager.device == device
430

431

432
@pytest.mark.parametrize("device", DEVICES)
433
def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
434
                                          sql_lora_files, device):
435
436
437
438
    lora_config = LoRAConfig(max_lora_rank=8,
                             max_cpu_loras=4,
                             max_loras=4,
                             lora_dtype=DEFAULT_DTYPE)
439
    worker_adapter_manager = LRUCacheWorkerLoRAManager(
Terry's avatar
Terry committed
440
        4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
441
        lora_config.lora_extra_vocab_size, lora_config, device,
Terry's avatar
Terry committed
442
        EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
443
444
    worker_adapter_manager.create_lora_manager(
        llama_2_7b_model_extra_embeddings)
445
446

    mapping = LoRAMapping([], [])
447
    worker_adapter_manager.set_active_adapters([
448
449
450
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files)
    ], mapping)
451
452
453
    assert worker_adapter_manager.list_adapters() == {1, 2}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
454

455
    worker_adapter_manager.set_active_adapters([
456
457
458
459
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("3", 3, sql_lora_files),
        LoRARequest("4", 4, sql_lora_files)
    ], mapping)
460
461
462
463
464
    assert worker_adapter_manager.list_adapters() == {1, 2, 3, 4}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 3
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
465

466
    worker_adapter_manager.set_active_adapters([
467
468
469
470
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files),
        LoRARequest("5", 5, sql_lora_files)
    ], mapping)
471
472
473
474
475
    assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
476

477
    worker_adapter_manager.set_active_adapters([
478
479
480
481
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files)
    ], mapping)
482
483
484
485
486
    assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
487

488
    worker_adapter_manager.set_active_adapters([
489
490
491
492
        LoRARequest("6", 6, sql_lora_files),
        LoRARequest("7", 7, sql_lora_files),
        LoRARequest("8", 8, sql_lora_files)
    ], mapping)
493
494
495
496
497
    assert worker_adapter_manager.list_adapters() == {1, 6, 7, 8}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 7
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 8
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 6
498
499
500

    # Over capacity
    with pytest.raises(RuntimeError):
501
        worker_adapter_manager.set_active_adapters([
502
503
504
505
506
507
508
            LoRARequest("10", 10, sql_lora_files),
            LoRARequest("11", 11, sql_lora_files),
            LoRARequest("12", 12, sql_lora_files),
            LoRARequest("13", 13, sql_lora_files),
            LoRARequest("14", 14, sql_lora_files)
        ], mapping)

509
510
511
    assert worker_adapter_manager.device == device
    assert (worker_adapter_manager._adapter_manager.punica_wrapper.device ==
            device)
512

513

514
@pytest.mark.parametrize("device", DEVICES)
515
def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
516
                                sql_lora_files, device):
517
    # Should remove every LoRA not specified in the request.
518
519
520
521
    lora_config = LoRAConfig(max_lora_rank=8,
                             max_cpu_loras=4,
                             max_loras=4,
                             lora_dtype=DEFAULT_DTYPE)
522
    worker_adapter_manager = WorkerLoRAManager(
Terry's avatar
Terry committed
523
        4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
524
        lora_config.lora_extra_vocab_size, lora_config, device,
Terry's avatar
Terry committed
525
        EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
526
527
    worker_adapter_manager.create_lora_manager(
        llama_2_7b_model_extra_embeddings)
528
529

    mapping = LoRAMapping([], [])
530
    worker_adapter_manager.set_active_adapters([
531
532
533
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files)
    ], mapping)
534
535
536
    assert worker_adapter_manager.list_adapters() == {1, 2}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
537

538
    worker_adapter_manager.set_active_adapters([
539
540
541
542
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("3", 3, sql_lora_files),
        LoRARequest("4", 4, sql_lora_files)
    ], mapping)
543
544
545
546
    assert worker_adapter_manager.list_adapters() == {1, 3, 4}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 3
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 4
547

548
    worker_adapter_manager.set_active_adapters([
549
550
551
552
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files),
        LoRARequest("5", 5, sql_lora_files)
    ], mapping)
553
554
555
556
    assert worker_adapter_manager.list_adapters() == {1, 2, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
557

558
    worker_adapter_manager.set_active_adapters([
559
560
561
562
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files)
    ], mapping)
563
564
565
566
    assert worker_adapter_manager.list_adapters() == {1}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] is None
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] is None
567

568
    worker_adapter_manager.set_active_adapters([
569
570
571
572
        LoRARequest("6", 6, sql_lora_files),
        LoRARequest("7", 7, sql_lora_files),
        LoRARequest("8", 8, sql_lora_files)
    ], mapping)
573
574
575
576
    assert worker_adapter_manager.list_adapters() == {6, 7, 8}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 8
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 6
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 7
577
578
579

    # Over capacity
    with pytest.raises(RuntimeError):
580
        worker_adapter_manager.set_active_adapters([
581
582
583
584
585
586
587
            LoRARequest("10", 10, sql_lora_files),
            LoRARequest("11", 11, sql_lora_files),
            LoRARequest("12", 12, sql_lora_files),
            LoRARequest("13", 13, sql_lora_files),
            LoRARequest("14", 14, sql_lora_files)
        ], mapping)

588
589
590
591
    assert worker_adapter_manager.device == device
    assert (worker_adapter_manager._adapter_manager.punica_wrapper.device ==
            device)

592

593
@pytest.mark.parametrize("device", DEVICES)
594
def test_packed_loras(dist_init, dummy_model_gate_up, device):
595
596
597
598
599
    model = dummy_model_gate_up
    model_lora = create_packed_lora(
        1,
        model,
        module_name="gate_up_proj",
600
601
        replaced_module_names=["gate_proj", "up_proj"],
        device=device)
602
603
604
605
606
    model_lora1 = create_packed_lora(
        2,
        model,
        module_name="gate_up_proj",
        replaced_module_names=["gate_proj", "up_proj"],
607
        device=device,
608
609
610
        empty_replaced_module_name="gate_proj",
    )

611
612
613
614
615
616
    manager = LoRAModelManager(model,
                               2,
                               2,
                               2,
                               LoRAConfig(max_lora_rank=8,
                                          max_cpu_loras=2,
617
618
                                          max_loras=2,
                                          lora_dtype=DEFAULT_DTYPE),
619
                               device=device)
620
621
622
623
    model = manager.model

    assert isinstance(model.get_submodule("gate_up_proj"),
                      MergedColumnParallelLinearWithLoRA)
624
625
626
    # Verify packed lora is correct
    model_lora_clone = model_lora.clone(1)
    model_lora_clone1 = model_lora1.clone(1)
627
628
    assert manager.add_adapter(model_lora)
    assert manager.add_adapter(model_lora1)
629

630
631
632
    assert model_lora.get_lora("gate_proj") is None
    assert model_lora.get_lora("up_proj") is None
    assert model_lora1.get_lora("up_proj") is None
633
634
635
    packed_lora = model_lora.get_lora("gate_up_proj")
    assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights)

636
    torch.testing.assert_close(packed_lora.lora_a[0],
637
                               model_lora_clone.get_lora("gate_proj").lora_a)
638
    torch.testing.assert_close(packed_lora.lora_b[0],
639
                               model_lora_clone.get_lora("gate_proj").lora_b)
640
    torch.testing.assert_close(packed_lora.lora_a[1],
641
                               model_lora_clone.get_lora("up_proj").lora_a)
642
    torch.testing.assert_close(packed_lora.lora_b[1],
643
                               model_lora_clone.get_lora("up_proj").lora_b)
644
645
646
647
648
649

    packed_lora1 = model_lora1.get_lora("gate_up_proj")
    assert packed_lora1 and isinstance(packed_lora1, PackedLoRALayerWeights)

    assert packed_lora1.lora_a[0] is None
    assert packed_lora1.lora_b[0] is None
650
    torch.testing.assert_close(packed_lora1.lora_a[1],
651
                               model_lora_clone1.get_lora("up_proj").lora_a)
652
    torch.testing.assert_close(packed_lora1.lora_b[1],
653
                               model_lora_clone1.get_lora("up_proj").lora_b)