test_lora_manager.py 25.5 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import os
4
from typing import Dict, List
5
6
7
8
9
10
11
12

import pytest
import torch
from safetensors.torch import load_file
from torch import nn

from vllm.config import LoRAConfig
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
13
14
                              MergedColumnParallelLinearWithLoRA,
                              RowParallelLinearWithLoRA)
15
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
16
17
from vllm.lora.models import (LoRAMapping, LoRAModel, LoRAModelManager,
                              LRUCacheLoRAModelManager)
18
from vllm.lora.peft_helper import PEFTHelper
19
20
21
22
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager,
                                      WorkerLoRAManager)
from vllm.model_executor.layers.linear import RowParallelLinear
23
from vllm.platforms import current_platform
24

Terry's avatar
Terry committed
25
26
27
28
29
30
31
EMBEDDING_MODULES = {
    "embed_tokens": "input_embeddings",
    "lm_head": "output_embeddings",
}

EMBEDDING_PADDING_MODULES = ["lm_head"]

32
DEVICES = ([
33
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
34
] if current_platform.is_cuda_alike() else ["cpu"])
35

36

37
@pytest.mark.parametrize("device", DEVICES)
38
def test_from_lora_tensors(sql_lora_files, device):
39
40
41
42
    tensors = load_file(
        os.path.join(sql_lora_files, "adapter_model.safetensors"))
    new_embeddings = load_file(
        os.path.join(sql_lora_files, "new_embeddings.safetensors"))
43

44
45
    peft_helper = PEFTHelper.from_local_dir(sql_lora_files,
                                            max_position_embeddings=4096)
Terry's avatar
Terry committed
46
47
48
    lora_model = LoRAModel.from_lora_tensors(
        1,
        tensors,
49
50
        peft_helper=peft_helper,
        device=device,
Terry's avatar
Terry committed
51
52
53
        embeddings=new_embeddings,
        embedding_modules=EMBEDDING_MODULES,
        embedding_padding_modules=EMBEDDING_PADDING_MODULES)
54
55
56
57
58
59
    for module_name, lora in lora_model.loras.items():
        assert lora.module_name == module_name
        assert lora.rank == 8
        assert lora.lora_alpha == 16
        assert lora.lora_a is not None
        assert lora.lora_b is not None
60
61
        assert lora.lora_a.device == torch.device(device)
        assert lora.lora_b.device == torch.device(device)
62
63
64
65
66
67
68
69
70
71
72
73
74
75
        assert (lora.lora_a.shape[1] == lora.lora_b.shape[0]
                ), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}"
        assert lora.lora_a.shape[1] == 8
        embeddings_module = next(
            (k for k in EMBEDDING_MODULES if k in module_name), None)
        if embeddings_module:
            assert torch.equal(
                lora.embeddings_tensor,
                new_embeddings[EMBEDDING_MODULES[embeddings_module]].to(
                    device=lora.embeddings_tensor.device))
        else:
            assert lora.embeddings_tensor is None


76
77
def create_lora(lora_id: int, model: nn.Module, sub_modules: List[str],
                device: torch.device) -> LoRAModel:
78
    loras: Dict[str, LoRALayerWeights] = {}
79
80
81
82
83
84
    for name in sub_modules:
        w = model.get_submodule(name).weight
        loras[name] = LoRALayerWeights(
            name,
            8,
            16,
85
86
            torch.rand([w.shape[1], 8], device=device),
            torch.rand([8, w.shape[0]], device=device),
87
88
89
90
91
92
93
94
95
        )
    return LoRAModel(lora_id, 8, loras)


def create_packed_lora(
    lora_id: int,
    model: nn.Module,
    module_name,
    replaced_module_names,
96
    device: torch.device,
97
98
99
    empty_replaced_module_name=None,
) -> LoRAModel:
    w = model.get_submodule(module_name).weight
100
    loras: Dict[str, LoRALayerWeights] = {}
101
102
103
104
105
106
107
    for replaced_module_name in replaced_module_names:
        if replaced_module_name == empty_replaced_module_name:
            continue
        loras[replaced_module_name] = LoRALayerWeights(
            replaced_module_name,
            8,
            16,
108
            torch.rand([w.shape[1], 8], device=device),
109
            torch.rand([8, w.shape[0] // len(replaced_module_names)],
110
                       device=device),
111
112
113
114
115
116
        )
    return LoRAModel(lora_id, 8, loras)


def test_replace_submodules(dist_init, dummy_model):
    model = dummy_model
Terry's avatar
Terry committed
117
118
119
120
    model.supported_lora_modules = ["dense1", "layer1.dense2"]
    model.packed_modules_mapping = {}
    manager = LoRAModelManager(
        model, 1, 1, 1,
121
        LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8),
122
        torch.device(DEVICES[0]))
123
124
125
126
127
128
129
130
131
132
133
    model = manager.model

    assert isinstance(model.get_submodule("dense1"),
                      ColumnParallelLinearWithLoRA)
    assert isinstance(model.get_submodule("layer1.dense1"),
                      ColumnParallelLinearWithLoRA)
    assert isinstance(model.get_submodule("dense2"), RowParallelLinear)
    assert isinstance(model.get_submodule("layer1.dense2"),
                      RowParallelLinearWithLoRA)


134
@pytest.mark.parametrize("device", DEVICES)
135
def test_lora_model_manager(dist_init, dummy_model, device):
136
    model = dummy_model
Terry's avatar
Terry committed
137
138
    model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
    model.packed_modules_mapping = {}
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
    model_lora1 = create_lora(1,
                              model, ["layer1.dense1", "dense2", "lm_head"],
                              device=device)
    model_lora2 = create_lora(2,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora3 = create_lora(3,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    manager = LoRAModelManager(model,
                               2,
                               2,
                               2,
                               LoRAConfig(max_lora_rank=8,
                                          max_cpu_loras=3,
                                          max_loras=2),
                               device=device)
156
    assert all(x is None for x in manager.lora_index_to_id)
157
158
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
159
    assert manager.lora_index_to_id[0] == 1
160
161
162
163
    assert not manager.add_adapter(model_lora1)
    assert not manager.activate_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
164
165
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
166
167
168
    assert not manager.add_adapter(model_lora2)
    assert not manager.activate_adapter(2)
    assert manager.add_adapter(model_lora3)
169
170
171
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
    with pytest.raises(ValueError):
172
        assert manager.activate_adapter(3)
173
174
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
175
    assert manager.remove_adapter(model_lora2.id)
176
    assert manager.lora_index_to_id[1] is None
177
178
179
180
    assert not manager.remove_adapter(model_lora2.id)
    assert manager.remove_adapter(model_lora1.id)
    assert not manager.remove_adapter(model_lora1.id)
    assert manager.add_adapter(model_lora1)
181
182
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] is None
183
184
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(3)
185
186
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] is None
187
    assert manager.activate_adapter(2)
188
189
190
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2

191
192
    assert manager.device == device
    assert manager.punica_wrapper.device == device
193

194

195
@pytest.mark.parametrize("device", DEVICES)
196
def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
197
    model = dummy_model
Terry's avatar
Terry committed
198
199
    model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
    model.packed_modules_mapping = {}
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
    model_lora1 = create_lora(1,
                              model, ["layer1.dense1", "dense2", "lm_head"],
                              device=device)
    model_lora2 = create_lora(2,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora3 = create_lora(3,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    manager = LRUCacheLoRAModelManager(model,
                                       2,
                                       2,
                                       2,
                                       LoRAConfig(max_lora_rank=8,
                                                  max_cpu_loras=3,
                                                  max_loras=2),
                                       device=device)
217
    assert all(x is None for x in manager.lora_index_to_id)
218
219
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
220
    assert manager.lora_index_to_id[0] == 1
221
222
223
224
    assert not manager.add_adapter(model_lora1)
    assert not manager.activate_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
225
226
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
227
228
229
    assert not manager.add_adapter(model_lora2)
    assert not manager.activate_adapter(2)
    assert manager.add_adapter(model_lora3)
230
231
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
232
    assert manager.activate_adapter(3)
233
234
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2
235
    assert manager.remove_adapter(model_lora2.id)
236
    assert manager.lora_index_to_id[1] is None
237
238
239
240
241
    assert not manager.remove_adapter(model_lora2.id)
    assert manager.remove_adapter(model_lora1.id)
    assert not manager.remove_adapter(model_lora1.id)
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
242
243
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
244
245
    assert manager.add_adapter(model_lora2)
    assert manager.deactivate_adapter(3)
246
247
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 1
248
    assert manager.activate_adapter(2)
249
250
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
251
    assert manager.activate_adapter(3)
252
253
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 3
254
    assert manager.pin_adapter(2)
255
256
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 3
257
    assert manager.activate_adapter(1)
258
259
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
260
    assert manager.deactivate_adapter(2)
261
262
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 1
263
    assert manager.activate_adapter(3)
264
265
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
266
267
    assert manager.pin_adapter(3)
    assert manager.pin_adapter(1)
268
    with pytest.raises(RuntimeError):
269
        assert manager.pin_adapter(2)
270
271
272
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
    with pytest.raises(RuntimeError):
273
        assert manager.activate_adapter(2)
274

275
276
    assert manager.deactivate_adapter(3)
    assert manager.pin_adapter(2)
277
278
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
279
    assert manager.remove_adapter(3)
280
    with pytest.raises(ValueError):
281
        assert manager.pin_adapter(3)
282

283
284
285
    assert manager.punica_wrapper.device == device
    assert manager.device == device

286

287
@pytest.mark.parametrize("device", DEVICES)
288
def test_lru_lora_model_manager(dist_init, dummy_model, device):
289
290
291
    # This tests just the LRU cache functionality, everything else is
    # tested in test_lora_model_manager
    model = dummy_model
Terry's avatar
Terry committed
292
293
    model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
    model.packed_modules_mapping = {}
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
    model_lora1 = create_lora(1,
                              model, ["layer1.dense1", "dense2", "lm_head"],
                              device=device)
    model_lora2 = create_lora(2,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora3 = create_lora(3,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora4 = create_lora(4,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    manager = LRUCacheLoRAModelManager(model,
                                       2,
                                       2,
                                       2,
                                       LoRAConfig(max_lora_rank=8,
                                                  max_cpu_loras=2,
                                                  max_loras=2),
                                       device=device)
314
315
316
317

    assert all(x is None for x in manager.lora_index_to_id)

    # Add up to capacity
318
319
320
321
    assert manager.add_adapter(model_lora1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(1)
    assert manager.activate_adapter(2)
322

323
    assert set(manager.list_adapters()) == {1, 2}
324
325
326
327
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2

    # Add over capacity
328
329
330
331
    assert manager.add_adapter(model_lora3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(3)
    assert manager.activate_adapter(4)
332

333
    assert set(manager.list_adapters()) == {3, 4}
334
335
336
337
338
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 4

    # Add 3 again to move it to the top and then add 2
    # should return false since it's in already
339
340
341
342
    assert not manager.add_adapter(model_lora3)
    assert not manager.activate_adapter(3)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
343

344
    assert set(manager.list_adapters()) == {3, 2}
345
346
347
348
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2

    # Remove manually
349
350
    assert manager.remove_adapter(3)
    assert not manager.remove_adapter(3)
351

352
    assert set(manager.list_adapters()) == {2}
353
354
355
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 2

356
357
358
359
    assert manager.add_adapter(model_lora3)
    assert manager.activate_adapter(3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(4)
360

361
    assert set(manager.list_adapters()) == {3, 4}
362
363
364
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 4

365
366
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == {4}
367
368
369
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 4

370
371
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == set()
372
373
    assert all(x is None for x in manager.lora_index_to_id)

374
375
    assert not manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == set()
376
377
    assert all(x is None for x in manager.lora_index_to_id)

378
    # pinning
379
380
381
382
383
    assert manager.add_adapter(model_lora3)
    assert manager.activate_adapter(3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(4)
    assert set(manager.list_adapters()) == {3, 4}
384
    with pytest.raises(ValueError):
385
386
        assert manager.pin_adapter(1)
    assert manager.pin_adapter(3)
387
    # Remove manually
388
389
    assert manager.remove_adapter(3)
    assert not manager.remove_adapter(3)
390

391
    assert set(manager.list_adapters()) == {4}
392
393
394
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 4

395
396
397
398
    assert manager.add_adapter(model_lora1)
    assert manager.pin_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
399

400
    assert set(manager.list_adapters()) == {1, 2}
401
402
403
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2

404
405
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == {1}
406
407
408
409
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] is None

    with pytest.raises(RuntimeError):
410
        assert manager.remove_oldest_adapter()
411

412
    assert set(manager.list_adapters()) == {1}
413
414
    assert manager.punica_wrapper.device == device
    assert manager.device == device
415

416

417
@pytest.mark.parametrize("device", DEVICES)
418
def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
419
                                          sql_lora_files, device):
420
    lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
421
    worker_adapter_manager = LRUCacheWorkerLoRAManager(
Terry's avatar
Terry committed
422
        4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
423
        lora_config.lora_extra_vocab_size, lora_config, device,
Terry's avatar
Terry committed
424
        EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
425
426
    worker_adapter_manager.create_lora_manager(
        llama_2_7b_model_extra_embeddings)
427
428

    mapping = LoRAMapping([], [])
429
    worker_adapter_manager.set_active_adapters([
430
431
432
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files)
    ], mapping)
433
434
435
    assert worker_adapter_manager.list_adapters() == {1, 2}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
436

437
    worker_adapter_manager.set_active_adapters([
438
439
440
441
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("3", 3, sql_lora_files),
        LoRARequest("4", 4, sql_lora_files)
    ], mapping)
442
443
444
445
446
    assert worker_adapter_manager.list_adapters() == {1, 2, 3, 4}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 3
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
447

448
    worker_adapter_manager.set_active_adapters([
449
450
451
452
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files),
        LoRARequest("5", 5, sql_lora_files)
    ], mapping)
453
454
455
456
457
    assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
458

459
    worker_adapter_manager.set_active_adapters([
460
461
462
463
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files)
    ], mapping)
464
465
466
467
468
    assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
469

470
    worker_adapter_manager.set_active_adapters([
471
472
473
474
        LoRARequest("6", 6, sql_lora_files),
        LoRARequest("7", 7, sql_lora_files),
        LoRARequest("8", 8, sql_lora_files)
    ], mapping)
475
476
477
478
479
    assert worker_adapter_manager.list_adapters() == {1, 6, 7, 8}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 7
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 8
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 6
480
481
482

    # Over capacity
    with pytest.raises(RuntimeError):
483
        worker_adapter_manager.set_active_adapters([
484
485
486
487
488
489
490
            LoRARequest("10", 10, sql_lora_files),
            LoRARequest("11", 11, sql_lora_files),
            LoRARequest("12", 12, sql_lora_files),
            LoRARequest("13", 13, sql_lora_files),
            LoRARequest("14", 14, sql_lora_files)
        ], mapping)

491
492
493
    assert worker_adapter_manager.device == device
    assert (worker_adapter_manager._adapter_manager.punica_wrapper.device ==
            device)
494

495

496
@pytest.mark.parametrize("device", DEVICES)
497
def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
498
                                sql_lora_files, device):
499
500
    # Should remove every LoRA not specified in the request.
    lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
501
    worker_adapter_manager = WorkerLoRAManager(
Terry's avatar
Terry committed
502
        4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
503
        lora_config.lora_extra_vocab_size, lora_config, device,
Terry's avatar
Terry committed
504
        EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
505
506
    worker_adapter_manager.create_lora_manager(
        llama_2_7b_model_extra_embeddings)
507
508

    mapping = LoRAMapping([], [])
509
    worker_adapter_manager.set_active_adapters([
510
511
512
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files)
    ], mapping)
513
514
515
    assert worker_adapter_manager.list_adapters() == {1, 2}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
516

517
    worker_adapter_manager.set_active_adapters([
518
519
520
521
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("3", 3, sql_lora_files),
        LoRARequest("4", 4, sql_lora_files)
    ], mapping)
522
523
524
525
    assert worker_adapter_manager.list_adapters() == {1, 3, 4}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 3
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 4
526

527
    worker_adapter_manager.set_active_adapters([
528
529
530
531
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files),
        LoRARequest("5", 5, sql_lora_files)
    ], mapping)
532
533
534
535
    assert worker_adapter_manager.list_adapters() == {1, 2, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
536

537
    worker_adapter_manager.set_active_adapters([
538
539
540
541
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files)
    ], mapping)
542
543
544
545
    assert worker_adapter_manager.list_adapters() == {1}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] is None
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] is None
546

547
    worker_adapter_manager.set_active_adapters([
548
549
550
551
        LoRARequest("6", 6, sql_lora_files),
        LoRARequest("7", 7, sql_lora_files),
        LoRARequest("8", 8, sql_lora_files)
    ], mapping)
552
553
554
555
    assert worker_adapter_manager.list_adapters() == {6, 7, 8}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 8
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 6
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 7
556
557
558

    # Over capacity
    with pytest.raises(RuntimeError):
559
        worker_adapter_manager.set_active_adapters([
560
561
562
563
564
565
566
            LoRARequest("10", 10, sql_lora_files),
            LoRARequest("11", 11, sql_lora_files),
            LoRARequest("12", 12, sql_lora_files),
            LoRARequest("13", 13, sql_lora_files),
            LoRARequest("14", 14, sql_lora_files)
        ], mapping)

567
568
569
570
    assert worker_adapter_manager.device == device
    assert (worker_adapter_manager._adapter_manager.punica_wrapper.device ==
            device)

571

572
@pytest.mark.parametrize("device", DEVICES)
573
def test_packed_loras(dist_init, dummy_model_gate_up, device):
574
    model = dummy_model_gate_up
Terry's avatar
Terry committed
575
576
577
578
579
580
581
    model.supported_lora_modules = ["gate_up_proj"]
    model.packed_modules_mapping = {
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
582
583
584
585
    model_lora = create_packed_lora(
        1,
        model,
        module_name="gate_up_proj",
586
587
        replaced_module_names=["gate_proj", "up_proj"],
        device=device)
588
589
590
591
592
    model_lora1 = create_packed_lora(
        2,
        model,
        module_name="gate_up_proj",
        replaced_module_names=["gate_proj", "up_proj"],
593
        device=device,
594
595
596
        empty_replaced_module_name="gate_proj",
    )

597
598
599
600
601
602
603
604
    manager = LoRAModelManager(model,
                               2,
                               2,
                               2,
                               LoRAConfig(max_lora_rank=8,
                                          max_cpu_loras=2,
                                          max_loras=2),
                               device=device)
605
606
607
608
    model = manager.model

    assert isinstance(model.get_submodule("gate_up_proj"),
                      MergedColumnParallelLinearWithLoRA)
609
610
    assert manager.add_adapter(model_lora)
    assert manager.add_adapter(model_lora1)
611
612
613
614

    packed_lora = model_lora.get_lora("gate_up_proj")
    assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights)

615
616
617
618
619
620
621
622
    torch.testing.assert_close(packed_lora.lora_a[0],
                               model_lora.get_lora("gate_proj").lora_a)
    torch.testing.assert_close(packed_lora.lora_b[0],
                               model_lora.get_lora("gate_proj").lora_b)
    torch.testing.assert_close(packed_lora.lora_a[1],
                               model_lora.get_lora("up_proj").lora_a)
    torch.testing.assert_close(packed_lora.lora_b[1],
                               model_lora.get_lora("up_proj").lora_b)
623
624
625
626
627
628

    packed_lora1 = model_lora1.get_lora("gate_up_proj")
    assert packed_lora1 and isinstance(packed_lora1, PackedLoRALayerWeights)

    assert packed_lora1.lora_a[0] is None
    assert packed_lora1.lora_b[0] is None
629
630
631
632
    torch.testing.assert_close(packed_lora1.lora_a[1],
                               model_lora1.get_lora("up_proj").lora_a)
    torch.testing.assert_close(packed_lora1.lora_b[1],
                               model_lora1.get_lora("up_proj").lora_b)