test_lora_manager.py 25.4 KB
Newer Older
1
import os
2
from typing import Dict, List
3
4
5
6
7
8
9
10

import pytest
import torch
from safetensors.torch import load_file
from torch import nn

from vllm.config import LoRAConfig
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
11
12
                              MergedColumnParallelLinearWithLoRA,
                              RowParallelLinearWithLoRA)
13
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
14
15
from vllm.lora.models import (LoRAMapping, LoRAModel, LoRAModelManager,
                              LRUCacheLoRAModelManager)
16
from vllm.lora.peft_helper import PEFTHelper
17
18
19
20
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager,
                                      WorkerLoRAManager)
from vllm.model_executor.layers.linear import RowParallelLinear
21
from vllm.platforms import current_platform
22

Terry's avatar
Terry committed
23
24
25
26
27
28
29
EMBEDDING_MODULES = {
    "embed_tokens": "input_embeddings",
    "lm_head": "output_embeddings",
}

EMBEDDING_PADDING_MODULES = ["lm_head"]

30
DEVICES = ([
31
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
32
] if current_platform.is_cuda_alike() else ["cpu"])
33

34

35
@pytest.mark.parametrize("device", DEVICES)
36
def test_from_lora_tensors(sql_lora_files, device):
37
38
39
40
    tensors = load_file(
        os.path.join(sql_lora_files, "adapter_model.safetensors"))
    new_embeddings = load_file(
        os.path.join(sql_lora_files, "new_embeddings.safetensors"))
41

42
43
    peft_helper = PEFTHelper.from_local_dir(sql_lora_files,
                                            max_position_embeddings=4096)
Terry's avatar
Terry committed
44
45
46
    lora_model = LoRAModel.from_lora_tensors(
        1,
        tensors,
47
48
        peft_helper=peft_helper,
        device=device,
Terry's avatar
Terry committed
49
50
51
        embeddings=new_embeddings,
        embedding_modules=EMBEDDING_MODULES,
        embedding_padding_modules=EMBEDDING_PADDING_MODULES)
52
53
54
55
56
57
    for module_name, lora in lora_model.loras.items():
        assert lora.module_name == module_name
        assert lora.rank == 8
        assert lora.lora_alpha == 16
        assert lora.lora_a is not None
        assert lora.lora_b is not None
58
59
        assert lora.lora_a.device == torch.device(device)
        assert lora.lora_b.device == torch.device(device)
60
61
62
63
64
65
66
67
68
69
70
71
72
73
        assert (lora.lora_a.shape[1] == lora.lora_b.shape[0]
                ), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}"
        assert lora.lora_a.shape[1] == 8
        embeddings_module = next(
            (k for k in EMBEDDING_MODULES if k in module_name), None)
        if embeddings_module:
            assert torch.equal(
                lora.embeddings_tensor,
                new_embeddings[EMBEDDING_MODULES[embeddings_module]].to(
                    device=lora.embeddings_tensor.device))
        else:
            assert lora.embeddings_tensor is None


74
75
def create_lora(lora_id: int, model: nn.Module, sub_modules: List[str],
                device: torch.device) -> LoRAModel:
76
    loras: Dict[str, LoRALayerWeights] = {}
77
78
79
80
81
82
    for name in sub_modules:
        w = model.get_submodule(name).weight
        loras[name] = LoRALayerWeights(
            name,
            8,
            16,
83
84
            torch.rand([w.shape[1], 8], device=device),
            torch.rand([8, w.shape[0]], device=device),
85
86
87
88
89
90
91
92
93
        )
    return LoRAModel(lora_id, 8, loras)


def create_packed_lora(
    lora_id: int,
    model: nn.Module,
    module_name,
    replaced_module_names,
94
    device: torch.device,
95
96
97
    empty_replaced_module_name=None,
) -> LoRAModel:
    w = model.get_submodule(module_name).weight
98
    loras: Dict[str, LoRALayerWeights] = {}
99
100
101
102
103
104
105
    for replaced_module_name in replaced_module_names:
        if replaced_module_name == empty_replaced_module_name:
            continue
        loras[replaced_module_name] = LoRALayerWeights(
            replaced_module_name,
            8,
            16,
106
            torch.rand([w.shape[1], 8], device=device),
107
            torch.rand([8, w.shape[0] // len(replaced_module_names)],
108
                       device=device),
109
110
111
112
113
114
        )
    return LoRAModel(lora_id, 8, loras)


def test_replace_submodules(dist_init, dummy_model):
    model = dummy_model
Terry's avatar
Terry committed
115
116
117
118
    model.supported_lora_modules = ["dense1", "layer1.dense2"]
    model.packed_modules_mapping = {}
    manager = LoRAModelManager(
        model, 1, 1, 1,
119
        LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8),
120
        torch.device(DEVICES[0]))
121
122
123
124
125
126
127
128
129
130
131
    model = manager.model

    assert isinstance(model.get_submodule("dense1"),
                      ColumnParallelLinearWithLoRA)
    assert isinstance(model.get_submodule("layer1.dense1"),
                      ColumnParallelLinearWithLoRA)
    assert isinstance(model.get_submodule("dense2"), RowParallelLinear)
    assert isinstance(model.get_submodule("layer1.dense2"),
                      RowParallelLinearWithLoRA)


132
@pytest.mark.parametrize("device", DEVICES)
133
def test_lora_model_manager(dist_init, dummy_model, device):
134
    model = dummy_model
Terry's avatar
Terry committed
135
136
    model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
    model.packed_modules_mapping = {}
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
    model_lora1 = create_lora(1,
                              model, ["layer1.dense1", "dense2", "lm_head"],
                              device=device)
    model_lora2 = create_lora(2,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora3 = create_lora(3,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    manager = LoRAModelManager(model,
                               2,
                               2,
                               2,
                               LoRAConfig(max_lora_rank=8,
                                          max_cpu_loras=3,
                                          max_loras=2),
                               device=device)
154
    assert all(x is None for x in manager.lora_index_to_id)
155
156
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
157
    assert manager.lora_index_to_id[0] == 1
158
159
160
161
    assert not manager.add_adapter(model_lora1)
    assert not manager.activate_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
162
163
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
164
165
166
    assert not manager.add_adapter(model_lora2)
    assert not manager.activate_adapter(2)
    assert manager.add_adapter(model_lora3)
167
168
169
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
    with pytest.raises(ValueError):
170
        assert manager.activate_adapter(3)
171
172
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
173
    assert manager.remove_adapter(model_lora2.id)
174
    assert manager.lora_index_to_id[1] is None
175
176
177
178
    assert not manager.remove_adapter(model_lora2.id)
    assert manager.remove_adapter(model_lora1.id)
    assert not manager.remove_adapter(model_lora1.id)
    assert manager.add_adapter(model_lora1)
179
180
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] is None
181
182
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(3)
183
184
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] is None
185
    assert manager.activate_adapter(2)
186
187
188
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2

189
190
    assert manager.device == device
    assert manager.punica_wrapper.device == device
191

192

193
@pytest.mark.parametrize("device", DEVICES)
194
def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
195
    model = dummy_model
Terry's avatar
Terry committed
196
197
    model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
    model.packed_modules_mapping = {}
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
    model_lora1 = create_lora(1,
                              model, ["layer1.dense1", "dense2", "lm_head"],
                              device=device)
    model_lora2 = create_lora(2,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora3 = create_lora(3,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    manager = LRUCacheLoRAModelManager(model,
                                       2,
                                       2,
                                       2,
                                       LoRAConfig(max_lora_rank=8,
                                                  max_cpu_loras=3,
                                                  max_loras=2),
                                       device=device)
215
    assert all(x is None for x in manager.lora_index_to_id)
216
217
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
218
    assert manager.lora_index_to_id[0] == 1
219
220
221
222
    assert not manager.add_adapter(model_lora1)
    assert not manager.activate_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
223
224
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
225
226
227
    assert not manager.add_adapter(model_lora2)
    assert not manager.activate_adapter(2)
    assert manager.add_adapter(model_lora3)
228
229
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
230
    assert manager.activate_adapter(3)
231
232
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2
233
    assert manager.remove_adapter(model_lora2.id)
234
    assert manager.lora_index_to_id[1] is None
235
236
237
238
239
    assert not manager.remove_adapter(model_lora2.id)
    assert manager.remove_adapter(model_lora1.id)
    assert not manager.remove_adapter(model_lora1.id)
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
240
241
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
242
243
    assert manager.add_adapter(model_lora2)
    assert manager.deactivate_adapter(3)
244
245
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 1
246
    assert manager.activate_adapter(2)
247
248
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
249
    assert manager.activate_adapter(3)
250
251
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 3
252
    assert manager.pin_adapter(2)
253
254
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 3
255
    assert manager.activate_adapter(1)
256
257
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
258
    assert manager.deactivate_adapter(2)
259
260
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 1
261
    assert manager.activate_adapter(3)
262
263
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
264
265
    assert manager.pin_adapter(3)
    assert manager.pin_adapter(1)
266
    with pytest.raises(RuntimeError):
267
        assert manager.pin_adapter(2)
268
269
270
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
    with pytest.raises(RuntimeError):
271
        assert manager.activate_adapter(2)
272

273
274
    assert manager.deactivate_adapter(3)
    assert manager.pin_adapter(2)
275
276
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
277
    assert manager.remove_adapter(3)
278
    with pytest.raises(ValueError):
279
        assert manager.pin_adapter(3)
280

281
282
283
    assert manager.punica_wrapper.device == device
    assert manager.device == device

284

285
@pytest.mark.parametrize("device", DEVICES)
286
def test_lru_lora_model_manager(dist_init, dummy_model, device):
287
288
289
    # This tests just the LRU cache functionality, everything else is
    # tested in test_lora_model_manager
    model = dummy_model
Terry's avatar
Terry committed
290
291
    model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
    model.packed_modules_mapping = {}
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
    model_lora1 = create_lora(1,
                              model, ["layer1.dense1", "dense2", "lm_head"],
                              device=device)
    model_lora2 = create_lora(2,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora3 = create_lora(3,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora4 = create_lora(4,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    manager = LRUCacheLoRAModelManager(model,
                                       2,
                                       2,
                                       2,
                                       LoRAConfig(max_lora_rank=8,
                                                  max_cpu_loras=2,
                                                  max_loras=2),
                                       device=device)
312
313
314
315

    assert all(x is None for x in manager.lora_index_to_id)

    # Add up to capacity
316
317
318
319
    assert manager.add_adapter(model_lora1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(1)
    assert manager.activate_adapter(2)
320

321
    assert set(manager.list_adapters()) == {1, 2}
322
323
324
325
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2

    # Add over capacity
326
327
328
329
    assert manager.add_adapter(model_lora3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(3)
    assert manager.activate_adapter(4)
330

331
    assert set(manager.list_adapters()) == {3, 4}
332
333
334
335
336
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 4

    # Add 3 again to move it to the top and then add 2
    # should return false since it's in already
337
338
339
340
    assert not manager.add_adapter(model_lora3)
    assert not manager.activate_adapter(3)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
341

342
    assert set(manager.list_adapters()) == {3, 2}
343
344
345
346
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2

    # Remove manually
347
348
    assert manager.remove_adapter(3)
    assert not manager.remove_adapter(3)
349

350
    assert set(manager.list_adapters()) == {2}
351
352
353
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 2

354
355
356
357
    assert manager.add_adapter(model_lora3)
    assert manager.activate_adapter(3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(4)
358

359
    assert set(manager.list_adapters()) == {3, 4}
360
361
362
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 4

363
364
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == {4}
365
366
367
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 4

368
369
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == set()
370
371
    assert all(x is None for x in manager.lora_index_to_id)

372
373
    assert not manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == set()
374
375
    assert all(x is None for x in manager.lora_index_to_id)

376
    # pinning
377
378
379
380
381
    assert manager.add_adapter(model_lora3)
    assert manager.activate_adapter(3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(4)
    assert set(manager.list_adapters()) == {3, 4}
382
    with pytest.raises(ValueError):
383
384
        assert manager.pin_adapter(1)
    assert manager.pin_adapter(3)
385
    # Remove manually
386
387
    assert manager.remove_adapter(3)
    assert not manager.remove_adapter(3)
388

389
    assert set(manager.list_adapters()) == {4}
390
391
392
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 4

393
394
395
396
    assert manager.add_adapter(model_lora1)
    assert manager.pin_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
397

398
    assert set(manager.list_adapters()) == {1, 2}
399
400
401
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2

402
403
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == {1}
404
405
406
407
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] is None

    with pytest.raises(RuntimeError):
408
        assert manager.remove_oldest_adapter()
409

410
    assert set(manager.list_adapters()) == {1}
411
412
    assert manager.punica_wrapper.device == device
    assert manager.device == device
413

414

415
@pytest.mark.parametrize("device", DEVICES)
416
def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
417
                                          sql_lora_files, device):
418
    lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
419
    worker_adapter_manager = LRUCacheWorkerLoRAManager(
Terry's avatar
Terry committed
420
        4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
421
        lora_config.lora_extra_vocab_size, lora_config, device,
Terry's avatar
Terry committed
422
        EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
423
424
    worker_adapter_manager.create_lora_manager(
        llama_2_7b_model_extra_embeddings)
425
426

    mapping = LoRAMapping([], [])
427
    worker_adapter_manager.set_active_adapters([
428
429
430
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files)
    ], mapping)
431
432
433
    assert worker_adapter_manager.list_adapters() == {1, 2}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
434

435
    worker_adapter_manager.set_active_adapters([
436
437
438
439
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("3", 3, sql_lora_files),
        LoRARequest("4", 4, sql_lora_files)
    ], mapping)
440
441
442
443
444
    assert worker_adapter_manager.list_adapters() == {1, 2, 3, 4}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 3
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
445

446
    worker_adapter_manager.set_active_adapters([
447
448
449
450
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files),
        LoRARequest("5", 5, sql_lora_files)
    ], mapping)
451
452
453
454
455
    assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
456

457
    worker_adapter_manager.set_active_adapters([
458
459
460
461
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files)
    ], mapping)
462
463
464
465
466
    assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
467

468
    worker_adapter_manager.set_active_adapters([
469
470
471
472
        LoRARequest("6", 6, sql_lora_files),
        LoRARequest("7", 7, sql_lora_files),
        LoRARequest("8", 8, sql_lora_files)
    ], mapping)
473
474
475
476
477
    assert worker_adapter_manager.list_adapters() == {1, 6, 7, 8}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 7
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 8
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 6
478
479
480

    # Over capacity
    with pytest.raises(RuntimeError):
481
        worker_adapter_manager.set_active_adapters([
482
483
484
485
486
487
488
            LoRARequest("10", 10, sql_lora_files),
            LoRARequest("11", 11, sql_lora_files),
            LoRARequest("12", 12, sql_lora_files),
            LoRARequest("13", 13, sql_lora_files),
            LoRARequest("14", 14, sql_lora_files)
        ], mapping)

489
490
491
    assert worker_adapter_manager.device == device
    assert (worker_adapter_manager._adapter_manager.punica_wrapper.device ==
            device)
492

493

494
@pytest.mark.parametrize("device", DEVICES)
495
def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
496
                                sql_lora_files, device):
497
498
    # Should remove every LoRA not specified in the request.
    lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
499
    worker_adapter_manager = WorkerLoRAManager(
Terry's avatar
Terry committed
500
        4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
501
        lora_config.lora_extra_vocab_size, lora_config, device,
Terry's avatar
Terry committed
502
        EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
503
504
    worker_adapter_manager.create_lora_manager(
        llama_2_7b_model_extra_embeddings)
505
506

    mapping = LoRAMapping([], [])
507
    worker_adapter_manager.set_active_adapters([
508
509
510
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files)
    ], mapping)
511
512
513
    assert worker_adapter_manager.list_adapters() == {1, 2}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
514

515
    worker_adapter_manager.set_active_adapters([
516
517
518
519
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("3", 3, sql_lora_files),
        LoRARequest("4", 4, sql_lora_files)
    ], mapping)
520
521
522
523
    assert worker_adapter_manager.list_adapters() == {1, 3, 4}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 3
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 4
524

525
    worker_adapter_manager.set_active_adapters([
526
527
528
529
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files),
        LoRARequest("5", 5, sql_lora_files)
    ], mapping)
530
531
532
533
    assert worker_adapter_manager.list_adapters() == {1, 2, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
534

535
    worker_adapter_manager.set_active_adapters([
536
537
538
539
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files)
    ], mapping)
540
541
542
543
    assert worker_adapter_manager.list_adapters() == {1}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] is None
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] is None
544

545
    worker_adapter_manager.set_active_adapters([
546
547
548
549
        LoRARequest("6", 6, sql_lora_files),
        LoRARequest("7", 7, sql_lora_files),
        LoRARequest("8", 8, sql_lora_files)
    ], mapping)
550
551
552
553
    assert worker_adapter_manager.list_adapters() == {6, 7, 8}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 8
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 6
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 7
554
555
556

    # Over capacity
    with pytest.raises(RuntimeError):
557
        worker_adapter_manager.set_active_adapters([
558
559
560
561
562
563
564
            LoRARequest("10", 10, sql_lora_files),
            LoRARequest("11", 11, sql_lora_files),
            LoRARequest("12", 12, sql_lora_files),
            LoRARequest("13", 13, sql_lora_files),
            LoRARequest("14", 14, sql_lora_files)
        ], mapping)

565
566
567
568
    assert worker_adapter_manager.device == device
    assert (worker_adapter_manager._adapter_manager.punica_wrapper.device ==
            device)

569

570
@pytest.mark.parametrize("device", DEVICES)
571
def test_packed_loras(dist_init, dummy_model_gate_up, device):
572
    model = dummy_model_gate_up
Terry's avatar
Terry committed
573
574
575
576
577
578
579
    model.supported_lora_modules = ["gate_up_proj"]
    model.packed_modules_mapping = {
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
580
581
582
583
    model_lora = create_packed_lora(
        1,
        model,
        module_name="gate_up_proj",
584
585
        replaced_module_names=["gate_proj", "up_proj"],
        device=device)
586
587
588
589
590
    model_lora1 = create_packed_lora(
        2,
        model,
        module_name="gate_up_proj",
        replaced_module_names=["gate_proj", "up_proj"],
591
        device=device,
592
593
594
        empty_replaced_module_name="gate_proj",
    )

595
596
597
598
599
600
601
602
    manager = LoRAModelManager(model,
                               2,
                               2,
                               2,
                               LoRAConfig(max_lora_rank=8,
                                          max_cpu_loras=2,
                                          max_loras=2),
                               device=device)
603
604
605
606
    model = manager.model

    assert isinstance(model.get_submodule("gate_up_proj"),
                      MergedColumnParallelLinearWithLoRA)
607
608
    assert manager.add_adapter(model_lora)
    assert manager.add_adapter(model_lora1)
609
610
611
612

    packed_lora = model_lora.get_lora("gate_up_proj")
    assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights)

613
614
615
616
617
618
619
620
    torch.testing.assert_close(packed_lora.lora_a[0],
                               model_lora.get_lora("gate_proj").lora_a)
    torch.testing.assert_close(packed_lora.lora_b[0],
                               model_lora.get_lora("gate_proj").lora_b)
    torch.testing.assert_close(packed_lora.lora_a[1],
                               model_lora.get_lora("up_proj").lora_a)
    torch.testing.assert_close(packed_lora.lora_b[1],
                               model_lora.get_lora("up_proj").lora_b)
621
622
623
624
625
626

    packed_lora1 = model_lora1.get_lora("gate_up_proj")
    assert packed_lora1 and isinstance(packed_lora1, PackedLoRALayerWeights)

    assert packed_lora1.lora_a[0] is None
    assert packed_lora1.lora_b[0] is None
627
628
629
630
    torch.testing.assert_close(packed_lora1.lora_a[1],
                               model_lora1.get_lora("up_proj").lora_a)
    torch.testing.assert_close(packed_lora1.lora_b[1],
                               model_lora1.get_lora("up_proj").lora_b)