test_lora_manager.py 22.2 KB
Newer Older
1
import os
2
from typing import Dict, List
3
4
5
6
7
8
9
10

import pytest
import torch
from safetensors.torch import load_file
from torch import nn

from vllm.config import LoRAConfig
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
11
12
                              MergedColumnParallelLinearWithLoRA,
                              RowParallelLinearWithLoRA)
13
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
14
15
from vllm.lora.models import (LoRAMapping, LoRAModel, LoRAModelManager,
                              LRUCacheLoRAModelManager)
16
17
18
19
20
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager,
                                      WorkerLoRAManager)
from vllm.model_executor.layers.linear import RowParallelLinear

Terry's avatar
Terry committed
21
22
23
24
25
26
27
EMBEDDING_MODULES = {
    "embed_tokens": "input_embeddings",
    "lm_head": "output_embeddings",
}

EMBEDDING_PADDING_MODULES = ["lm_head"]

28
29
30
31
32
33

def test_from_lora_tensors(sql_lora_files):
    tensors = load_file(
        os.path.join(sql_lora_files, "adapter_model.safetensors"))
    new_embeddings = load_file(
        os.path.join(sql_lora_files, "new_embeddings.safetensors"))
Terry's avatar
Terry committed
34
35
36
37
38
39
40
41
42
    lora_model = LoRAModel.from_lora_tensors(
        1,
        8,
        16,
        tensors,
        "cuda",
        embeddings=new_embeddings,
        embedding_modules=EMBEDDING_MODULES,
        embedding_padding_modules=EMBEDDING_PADDING_MODULES)
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
    for module_name, lora in lora_model.loras.items():
        assert lora.module_name == module_name
        assert lora.rank == 8
        assert lora.lora_alpha == 16
        assert lora.lora_a is not None
        assert lora.lora_b is not None
        assert (lora.lora_a.shape[1] == lora.lora_b.shape[0]
                ), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}"
        assert lora.lora_a.shape[1] == 8
        embeddings_module = next(
            (k for k in EMBEDDING_MODULES if k in module_name), None)
        if embeddings_module:
            assert torch.equal(
                lora.embeddings_tensor,
                new_embeddings[EMBEDDING_MODULES[embeddings_module]].to(
                    device=lora.embeddings_tensor.device))
        else:
            assert lora.embeddings_tensor is None


def create_lora(lora_id: int, model: nn.Module,
                sub_modules: List[str]) -> LoRAModel:
65
    loras: Dict[str, LoRALayerWeights] = {}
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
    for name in sub_modules:
        w = model.get_submodule(name).weight
        loras[name] = LoRALayerWeights(
            name,
            8,
            16,
            torch.rand([w.shape[1], 8], device="cuda"),
            torch.rand([8, w.shape[0]], device="cuda"),
        )
    return LoRAModel(lora_id, 8, loras)


def create_packed_lora(
    lora_id: int,
    model: nn.Module,
    module_name,
    replaced_module_names,
    empty_replaced_module_name=None,
) -> LoRAModel:
    w = model.get_submodule(module_name).weight
86
    loras: Dict[str, LoRALayerWeights] = {}
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
    for replaced_module_name in replaced_module_names:
        if replaced_module_name == empty_replaced_module_name:
            continue
        loras[replaced_module_name] = LoRALayerWeights(
            replaced_module_name,
            8,
            16,
            torch.rand([w.shape[1], 8], device="cuda"),
            torch.rand([8, w.shape[0] // len(replaced_module_names)],
                       device="cuda"),
        )
    return LoRAModel(lora_id, 8, loras)


def test_replace_submodules(dist_init, dummy_model):
    model = dummy_model
Terry's avatar
Terry committed
103
104
105
106
107
    model.supported_lora_modules = ["dense1", "layer1.dense2"]
    model.packed_modules_mapping = {}
    manager = LoRAModelManager(
        model, 1, 1, 1,
        LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8))
108
109
110
111
112
113
114
115
116
117
118
119
120
    model = manager.model

    assert isinstance(model.get_submodule("dense1"),
                      ColumnParallelLinearWithLoRA)
    assert isinstance(model.get_submodule("layer1.dense1"),
                      ColumnParallelLinearWithLoRA)
    assert isinstance(model.get_submodule("dense2"), RowParallelLinear)
    assert isinstance(model.get_submodule("layer1.dense2"),
                      RowParallelLinearWithLoRA)


def test_lora_model_manager(dist_init, dummy_model):
    model = dummy_model
Terry's avatar
Terry committed
121
122
    model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
    model.packed_modules_mapping = {}
123
124
125
126
    model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"])
    model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"])
    model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"])
    manager = LoRAModelManager(
Terry's avatar
Terry committed
127
128
        model, 2, 2, 2,
        LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2))
129
    assert all(x is None for x in manager.lora_index_to_id)
130
131
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
132
    assert manager.lora_index_to_id[0] == 1
133
134
135
136
    assert not manager.add_adapter(model_lora1)
    assert not manager.activate_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
137
138
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
139
140
141
    assert not manager.add_adapter(model_lora2)
    assert not manager.activate_adapter(2)
    assert manager.add_adapter(model_lora3)
142
143
144
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
    with pytest.raises(ValueError):
145
        assert manager.activate_adapter(3)
146
147
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
148
    assert manager.remove_adapter(model_lora2.id)
149
    assert manager.lora_index_to_id[1] is None
150
151
152
153
    assert not manager.remove_adapter(model_lora2.id)
    assert manager.remove_adapter(model_lora1.id)
    assert not manager.remove_adapter(model_lora1.id)
    assert manager.add_adapter(model_lora1)
154
155
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] is None
156
157
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(3)
158
159
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] is None
160
    assert manager.activate_adapter(2)
161
162
163
164
165
166
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2


def test_lora_lru_cache_model_manager(dist_init, dummy_model):
    model = dummy_model
Terry's avatar
Terry committed
167
168
    model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
    model.packed_modules_mapping = {}
169
170
171
172
    model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"])
    model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"])
    model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"])
    manager = LRUCacheLoRAModelManager(
Terry's avatar
Terry committed
173
174
        model, 2, 2, 2,
        LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2))
175
    assert all(x is None for x in manager.lora_index_to_id)
176
177
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
178
    assert manager.lora_index_to_id[0] == 1
179
180
181
182
    assert not manager.add_adapter(model_lora1)
    assert not manager.activate_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
183
184
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
185
186
187
    assert not manager.add_adapter(model_lora2)
    assert not manager.activate_adapter(2)
    assert manager.add_adapter(model_lora3)
188
189
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2
190
    assert manager.activate_adapter(3)
191
192
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2
193
    assert manager.remove_adapter(model_lora2.id)
194
    assert manager.lora_index_to_id[1] is None
195
196
197
198
199
    assert not manager.remove_adapter(model_lora2.id)
    assert manager.remove_adapter(model_lora1.id)
    assert not manager.remove_adapter(model_lora1.id)
    assert manager.add_adapter(model_lora1)
    assert manager.activate_adapter(1)
200
201
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
202
203
    assert manager.add_adapter(model_lora2)
    assert manager.deactivate_adapter(3)
204
205
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 1
206
    assert manager.activate_adapter(2)
207
208
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
209
    assert manager.activate_adapter(3)
210
211
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 3
212
    assert manager.pin_adapter(2)
213
214
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 3
215
    assert manager.activate_adapter(1)
216
217
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
218
    assert manager.deactivate_adapter(2)
219
220
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 1
221
    assert manager.activate_adapter(3)
222
223
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
224
225
    assert manager.pin_adapter(3)
    assert manager.pin_adapter(1)
226
    with pytest.raises(RuntimeError):
227
        assert manager.pin_adapter(2)
228
229
230
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 1
    with pytest.raises(RuntimeError):
231
        assert manager.activate_adapter(2)
232

233
234
    assert manager.deactivate_adapter(3)
    assert manager.pin_adapter(2)
235
236
    assert manager.lora_index_to_id[0] == 2
    assert manager.lora_index_to_id[1] == 1
237
    assert manager.remove_adapter(3)
238
    with pytest.raises(ValueError):
239
        assert manager.pin_adapter(3)
240
241
242
243
244
245


def test_lru_lora_model_manager(dist_init, dummy_model):
    # This tests just the LRU cache functionality, everything else is
    # tested in test_lora_model_manager
    model = dummy_model
Terry's avatar
Terry committed
246
247
    model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
    model.packed_modules_mapping = {}
248
249
250
251
252
253
    model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"])
    model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"])
    model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"])
    model_lora4 = create_lora(4, model, ["dense1", "dense2", "lm_head"])
    manager = LRUCacheLoRAModelManager(
        model, 2, 2, 2,
Terry's avatar
Terry committed
254
        LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2))
255
256
257
258

    assert all(x is None for x in manager.lora_index_to_id)

    # Add up to capacity
259
260
261
262
    assert manager.add_adapter(model_lora1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(1)
    assert manager.activate_adapter(2)
263

264
    assert set(manager.list_adapters()) == {1, 2}
265
266
267
268
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2

    # Add over capacity
269
270
271
272
    assert manager.add_adapter(model_lora3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(3)
    assert manager.activate_adapter(4)
273

274
    assert set(manager.list_adapters()) == {3, 4}
275
276
277
278
279
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 4

    # Add 3 again to move it to the top and then add 2
    # should return false since it's in already
280
281
282
283
    assert not manager.add_adapter(model_lora3)
    assert not manager.activate_adapter(3)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
284

285
    assert set(manager.list_adapters()) == {3, 2}
286
287
288
289
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 2

    # Remove manually
290
291
    assert manager.remove_adapter(3)
    assert not manager.remove_adapter(3)
292

293
    assert set(manager.list_adapters()) == {2}
294
295
296
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 2

297
298
299
300
    assert manager.add_adapter(model_lora3)
    assert manager.activate_adapter(3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(4)
301

302
    assert set(manager.list_adapters()) == {3, 4}
303
304
305
    assert manager.lora_index_to_id[0] == 3
    assert manager.lora_index_to_id[1] == 4

306
307
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == {4}
308
309
310
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 4

311
312
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == set()
313
314
    assert all(x is None for x in manager.lora_index_to_id)

315
316
    assert not manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == set()
317
318
    assert all(x is None for x in manager.lora_index_to_id)

319
    # pinning
320
321
322
323
324
    assert manager.add_adapter(model_lora3)
    assert manager.activate_adapter(3)
    assert manager.add_adapter(model_lora4)
    assert manager.activate_adapter(4)
    assert set(manager.list_adapters()) == {3, 4}
325
    with pytest.raises(ValueError):
326
327
        assert manager.pin_adapter(1)
    assert manager.pin_adapter(3)
328
    # Remove manually
329
330
    assert manager.remove_adapter(3)
    assert not manager.remove_adapter(3)
331

332
    assert set(manager.list_adapters()) == {4}
333
334
335
    assert manager.lora_index_to_id[0] is None
    assert manager.lora_index_to_id[1] == 4

336
337
338
339
    assert manager.add_adapter(model_lora1)
    assert manager.pin_adapter(1)
    assert manager.add_adapter(model_lora2)
    assert manager.activate_adapter(2)
340

341
    assert set(manager.list_adapters()) == {1, 2}
342
343
344
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] == 2

345
346
    assert manager.remove_oldest_adapter()
    assert set(manager.list_adapters()) == {1}
347
348
349
350
    assert manager.lora_index_to_id[0] == 1
    assert manager.lora_index_to_id[1] is None

    with pytest.raises(RuntimeError):
351
        assert manager.remove_oldest_adapter()
352

353
    assert set(manager.list_adapters()) == {1}
354

355

356
357
def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
                                          sql_lora_files):
358
    lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
359
    worker_adapter_manager = LRUCacheWorkerLoRAManager(
Terry's avatar
Terry committed
360
361
362
        4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
        lora_config.lora_extra_vocab_size, lora_config, torch.device("cuda"),
        EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
363
364
    worker_adapter_manager.create_lora_manager(
        llama_2_7b_model_extra_embeddings)
365
366

    mapping = LoRAMapping([], [])
367
    worker_adapter_manager.set_active_adapters([
368
369
370
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files)
    ], mapping)
371
372
373
    assert worker_adapter_manager.list_adapters() == {1, 2}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
374

375
    worker_adapter_manager.set_active_adapters([
376
377
378
379
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("3", 3, sql_lora_files),
        LoRARequest("4", 4, sql_lora_files)
    ], mapping)
380
381
382
383
384
    assert worker_adapter_manager.list_adapters() == {1, 2, 3, 4}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 3
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
385

386
    worker_adapter_manager.set_active_adapters([
387
388
389
390
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files),
        LoRARequest("5", 5, sql_lora_files)
    ], mapping)
391
392
393
394
395
    assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
396

397
    worker_adapter_manager.set_active_adapters([
398
399
400
401
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files)
    ], mapping)
402
403
404
405
406
    assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
407

408
    worker_adapter_manager.set_active_adapters([
409
410
411
412
        LoRARequest("6", 6, sql_lora_files),
        LoRARequest("7", 7, sql_lora_files),
        LoRARequest("8", 8, sql_lora_files)
    ], mapping)
413
414
415
416
417
    assert worker_adapter_manager.list_adapters() == {1, 6, 7, 8}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 7
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 8
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 6
418
419
420

    # Over capacity
    with pytest.raises(RuntimeError):
421
        worker_adapter_manager.set_active_adapters([
422
423
424
425
426
427
428
429
            LoRARequest("10", 10, sql_lora_files),
            LoRARequest("11", 11, sql_lora_files),
            LoRARequest("12", 12, sql_lora_files),
            LoRARequest("13", 13, sql_lora_files),
            LoRARequest("14", 14, sql_lora_files)
        ], mapping)


430
431
def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
                                sql_lora_files):
432
433
    # Should remove every LoRA not specified in the request.
    lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
434
    worker_adapter_manager = WorkerLoRAManager(
Terry's avatar
Terry committed
435
436
437
        4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
        lora_config.lora_extra_vocab_size, lora_config, torch.device("cuda"),
        EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
438
439
    worker_adapter_manager.create_lora_manager(
        llama_2_7b_model_extra_embeddings)
440
441

    mapping = LoRAMapping([], [])
442
    worker_adapter_manager.set_active_adapters([
443
444
445
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files)
    ], mapping)
446
447
448
    assert worker_adapter_manager.list_adapters() == {1, 2}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
449

450
    worker_adapter_manager.set_active_adapters([
451
452
453
454
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("3", 3, sql_lora_files),
        LoRARequest("4", 4, sql_lora_files)
    ], mapping)
455
456
457
458
    assert worker_adapter_manager.list_adapters() == {1, 3, 4}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 3
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 4
459

460
    worker_adapter_manager.set_active_adapters([
461
462
463
464
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("2", 2, sql_lora_files),
        LoRARequest("5", 5, sql_lora_files)
    ], mapping)
465
466
467
468
    assert worker_adapter_manager.list_adapters() == {1, 2, 5}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
469

470
    worker_adapter_manager.set_active_adapters([
471
472
473
474
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files),
        LoRARequest("1", 1, sql_lora_files)
    ], mapping)
475
476
477
478
    assert worker_adapter_manager.list_adapters() == {1}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] is None
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] is None
479

480
    worker_adapter_manager.set_active_adapters([
481
482
483
484
        LoRARequest("6", 6, sql_lora_files),
        LoRARequest("7", 7, sql_lora_files),
        LoRARequest("8", 8, sql_lora_files)
    ], mapping)
485
486
487
488
    assert worker_adapter_manager.list_adapters() == {6, 7, 8}
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 8
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 6
    assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 7
489
490
491

    # Over capacity
    with pytest.raises(RuntimeError):
492
        worker_adapter_manager.set_active_adapters([
493
494
495
496
497
498
499
500
501
502
            LoRARequest("10", 10, sql_lora_files),
            LoRARequest("11", 11, sql_lora_files),
            LoRARequest("12", 12, sql_lora_files),
            LoRARequest("13", 13, sql_lora_files),
            LoRARequest("14", 14, sql_lora_files)
        ], mapping)


def test_packed_loras(dist_init, dummy_model_gate_up):
    model = dummy_model_gate_up
Terry's avatar
Terry committed
503
504
505
506
507
508
509
    model.supported_lora_modules = ["gate_up_proj"]
    model.packed_modules_mapping = {
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
    model_lora = create_packed_lora(
        1,
        model,
        module_name="gate_up_proj",
        replaced_module_names=["gate_proj", "up_proj"])
    model_lora1 = create_packed_lora(
        2,
        model,
        module_name="gate_up_proj",
        replaced_module_names=["gate_proj", "up_proj"],
        empty_replaced_module_name="gate_proj",
    )

    manager = LoRAModelManager(
        model, 2, 2, 2,
Terry's avatar
Terry committed
525
        LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2))
526
527
528
529
    model = manager.model

    assert isinstance(model.get_submodule("gate_up_proj"),
                      MergedColumnParallelLinearWithLoRA)
530
531
    assert manager.add_adapter(model_lora)
    assert manager.add_adapter(model_lora1)
532
533
534
535

    packed_lora = model_lora.get_lora("gate_up_proj")
    assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights)

536
537
538
539
540
541
542
543
    torch.testing.assert_close(packed_lora.lora_a[0],
                               model_lora.get_lora("gate_proj").lora_a)
    torch.testing.assert_close(packed_lora.lora_b[0],
                               model_lora.get_lora("gate_proj").lora_b)
    torch.testing.assert_close(packed_lora.lora_a[1],
                               model_lora.get_lora("up_proj").lora_a)
    torch.testing.assert_close(packed_lora.lora_b[1],
                               model_lora.get_lora("up_proj").lora_b)
544
545
546
547
548
549

    packed_lora1 = model_lora1.get_lora("gate_up_proj")
    assert packed_lora1 and isinstance(packed_lora1, PackedLoRALayerWeights)

    assert packed_lora1.lora_a[0] is None
    assert packed_lora1.lora_b[0] is None
550
551
552
553
    torch.testing.assert_close(packed_lora1.lora_a[1],
                               model_lora1.get_lora("up_proj").lora_a)
    torch.testing.assert_close(packed_lora1.lora_b[1],
                               model_lora1.get_lora("up_proj").lora_b)