test_layers.py 33 KB
Newer Older
1
2
3
import random
from copy import deepcopy
from dataclasses import dataclass
4
from typing import Dict, List, Optional, Tuple
5

6
import pytest
7
8
9
10
import torch
import torch.nn.functional as F

from vllm.config import LoRAConfig
11
12
13
14
from vllm.lora.fully_sharded_layers import (
    ColumnParallelLinearWithShardedLoRA,
    MergedColumnParallelLinearWithShardedLoRA,
    MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA)
15
16
# yapf conflicts with isort for this block
# yapf: disable
17
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
18
                              LinearScalingRotaryEmbeddingWithLora,
19
20
                              LogitsProcessorWithLoRA, LoRAMapping,
                              MergedColumnParallelLinearWithLoRA,
21
                              MergedQKVParallelLinearWithLora,
22
23
24
                              QKVParallelLinearWithLora,
                              RowParallelLinearWithLoRA,
                              VocabParallelEmbeddingWithLoRA)
25
# yapf: enable
26
27
from vllm.lora.models import (LongContextLoRAContext, LoRALayerWeights,
                              PackedLoRALayerWeights, convert_mapping)
28
29
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               MergedColumnParallelLinear,
30
31
32
                                               QKVParallelLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
33
from vllm.model_executor.layers.rotary_embedding import get_rope
34
from vllm.model_executor.layers.vocab_parallel_embedding import (
35
    ParallelLMHead, VocabParallelEmbedding)
36
37
38
39
40
41
42
43
44
from vllm.model_executor.utils import set_random_seed

from .utils import DummyLoRAManager

TOLERANCES = {
    torch.float16: (5e-3, 5e-3),
    torch.float32: (5e-3, 5e-3),
    torch.bfloat16: (3e-2, 2e-2),
}
45
46
47
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105


def get_random_id_to_index(num_loras: int,
                           num_slots: int,
                           log: bool = True) -> List[Optional[int]]:
    """Creates a random lora_id_to_index mapping.

    Args:
        num_loras: The number of active loras in the mapping.
        num_slots: The number of slots in the mapping. Must be larger
            than num_loras.
        log: Whether to log the output.
    """

    if num_loras > num_slots:
        raise ValueError(
            f"num_loras is higher than num_slots: {num_loras} > {num_slots}. "
            "num_loras must be less than or equal to num_slots.")

    slots: List[Optional[int]] = [None] * num_slots
    random_slot_selections = (torch.randperm(num_slots)[:num_loras]).tolist()
    for lora_id, slot_idx in enumerate(random_slot_selections, start=1):
        slots[slot_idx] = lora_id

    if log:
        print(f"Created lora_id_to_index mapping: {slots}.")

    return slots


def populate_loras(
    id_to_index: List[Optional[int]],
    layer: BaseLayerWithLoRA,
    layer_weights: torch.Tensor,
    generate_embeddings_tensor: int = 0,
    repeats: int = 1,
) -> Tuple[Dict[int, LoRALayerWeights], Dict[int, List[LoRALayerWeights]]]:
    """This method populates the lora layers with lora weights.

    Args:
        id_to_index: a list of lora ids. The index of the lora id
            represents which memory slot the lora matrices are
            stored in. A None value indicates a free slot.
        layer: the LoRAlayer to populate.
        layer_weights: the PyTorch tensor containing the layer's
            weights.
        generate_embeddings_tensor: whether to generate an
            embeddings tensor for each LoRA.
        repeats: must only be set for column parallel packed
            layers. Indicates the number of loras to compose
            together to create a single lora layer.
    """

    # Dictionary that maps the lora ID to the
    # corresponding lora weights.
    lora_dict: Dict[int, LoRALayerWeights] = dict()

    # Dictionary that maps the lora ID to the
106
    # corresponding subloras.
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
    sublora_dict: Dict[int, List[LoRALayerWeights]] = dict()

    for slot_idx, lora_id in enumerate(id_to_index):
        if lora_id is not None:
            subloras = []
            sublora_len = layer_weights.shape[0] // repeats
            for i in range(repeats):
                sublora = DummyLoRAManager().init_random_lora(
                    module_name=f"fake_{i}",
                    weight=layer_weights,
                    generate_embeddings_tensor=generate_embeddings_tensor,
                )
                sublora.lora_b = sublora.lora_b[:, (sublora_len *
                                                    i):(sublora_len * (i + 1))]
                sublora.optimize()
                subloras.append(sublora)

            lora = PackedLoRALayerWeights.pack(
                subloras) if repeats > 1 else subloras[0]

            layer.set_lora(
                slot_idx,
                lora_a=lora.lora_a,
                lora_b=lora.lora_b,
                embeddings_tensor=lora.embeddings_tensor,
            )

            lora_dict[lora_id] = lora
            sublora_dict[lora_id] = subloras

    return lora_dict, sublora_dict


def create_random_inputs(
    active_lora_ids: List[int],
    num_inputs: int,
    input_size: Tuple[int, ...],
    input_range: Tuple[float, float],
    input_type: torch.dtype = torch.int,
) -> Tuple[List[torch.Tensor], List[int], List[int]]:
    """Creates random inputs.

    Args:
        active_lora_ids: lora IDs of active lora weights.
        num_inputs: the number of inputs to create.
        input_size: the size of each individual input.
        input_range: the range of values to include in the input.
            input_range[0] <= possible input values < input_range[1]
        input_type: the type of values in the input.
    """

    low, high = input_range

    inputs, index_mapping, prompt_mapping = [], [], []
    for _ in range(num_inputs):
        if input_type == torch.int:
            inputs.append(
164
                torch.randint(low=int(low), high=int(high), size=input_size))
165
166
        else:
            inputs.append(
167
                torch.rand(size=input_size, dtype=input_type) * high + low)
168
169
170
171
172
173
174
175
176
177

        lora_id = random.choice(active_lora_ids)
        index_mapping += [lora_id] * input_size[0]
        prompt_mapping += [lora_id]

    return inputs, index_mapping, prompt_mapping


@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
178
@pytest.mark.parametrize("device", CUDA_DEVICES)
179
180
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
def test_embeddings(dist_init, num_loras, device, vocab_size) -> None:
181

182
    torch.set_default_device(device)
183
184
185
186
187
188
    max_loras = 8
    lora_config = LoRAConfig(max_loras=max_loras,
                             max_lora_rank=8,
                             lora_dtype=torch.float16)

    def create_random_embedding_layer():
189
        embedding = VocabParallelEmbedding(vocab_size, 256)
190
        embedding.weight.data = torch.rand_like(embedding.weight.data)
191
        embedding.weight.data[vocab_size:, :] = 0
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
        lora_embedding = VocabParallelEmbeddingWithLoRA(embedding)
        lora_embedding.create_lora_weights(max_loras, lora_config)

        return embedding, lora_embedding

    for i in range(10):
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
        embedding, lora_embedding = create_random_embedding_layer()

        lora_dict, _ = populate_loras(
            id_to_index,
            layer=lora_embedding,
            layer_weights=embedding.weight.T,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=num_loras * 3,
            input_size=(200, ),
213
            input_range=(1, vocab_size),
214
215
216
217
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping)

        mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
218
219
                                       vocab_size,
                                       lora_config.lora_extra_vocab_size)
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
        lora_embedding.set_mapping(*mapping_info)

        lora_result = lora_embedding(torch.cat(inputs))

        expected_results = []
        for input_, lora_id in zip(inputs, prompt_mapping):
            lora = lora_dict[lora_id]
            result = embedding(input_)
            after_a = F.embedding(
                input_,
                lora.lora_a,
            )
            result += (after_a @ lora.lora_b)
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
        assert torch.allclose(lora_result,
                              expected_result,
                              rtol=rtol,
                              atol=atol)

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
            lora_embedding.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=num_loras * 3,
            input_size=(200, ),
251
            input_range=(1, vocab_size),
252
253
254
255
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping)

        mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
256
257
                                       vocab_size,
                                       lora_config.lora_extra_vocab_size)
258
259
260
261
262
263
264
265
266
267
268
269
270
        lora_embedding.set_mapping(*mapping_info, )

        lora_result = lora_embedding(torch.cat(inputs))
        expected_result = embedding(torch.cat(inputs))

        rtol, atol = TOLERANCES[lora_result.dtype]
        assert torch.allclose(lora_result,
                              expected_result,
                              rtol=rtol,
                              atol=atol)


@torch.inference_mode()
271
272
# @pytest.mark.skip(
#     reason="Fails when loras are in any slot other than the first.")
273
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
274
@pytest.mark.parametrize("device", CUDA_DEVICES)
275
276
277
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
                                        vocab_size) -> None:
278

279
    torch.set_default_device(device)
280
281
282
283
284
285
    max_loras = 8
    lora_config = LoRAConfig(max_loras=max_loras,
                             max_lora_rank=8,
                             lora_dtype=torch.float16)

    def create_random_embedding_layer():
286
        embedding = VocabParallelEmbedding(vocab_size, 256)
287
288
        embedding_data = torch.rand_like(embedding.weight.data)
        embedding.weight.data = embedding_data
289
        embedding.weight.data[vocab_size:, :] = 0
290
        expanded_embedding = VocabParallelEmbedding(
291
            vocab_size + lora_config.lora_extra_vocab_size * max_loras,
292
            256,
293
294
            org_num_embeddings=vocab_size)
        expanded_embedding.weight.data[:vocab_size, :] = embedding_data
295
        # We need to deepcopy the embedding as it will be modified
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
        # in place
        lora_embedding = VocabParallelEmbeddingWithLoRA(
            deepcopy(expanded_embedding))
        lora_embedding.create_lora_weights(max_loras, lora_config)

        return expanded_embedding, lora_embedding

    for i in range(10):
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
        expanded_embedding, lora_embedding = create_random_embedding_layer()
        lora_dict, _ = populate_loras(
            id_to_index,
            layer=lora_embedding,
            layer_weights=torch.zeros(
312
                (256, vocab_size + lora_config.lora_extra_vocab_size)),
313
314
315
316
317
318
319
320
321
322
323
            generate_embeddings_tensor=256,
        )

        # All embeddings tensors have the same shape.
        embeddings_tensors = [
            lora_dict[id].embeddings_tensor for id in sorted(lora_dict.keys())
        ]
        embeddings_tensor_len = embeddings_tensors[0].shape[0]

        # Add empty embeddings_tensors for unoccupied lora slots.
        for _ in range(max_loras - len(embeddings_tensors)):
324
            embeddings_tensors.append(torch.zeros(embeddings_tensors[0].shape))
325
326
327
328
329

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=num_loras * 3,
            input_size=(200, ),
330
            input_range=(1, vocab_size),
331
332
333
334
335
336
337
338
339
340
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping)

        original_inputs = deepcopy(inputs)

        # Force some of the inputs to be in the extended embeddings range
        # to guarantee that their behavior is tested.
        for input_, original_input_, lora_id in zip(inputs, original_inputs,
                                                    prompt_mapping):
            embedding_id = lora_id - 1
341
342
343
344
345
            input_[-1] = vocab_size + (embedding_id * embeddings_tensor_len)
            original_input_[-1] = vocab_size
            input_[-2] = vocab_size + (
                (embedding_id + 1) * embeddings_tensor_len - 1)
            original_input_[-2] = vocab_size + embeddings_tensor_len - 1
346
347

        mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
348
349
                                       vocab_size,
                                       lora_config.lora_extra_vocab_size)
350
351
        lora_embedding.set_mapping(*mapping_info, )

352
        expanded_embedding.weight[vocab_size:vocab_size +
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
                                  (embeddings_tensor_len *
                                   max_loras)] = torch.cat(embeddings_tensors)

        lora_result = lora_embedding(torch.cat(original_inputs))

        expected_results = []
        for input_, original_input_, lora_id in zip(inputs, original_inputs,
                                                    prompt_mapping):
            lora = lora_dict[lora_id]
            result = expanded_embedding(input_)
            after_a = F.embedding(
                original_input_,
                lora.lora_a,
            )
            result += (after_a @ lora.lora_b)
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
        assert torch.allclose(lora_result,
                              expected_result,
                              rtol=rtol,
                              atol=atol)

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
            lora_embedding.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=num_loras * 3,
            input_size=(200, ),
386
            input_range=(1, vocab_size),
387
388
389
390
391
392
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping)

        original_inputs = deepcopy(inputs)

        mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
393
394
                                       vocab_size,
                                       lora_config.lora_extra_vocab_size)
395
396
397
398
399
400
401
402
403
404
405
406
407
408
        lora_embedding.set_mapping(*mapping_info, )

        lora_result = lora_embedding(torch.cat(original_inputs))
        expected_result = expanded_embedding(torch.cat(inputs))

        rtol, atol = TOLERANCES[lora_result.dtype]
        assert torch.allclose(lora_result,
                              expected_result,
                              rtol=rtol,
                              atol=atol)


@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
409
@pytest.mark.parametrize("device", CUDA_DEVICES)
410
411
412
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
def test_lm_head_logits_processor(dist_init, num_loras, device,
                                  vocab_size) -> None:
413

414
    torch.set_default_device(device)
415
416
417
418
419
    max_loras = 8
    lora_config = LoRAConfig(max_loras=max_loras,
                             max_lora_rank=8,
                             lora_dtype=torch.float16)

420
    def _pretest():
421
        linear = ParallelLMHead(vocab_size + lora_config.lora_extra_vocab_size,
422
423
424
                                1024,
                                vocab_size,
                                params_dtype=torch.float16)
425
        linear.weight.data = torch.rand_like(linear.weight.data)
426
        linear.weight.data[:, vocab_size:] = 0
427
        logits_processor = LogitsProcessor(
428
            vocab_size + lora_config.lora_extra_vocab_size, vocab_size)
429
430
431
        lora_logits_processor = LogitsProcessorWithLoRA(
            logits_processor, 1024, linear.weight.dtype, linear.weight.device)
        lora_logits_processor.create_lora_weights(max_loras, lora_config)
432

433
        return linear, logits_processor, lora_logits_processor
434
435
436
437
438

    for i in range(10):
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
439
        linear, logits_processor, lora_logits_processor = _pretest()
440
441
442
443

        # NOTE: all the generated loras share the same embeddings tensor.
        lora_dict, _ = populate_loras(
            id_to_index,
444
            layer=lora_logits_processor,
445
446
447
448
449
450
451
452
453
454
455
            layer_weights=linear.weight,
            generate_embeddings_tensor=1024,
        )
        embeddings_tensor = list(lora_dict.values())[0].embeddings_tensor
        embeddings_tensor_len = embeddings_tensor.shape[0]

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=8 * num_loras,  # * 3,
            input_size=(1, 1024),
            input_range=(0, 1),
456
            input_type=torch.float16,
457
458
459
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping)

460
        input_ = torch.rand(20, 1024)
461
462
463
464
        mapping_info = convert_mapping(
            lora_mapping,
            id_to_index,
            max_loras,
465
            vocab_size,
466
467
            lora_config.lora_extra_vocab_size,
        )
468
        lora_logits_processor.set_mapping(*mapping_info, )
469

470
471
472
473
        lora_result = lora_logits_processor._get_logits(
            hidden_states=torch.cat(inputs),
            embedding=linear.weight,
            embedding_bias=None)
474
475
476

        original_weight = linear.weight.clone()

477
478
        linear.weight[logits_processor.
                      org_vocab_size:logits_processor.org_vocab_size +
479
480
                      embeddings_tensor_len] = embeddings_tensor

481
        logits_processor.org_vocab_size = (vocab_size +
482
                                           lora_config.lora_extra_vocab_size)
483
484
485
        expected_results = []
        for input_, lora_id in zip(inputs, prompt_mapping):
            lora = lora_dict[lora_id]
486
487
488
            result = logits_processor._get_logits(hidden_states=input_,
                                                  embedding=linear.weight,
                                                  embedding_bias=None)
489
            result[:, vocab_size + embeddings_tensor_len:] = float("-inf")
490
491
492
            result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
            expected_results.append(result)
        expected_result = torch.cat(expected_results)
493
        logits_processor.org_vocab_size = vocab_size
494
495
496
497

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
498
            lora_logits_processor.reset_lora(slot_idx)
499
500
501
502
503
504

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=8 * num_loras * 3,
            input_size=(1, 1024),
            input_range=(0, 1),
505
            input_type=torch.float16,
506
507
508
509
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping)

        mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
510
                                       vocab_size,
511
                                       lora_config.lora_extra_vocab_size)
512
513
514
515
516
        lora_logits_processor.set_mapping(*mapping_info, )

        lora_result = lora_logits_processor._get_logits(
            hidden_states=torch.cat(inputs),
            embedding=original_weight,
517
            embedding_bias=None)[:, :vocab_size]
518
519
520
521
        expected_result = logits_processor._get_logits(
            hidden_states=torch.cat(inputs),
            embedding=original_weight,
            embedding_bias=None)
522
523
524
525
526
527
528
529
530
531
532

        rtol, atol = TOLERANCES[lora_result.dtype]
        assert torch.allclose(lora_result,
                              expected_result,
                              rtol=rtol,
                              atol=atol)


@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("orientation", ["row", "column"])
533
@pytest.mark.parametrize("fully_shard", [True, False])
534
@pytest.mark.parametrize("device", CUDA_DEVICES)
535
536
def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
                         device) -> None:
537

538
    torch.set_default_device(device)
539
540
541
    max_loras = 8
    lora_config = LoRAConfig(max_loras=max_loras,
                             max_lora_rank=8,
542
                             fully_sharded_loras=fully_shard,
543
544
545
546
                             lora_dtype=torch.float16)

    def create_random_linear_parallel_layer():
        if orientation == "row":
547
548
549
550
            linear = RowParallelLinear(4096,
                                       4096,
                                       bias=False,
                                       params_dtype=torch.float16)
551
            linear.weight.data = torch.rand_like(linear.weight.data)
552
553
            lora_linear = (RowParallelLinearWithLoRA(linear) if not fully_shard
                           else RowParallelLinearWithShardedLoRA(linear))
554
        else:
555
556
557
558
            linear = ColumnParallelLinear(4096,
                                          4096,
                                          bias=False,
                                          params_dtype=torch.float16)
559
            linear.weight.data = torch.rand_like(linear.weight.data)
560
561
562
            lora_linear = (ColumnParallelLinearWithLoRA(linear)
                           if not fully_shard else
                           ColumnParallelLinearWithShardedLoRA(linear))
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
        lora_linear.create_lora_weights(max_loras, lora_config)

        return linear, lora_linear

    for i in range(10):
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
        linear, lora_linear = create_random_linear_parallel_layer()

        lora_dict, _ = populate_loras(
            id_to_index,
            layer=lora_linear,
            layer_weights=linear.weight,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
584
            input_type=torch.float16,
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping)

        mapping_info = convert_mapping(
            lora_mapping,
            id_to_index,
            max_loras,
            512,
            lora_config.lora_extra_vocab_size,
        )
        lora_linear.set_mapping(*mapping_info, )

        lora_result = lora_linear(torch.cat(inputs))[0]

        expected_results = []
        for input_, lora_id in zip(inputs, prompt_mapping):
            lora = lora_dict[lora_id]
            result = linear(input_)[0]
            result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
        assert torch.allclose(lora_result,
                              expected_result,
                              rtol=rtol,
                              atol=atol)

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
            lora_linear.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
623
            input_type=torch.float16,
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping)

        mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
                                       512, lora_config.lora_extra_vocab_size)
        lora_linear.set_mapping(*mapping_info, )

        lora_result = lora_linear(torch.cat(inputs))[0]
        expected_result = linear(torch.cat(inputs))[0]

        rtol, atol = TOLERANCES[lora_result.dtype]
        assert torch.allclose(lora_result,
                              expected_result,
                              rtol=rtol,
                              atol=atol)


@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
643
@pytest.mark.parametrize("repeats", [1, 2, 3])
644
@pytest.mark.parametrize("fully_shard", [True, False])
645
@pytest.mark.parametrize("device", CUDA_DEVICES)
646
647
def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
                                device) -> None:
648

649
    torch.set_default_device(device)
650
651
652
    max_loras = 8
    lora_config = LoRAConfig(max_loras=max_loras,
                             max_lora_rank=8,
653
                             fully_sharded_loras=fully_shard,
654
655
656
657
658
                             lora_dtype=torch.float16)

    def create_column_parallel_packed_layer():
        if repeats == 2:
            linear = MergedColumnParallelLinear(4096, [4096] * repeats,
659
660
                                                bias=False,
                                                params_dtype=torch.float16)
661
            linear.weight.data = torch.rand_like(linear.weight.data)
662
663
664
            lora_linear = (MergedColumnParallelLinearWithLoRA(linear)
                           if not fully_shard else
                           MergedColumnParallelLinearWithShardedLoRA(linear))
665
        elif repeats == 3:
666
667
668
669
670
            linear = QKVParallelLinear(4096,
                                       64,
                                       32,
                                       bias=False,
                                       params_dtype=torch.float16)
671
            linear.weight.data = torch.rand_like(linear.weight.data)
672
673
674
            lora_linear = (MergedQKVParallelLinearWithLora(linear)
                           if not fully_shard else
                           MergedQKVParallelLinearWithShardedLora(linear))
675
        else:
676
677
678
679
680
            linear = QKVParallelLinear(4096,
                                       64,
                                       32,
                                       bias=False,
                                       params_dtype=torch.float16)
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
            linear.weight.data = torch.rand_like(linear.weight.data)
            lora_linear = QKVParallelLinearWithLora(linear)

        @dataclass
        class FakeConfig:
            hidden_size = 4096
            num_key_value_heads = 32
            num_attention_heads = 32

        lora_linear.create_lora_weights(max_loras,
                                        lora_config,
                                        model_config=FakeConfig())

        return linear, lora_linear

    for i in range(10):
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)

        linear, lora_linear = create_column_parallel_packed_layer()

        lora_dict, sublora_dict = populate_loras(
            id_to_index,
            layer=lora_linear,
            layer_weights=linear.weight,
            repeats=repeats,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
715
            input_type=torch.float16,
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping)

        mapping_info = convert_mapping(
            lora_mapping,
            id_to_index,
            max_loras,
            512,
            lora_config.lora_extra_vocab_size,
        )
        lora_linear.set_mapping(*mapping_info)

        lora_result = lora_linear(torch.cat(inputs))[0]

        expected_results = []
        for input_, lora_id in zip(inputs, prompt_mapping):
            result = linear(input_)[0]
            subloras = sublora_dict[lora_id]
            for i, sublora in enumerate(subloras):
735
736
737
                result[:, sublora.lora_b.shape[1] * i:sublora.lora_b.shape[1] *
                       (i + 1)] += (input_ @ sublora.lora_a @ sublora.lora_b *
                                    sublora.scaling)
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
        assert torch.allclose(lora_result,
                              expected_result,
                              rtol=rtol,
                              atol=atol)

        for slot_idx in range(max_loras):
            lora_linear.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
755
            input_type=torch.float16,
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping)

        mapping_info = convert_mapping(
            lora_mapping,
            id_to_index,
            max_loras,
            512,
            lora_config.lora_extra_vocab_size,
        )
        lora_linear.set_mapping(*mapping_info)

        lora_result = lora_linear(torch.cat(inputs))[0]
        expected_result = linear(torch.cat(inputs))[0]

        rtol, atol = TOLERANCES[lora_result.dtype]
        assert torch.allclose(lora_result,
                              expected_result,
                              rtol=rtol,
                              atol=atol)
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869


@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 8])
@pytest.mark.parametrize("device", ["cuda"])
@pytest.mark.parametrize("scaling_factors", [(1.0, ), (4.0, ), (4.0, 8.0),
                                             (6.0, 1.0)])
@pytest.mark.parametrize("max_position", [11, 4096, 32768])
@pytest.mark.parametrize("is_neox_style", [True, False])
@pytest.mark.parametrize("rotary_dim", [None, 32])
@pytest.mark.parametrize("head_size", [32, 108])
@pytest.mark.parametrize("seq_len", [11, 1024])
def test_rotary_embedding_long_context(dist_init, num_loras, device,
                                       scaling_factors, max_position,
                                       is_neox_style, rotary_dim, head_size,
                                       seq_len) -> None:
    dtype = torch.float16
    seed = 0
    torch.random.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    torch.set_default_device(device)

    max_loras = 8
    lora_config = LoRAConfig(max_loras=max_loras,
                             max_lora_rank=8,
                             long_lora_scaling_factors=scaling_factors,
                             lora_dtype=dtype)

    if rotary_dim is None:
        rotary_dim = head_size
    base = 10000
    batch_size = 5 * num_loras
    num_heads = 7

    # Verify lora is equivalent to linear scaling rotary embedding.
    rope = get_rope(
        head_size,
        rotary_dim,
        max_position,
        base,
        is_neox_style,
    )
    lora_rope = LinearScalingRotaryEmbeddingWithLora(rope)
    lora_rope.create_lora_weights(max_loras, lora_config)
    linear_rope = get_rope(head_size, rotary_dim, max_position, base,
                           is_neox_style, {
                               "type": "linear",
                               "factor": scaling_factors
                           })
    linear_rope = linear_rope.to(dtype=dtype)
    id_to_index = get_random_id_to_index(num_loras, max_loras)
    _, index_mapping, prompt_mapping = create_random_inputs(
        active_lora_ids=[0],
        num_inputs=batch_size,
        input_size=(1, max_position),
        input_range=(0, lora_config.lora_extra_vocab_size),
        input_type=torch.float16,
    )
    lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
    long_lora_context = LongContextLoRAContext(list(scaling_factors),
                                               rotary_dim)

    next_expected_offset = 0
    # Make sure the offset is correct.
    scaling_factor_to_offset = lora_rope.scaling_factor_to_offset
    for scaling_factor, offset in scaling_factor_to_offset.items():
        assert offset == next_expected_offset
        next_expected_offset += scaling_factor * max_position

    for i in range(len(scaling_factors)):
        long_lora_context.offsets_by_lora_id[i] = scaling_factor_to_offset.get(
            scaling_factors[i], 0)
    mapping_info = convert_mapping(
        lora_mapping,
        id_to_index,
        max_loras,
        512,
        lora_config.lora_extra_vocab_size,
        long_lora_context=long_lora_context,
    )
    lora_rope.set_mapping(*mapping_info)

    positions = torch.randint(0, max_position, (batch_size, seq_len))
    query = torch.randn(batch_size,
                        seq_len,
                        num_heads * head_size,
                        dtype=dtype)
    key = torch.randn_like(query)
    ref_q, ref_k = linear_rope(positions, query, key)
    actual_q, actual_k = lora_rope(positions, query, key)

    torch.allclose(ref_q, actual_q)
    torch.allclose(ref_k, actual_k)