test_lora_manager.py 25.2 KB
Newer Older
1
import os
2
from typing import Dict, List
3
4
5
6
7
8
9
10

import pytest
import torch
from safetensors.torch import load_file
from torch import nn

from vllm.config import LoRAConfig
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
11
12
                              MergedColumnParallelLinearWithLoRA,
                              RowParallelLinearWithLoRA)
13
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
14
15
from vllm.lora.models import (LoRAMapping, LoRAModel, LoRAModelManager,
                              LRUCacheLoRAModelManager)
16
17
18
19
20
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager,
                                      WorkerLoRAManager)
from vllm.model_executor.layers.linear import RowParallelLinear

Terry's avatar
Terry committed
21
22
23
24
25
26
27
EMBEDDING_MODULES = {
    "embed_tokens": "input_embeddings",
    "lm_head": "output_embeddings",
}

EMBEDDING_PADDING_MODULES = ["lm_head"]

28
29
30
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
31

32
33
34

@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_from_lora_tensors(sql_lora_files, device):
35
36
37
38
    tensors = load_file(
        os.path.join(sql_lora_files, "adapter_model.safetensors"))
    new_embeddings = load_file(
        os.path.join(sql_lora_files, "new_embeddings.safetensors"))
Terry's avatar
Terry committed
39
40
41
42
43
    lora_model = LoRAModel.from_lora_tensors(
        1,
        8,
        16,
        tensors,
44
        device,
Terry's avatar
Terry committed
45
46
47
        embeddings=new_embeddings,
        embedding_modules=EMBEDDING_MODULES,
        embedding_padding_modules=EMBEDDING_PADDING_MODULES)
48
49
50
51
52
53
    for module_name, lora in lora_model.loras.items():
        assert lora.module_name == module_name
        assert lora.rank == 8
        assert lora.lora_alpha == 16
        assert lora.lora_a is not None
        assert lora.lora_b is not None
54
55
        assert lora.lora_a.device == torch.device(device)
        assert lora.lora_b.device == torch.device(device)
56
57
58
59
60
61
62
63
64
65
66
67
68
69
        assert (lora.lora_a.shape[1] == lora.lora_b.shape[0]
                ), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}"
        assert lora.lora_a.shape[1] == 8
        embeddings_module = next(
            (k for k in EMBEDDING_MODULES if k in module_name), None)
        if embeddings_module:
            assert torch.equal(
                lora.embeddings_tensor,
                new_embeddings[EMBEDDING_MODULES[embeddings_module]].to(
                    device=lora.embeddings_tensor.device))
        else:
            assert lora.embeddings_tensor is None


70
71
def create_lora(lora_id: int, model: nn.Module, sub_modules: List[str],
                device: torch.device) -> LoRAModel:
72
    loras: Dict[str, LoRALayerWeights] = {}
73
74
75
76
77
78
    for name in sub_modules:
        w = model.get_submodule(name).weight
        loras[name] = LoRALayerWeights(
            name,
            8,
            16,
79
80
            torch.rand([w.shape[1], 8], device=device),
            torch.rand([8, w.shape[0]], device=device),
81
82
83
84
85
86
87
88
89
        )
    return LoRAModel(lora_id, 8, loras)


def create_packed_lora(
    lora_id: int,
    model: nn.Module,
    module_name,
    replaced_module_names,
90
    device: torch.device,
91
92
93
    empty_replaced_module_name=None,
) -> LoRAModel:
    w = model.get_submodule(module_name).weight
94
    loras: Dict[str, LoRALayerWeights] = {}
95
96
97
98
99
100
101
    for replaced_module_name in replaced_module_names:
        if replaced_module_name == empty_replaced_module_name:
            continue
        loras[replaced_module_name] = LoRALayerWeights(
            replaced_module_name,
            8,
            16,
102
            torch.rand([w.shape[1], 8], device=device),
103
            torch.rand([8, w.shape[0] // len(replaced_module_names)],
104
                       device=device),
105
106
107
108
109
110
        )
    return LoRAModel(lora_id, 8, loras)


def test_replace_submodules(dist_init, dummy_model):
    model = dummy_model
Terry's avatar
Terry committed
111
112
113
114
    model.supported_lora_modules = ["dense1", "layer1.dense2"]
    model.packed_modules_mapping = {}
    manager = LoRAModelManager(
        model, 1, 1, 1,
115
116
        LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8),
        torch.device("cuda"))
117
118
119
120
121
122
123
124
125
126
127
    model = manager.model

    assert isinstance(model.get_submodule("dense1"),
                      ColumnParallelLinearWithLoRA)
    assert isinstance(model.get_submodule("layer1.dense1"),
                      ColumnParallelLinearWithLoRA)
    assert isinstance(model.get_submodule("dense2"), RowParallelLinear)
    assert isinstance(model.get_submodule("layer1.dense2"),
                      RowParallelLinearWithLoRA)


128
129
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_lora_model_manager(dist_init, dummy_model, device):
130
    model = dummy_model
Terry's avatar
Terry committed
131
132
    model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
    model.packed_modules_mapping = {}
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
    model_lora1 = create_lora(1,
                              model, ["layer1.dense1", "dense2", "lm_head"],
                              device=device)
    model_lora2 = create_lora(2,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora3 = create_lora(3,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    manager = LoRAModelManager(model,
                               2,
                               2,
                               2,
                               LoRAConfig(max_lora_rank=8,
                                          max_cpu_loras=3,
                                          max_loras=2),
                               device=device)
150
    assert all(x is None for x in manager.lora_index_to_id)
151
152
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
153
    assert manager.lora_index_to_id[0] == 1
154
155
156
157
    assert not manager.add_adapter(model_lora1)
    assert not manager.activate_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
158
159
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
160
161
162
    assert not manager.add_adapter(model_lora2)
    assert not manager.activate_adapter(2)
    assert manager.add_adapter(model_lora3)
163
164
165
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
    with pytest.raises(ValueError):
166
        assert manager.activate_adapter(3)
167
168
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
169
    assert manager.remove_adapter(model_lora2.id)
170
    assert manager.lora_index_to_id[1] is None
171
172
173
174
    assert not manager.remove_adapter(model_lora2.id)
    assert manager.remove_adapter(model_lora1.id)
    assert not manager.remove_adapter(model_lora1.id)
    assert manager.add_adapter(model_lora1)
175
176
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] is None
177
178
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(3)
179
180
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] is None
181
    assert manager.activate_adapter(2)
182
183
184
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2

185
186
    assert manager.device == device
    assert manager.punica_wrapper.device == device
187

188
189
190

@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
191
    model = dummy_model
Terry's avatar
Terry committed
192
193
    model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
    model.packed_modules_mapping = {}
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
    model_lora1 = create_lora(1,
                              model, ["layer1.dense1", "dense2", "lm_head"],
                              device=device)
    model_lora2 = create_lora(2,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora3 = create_lora(3,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    manager = LRUCacheLoRAModelManager(model,
                                       2,
                                       2,
                                       2,
                                       LoRAConfig(max_lora_rank=8,
                                                  max_cpu_loras=3,
                                                  max_loras=2),
                                       device=device)
211
    assert all(x is None for x in manager.lora_index_to_id)
212
213
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
214
    assert manager.lora_index_to_id[0] == 1
215
216
217
218
    assert not manager.add_adapter(model_lora1)
    assert not manager.activate_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
219
220
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
221
222
223
    assert not manager.add_adapter(model_lora2)
    assert not manager.activate_adapter(2)
    assert manager.add_adapter(model_lora3)
224
225
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
226
    assert manager.activate_adapter(3)
227
228
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2
229
    assert manager.remove_adapter(model_lora2.id)
230
    assert manager.lora_index_to_id[1] is None
231
232
233
234
235
    assert not manager.remove_adapter(model_lora2.id)
    assert manager.remove_adapter(model_lora1.id)
    assert not manager.remove_adapter(model_lora1.id)
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
236
237
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
238
239
    assert manager.add_adapter(model_lora2)
    assert manager.deactivate_adapter(3)
240
241
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 1
242
    assert manager.activate_adapter(2)
243
244
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
245
    assert manager.activate_adapter(3)
246
247
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 3
248
    assert manager.pin_adapter(2)
249
250
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 3
251
    assert manager.activate_adapter(1)
252
253
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
254
    assert manager.deactivate_adapter(2)
255
256
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 1
257
    assert manager.activate_adapter(3)
258
259
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
260
261
    assert manager.pin_adapter(3)
    assert manager.pin_adapter(1)
262
    with pytest.raises(RuntimeError):
263
        assert manager.pin_adapter(2)
264
265
266
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
    with pytest.raises(RuntimeError):
267
        assert manager.activate_adapter(2)
268

269
270
    assert manager.deactivate_adapter(3)
    assert manager.pin_adapter(2)
271
272
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
273
    assert manager.remove_adapter(3)
274
    with pytest.raises(ValueError):
275
        assert manager.pin_adapter(3)
276

277
278
279
    assert manager.punica_wrapper.device == device
    assert manager.device == device

280

281
282
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_lru_lora_model_manager(dist_init, dummy_model, device):
283
284
285
    # This tests just the LRU cache functionality, everything else is
    # tested in test_lora_model_manager
    model = dummy_model
Terry's avatar
Terry committed
286
287
    model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
    model.packed_modules_mapping = {}
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
    model_lora1 = create_lora(1,
                              model, ["layer1.dense1", "dense2", "lm_head"],
                              device=device)
    model_lora2 = create_lora(2,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora3 = create_lora(3,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora4 = create_lora(4,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    manager = LRUCacheLoRAModelManager(model,
                                       2,
                                       2,
                                       2,
                                       LoRAConfig(max_lora_rank=8,
                                                  max_cpu_loras=2,
                                                  max_loras=2),
                                       device=device)
308
309
310
311

    assert all(x is None for x in manager.lora_index_to_id)

    # Add up to capacity
312
313
314
315
    assert manager.add_adapter(model_lora1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(1)
    assert manager.activate_adapter(2)
316

317
    assert set(manager.list_adapters()) == {1, 2}
318
319
320
321
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2

    # Add over capacity
322
323
324
325
    assert manager.add_adapter(model_lora3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(3)
    assert manager.activate_adapter(4)
326

327
    assert set(manager.list_adapters()) == {3, 4}
328
329
330
331
332
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 4

    # Add 3 again to move it to the top and then add 2
    # should return false since it's in already
333
334
335
336
    assert not manager.add_adapter(model_lora3)
    assert not manager.activate_adapter(3)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
337

338
    assert set(manager.list_adapters()) == {3, 2}
339
340
341
342
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2

    # Remove manually
343
344
    assert manager.remove_adapter(3)
    assert not manager.remove_adapter(3)
345

346
    assert set(manager.list_adapters()) == {2}
347
348
349
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 2

350
351
352
353
    assert manager.add_adapter(model_lora3)
    assert manager.activate_adapter(3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(4)
354

355
    assert set(manager.list_adapters()) == {3, 4}
356
357
358
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 4

359
360
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == {4}
361
362
363
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 4

364
365
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == set()
366
367
    assert all(x is None for x in manager.lora_index_to_id)

368
369
    assert not manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == set()
370
371
    assert all(x is None for x in manager.lora_index_to_id)

372
    # pinning
373
374
375
376
377
    assert manager.add_adapter(model_lora3)
    assert manager.activate_adapter(3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(4)
    assert set(manager.list_adapters()) == {3, 4}
378
    with pytest.raises(ValueError):
379
380
        assert manager.pin_adapter(1)
    assert manager.pin_adapter(3)
381
    # Remove manually
382
383
    assert manager.remove_adapter(3)
    assert not manager.remove_adapter(3)
384

385
    assert set(manager.list_adapters()) == {4}
386
387
388
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 4

389
390
391
392
    assert manager.add_adapter(model_lora1)
    assert manager.pin_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
393

394
    assert set(manager.list_adapters()) == {1, 2}
395
396
397
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2

398
399
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == {1}
400
401
402
403
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] is None

    with pytest.raises(RuntimeError):
404
        assert manager.remove_oldest_adapter()
405

406
    assert set(manager.list_adapters()) == {1}
407
408
    assert manager.punica_wrapper.device == device
    assert manager.device == device
409

410

411
@pytest.mark.parametrize("device", CUDA_DEVICES)
412
def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
413
                                          sql_lora_files, device):
414
    lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
415
    worker_adapter_manager = LRUCacheWorkerLoRAManager(
Terry's avatar
Terry committed
416
        4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
417
        lora_config.lora_extra_vocab_size, lora_config, device,
Terry's avatar
Terry committed
418
        EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
419
420
    worker_adapter_manager.create_lora_manager(
        llama_2_7b_model_extra_embeddings)
421
422

    mapping = LoRAMapping([], [])
423
    worker_adapter_manager.set_active_adapters([
424
425
426
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files)
    ], mapping)
427
428
429
    assert worker_adapter_manager.list_adapters() == {1, 2}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
430

431
    worker_adapter_manager.set_active_adapters([
432
433
434
435
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("3", 3, sql_lora_files),
        LoRARequest("4", 4, sql_lora_files)
    ], mapping)
436
437
438
439
440
    assert worker_adapter_manager.list_adapters() == {1, 2, 3, 4}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 3
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
441

442
    worker_adapter_manager.set_active_adapters([
443
444
445
446
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files),
        LoRARequest("5", 5, sql_lora_files)
    ], mapping)
447
448
449
450
451
    assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
452

453
    worker_adapter_manager.set_active_adapters([
454
455
456
457
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files)
    ], mapping)
458
459
460
461
462
    assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
463

464
    worker_adapter_manager.set_active_adapters([
465
466
467
468
        LoRARequest("6", 6, sql_lora_files),
        LoRARequest("7", 7, sql_lora_files),
        LoRARequest("8", 8, sql_lora_files)
    ], mapping)
469
470
471
472
473
    assert worker_adapter_manager.list_adapters() == {1, 6, 7, 8}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 7
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 8
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 6
474
475
476

    # Over capacity
    with pytest.raises(RuntimeError):
477
        worker_adapter_manager.set_active_adapters([
478
479
480
481
482
483
484
            LoRARequest("10", 10, sql_lora_files),
            LoRARequest("11", 11, sql_lora_files),
            LoRARequest("12", 12, sql_lora_files),
            LoRARequest("13", 13, sql_lora_files),
            LoRARequest("14", 14, sql_lora_files)
        ], mapping)

485
486
487
    assert worker_adapter_manager.device == device
    assert (worker_adapter_manager._adapter_manager.punica_wrapper.device ==
            device)
488

489
490

@pytest.mark.parametrize("device", CUDA_DEVICES)
491
def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
492
                                sql_lora_files, device):
493
494
    # Should remove every LoRA not specified in the request.
    lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
495
    worker_adapter_manager = WorkerLoRAManager(
Terry's avatar
Terry committed
496
        4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
497
        lora_config.lora_extra_vocab_size, lora_config, device,
Terry's avatar
Terry committed
498
        EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
499
500
    worker_adapter_manager.create_lora_manager(
        llama_2_7b_model_extra_embeddings)
501
502

    mapping = LoRAMapping([], [])
503
    worker_adapter_manager.set_active_adapters([
504
505
506
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files)
    ], mapping)
507
508
509
    assert worker_adapter_manager.list_adapters() == {1, 2}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
510

511
    worker_adapter_manager.set_active_adapters([
512
513
514
515
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("3", 3, sql_lora_files),
        LoRARequest("4", 4, sql_lora_files)
    ], mapping)
516
517
518
519
    assert worker_adapter_manager.list_adapters() == {1, 3, 4}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 3
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 4
520

521
    worker_adapter_manager.set_active_adapters([
522
523
524
525
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files),
        LoRARequest("5", 5, sql_lora_files)
    ], mapping)
526
527
528
529
    assert worker_adapter_manager.list_adapters() == {1, 2, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
530

531
    worker_adapter_manager.set_active_adapters([
532
533
534
535
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files)
    ], mapping)
536
537
538
539
    assert worker_adapter_manager.list_adapters() == {1}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] is None
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] is None
540

541
    worker_adapter_manager.set_active_adapters([
542
543
544
545
        LoRARequest("6", 6, sql_lora_files),
        LoRARequest("7", 7, sql_lora_files),
        LoRARequest("8", 8, sql_lora_files)
    ], mapping)
546
547
548
549
    assert worker_adapter_manager.list_adapters() == {6, 7, 8}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 8
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 6
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 7
550
551
552

    # Over capacity
    with pytest.raises(RuntimeError):
553
        worker_adapter_manager.set_active_adapters([
554
555
556
557
558
559
560
            LoRARequest("10", 10, sql_lora_files),
            LoRARequest("11", 11, sql_lora_files),
            LoRARequest("12", 12, sql_lora_files),
            LoRARequest("13", 13, sql_lora_files),
            LoRARequest("14", 14, sql_lora_files)
        ], mapping)

561
562
563
564
    assert worker_adapter_manager.device == device
    assert (worker_adapter_manager._adapter_manager.punica_wrapper.device ==
            device)

565

566
567
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_packed_loras(dist_init, dummy_model_gate_up, device):
568
    model = dummy_model_gate_up
Terry's avatar
Terry committed
569
570
571
572
573
574
575
    model.supported_lora_modules = ["gate_up_proj"]
    model.packed_modules_mapping = {
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
576
577
578
579
    model_lora = create_packed_lora(
        1,
        model,
        module_name="gate_up_proj",
580
581
        replaced_module_names=["gate_proj", "up_proj"],
        device=device)
582
583
584
585
586
    model_lora1 = create_packed_lora(
        2,
        model,
        module_name="gate_up_proj",
        replaced_module_names=["gate_proj", "up_proj"],
587
        device=device,
588
589
590
        empty_replaced_module_name="gate_proj",
    )

591
592
593
594
595
596
597
598
    manager = LoRAModelManager(model,
                               2,
                               2,
                               2,
                               LoRAConfig(max_lora_rank=8,
                                          max_cpu_loras=2,
                                          max_loras=2),
                               device=device)
599
600
601
602
    model = manager.model

    assert isinstance(model.get_submodule("gate_up_proj"),
                      MergedColumnParallelLinearWithLoRA)
603
604
    assert manager.add_adapter(model_lora)
    assert manager.add_adapter(model_lora1)
605
606
607
608

    packed_lora = model_lora.get_lora("gate_up_proj")
    assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights)

609
610
611
612
613
614
615
616
    torch.testing.assert_close(packed_lora.lora_a[0],
                               model_lora.get_lora("gate_proj").lora_a)
    torch.testing.assert_close(packed_lora.lora_b[0],
                               model_lora.get_lora("gate_proj").lora_b)
    torch.testing.assert_close(packed_lora.lora_a[1],
                               model_lora.get_lora("up_proj").lora_a)
    torch.testing.assert_close(packed_lora.lora_b[1],
                               model_lora.get_lora("up_proj").lora_b)
617
618
619
620
621
622

    packed_lora1 = model_lora1.get_lora("gate_up_proj")
    assert packed_lora1 and isinstance(packed_lora1, PackedLoRALayerWeights)

    assert packed_lora1.lora_a[0] is None
    assert packed_lora1.lora_b[0] is None
623
624
625
626
    torch.testing.assert_close(packed_lora1.lora_a[1],
                               model_lora1.get_lora("up_proj").lora_a)
    torch.testing.assert_close(packed_lora1.lora_b[1],
                               model_lora1.get_lora("up_proj").lora_b)