test_lora_manager.py 27.1 KB
Newer Older
1
import json
2
import math
3
import os
4
from typing import Dict, List
5
6
7
8
9
10
11
12

import pytest
import torch
from safetensors.torch import load_file
from torch import nn

from vllm.config import LoRAConfig
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
13
14
                              MergedColumnParallelLinearWithLoRA,
                              RowParallelLinearWithLoRA)
15
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
16
17
from vllm.lora.models import (LoRAMapping, LoRAModel, LoRAModelManager,
                              LRUCacheLoRAModelManager)
18
from vllm.lora.peft_helper import PEFTHelper
19
20
21
22
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager,
                                      WorkerLoRAManager)
from vllm.model_executor.layers.linear import RowParallelLinear
23
from vllm.platforms import current_platform
24

Terry's avatar
Terry committed
25
26
27
28
29
30
31
EMBEDDING_MODULES = {
    "embed_tokens": "input_embeddings",
    "lm_head": "output_embeddings",
}

EMBEDDING_PADDING_MODULES = ["lm_head"]

32
DEVICES = ([
33
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
34
] if current_platform.is_cuda_alike() else ["cpu"])
35

36

37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def test_peft_helper(sql_lora_files):
    lora_config_path = os.path.join(sql_lora_files, "adapter_config.json")
    with open(lora_config_path) as f:
        config = json.load(f)
    peft_helper = PEFTHelper.from_dict(config)
    assert peft_helper.r == 8
    assert peft_helper.lora_alpha == 16
    assert peft_helper.target_modules == [
        "q_proj",
        "v_proj",
        "k_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
        "embed_tokens",
        "lm_head",
    ]
55
56
57
58
59
60
61
62
63
64
65
66
    scaling = peft_helper.lora_alpha / peft_helper.r
    assert abs(peft_helper.vllm_lora_scaling_factor - scaling) < 1e-3

    # test RSLoRA
    config = dict(r=8,
                  lora_alpha=16,
                  target_modules=["gate_proj"],
                  use_rslora=True)
    peft_helper = PEFTHelper.from_dict(config)

    scaling = peft_helper.lora_alpha / math.sqrt(peft_helper.r)
    assert abs(peft_helper.vllm_lora_scaling_factor - scaling) < 1e-3
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86

    expected_error = "vLLM only supports modules_to_save being None."
    with pytest.raises(ValueError, match=expected_error):
        config = dict(
            r=8,
            lora_alpha=16,
            target_modules=["gate_proj"],
            modules_to_save=["lm_head"],
        )
        PEFTHelper.from_dict(config)

    expected_error = "vLLM does not yet support DoRA."
    with pytest.raises(ValueError, match=expected_error):
        config = dict(r=8,
                      lora_alpha=16,
                      target_modules=["gate_proj"],
                      use_dora=True)
        PEFTHelper.from_dict(config)


87
@pytest.mark.parametrize("device", DEVICES)
88
def test_from_lora_tensors(sql_lora_files, device):
89
90
91
92
    tensors = load_file(
        os.path.join(sql_lora_files, "adapter_model.safetensors"))
    new_embeddings = load_file(
        os.path.join(sql_lora_files, "new_embeddings.safetensors"))
93
94
95
96
97
98

    lora_config_path = os.path.join(sql_lora_files, "adapter_config.json")
    with open(lora_config_path) as f:
        config = json.load(f)

    peft_helper = PEFTHelper.from_dict(config)
Terry's avatar
Terry committed
99
100
101
    lora_model = LoRAModel.from_lora_tensors(
        1,
        tensors,
102
103
        peft_helper=peft_helper,
        device=device,
Terry's avatar
Terry committed
104
105
106
        embeddings=new_embeddings,
        embedding_modules=EMBEDDING_MODULES,
        embedding_padding_modules=EMBEDDING_PADDING_MODULES)
107
108
109
110
111
112
    for module_name, lora in lora_model.loras.items():
        assert lora.module_name == module_name
        assert lora.rank == 8
        assert lora.lora_alpha == 16
        assert lora.lora_a is not None
        assert lora.lora_b is not None
113
114
        assert lora.lora_a.device == torch.device(device)
        assert lora.lora_b.device == torch.device(device)
115
116
117
118
119
120
121
122
123
124
125
126
127
128
        assert (lora.lora_a.shape[1] == lora.lora_b.shape[0]
                ), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}"
        assert lora.lora_a.shape[1] == 8
        embeddings_module = next(
            (k for k in EMBEDDING_MODULES if k in module_name), None)
        if embeddings_module:
            assert torch.equal(
                lora.embeddings_tensor,
                new_embeddings[EMBEDDING_MODULES[embeddings_module]].to(
                    device=lora.embeddings_tensor.device))
        else:
            assert lora.embeddings_tensor is None


129
130
def create_lora(lora_id: int, model: nn.Module, sub_modules: List[str],
                device: torch.device) -> LoRAModel:
131
    loras: Dict[str, LoRALayerWeights] = {}
132
133
134
135
136
137
    for name in sub_modules:
        w = model.get_submodule(name).weight
        loras[name] = LoRALayerWeights(
            name,
            8,
            16,
138
139
            torch.rand([w.shape[1], 8], device=device),
            torch.rand([8, w.shape[0]], device=device),
140
141
142
143
144
145
146
147
148
        )
    return LoRAModel(lora_id, 8, loras)


def create_packed_lora(
    lora_id: int,
    model: nn.Module,
    module_name,
    replaced_module_names,
149
    device: torch.device,
150
151
152
    empty_replaced_module_name=None,
) -> LoRAModel:
    w = model.get_submodule(module_name).weight
153
    loras: Dict[str, LoRALayerWeights] = {}
154
155
156
157
158
159
160
    for replaced_module_name in replaced_module_names:
        if replaced_module_name == empty_replaced_module_name:
            continue
        loras[replaced_module_name] = LoRALayerWeights(
            replaced_module_name,
            8,
            16,
161
            torch.rand([w.shape[1], 8], device=device),
162
            torch.rand([8, w.shape[0] // len(replaced_module_names)],
163
                       device=device),
164
165
166
167
168
169
        )
    return LoRAModel(lora_id, 8, loras)


def test_replace_submodules(dist_init, dummy_model):
    model = dummy_model
Terry's avatar
Terry committed
170
171
172
173
    model.supported_lora_modules = ["dense1", "layer1.dense2"]
    model.packed_modules_mapping = {}
    manager = LoRAModelManager(
        model, 1, 1, 1,
174
        LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8),
175
        torch.device(DEVICES[0]))
176
177
178
179
180
181
182
183
184
185
186
    model = manager.model

    assert isinstance(model.get_submodule("dense1"),
                      ColumnParallelLinearWithLoRA)
    assert isinstance(model.get_submodule("layer1.dense1"),
                      ColumnParallelLinearWithLoRA)
    assert isinstance(model.get_submodule("dense2"), RowParallelLinear)
    assert isinstance(model.get_submodule("layer1.dense2"),
                      RowParallelLinearWithLoRA)


187
@pytest.mark.parametrize("device", DEVICES)
188
def test_lora_model_manager(dist_init, dummy_model, device):
189
    model = dummy_model
Terry's avatar
Terry committed
190
191
    model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
    model.packed_modules_mapping = {}
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
    model_lora1 = create_lora(1,
                              model, ["layer1.dense1", "dense2", "lm_head"],
                              device=device)
    model_lora2 = create_lora(2,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora3 = create_lora(3,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    manager = LoRAModelManager(model,
                               2,
                               2,
                               2,
                               LoRAConfig(max_lora_rank=8,
                                          max_cpu_loras=3,
                                          max_loras=2),
                               device=device)
209
    assert all(x is None for x in manager.lora_index_to_id)
210
211
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
212
    assert manager.lora_index_to_id[0] == 1
213
214
215
216
    assert not manager.add_adapter(model_lora1)
    assert not manager.activate_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
217
218
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
219
220
221
    assert not manager.add_adapter(model_lora2)
    assert not manager.activate_adapter(2)
    assert manager.add_adapter(model_lora3)
222
223
224
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
    with pytest.raises(ValueError):
225
        assert manager.activate_adapter(3)
226
227
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
228
    assert manager.remove_adapter(model_lora2.id)
229
    assert manager.lora_index_to_id[1] is None
230
231
232
233
    assert not manager.remove_adapter(model_lora2.id)
    assert manager.remove_adapter(model_lora1.id)
    assert not manager.remove_adapter(model_lora1.id)
    assert manager.add_adapter(model_lora1)
234
235
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] is None
236
237
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(3)
238
239
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] is None
240
    assert manager.activate_adapter(2)
241
242
243
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2

244
245
    assert manager.device == device
    assert manager.punica_wrapper.device == device
246

247

248
@pytest.mark.parametrize("device", DEVICES)
249
def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
250
    model = dummy_model
Terry's avatar
Terry committed
251
252
    model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
    model.packed_modules_mapping = {}
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
    model_lora1 = create_lora(1,
                              model, ["layer1.dense1", "dense2", "lm_head"],
                              device=device)
    model_lora2 = create_lora(2,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora3 = create_lora(3,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    manager = LRUCacheLoRAModelManager(model,
                                       2,
                                       2,
                                       2,
                                       LoRAConfig(max_lora_rank=8,
                                                  max_cpu_loras=3,
                                                  max_loras=2),
                                       device=device)
270
    assert all(x is None for x in manager.lora_index_to_id)
271
272
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
273
    assert manager.lora_index_to_id[0] == 1
274
275
276
277
    assert not manager.add_adapter(model_lora1)
    assert not manager.activate_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
278
279
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
280
281
282
    assert not manager.add_adapter(model_lora2)
    assert not manager.activate_adapter(2)
    assert manager.add_adapter(model_lora3)
283
284
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
285
    assert manager.activate_adapter(3)
286
287
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2
288
    assert manager.remove_adapter(model_lora2.id)
289
    assert manager.lora_index_to_id[1] is None
290
291
292
293
294
    assert not manager.remove_adapter(model_lora2.id)
    assert manager.remove_adapter(model_lora1.id)
    assert not manager.remove_adapter(model_lora1.id)
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
295
296
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
297
298
    assert manager.add_adapter(model_lora2)
    assert manager.deactivate_adapter(3)
299
300
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 1
301
    assert manager.activate_adapter(2)
302
303
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
304
    assert manager.activate_adapter(3)
305
306
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 3
307
    assert manager.pin_adapter(2)
308
309
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 3
310
    assert manager.activate_adapter(1)
311
312
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
313
    assert manager.deactivate_adapter(2)
314
315
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 1
316
    assert manager.activate_adapter(3)
317
318
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
319
320
    assert manager.pin_adapter(3)
    assert manager.pin_adapter(1)
321
    with pytest.raises(RuntimeError):
322
        assert manager.pin_adapter(2)
323
324
325
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
    with pytest.raises(RuntimeError):
326
        assert manager.activate_adapter(2)
327

328
329
    assert manager.deactivate_adapter(3)
    assert manager.pin_adapter(2)
330
331
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
332
    assert manager.remove_adapter(3)
333
    with pytest.raises(ValueError):
334
        assert manager.pin_adapter(3)
335

336
337
338
    assert manager.punica_wrapper.device == device
    assert manager.device == device

339

340
@pytest.mark.parametrize("device", DEVICES)
341
def test_lru_lora_model_manager(dist_init, dummy_model, device):
342
343
344
    # This tests just the LRU cache functionality, everything else is
    # tested in test_lora_model_manager
    model = dummy_model
Terry's avatar
Terry committed
345
346
    model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
    model.packed_modules_mapping = {}
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
    model_lora1 = create_lora(1,
                              model, ["layer1.dense1", "dense2", "lm_head"],
                              device=device)
    model_lora2 = create_lora(2,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora3 = create_lora(3,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora4 = create_lora(4,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    manager = LRUCacheLoRAModelManager(model,
                                       2,
                                       2,
                                       2,
                                       LoRAConfig(max_lora_rank=8,
                                                  max_cpu_loras=2,
                                                  max_loras=2),
                                       device=device)
367
368
369
370

    assert all(x is None for x in manager.lora_index_to_id)

    # Add up to capacity
371
372
373
374
    assert manager.add_adapter(model_lora1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(1)
    assert manager.activate_adapter(2)
375

376
    assert set(manager.list_adapters()) == {1, 2}
377
378
379
380
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2

    # Add over capacity
381
382
383
384
    assert manager.add_adapter(model_lora3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(3)
    assert manager.activate_adapter(4)
385

386
    assert set(manager.list_adapters()) == {3, 4}
387
388
389
390
391
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 4

    # Add 3 again to move it to the top and then add 2
    # should return false since it's in already
392
393
394
395
    assert not manager.add_adapter(model_lora3)
    assert not manager.activate_adapter(3)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
396

397
    assert set(manager.list_adapters()) == {3, 2}
398
399
400
401
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2

    # Remove manually
402
403
    assert manager.remove_adapter(3)
    assert not manager.remove_adapter(3)
404

405
    assert set(manager.list_adapters()) == {2}
406
407
408
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 2

409
410
411
412
    assert manager.add_adapter(model_lora3)
    assert manager.activate_adapter(3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(4)
413

414
    assert set(manager.list_adapters()) == {3, 4}
415
416
417
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 4

418
419
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == {4}
420
421
422
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 4

423
424
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == set()
425
426
    assert all(x is None for x in manager.lora_index_to_id)

427
428
    assert not manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == set()
429
430
    assert all(x is None for x in manager.lora_index_to_id)

431
    # pinning
432
433
434
435
436
    assert manager.add_adapter(model_lora3)
    assert manager.activate_adapter(3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(4)
    assert set(manager.list_adapters()) == {3, 4}
437
    with pytest.raises(ValueError):
438
439
        assert manager.pin_adapter(1)
    assert manager.pin_adapter(3)
440
    # Remove manually
441
442
    assert manager.remove_adapter(3)
    assert not manager.remove_adapter(3)
443

444
    assert set(manager.list_adapters()) == {4}
445
446
447
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 4

448
449
450
451
    assert manager.add_adapter(model_lora1)
    assert manager.pin_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
452

453
    assert set(manager.list_adapters()) == {1, 2}
454
455
456
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2

457
458
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == {1}
459
460
461
462
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] is None

    with pytest.raises(RuntimeError):
463
        assert manager.remove_oldest_adapter()
464

465
    assert set(manager.list_adapters()) == {1}
466
467
    assert manager.punica_wrapper.device == device
    assert manager.device == device
468

469

470
@pytest.mark.parametrize("device", DEVICES)
471
def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
472
                                          sql_lora_files, device):
473
    lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
474
    worker_adapter_manager = LRUCacheWorkerLoRAManager(
Terry's avatar
Terry committed
475
        4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
476
        lora_config.lora_extra_vocab_size, lora_config, device,
Terry's avatar
Terry committed
477
        EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
478
479
    worker_adapter_manager.create_lora_manager(
        llama_2_7b_model_extra_embeddings)
480
481

    mapping = LoRAMapping([], [])
482
    worker_adapter_manager.set_active_adapters([
483
484
485
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files)
    ], mapping)
486
487
488
    assert worker_adapter_manager.list_adapters() == {1, 2}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
489

490
    worker_adapter_manager.set_active_adapters([
491
492
493
494
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("3", 3, sql_lora_files),
        LoRARequest("4", 4, sql_lora_files)
    ], mapping)
495
496
497
498
499
    assert worker_adapter_manager.list_adapters() == {1, 2, 3, 4}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 3
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
500

501
    worker_adapter_manager.set_active_adapters([
502
503
504
505
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files),
        LoRARequest("5", 5, sql_lora_files)
    ], mapping)
506
507
508
509
510
    assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
511

512
    worker_adapter_manager.set_active_adapters([
513
514
515
516
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files)
    ], mapping)
517
518
519
520
521
    assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
522

523
    worker_adapter_manager.set_active_adapters([
524
525
526
527
        LoRARequest("6", 6, sql_lora_files),
        LoRARequest("7", 7, sql_lora_files),
        LoRARequest("8", 8, sql_lora_files)
    ], mapping)
528
529
530
531
532
    assert worker_adapter_manager.list_adapters() == {1, 6, 7, 8}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 7
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 8
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 6
533
534
535

    # Over capacity
    with pytest.raises(RuntimeError):
536
        worker_adapter_manager.set_active_adapters([
537
538
539
540
541
542
543
            LoRARequest("10", 10, sql_lora_files),
            LoRARequest("11", 11, sql_lora_files),
            LoRARequest("12", 12, sql_lora_files),
            LoRARequest("13", 13, sql_lora_files),
            LoRARequest("14", 14, sql_lora_files)
        ], mapping)

544
545
546
    assert worker_adapter_manager.device == device
    assert (worker_adapter_manager._adapter_manager.punica_wrapper.device ==
            device)
547

548

549
@pytest.mark.parametrize("device", DEVICES)
550
def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
551
                                sql_lora_files, device):
552
553
    # Should remove every LoRA not specified in the request.
    lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
554
    worker_adapter_manager = WorkerLoRAManager(
Terry's avatar
Terry committed
555
        4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
556
        lora_config.lora_extra_vocab_size, lora_config, device,
Terry's avatar
Terry committed
557
        EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
558
559
    worker_adapter_manager.create_lora_manager(
        llama_2_7b_model_extra_embeddings)
560
561

    mapping = LoRAMapping([], [])
562
    worker_adapter_manager.set_active_adapters([
563
564
565
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files)
    ], mapping)
566
567
568
    assert worker_adapter_manager.list_adapters() == {1, 2}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
569

570
    worker_adapter_manager.set_active_adapters([
571
572
573
574
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("3", 3, sql_lora_files),
        LoRARequest("4", 4, sql_lora_files)
    ], mapping)
575
576
577
578
    assert worker_adapter_manager.list_adapters() == {1, 3, 4}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 3
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 4
579

580
    worker_adapter_manager.set_active_adapters([
581
582
583
584
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files),
        LoRARequest("5", 5, sql_lora_files)
    ], mapping)
585
586
587
588
    assert worker_adapter_manager.list_adapters() == {1, 2, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
589

590
    worker_adapter_manager.set_active_adapters([
591
592
593
594
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files)
    ], mapping)
595
596
597
598
    assert worker_adapter_manager.list_adapters() == {1}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] is None
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] is None
599

600
    worker_adapter_manager.set_active_adapters([
601
602
603
604
        LoRARequest("6", 6, sql_lora_files),
        LoRARequest("7", 7, sql_lora_files),
        LoRARequest("8", 8, sql_lora_files)
    ], mapping)
605
606
607
608
    assert worker_adapter_manager.list_adapters() == {6, 7, 8}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 8
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 6
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 7
609
610
611

    # Over capacity
    with pytest.raises(RuntimeError):
612
        worker_adapter_manager.set_active_adapters([
613
614
615
616
617
618
619
            LoRARequest("10", 10, sql_lora_files),
            LoRARequest("11", 11, sql_lora_files),
            LoRARequest("12", 12, sql_lora_files),
            LoRARequest("13", 13, sql_lora_files),
            LoRARequest("14", 14, sql_lora_files)
        ], mapping)

620
621
622
623
    assert worker_adapter_manager.device == device
    assert (worker_adapter_manager._adapter_manager.punica_wrapper.device ==
            device)

624

625
@pytest.mark.parametrize("device", DEVICES)
626
def test_packed_loras(dist_init, dummy_model_gate_up, device):
627
    model = dummy_model_gate_up
Terry's avatar
Terry committed
628
629
630
631
632
633
634
    model.supported_lora_modules = ["gate_up_proj"]
    model.packed_modules_mapping = {
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
635
636
637
638
    model_lora = create_packed_lora(
        1,
        model,
        module_name="gate_up_proj",
639
640
        replaced_module_names=["gate_proj", "up_proj"],
        device=device)
641
642
643
644
645
    model_lora1 = create_packed_lora(
        2,
        model,
        module_name="gate_up_proj",
        replaced_module_names=["gate_proj", "up_proj"],
646
        device=device,
647
648
649
        empty_replaced_module_name="gate_proj",
    )

650
651
652
653
654
655
656
657
    manager = LoRAModelManager(model,
                               2,
                               2,
                               2,
                               LoRAConfig(max_lora_rank=8,
                                          max_cpu_loras=2,
                                          max_loras=2),
                               device=device)
658
659
660
661
    model = manager.model

    assert isinstance(model.get_submodule("gate_up_proj"),
                      MergedColumnParallelLinearWithLoRA)
662
663
    assert manager.add_adapter(model_lora)
    assert manager.add_adapter(model_lora1)
664
665
666
667

    packed_lora = model_lora.get_lora("gate_up_proj")
    assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights)

668
669
670
671
672
673
674
675
    torch.testing.assert_close(packed_lora.lora_a[0],
                               model_lora.get_lora("gate_proj").lora_a)
    torch.testing.assert_close(packed_lora.lora_b[0],
                               model_lora.get_lora("gate_proj").lora_b)
    torch.testing.assert_close(packed_lora.lora_a[1],
                               model_lora.get_lora("up_proj").lora_a)
    torch.testing.assert_close(packed_lora.lora_b[1],
                               model_lora.get_lora("up_proj").lora_b)
676
677
678
679
680
681

    packed_lora1 = model_lora1.get_lora("gate_up_proj")
    assert packed_lora1 and isinstance(packed_lora1, PackedLoRALayerWeights)

    assert packed_lora1.lora_a[0] is None
    assert packed_lora1.lora_b[0] is None
682
683
684
685
    torch.testing.assert_close(packed_lora1.lora_a[1],
                               model_lora1.get_lora("up_proj").lora_a)
    torch.testing.assert_close(packed_lora1.lora_b[1],
                               model_lora1.get_lora("up_proj").lora_b)