test_lora_manager.py 26.8 KB
Newer Older
1
import json
2
import os
3
from typing import Dict, List
4
5
6
7
8
9
10
11

import pytest
import torch
from safetensors.torch import load_file
from torch import nn

from vllm.config import LoRAConfig
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
12
13
                              MergedColumnParallelLinearWithLoRA,
                              RowParallelLinearWithLoRA)
14
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
15
16
from vllm.lora.models import (LoRAMapping, LoRAModel, LoRAModelManager,
                              LRUCacheLoRAModelManager)
17
from vllm.lora.peft_helper import PEFTHelper
18
19
20
21
22
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager,
                                      WorkerLoRAManager)
from vllm.model_executor.layers.linear import RowParallelLinear

Terry's avatar
Terry committed
23
24
25
26
27
28
29
EMBEDDING_MODULES = {
    "embed_tokens": "input_embeddings",
    "lm_head": "output_embeddings",
}

EMBEDDING_PADDING_MODULES = ["lm_head"]

30
31
32
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
33

34

35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def test_peft_helper(sql_lora_files):
    lora_config_path = os.path.join(sql_lora_files, "adapter_config.json")
    with open(lora_config_path) as f:
        config = json.load(f)
    peft_helper = PEFTHelper.from_dict(config)
    assert peft_helper.r == 8
    assert peft_helper.lora_alpha == 16
    assert peft_helper.target_modules == [
        "q_proj",
        "v_proj",
        "k_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
        "embed_tokens",
        "lm_head",
    ]

    expected_error = "vLLM only supports modules_to_save being None."
    with pytest.raises(ValueError, match=expected_error):
        config = dict(
            r=8,
            lora_alpha=16,
            target_modules=["gate_proj"],
            modules_to_save=["lm_head"],
        )
        PEFTHelper.from_dict(config)
    expected_error = "vLLM does not yet support RSLoRA."
    with pytest.raises(ValueError, match=expected_error):
        config = dict(r=8,
                      lora_alpha=16,
                      target_modules=["gate_proj"],
                      use_rslora=True)
        PEFTHelper.from_dict(config)

    expected_error = "vLLM does not yet support DoRA."
    with pytest.raises(ValueError, match=expected_error):
        config = dict(r=8,
                      lora_alpha=16,
                      target_modules=["gate_proj"],
                      use_dora=True)
        PEFTHelper.from_dict(config)


80
81
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_from_lora_tensors(sql_lora_files, device):
82
83
84
85
    tensors = load_file(
        os.path.join(sql_lora_files, "adapter_model.safetensors"))
    new_embeddings = load_file(
        os.path.join(sql_lora_files, "new_embeddings.safetensors"))
86
87
88
89
90
91

    lora_config_path = os.path.join(sql_lora_files, "adapter_config.json")
    with open(lora_config_path) as f:
        config = json.load(f)

    peft_helper = PEFTHelper.from_dict(config)
Terry's avatar
Terry committed
92
93
94
    lora_model = LoRAModel.from_lora_tensors(
        1,
        tensors,
95
96
        peft_helper=peft_helper,
        device=device,
Terry's avatar
Terry committed
97
98
99
        embeddings=new_embeddings,
        embedding_modules=EMBEDDING_MODULES,
        embedding_padding_modules=EMBEDDING_PADDING_MODULES)
100
101
102
103
104
105
    for module_name, lora in lora_model.loras.items():
        assert lora.module_name == module_name
        assert lora.rank == 8
        assert lora.lora_alpha == 16
        assert lora.lora_a is not None
        assert lora.lora_b is not None
106
107
        assert lora.lora_a.device == torch.device(device)
        assert lora.lora_b.device == torch.device(device)
108
109
110
111
112
113
114
115
116
117
118
119
120
121
        assert (lora.lora_a.shape[1] == lora.lora_b.shape[0]
                ), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}"
        assert lora.lora_a.shape[1] == 8
        embeddings_module = next(
            (k for k in EMBEDDING_MODULES if k in module_name), None)
        if embeddings_module:
            assert torch.equal(
                lora.embeddings_tensor,
                new_embeddings[EMBEDDING_MODULES[embeddings_module]].to(
                    device=lora.embeddings_tensor.device))
        else:
            assert lora.embeddings_tensor is None


122
123
def create_lora(lora_id: int, model: nn.Module, sub_modules: List[str],
                device: torch.device) -> LoRAModel:
124
    loras: Dict[str, LoRALayerWeights] = {}
125
126
127
128
129
130
    for name in sub_modules:
        w = model.get_submodule(name).weight
        loras[name] = LoRALayerWeights(
            name,
            8,
            16,
131
132
            torch.rand([w.shape[1], 8], device=device),
            torch.rand([8, w.shape[0]], device=device),
133
134
135
136
137
138
139
140
141
        )
    return LoRAModel(lora_id, 8, loras)


def create_packed_lora(
    lora_id: int,
    model: nn.Module,
    module_name,
    replaced_module_names,
142
    device: torch.device,
143
144
145
    empty_replaced_module_name=None,
) -> LoRAModel:
    w = model.get_submodule(module_name).weight
146
    loras: Dict[str, LoRALayerWeights] = {}
147
148
149
150
151
152
153
    for replaced_module_name in replaced_module_names:
        if replaced_module_name == empty_replaced_module_name:
            continue
        loras[replaced_module_name] = LoRALayerWeights(
            replaced_module_name,
            8,
            16,
154
            torch.rand([w.shape[1], 8], device=device),
155
            torch.rand([8, w.shape[0] // len(replaced_module_names)],
156
                       device=device),
157
158
159
160
161
162
        )
    return LoRAModel(lora_id, 8, loras)


def test_replace_submodules(dist_init, dummy_model):
    model = dummy_model
Terry's avatar
Terry committed
163
164
165
166
    model.supported_lora_modules = ["dense1", "layer1.dense2"]
    model.packed_modules_mapping = {}
    manager = LoRAModelManager(
        model, 1, 1, 1,
167
168
        LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8),
        torch.device("cuda"))
169
170
171
172
173
174
175
176
177
178
179
    model = manager.model

    assert isinstance(model.get_submodule("dense1"),
                      ColumnParallelLinearWithLoRA)
    assert isinstance(model.get_submodule("layer1.dense1"),
                      ColumnParallelLinearWithLoRA)
    assert isinstance(model.get_submodule("dense2"), RowParallelLinear)
    assert isinstance(model.get_submodule("layer1.dense2"),
                      RowParallelLinearWithLoRA)


180
181
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_lora_model_manager(dist_init, dummy_model, device):
182
    model = dummy_model
Terry's avatar
Terry committed
183
184
    model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
    model.packed_modules_mapping = {}
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
    model_lora1 = create_lora(1,
                              model, ["layer1.dense1", "dense2", "lm_head"],
                              device=device)
    model_lora2 = create_lora(2,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora3 = create_lora(3,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    manager = LoRAModelManager(model,
                               2,
                               2,
                               2,
                               LoRAConfig(max_lora_rank=8,
                                          max_cpu_loras=3,
                                          max_loras=2),
                               device=device)
202
    assert all(x is None for x in manager.lora_index_to_id)
203
204
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
205
    assert manager.lora_index_to_id[0] == 1
206
207
208
209
    assert not manager.add_adapter(model_lora1)
    assert not manager.activate_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
210
211
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
212
213
214
    assert not manager.add_adapter(model_lora2)
    assert not manager.activate_adapter(2)
    assert manager.add_adapter(model_lora3)
215
216
217
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
    with pytest.raises(ValueError):
218
        assert manager.activate_adapter(3)
219
220
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
221
    assert manager.remove_adapter(model_lora2.id)
222
    assert manager.lora_index_to_id[1] is None
223
224
225
226
    assert not manager.remove_adapter(model_lora2.id)
    assert manager.remove_adapter(model_lora1.id)
    assert not manager.remove_adapter(model_lora1.id)
    assert manager.add_adapter(model_lora1)
227
228
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] is None
229
230
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(3)
231
232
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] is None
233
    assert manager.activate_adapter(2)
234
235
236
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2

237
238
    assert manager.device == device
    assert manager.punica_wrapper.device == device
239

240
241
242

@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
243
    model = dummy_model
Terry's avatar
Terry committed
244
245
    model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
    model.packed_modules_mapping = {}
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
    model_lora1 = create_lora(1,
                              model, ["layer1.dense1", "dense2", "lm_head"],
                              device=device)
    model_lora2 = create_lora(2,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora3 = create_lora(3,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    manager = LRUCacheLoRAModelManager(model,
                                       2,
                                       2,
                                       2,
                                       LoRAConfig(max_lora_rank=8,
                                                  max_cpu_loras=3,
                                                  max_loras=2),
                                       device=device)
263
    assert all(x is None for x in manager.lora_index_to_id)
264
265
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
266
    assert manager.lora_index_to_id[0] == 1
267
268
269
270
    assert not manager.add_adapter(model_lora1)
    assert not manager.activate_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
271
272
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
273
274
275
    assert not manager.add_adapter(model_lora2)
    assert not manager.activate_adapter(2)
    assert manager.add_adapter(model_lora3)
276
277
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
278
    assert manager.activate_adapter(3)
279
280
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2
281
    assert manager.remove_adapter(model_lora2.id)
282
    assert manager.lora_index_to_id[1] is None
283
284
285
286
287
    assert not manager.remove_adapter(model_lora2.id)
    assert manager.remove_adapter(model_lora1.id)
    assert not manager.remove_adapter(model_lora1.id)
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
288
289
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
290
291
    assert manager.add_adapter(model_lora2)
    assert manager.deactivate_adapter(3)
292
293
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 1
294
    assert manager.activate_adapter(2)
295
296
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
297
    assert manager.activate_adapter(3)
298
299
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 3
300
    assert manager.pin_adapter(2)
301
302
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 3
303
    assert manager.activate_adapter(1)
304
305
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
306
    assert manager.deactivate_adapter(2)
307
308
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 1
309
    assert manager.activate_adapter(3)
310
311
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
312
313
    assert manager.pin_adapter(3)
    assert manager.pin_adapter(1)
314
    with pytest.raises(RuntimeError):
315
        assert manager.pin_adapter(2)
316
317
318
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
    with pytest.raises(RuntimeError):
319
        assert manager.activate_adapter(2)
320

321
322
    assert manager.deactivate_adapter(3)
    assert manager.pin_adapter(2)
323
324
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
325
    assert manager.remove_adapter(3)
326
    with pytest.raises(ValueError):
327
        assert manager.pin_adapter(3)
328

329
330
331
    assert manager.punica_wrapper.device == device
    assert manager.device == device

332

333
334
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_lru_lora_model_manager(dist_init, dummy_model, device):
335
336
337
    # This tests just the LRU cache functionality, everything else is
    # tested in test_lora_model_manager
    model = dummy_model
Terry's avatar
Terry committed
338
339
    model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
    model.packed_modules_mapping = {}
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
    model_lora1 = create_lora(1,
                              model, ["layer1.dense1", "dense2", "lm_head"],
                              device=device)
    model_lora2 = create_lora(2,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora3 = create_lora(3,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora4 = create_lora(4,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    manager = LRUCacheLoRAModelManager(model,
                                       2,
                                       2,
                                       2,
                                       LoRAConfig(max_lora_rank=8,
                                                  max_cpu_loras=2,
                                                  max_loras=2),
                                       device=device)
360
361
362
363

    assert all(x is None for x in manager.lora_index_to_id)

    # Add up to capacity
364
365
366
367
    assert manager.add_adapter(model_lora1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(1)
    assert manager.activate_adapter(2)
368

369
    assert set(manager.list_adapters()) == {1, 2}
370
371
372
373
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2

    # Add over capacity
374
375
376
377
    assert manager.add_adapter(model_lora3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(3)
    assert manager.activate_adapter(4)
378

379
    assert set(manager.list_adapters()) == {3, 4}
380
381
382
383
384
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 4

    # Add 3 again to move it to the top and then add 2
    # should return false since it's in already
385
386
387
388
    assert not manager.add_adapter(model_lora3)
    assert not manager.activate_adapter(3)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
389

390
    assert set(manager.list_adapters()) == {3, 2}
391
392
393
394
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2

    # Remove manually
395
396
    assert manager.remove_adapter(3)
    assert not manager.remove_adapter(3)
397

398
    assert set(manager.list_adapters()) == {2}
399
400
401
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 2

402
403
404
405
    assert manager.add_adapter(model_lora3)
    assert manager.activate_adapter(3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(4)
406

407
    assert set(manager.list_adapters()) == {3, 4}
408
409
410
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 4

411
412
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == {4}
413
414
415
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 4

416
417
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == set()
418
419
    assert all(x is None for x in manager.lora_index_to_id)

420
421
    assert not manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == set()
422
423
    assert all(x is None for x in manager.lora_index_to_id)

424
    # pinning
425
426
427
428
429
    assert manager.add_adapter(model_lora3)
    assert manager.activate_adapter(3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(4)
    assert set(manager.list_adapters()) == {3, 4}
430
    with pytest.raises(ValueError):
431
432
        assert manager.pin_adapter(1)
    assert manager.pin_adapter(3)
433
    # Remove manually
434
435
    assert manager.remove_adapter(3)
    assert not manager.remove_adapter(3)
436

437
    assert set(manager.list_adapters()) == {4}
438
439
440
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 4

441
442
443
444
    assert manager.add_adapter(model_lora1)
    assert manager.pin_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
445

446
    assert set(manager.list_adapters()) == {1, 2}
447
448
449
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2

450
451
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == {1}
452
453
454
455
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] is None

    with pytest.raises(RuntimeError):
456
        assert manager.remove_oldest_adapter()
457

458
    assert set(manager.list_adapters()) == {1}
459
460
    assert manager.punica_wrapper.device == device
    assert manager.device == device
461

462

463
@pytest.mark.parametrize("device", CUDA_DEVICES)
464
def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
465
                                          sql_lora_files, device):
466
    lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
467
    worker_adapter_manager = LRUCacheWorkerLoRAManager(
Terry's avatar
Terry committed
468
        4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
469
        lora_config.lora_extra_vocab_size, lora_config, device,
Terry's avatar
Terry committed
470
        EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
471
472
    worker_adapter_manager.create_lora_manager(
        llama_2_7b_model_extra_embeddings)
473
474

    mapping = LoRAMapping([], [])
475
    worker_adapter_manager.set_active_adapters([
476
477
478
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files)
    ], mapping)
479
480
481
    assert worker_adapter_manager.list_adapters() == {1, 2}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
482

483
    worker_adapter_manager.set_active_adapters([
484
485
486
487
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("3", 3, sql_lora_files),
        LoRARequest("4", 4, sql_lora_files)
    ], mapping)
488
489
490
491
492
    assert worker_adapter_manager.list_adapters() == {1, 2, 3, 4}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 3
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
493

494
    worker_adapter_manager.set_active_adapters([
495
496
497
498
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files),
        LoRARequest("5", 5, sql_lora_files)
    ], mapping)
499
500
501
502
503
    assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
504

505
    worker_adapter_manager.set_active_adapters([
506
507
508
509
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files)
    ], mapping)
510
511
512
513
514
    assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
515

516
    worker_adapter_manager.set_active_adapters([
517
518
519
520
        LoRARequest("6", 6, sql_lora_files),
        LoRARequest("7", 7, sql_lora_files),
        LoRARequest("8", 8, sql_lora_files)
    ], mapping)
521
522
523
524
525
    assert worker_adapter_manager.list_adapters() == {1, 6, 7, 8}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 7
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 8
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 6
526
527
528

    # Over capacity
    with pytest.raises(RuntimeError):
529
        worker_adapter_manager.set_active_adapters([
530
531
532
533
534
535
536
            LoRARequest("10", 10, sql_lora_files),
            LoRARequest("11", 11, sql_lora_files),
            LoRARequest("12", 12, sql_lora_files),
            LoRARequest("13", 13, sql_lora_files),
            LoRARequest("14", 14, sql_lora_files)
        ], mapping)

537
538
539
    assert worker_adapter_manager.device == device
    assert (worker_adapter_manager._adapter_manager.punica_wrapper.device ==
            device)
540

541
542

@pytest.mark.parametrize("device", CUDA_DEVICES)
543
def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
544
                                sql_lora_files, device):
545
546
    # Should remove every LoRA not specified in the request.
    lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
547
    worker_adapter_manager = WorkerLoRAManager(
Terry's avatar
Terry committed
548
        4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
549
        lora_config.lora_extra_vocab_size, lora_config, device,
Terry's avatar
Terry committed
550
        EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
551
552
    worker_adapter_manager.create_lora_manager(
        llama_2_7b_model_extra_embeddings)
553
554

    mapping = LoRAMapping([], [])
555
    worker_adapter_manager.set_active_adapters([
556
557
558
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files)
    ], mapping)
559
560
561
    assert worker_adapter_manager.list_adapters() == {1, 2}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
562

563
    worker_adapter_manager.set_active_adapters([
564
565
566
567
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("3", 3, sql_lora_files),
        LoRARequest("4", 4, sql_lora_files)
    ], mapping)
568
569
570
571
    assert worker_adapter_manager.list_adapters() == {1, 3, 4}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 3
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 4
572

573
    worker_adapter_manager.set_active_adapters([
574
575
576
577
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files),
        LoRARequest("5", 5, sql_lora_files)
    ], mapping)
578
579
580
581
    assert worker_adapter_manager.list_adapters() == {1, 2, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
582

583
    worker_adapter_manager.set_active_adapters([
584
585
586
587
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files)
    ], mapping)
588
589
590
591
    assert worker_adapter_manager.list_adapters() == {1}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] is None
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] is None
592

593
    worker_adapter_manager.set_active_adapters([
594
595
596
597
        LoRARequest("6", 6, sql_lora_files),
        LoRARequest("7", 7, sql_lora_files),
        LoRARequest("8", 8, sql_lora_files)
    ], mapping)
598
599
600
601
    assert worker_adapter_manager.list_adapters() == {6, 7, 8}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 8
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 6
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 7
602
603
604

    # Over capacity
    with pytest.raises(RuntimeError):
605
        worker_adapter_manager.set_active_adapters([
606
607
608
609
610
611
612
            LoRARequest("10", 10, sql_lora_files),
            LoRARequest("11", 11, sql_lora_files),
            LoRARequest("12", 12, sql_lora_files),
            LoRARequest("13", 13, sql_lora_files),
            LoRARequest("14", 14, sql_lora_files)
        ], mapping)

613
614
615
616
    assert worker_adapter_manager.device == device
    assert (worker_adapter_manager._adapter_manager.punica_wrapper.device ==
            device)

617

618
619
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_packed_loras(dist_init, dummy_model_gate_up, device):
620
    model = dummy_model_gate_up
Terry's avatar
Terry committed
621
622
623
624
625
626
627
    model.supported_lora_modules = ["gate_up_proj"]
    model.packed_modules_mapping = {
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
628
629
630
631
    model_lora = create_packed_lora(
        1,
        model,
        module_name="gate_up_proj",
632
633
        replaced_module_names=["gate_proj", "up_proj"],
        device=device)
634
635
636
637
638
    model_lora1 = create_packed_lora(
        2,
        model,
        module_name="gate_up_proj",
        replaced_module_names=["gate_proj", "up_proj"],
639
        device=device,
640
641
642
        empty_replaced_module_name="gate_proj",
    )

643
644
645
646
647
648
649
650
    manager = LoRAModelManager(model,
                               2,
                               2,
                               2,
                               LoRAConfig(max_lora_rank=8,
                                          max_cpu_loras=2,
                                          max_loras=2),
                               device=device)
651
652
653
654
    model = manager.model

    assert isinstance(model.get_submodule("gate_up_proj"),
                      MergedColumnParallelLinearWithLoRA)
655
656
    assert manager.add_adapter(model_lora)
    assert manager.add_adapter(model_lora1)
657
658
659
660

    packed_lora = model_lora.get_lora("gate_up_proj")
    assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights)

661
662
663
664
665
666
667
668
    torch.testing.assert_close(packed_lora.lora_a[0],
                               model_lora.get_lora("gate_proj").lora_a)
    torch.testing.assert_close(packed_lora.lora_b[0],
                               model_lora.get_lora("gate_proj").lora_b)
    torch.testing.assert_close(packed_lora.lora_a[1],
                               model_lora.get_lora("up_proj").lora_a)
    torch.testing.assert_close(packed_lora.lora_b[1],
                               model_lora.get_lora("up_proj").lora_b)
669
670
671
672
673
674

    packed_lora1 = model_lora1.get_lora("gate_up_proj")
    assert packed_lora1 and isinstance(packed_lora1, PackedLoRALayerWeights)

    assert packed_lora1.lora_a[0] is None
    assert packed_lora1.lora_b[0] is None
675
676
677
678
    torch.testing.assert_close(packed_lora1.lora_a[1],
                               model_lora1.get_lora("up_proj").lora_a)
    torch.testing.assert_close(packed_lora1.lora_b[1],
                               model_lora1.get_lora("up_proj").lora_b)