test_lora_manager.py 27 KB
Newer Older
1
import json
2
import math
3
import os
4
from typing import Dict, List
5
6
7
8
9
10
11
12

import pytest
import torch
from safetensors.torch import load_file
from torch import nn

from vllm.config import LoRAConfig
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
13
14
                              MergedColumnParallelLinearWithLoRA,
                              RowParallelLinearWithLoRA)
15
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
16
17
from vllm.lora.models import (LoRAMapping, LoRAModel, LoRAModelManager,
                              LRUCacheLoRAModelManager)
18
from vllm.lora.peft_helper import PEFTHelper
19
20
21
22
23
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager,
                                      WorkerLoRAManager)
from vllm.model_executor.layers.linear import RowParallelLinear

Terry's avatar
Terry committed
24
25
26
27
28
29
30
EMBEDDING_MODULES = {
    "embed_tokens": "input_embeddings",
    "lm_head": "output_embeddings",
}

EMBEDDING_PADDING_MODULES = ["lm_head"]

31
32
33
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
34

35

36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def test_peft_helper(sql_lora_files):
    lora_config_path = os.path.join(sql_lora_files, "adapter_config.json")
    with open(lora_config_path) as f:
        config = json.load(f)
    peft_helper = PEFTHelper.from_dict(config)
    assert peft_helper.r == 8
    assert peft_helper.lora_alpha == 16
    assert peft_helper.target_modules == [
        "q_proj",
        "v_proj",
        "k_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
        "embed_tokens",
        "lm_head",
    ]
54
55
56
57
58
59
60
61
62
63
64
65
    scaling = peft_helper.lora_alpha / peft_helper.r
    assert abs(peft_helper.vllm_lora_scaling_factor - scaling) < 1e-3

    # test RSLoRA
    config = dict(r=8,
                  lora_alpha=16,
                  target_modules=["gate_proj"],
                  use_rslora=True)
    peft_helper = PEFTHelper.from_dict(config)

    scaling = peft_helper.lora_alpha / math.sqrt(peft_helper.r)
    assert abs(peft_helper.vllm_lora_scaling_factor - scaling) < 1e-3
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85

    expected_error = "vLLM only supports modules_to_save being None."
    with pytest.raises(ValueError, match=expected_error):
        config = dict(
            r=8,
            lora_alpha=16,
            target_modules=["gate_proj"],
            modules_to_save=["lm_head"],
        )
        PEFTHelper.from_dict(config)

    expected_error = "vLLM does not yet support DoRA."
    with pytest.raises(ValueError, match=expected_error):
        config = dict(r=8,
                      lora_alpha=16,
                      target_modules=["gate_proj"],
                      use_dora=True)
        PEFTHelper.from_dict(config)


86
87
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_from_lora_tensors(sql_lora_files, device):
88
89
90
91
    tensors = load_file(
        os.path.join(sql_lora_files, "adapter_model.safetensors"))
    new_embeddings = load_file(
        os.path.join(sql_lora_files, "new_embeddings.safetensors"))
92
93
94
95
96
97

    lora_config_path = os.path.join(sql_lora_files, "adapter_config.json")
    with open(lora_config_path) as f:
        config = json.load(f)

    peft_helper = PEFTHelper.from_dict(config)
Terry's avatar
Terry committed
98
99
100
    lora_model = LoRAModel.from_lora_tensors(
        1,
        tensors,
101
102
        peft_helper=peft_helper,
        device=device,
Terry's avatar
Terry committed
103
104
105
        embeddings=new_embeddings,
        embedding_modules=EMBEDDING_MODULES,
        embedding_padding_modules=EMBEDDING_PADDING_MODULES)
106
107
108
109
110
111
    for module_name, lora in lora_model.loras.items():
        assert lora.module_name == module_name
        assert lora.rank == 8
        assert lora.lora_alpha == 16
        assert lora.lora_a is not None
        assert lora.lora_b is not None
112
113
        assert lora.lora_a.device == torch.device(device)
        assert lora.lora_b.device == torch.device(device)
114
115
116
117
118
119
120
121
122
123
124
125
126
127
        assert (lora.lora_a.shape[1] == lora.lora_b.shape[0]
                ), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}"
        assert lora.lora_a.shape[1] == 8
        embeddings_module = next(
            (k for k in EMBEDDING_MODULES if k in module_name), None)
        if embeddings_module:
            assert torch.equal(
                lora.embeddings_tensor,
                new_embeddings[EMBEDDING_MODULES[embeddings_module]].to(
                    device=lora.embeddings_tensor.device))
        else:
            assert lora.embeddings_tensor is None


128
129
def create_lora(lora_id: int, model: nn.Module, sub_modules: List[str],
                device: torch.device) -> LoRAModel:
130
    loras: Dict[str, LoRALayerWeights] = {}
131
132
133
134
135
136
    for name in sub_modules:
        w = model.get_submodule(name).weight
        loras[name] = LoRALayerWeights(
            name,
            8,
            16,
137
138
            torch.rand([w.shape[1], 8], device=device),
            torch.rand([8, w.shape[0]], device=device),
139
140
141
142
143
144
145
146
147
        )
    return LoRAModel(lora_id, 8, loras)


def create_packed_lora(
    lora_id: int,
    model: nn.Module,
    module_name,
    replaced_module_names,
148
    device: torch.device,
149
150
151
    empty_replaced_module_name=None,
) -> LoRAModel:
    w = model.get_submodule(module_name).weight
152
    loras: Dict[str, LoRALayerWeights] = {}
153
154
155
156
157
158
159
    for replaced_module_name in replaced_module_names:
        if replaced_module_name == empty_replaced_module_name:
            continue
        loras[replaced_module_name] = LoRALayerWeights(
            replaced_module_name,
            8,
            16,
160
            torch.rand([w.shape[1], 8], device=device),
161
            torch.rand([8, w.shape[0] // len(replaced_module_names)],
162
                       device=device),
163
164
165
166
167
168
        )
    return LoRAModel(lora_id, 8, loras)


def test_replace_submodules(dist_init, dummy_model):
    model = dummy_model
Terry's avatar
Terry committed
169
170
171
172
    model.supported_lora_modules = ["dense1", "layer1.dense2"]
    model.packed_modules_mapping = {}
    manager = LoRAModelManager(
        model, 1, 1, 1,
173
174
        LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8),
        torch.device("cuda"))
175
176
177
178
179
180
181
182
183
184
185
    model = manager.model

    assert isinstance(model.get_submodule("dense1"),
                      ColumnParallelLinearWithLoRA)
    assert isinstance(model.get_submodule("layer1.dense1"),
                      ColumnParallelLinearWithLoRA)
    assert isinstance(model.get_submodule("dense2"), RowParallelLinear)
    assert isinstance(model.get_submodule("layer1.dense2"),
                      RowParallelLinearWithLoRA)


186
187
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_lora_model_manager(dist_init, dummy_model, device):
188
    model = dummy_model
Terry's avatar
Terry committed
189
190
    model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
    model.packed_modules_mapping = {}
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
    model_lora1 = create_lora(1,
                              model, ["layer1.dense1", "dense2", "lm_head"],
                              device=device)
    model_lora2 = create_lora(2,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora3 = create_lora(3,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    manager = LoRAModelManager(model,
                               2,
                               2,
                               2,
                               LoRAConfig(max_lora_rank=8,
                                          max_cpu_loras=3,
                                          max_loras=2),
                               device=device)
208
    assert all(x is None for x in manager.lora_index_to_id)
209
210
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
211
    assert manager.lora_index_to_id[0] == 1
212
213
214
215
    assert not manager.add_adapter(model_lora1)
    assert not manager.activate_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
216
217
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
218
219
220
    assert not manager.add_adapter(model_lora2)
    assert not manager.activate_adapter(2)
    assert manager.add_adapter(model_lora3)
221
222
223
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
    with pytest.raises(ValueError):
224
        assert manager.activate_adapter(3)
225
226
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
227
    assert manager.remove_adapter(model_lora2.id)
228
    assert manager.lora_index_to_id[1] is None
229
230
231
232
    assert not manager.remove_adapter(model_lora2.id)
    assert manager.remove_adapter(model_lora1.id)
    assert not manager.remove_adapter(model_lora1.id)
    assert manager.add_adapter(model_lora1)
233
234
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] is None
235
236
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(3)
237
238
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] is None
239
    assert manager.activate_adapter(2)
240
241
242
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2

243
244
    assert manager.device == device
    assert manager.punica_wrapper.device == device
245

246
247
248

@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
249
    model = dummy_model
Terry's avatar
Terry committed
250
251
    model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
    model.packed_modules_mapping = {}
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
    model_lora1 = create_lora(1,
                              model, ["layer1.dense1", "dense2", "lm_head"],
                              device=device)
    model_lora2 = create_lora(2,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora3 = create_lora(3,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    manager = LRUCacheLoRAModelManager(model,
                                       2,
                                       2,
                                       2,
                                       LoRAConfig(max_lora_rank=8,
                                                  max_cpu_loras=3,
                                                  max_loras=2),
                                       device=device)
269
    assert all(x is None for x in manager.lora_index_to_id)
270
271
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
272
    assert manager.lora_index_to_id[0] == 1
273
274
275
276
    assert not manager.add_adapter(model_lora1)
    assert not manager.activate_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
277
278
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
279
280
281
    assert not manager.add_adapter(model_lora2)
    assert not manager.activate_adapter(2)
    assert manager.add_adapter(model_lora3)
282
283
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
284
    assert manager.activate_adapter(3)
285
286
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2
287
    assert manager.remove_adapter(model_lora2.id)
288
    assert manager.lora_index_to_id[1] is None
289
290
291
292
293
    assert not manager.remove_adapter(model_lora2.id)
    assert manager.remove_adapter(model_lora1.id)
    assert not manager.remove_adapter(model_lora1.id)
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
294
295
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
296
297
    assert manager.add_adapter(model_lora2)
    assert manager.deactivate_adapter(3)
298
299
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 1
300
    assert manager.activate_adapter(2)
301
302
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
303
    assert manager.activate_adapter(3)
304
305
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 3
306
    assert manager.pin_adapter(2)
307
308
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 3
309
    assert manager.activate_adapter(1)
310
311
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
312
    assert manager.deactivate_adapter(2)
313
314
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 1
315
    assert manager.activate_adapter(3)
316
317
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
318
319
    assert manager.pin_adapter(3)
    assert manager.pin_adapter(1)
320
    with pytest.raises(RuntimeError):
321
        assert manager.pin_adapter(2)
322
323
324
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
    with pytest.raises(RuntimeError):
325
        assert manager.activate_adapter(2)
326

327
328
    assert manager.deactivate_adapter(3)
    assert manager.pin_adapter(2)
329
330
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
331
    assert manager.remove_adapter(3)
332
    with pytest.raises(ValueError):
333
        assert manager.pin_adapter(3)
334

335
336
337
    assert manager.punica_wrapper.device == device
    assert manager.device == device

338

339
340
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_lru_lora_model_manager(dist_init, dummy_model, device):
341
342
343
    # This tests just the LRU cache functionality, everything else is
    # tested in test_lora_model_manager
    model = dummy_model
Terry's avatar
Terry committed
344
345
    model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
    model.packed_modules_mapping = {}
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
    model_lora1 = create_lora(1,
                              model, ["layer1.dense1", "dense2", "lm_head"],
                              device=device)
    model_lora2 = create_lora(2,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora3 = create_lora(3,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    model_lora4 = create_lora(4,
                              model, ["dense1", "dense2", "lm_head"],
                              device=device)
    manager = LRUCacheLoRAModelManager(model,
                                       2,
                                       2,
                                       2,
                                       LoRAConfig(max_lora_rank=8,
                                                  max_cpu_loras=2,
                                                  max_loras=2),
                                       device=device)
366
367
368
369

    assert all(x is None for x in manager.lora_index_to_id)

    # Add up to capacity
370
371
372
373
    assert manager.add_adapter(model_lora1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(1)
    assert manager.activate_adapter(2)
374

375
    assert set(manager.list_adapters()) == {1, 2}
376
377
378
379
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2

    # Add over capacity
380
381
382
383
    assert manager.add_adapter(model_lora3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(3)
    assert manager.activate_adapter(4)
384

385
    assert set(manager.list_adapters()) == {3, 4}
386
387
388
389
390
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 4

    # Add 3 again to move it to the top and then add 2
    # should return false since it's in already
391
392
393
394
    assert not manager.add_adapter(model_lora3)
    assert not manager.activate_adapter(3)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
395

396
    assert set(manager.list_adapters()) == {3, 2}
397
398
399
400
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2

    # Remove manually
401
402
    assert manager.remove_adapter(3)
    assert not manager.remove_adapter(3)
403

404
    assert set(manager.list_adapters()) == {2}
405
406
407
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 2

408
409
410
411
    assert manager.add_adapter(model_lora3)
    assert manager.activate_adapter(3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(4)
412

413
    assert set(manager.list_adapters()) == {3, 4}
414
415
416
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 4

417
418
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == {4}
419
420
421
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 4

422
423
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == set()
424
425
    assert all(x is None for x in manager.lora_index_to_id)

426
427
    assert not manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == set()
428
429
    assert all(x is None for x in manager.lora_index_to_id)

430
    # pinning
431
432
433
434
435
    assert manager.add_adapter(model_lora3)
    assert manager.activate_adapter(3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(4)
    assert set(manager.list_adapters()) == {3, 4}
436
    with pytest.raises(ValueError):
437
438
        assert manager.pin_adapter(1)
    assert manager.pin_adapter(3)
439
    # Remove manually
440
441
    assert manager.remove_adapter(3)
    assert not manager.remove_adapter(3)
442

443
    assert set(manager.list_adapters()) == {4}
444
445
446
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 4

447
448
449
450
    assert manager.add_adapter(model_lora1)
    assert manager.pin_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
451

452
    assert set(manager.list_adapters()) == {1, 2}
453
454
455
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2

456
457
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == {1}
458
459
460
461
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] is None

    with pytest.raises(RuntimeError):
462
        assert manager.remove_oldest_adapter()
463

464
    assert set(manager.list_adapters()) == {1}
465
466
    assert manager.punica_wrapper.device == device
    assert manager.device == device
467

468

469
@pytest.mark.parametrize("device", CUDA_DEVICES)
470
def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
471
                                          sql_lora_files, device):
472
    lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
473
    worker_adapter_manager = LRUCacheWorkerLoRAManager(
Terry's avatar
Terry committed
474
        4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
475
        lora_config.lora_extra_vocab_size, lora_config, device,
Terry's avatar
Terry committed
476
        EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
477
478
    worker_adapter_manager.create_lora_manager(
        llama_2_7b_model_extra_embeddings)
479
480

    mapping = LoRAMapping([], [])
481
    worker_adapter_manager.set_active_adapters([
482
483
484
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files)
    ], mapping)
485
486
487
    assert worker_adapter_manager.list_adapters() == {1, 2}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
488

489
    worker_adapter_manager.set_active_adapters([
490
491
492
493
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("3", 3, sql_lora_files),
        LoRARequest("4", 4, sql_lora_files)
    ], mapping)
494
495
496
497
498
    assert worker_adapter_manager.list_adapters() == {1, 2, 3, 4}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 3
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
499

500
    worker_adapter_manager.set_active_adapters([
501
502
503
504
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files),
        LoRARequest("5", 5, sql_lora_files)
    ], mapping)
505
506
507
508
509
    assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
510

511
    worker_adapter_manager.set_active_adapters([
512
513
514
515
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files)
    ], mapping)
516
517
518
519
520
    assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
521

522
    worker_adapter_manager.set_active_adapters([
523
524
525
526
        LoRARequest("6", 6, sql_lora_files),
        LoRARequest("7", 7, sql_lora_files),
        LoRARequest("8", 8, sql_lora_files)
    ], mapping)
527
528
529
530
531
    assert worker_adapter_manager.list_adapters() == {1, 6, 7, 8}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 7
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 8
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 6
532
533
534

    # Over capacity
    with pytest.raises(RuntimeError):
535
        worker_adapter_manager.set_active_adapters([
536
537
538
539
540
541
542
            LoRARequest("10", 10, sql_lora_files),
            LoRARequest("11", 11, sql_lora_files),
            LoRARequest("12", 12, sql_lora_files),
            LoRARequest("13", 13, sql_lora_files),
            LoRARequest("14", 14, sql_lora_files)
        ], mapping)

543
544
545
    assert worker_adapter_manager.device == device
    assert (worker_adapter_manager._adapter_manager.punica_wrapper.device ==
            device)
546

547
548

@pytest.mark.parametrize("device", CUDA_DEVICES)
549
def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
550
                                sql_lora_files, device):
551
552
    # Should remove every LoRA not specified in the request.
    lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
553
    worker_adapter_manager = WorkerLoRAManager(
Terry's avatar
Terry committed
554
        4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
555
        lora_config.lora_extra_vocab_size, lora_config, device,
Terry's avatar
Terry committed
556
        EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
557
558
    worker_adapter_manager.create_lora_manager(
        llama_2_7b_model_extra_embeddings)
559
560

    mapping = LoRAMapping([], [])
561
    worker_adapter_manager.set_active_adapters([
562
563
564
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files)
    ], mapping)
565
566
567
    assert worker_adapter_manager.list_adapters() == {1, 2}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
568

569
    worker_adapter_manager.set_active_adapters([
570
571
572
573
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("3", 3, sql_lora_files),
        LoRARequest("4", 4, sql_lora_files)
    ], mapping)
574
575
576
577
    assert worker_adapter_manager.list_adapters() == {1, 3, 4}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 3
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 4
578

579
    worker_adapter_manager.set_active_adapters([
580
581
582
583
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files),
        LoRARequest("5", 5, sql_lora_files)
    ], mapping)
584
585
586
587
    assert worker_adapter_manager.list_adapters() == {1, 2, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
588

589
    worker_adapter_manager.set_active_adapters([
590
591
592
593
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files)
    ], mapping)
594
595
596
597
    assert worker_adapter_manager.list_adapters() == {1}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] is None
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] is None
598

599
    worker_adapter_manager.set_active_adapters([
600
601
602
603
        LoRARequest("6", 6, sql_lora_files),
        LoRARequest("7", 7, sql_lora_files),
        LoRARequest("8", 8, sql_lora_files)
    ], mapping)
604
605
606
607
    assert worker_adapter_manager.list_adapters() == {6, 7, 8}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 8
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 6
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 7
608
609
610

    # Over capacity
    with pytest.raises(RuntimeError):
611
        worker_adapter_manager.set_active_adapters([
612
613
614
615
616
617
618
            LoRARequest("10", 10, sql_lora_files),
            LoRARequest("11", 11, sql_lora_files),
            LoRARequest("12", 12, sql_lora_files),
            LoRARequest("13", 13, sql_lora_files),
            LoRARequest("14", 14, sql_lora_files)
        ], mapping)

619
620
621
622
    assert worker_adapter_manager.device == device
    assert (worker_adapter_manager._adapter_manager.punica_wrapper.device ==
            device)

623

624
625
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_packed_loras(dist_init, dummy_model_gate_up, device):
626
    model = dummy_model_gate_up
Terry's avatar
Terry committed
627
628
629
630
631
632
633
    model.supported_lora_modules = ["gate_up_proj"]
    model.packed_modules_mapping = {
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
634
635
636
637
    model_lora = create_packed_lora(
        1,
        model,
        module_name="gate_up_proj",
638
639
        replaced_module_names=["gate_proj", "up_proj"],
        device=device)
640
641
642
643
644
    model_lora1 = create_packed_lora(
        2,
        model,
        module_name="gate_up_proj",
        replaced_module_names=["gate_proj", "up_proj"],
645
        device=device,
646
647
648
        empty_replaced_module_name="gate_proj",
    )

649
650
651
652
653
654
655
656
    manager = LoRAModelManager(model,
                               2,
                               2,
                               2,
                               LoRAConfig(max_lora_rank=8,
                                          max_cpu_loras=2,
                                          max_loras=2),
                               device=device)
657
658
659
660
    model = manager.model

    assert isinstance(model.get_submodule("gate_up_proj"),
                      MergedColumnParallelLinearWithLoRA)
661
662
    assert manager.add_adapter(model_lora)
    assert manager.add_adapter(model_lora1)
663
664
665
666

    packed_lora = model_lora.get_lora("gate_up_proj")
    assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights)

667
668
669
670
671
672
673
674
    torch.testing.assert_close(packed_lora.lora_a[0],
                               model_lora.get_lora("gate_proj").lora_a)
    torch.testing.assert_close(packed_lora.lora_b[0],
                               model_lora.get_lora("gate_proj").lora_b)
    torch.testing.assert_close(packed_lora.lora_a[1],
                               model_lora.get_lora("up_proj").lora_a)
    torch.testing.assert_close(packed_lora.lora_b[1],
                               model_lora.get_lora("up_proj").lora_b)
675
676
677
678
679
680

    packed_lora1 = model_lora1.get_lora("gate_up_proj")
    assert packed_lora1 and isinstance(packed_lora1, PackedLoRALayerWeights)

    assert packed_lora1.lora_a[0] is None
    assert packed_lora1.lora_b[0] is None
681
682
683
684
    torch.testing.assert_close(packed_lora1.lora_a[1],
                               model_lora1.get_lora("up_proj").lora_a)
    torch.testing.assert_close(packed_lora1.lora_b[1],
                               model_lora1.get_lora("up_proj").lora_b)