test_layers.py 37.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
import random
from copy import deepcopy
from dataclasses import dataclass
7
from unittest.mock import patch
8

9
import pytest
10
11
12
import torch
import torch.nn.functional as F

13
from vllm.config.lora import LoRAConfig
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from vllm.lora.layers import (
    BaseLayerWithLoRA,
    ColumnParallelLinearWithLoRA,
    ColumnParallelLinearWithShardedLoRA,
    LogitsProcessorWithLoRA,
    LoRAMapping,
    MergedColumnParallelLinearWithLoRA,
    MergedColumnParallelLinearWithShardedLoRA,
    MergedQKVParallelLinearWithLoRA,
    MergedQKVParallelLinearWithShardedLoRA,
    QKVParallelLinearWithLoRA,
    QKVParallelLinearWithShardedLoRA,
    ReplicatedLinearWithLoRA,
    RowParallelLinearWithLoRA,
    RowParallelLinearWithShardedLoRA,
    VocabParallelEmbeddingWithLoRA,
)
31
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
32
from vllm.lora.punica_wrapper import get_punica_wrapper
33
34
35
36
37
38
39
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    MergedColumnParallelLinear,
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
40
from vllm.model_executor.layers.logits_processor import LogitsProcessor
41
from vllm.model_executor.layers.vocab_parallel_embedding import (
42
43
44
45
    ParallelLMHead,
    VocabParallelEmbedding,
    get_masked_input_and_mask,
)
46
from vllm.platforms import current_platform
47
from vllm.utils.torch_utils import set_random_seed
48
49
50
51
52
53
54
55

from .utils import DummyLoRAManager

TOLERANCES = {
    torch.float16: (5e-3, 5e-3),
    torch.float32: (5e-3, 5e-3),
    torch.bfloat16: (3e-2, 2e-2),
}
56
57
58

pytestmark = pytest.mark.skipif(
    not (current_platform.is_cuda_alike() or current_platform.is_cpu()),
59
60
    reason="Backend not supported",
)
61

62
63
64
65
66
DEVICES = (
    [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
    if current_platform.is_cuda_alike()
    else ["cpu"]
)
67

68
# prefill stage(True) or decode stage(False)
69
STAGES = [True, False]
70

71
NUM_RANDOM_SEEDS = 2
72

73
VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS = 2
74
75
76


@pytest.fixture(autouse=True)
77
def clean_cache_reset_device(reset_default_device):
78
    # Release any memory we might be holding on to. CI runs OOMs otherwise.
79
80
    from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT

81
82
83
84
85
    _LORA_B_PTR_DICT.clear()
    _LORA_A_PTR_DICT.clear()

    yield

86

87
88
89
@pytest.fixture(autouse=True)
def skip_cuda_with_stage_false(request):
    """
90
    On cuda-like platforms, we use the same kernels for prefill and decode
91
92
93
94
95
    stage, and 'stage' is generally ignored, so we only need to test once.
    """
    if current_platform.is_cuda_alike():
        try:
            if hasattr(request.node, "callspec") and hasattr(
96
97
                request.node.callspec, "params"
            ):
98
99
100
101
102
103
104
105
                params = request.node.callspec.params
                if "stage" in params and params["stage"] is False:
                    pytest.skip("Skip test when stage=False")
        except Exception:
            pass
    yield


106
107
def get_random_id_to_index(
    num_loras: int, num_slots: int, log: bool = True
108
) -> list[int | None]:
109
110
111
112
113
114
115
116
117
118
119
120
    """Creates a random lora_id_to_index mapping.

    Args:
        num_loras: The number of active loras in the mapping.
        num_slots: The number of slots in the mapping. Must be larger
            than num_loras.
        log: Whether to log the output.
    """

    if num_loras > num_slots:
        raise ValueError(
            f"num_loras is higher than num_slots: {num_loras} > {num_slots}. "
121
122
            "num_loras must be less than or equal to num_slots."
        )
123

124
    slots: list[int | None] = [None] * num_slots
125
126
127
128
129
130
131
132
133
134
135
    random_slot_selections = (torch.randperm(num_slots)[:num_loras]).tolist()
    for lora_id, slot_idx in enumerate(random_slot_selections, start=1):
        slots[slot_idx] = lora_id

    if log:
        print(f"Created lora_id_to_index mapping: {slots}.")

    return slots


def populate_loras(
136
    id_to_index: list[int | None],
137
138
139
    layer: BaseLayerWithLoRA,
    layer_weights: torch.Tensor,
    repeats: int = 1,
140
) -> tuple[dict[int, LoRALayerWeights], dict[int, list[LoRALayerWeights]]]:
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
    """This method populates the lora layers with lora weights.

    Args:
        id_to_index: a list of lora ids. The index of the lora id
            represents which memory slot the lora matrices are
            stored in. A None value indicates a free slot.
        layer: the LoRAlayer to populate.
        layer_weights: the PyTorch tensor containing the layer's
            weights.
        repeats: must only be set for column parallel packed
            layers. Indicates the number of loras to compose
            together to create a single lora layer.
    """

    # Dictionary that maps the lora ID to the
    # corresponding lora weights.
157
    lora_dict: dict[int, LoRALayerWeights] = dict()
158
159

    # Dictionary that maps the lora ID to the
160
    # corresponding subloras.
161
    sublora_dict: dict[int, list[LoRALayerWeights]] = dict()
162
163
164

    for slot_idx, lora_id in enumerate(id_to_index):
        if lora_id is not None:
165
            subloras: list[LoRALayerWeights] = []
166
167
            sublora_len = layer_weights.shape[0] // repeats
            for i in range(repeats):
168
169
170
171
172
173
174
                sublora = DummyLoRAManager(layer_weights.device).init_random_lora(
                    module_name=f"fake_{i}",
                    weight=layer_weights,
                )
                sublora.lora_b = sublora.lora_b[
                    (sublora_len * i) : (sublora_len * (i + 1)), :
                ]
175
176
177
                sublora.optimize()
                subloras.append(sublora)

178
            lora = PackedLoRALayerWeights.pack(subloras) if repeats > 1 else subloras[0]
179
180
181
182
183
184
185
186
187
188
189
190
191
192

            layer.set_lora(
                slot_idx,
                lora_a=lora.lora_a,
                lora_b=lora.lora_b,
            )

            lora_dict[lora_id] = lora
            sublora_dict[lora_id] = subloras

    return lora_dict, sublora_dict


def create_random_inputs(
193
    active_lora_ids: list[int],
194
    num_inputs: int,
195
196
    input_size: tuple[int, ...],
    input_range: tuple[float, float],
197
    input_type: torch.dtype = torch.int,
198
    device: torch.device = "cuda",
199
) -> tuple[list[torch.Tensor], list[int], list[int]]:
200
201
202
203
204
205
206
207
208
209
210
211
212
    """Creates random inputs.

    Args:
        active_lora_ids: lora IDs of active lora weights.
        num_inputs: the number of inputs to create.
        input_size: the size of each individual input.
        input_range: the range of values to include in the input.
            input_range[0] <= possible input values < input_range[1]
        input_type: the type of values in the input.
    """

    low, high = input_range

213
214
215
    inputs: list[torch.Tensor] = []
    index_mapping: list[int] = []
    prompt_mapping: list[int] = []
216

217
218
219
    for _ in range(num_inputs):
        if input_type == torch.int:
            inputs.append(
220
221
222
223
                torch.randint(
                    low=int(low), high=int(high), size=input_size, device=device
                )
            )
224
225
        else:
            inputs.append(
226
227
228
                torch.rand(size=input_size, dtype=input_type, device=device) * high
                + low
            )
229
230
231
232
233
234
235
236

        lora_id = random.choice(active_lora_ids)
        index_mapping += [lora_id] * input_size[0]
        prompt_mapping += [lora_id]

    return inputs, index_mapping, prompt_mapping


237
238
239
240
241
def check_punica_wrapper(punica_wrapper) -> bool:
    if current_platform.is_cuda_alike():
        from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU

        return type(punica_wrapper) is PunicaWrapperGPU
242
243
244
245
    elif current_platform.is_cpu():
        from vllm.lora.punica_wrapper.punica_cpu import PunicaWrapperCPU

        return type(punica_wrapper) is PunicaWrapperCPU
246
247
248
249
    else:
        return False


250
@torch.inference_mode()
251
@pytest.mark.parametrize("num_loras", [1, 2, 4])
252
@pytest.mark.parametrize("device", DEVICES)
253
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
254
@pytest.mark.parametrize("stage", STAGES)
255
256
257
def test_embeddings(
    default_vllm_config, dist_init, num_loras, device, vocab_size, stage
) -> None:
258
259
260
    # For multi-GPU testing of Triton kernel, we must explicitly set the CUDA
    # device, see: https://github.com/triton-lang/triton/issues/2925
    # Same below.
261
262
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)
263

264
    torch.set_default_device(device)
265
    max_loras = 8
266
267
268
    lora_config = LoRAConfig(
        max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16
    )
269
270
    punica_wrapper = get_punica_wrapper(8192, 256, device, lora_config=lora_config)
    assert check_punica_wrapper(punica_wrapper)
271
272

    def create_random_embedding_layer():
273
        embedding = VocabParallelEmbedding(vocab_size, 256)
274
        embedding.weight.data = torch.rand_like(embedding.weight.data)
275
        embedding.weight.data[vocab_size:, :] = 0
276
277
278
279
280
        lora_embedding = VocabParallelEmbeddingWithLoRA(embedding)
        lora_embedding.create_lora_weights(max_loras, lora_config)

        return embedding, lora_embedding

281
    for i in range(NUM_RANDOM_SEEDS):
282
283
284
285
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
        embedding, lora_embedding = create_random_embedding_layer()
286
        lora_embedding.set_mapping(punica_wrapper)
287
288
289
290
291
292
293
294
295
        lora_dict, _ = populate_loras(
            id_to_index,
            layer=lora_embedding,
            layer_weights=embedding.weight.T,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=num_loras * 3,
296
            input_size=(200,),
297
            input_range=(1, vocab_size),
298
299
300
301
302
303
304
305
306
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
        punica_wrapper.update_metadata(
            lora_mapping,
            id_to_index,
            max_loras,
            vocab_size,
        )
307
308
309

        lora_result = lora_embedding(torch.cat(inputs))

310
        expected_results: list[torch.Tensor] = []
311
312
313
314
315
        for input_, lora_id in zip(inputs, prompt_mapping):
            lora = lora_dict[lora_id]
            result = embedding(input_)
            after_a = F.embedding(
                input_,
316
                lora.lora_a.T,
317
            )
318
            result += after_a @ lora.lora_b.T
319
320
321
322
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
323
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
324
325
326
327
328
329
330
331
332

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
            lora_embedding.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=num_loras * 3,
333
            input_size=(200,),
334
            input_range=(1, vocab_size),
335
336
337
338
339
340
341
342
343
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
        punica_wrapper.update_metadata(
            lora_mapping,
            id_to_index,
            max_loras,
            vocab_size,
        )
344
345
346
347
348

        lora_result = lora_embedding(torch.cat(inputs))
        expected_result = embedding(torch.cat(inputs))

        rtol, atol = TOLERANCES[lora_result.dtype]
349
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
350
351
352


@torch.inference_mode()
353
@pytest.mark.parametrize("num_loras", [1, 2, 4])
354
@pytest.mark.parametrize("device", DEVICES)
355
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 256512])
356
@pytest.mark.parametrize("stage", STAGES)
357
def test_lm_head_logits_processor(
358
    default_vllm_config, dist_init, num_loras, device, vocab_size, stage
359
) -> None:
360
361
362
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)

363
    torch.set_default_device(device)
364
    max_loras = 8
365
366
367
    lora_config = LoRAConfig(
        max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16
    )
368
369
    punica_wrapper = get_punica_wrapper(8192, 256, device, lora_config=lora_config)
    assert check_punica_wrapper(punica_wrapper)
370

371
    def _pretest():
372
        linear = ParallelLMHead(
373
374
            num_embeddings=vocab_size,
            embedding_dim=1024,
375
376
            params_dtype=torch.float16,
        )
377
        linear.weight.data = torch.rand_like(linear.weight.data)
378
        linear.weight.data[:, vocab_size:] = 0
379
        logits_processor = LogitsProcessor(vocab_size)
380
        lora_logits_processor = LogitsProcessorWithLoRA(
381
382
            logits_processor, 1024, linear.weight.dtype, linear.weight.device, None
        )
383
        lora_logits_processor.create_lora_weights(max_loras, lora_config)
384

385
        return linear, logits_processor, lora_logits_processor
386

387
    for i in range(NUM_RANDOM_SEEDS):
388
389
390
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
391
        linear, logits_processor, lora_logits_processor = _pretest()
392
        lora_logits_processor.set_mapping(punica_wrapper)
393

394
395
        lora_dict, _ = populate_loras(
            id_to_index,
396
            layer=lora_logits_processor,
397
398
399
400
401
402
403
404
            layer_weights=linear.weight,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=8 * num_loras,  # * 3,
            input_size=(1, 1024),
            input_range=(0, 1),
405
            input_type=torch.float16,
406
407
408
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
409
        punica_wrapper.update_metadata(
410
411
412
            lora_mapping,
            id_to_index,
            max_loras,
413
            vocab_size,
414
        )
415
        input_ = torch.rand(20, 1024)
416

417
        lora_result = lora_logits_processor._get_logits(
418
419
            hidden_states=torch.cat(inputs), lm_head=linear, embedding_bias=None
        )
420

421
        original_lm_head = deepcopy(linear)
422

423
        expected_results: list[torch.Tensor] = []
424
425
        for input_, lora_id in zip(inputs, prompt_mapping):
            lora = lora_dict[lora_id]
426
427
428
            result = logits_processor._get_logits(
                hidden_states=input_, lm_head=linear, embedding_bias=None
            )
429

430
            result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
431
432
433
434
435
436
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
437
            lora_logits_processor.reset_lora(slot_idx)
438
439
440
441
442
443

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=8 * num_loras * 3,
            input_size=(1, 1024),
            input_range=(0, 1),
444
            input_type=torch.float16,
445
446
447
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
448
449
450
451
452
453
        punica_wrapper.update_metadata(
            lora_mapping,
            id_to_index,
            max_loras,
            vocab_size,
        )
454
455
456

        lora_result = lora_logits_processor._get_logits(
            hidden_states=torch.cat(inputs),
457
            lm_head=original_lm_head,
458
459
            embedding_bias=None,
        )[:, :vocab_size]
460
461
        expected_result = logits_processor._get_logits(
            hidden_states=torch.cat(inputs),
462
            lm_head=original_lm_head,
463
464
            embedding_bias=None,
        )
465
466

        rtol, atol = TOLERANCES[lora_result.dtype]
467
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
468
469


470
@torch.inference_mode()
471
@pytest.mark.parametrize("num_loras", [1, 2, 4])
472
@pytest.mark.parametrize("device", DEVICES)
473
@pytest.mark.parametrize("stage", STAGES)
474
def test_linear_replicated(
475
    default_vllm_config,
476
477
478
479
480
    dist_init,
    num_loras,
    device,
    stage,
) -> None:
481
482
483
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)

484
    max_loras = 8
485
    torch.set_default_device(device)
486
487
488
489
490
    lora_config = LoRAConfig(
        max_loras=max_loras,
        max_lora_rank=8,
        lora_dtype=torch.float16,
    )
491
492
    punica_wrapper = get_punica_wrapper(8192, 256, device, lora_config=lora_config)
    assert check_punica_wrapper(punica_wrapper)
493
494

    def create_random_linear_replicated_layer():
495
        linear = ReplicatedLinear(4096, 4096, bias=False, params_dtype=torch.float16)
496
497
498
499
        linear.weight.data = torch.rand_like(linear.weight.data)
        lora_linear = ReplicatedLinearWithLoRA(linear)

        lora_linear.create_lora_weights(max_loras, lora_config)
500
501
502
503
504
505
        assert (
            lora_linear.n_slices
            == len(lora_linear.lora_a_stacked)
            == len(lora_linear.lora_b_stacked)
            == 1
        )
506
507
        return linear, lora_linear

508
    for i in range(NUM_RANDOM_SEEDS):
509
510
511
512
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
        linear, lora_linear = create_random_linear_replicated_layer()
513
        assert torch.equal(linear.weight, lora_linear.weight)
514
515
516
517
518
519
520
521
522
523
524
525
526
        lora_linear.set_mapping(punica_wrapper)
        lora_dict, _ = populate_loras(
            id_to_index,
            layer=lora_linear,
            layer_weights=linear.weight,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
            input_type=torch.float16,
527
528
529
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
530
531
532
533
534
535
536
537
538
        punica_wrapper.update_metadata(
            lora_mapping,
            id_to_index,
            max_loras,
            512,
        )

        lora_result = lora_linear(torch.cat(inputs))[0]

539
        expected_results: list[torch.Tensor] = []
540
541
542
        for input_, lora_id in zip(inputs, prompt_mapping):
            lora = lora_dict[lora_id]
            result = linear(input_)[0]
543
            result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
544
545
546
547
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
548
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
549
550
551
552
553
554
555
556
557
558
559
560

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
            lora_linear.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
            input_type=torch.float16,
561
562
563
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
564

565
        punica_wrapper.update_metadata(
566
567
568
569
            lora_mapping,
            id_to_index,
            max_loras,
            512,
570
        )
571
572
573
574
575

        lora_result = lora_linear(torch.cat(inputs))[0]
        expected_result = linear(torch.cat(inputs))[0]

        rtol, atol = TOLERANCES[lora_result.dtype]
576
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
577
578


579
@torch.inference_mode()
580
@pytest.mark.parametrize("num_loras", [1, 2, 4])
581
@pytest.mark.parametrize("orientation", ["row", "column"])
582
@pytest.mark.parametrize("fully_shard", [True, False])
583
@pytest.mark.parametrize("device", DEVICES)
584
@pytest.mark.parametrize("stage", STAGES)
585
def test_linear_parallel(
586
    default_vllm_config, dist_init, num_loras, orientation, fully_shard, device, stage
587
) -> None:
588
589
590
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)

591
    max_loras = 8
592
    torch.set_default_device(device)
593
594
595
596
597
598
    lora_config = LoRAConfig(
        max_loras=max_loras,
        max_lora_rank=8,
        fully_sharded_loras=fully_shard,
        lora_dtype=torch.float16,
    )
599
600
    punica_wrapper = get_punica_wrapper(8192, 256, device, lora_config=lora_config)
    assert check_punica_wrapper(punica_wrapper)
601
602
603

    def create_random_linear_parallel_layer():
        if orientation == "row":
604
605
606
            linear = RowParallelLinear(
                4096, 4096, bias=False, params_dtype=torch.float16
            )
607
            linear.weight.data = torch.rand_like(linear.weight.data)
608
609
610
611
612
            lora_linear = (
                RowParallelLinearWithLoRA(linear)
                if not fully_shard
                else RowParallelLinearWithShardedLoRA(linear)
            )
613
        else:
614
615
616
            linear = ColumnParallelLinear(
                4096, 4096, bias=False, params_dtype=torch.float16
            )
617
            linear.weight.data = torch.rand_like(linear.weight.data)
618
619
620
621
622
            lora_linear = (
                ColumnParallelLinearWithLoRA(linear)
                if not fully_shard
                else ColumnParallelLinearWithShardedLoRA(linear)
            )
623
        lora_linear.create_lora_weights(max_loras, lora_config)
624
625
626
627
628
629
        assert (
            lora_linear.n_slices
            == len(lora_linear.lora_a_stacked)
            == len(lora_linear.lora_b_stacked)
            == 1
        )
630

631
632
        return linear, lora_linear

633
    for i in range(NUM_RANDOM_SEEDS):
634
635
636
637
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
        linear, lora_linear = create_random_linear_parallel_layer()
638
        assert torch.equal(linear.weight, lora_linear.weight)
639
        lora_linear.set_mapping(punica_wrapper)
640
641
642
643
644
645
646
647
648
649
650
        lora_dict, _ = populate_loras(
            id_to_index,
            layer=lora_linear,
            layer_weights=linear.weight,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
651
            input_type=torch.float16,
652
653
654
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
655
        punica_wrapper.update_metadata(
656
657
658
659
660
661
662
663
            lora_mapping,
            id_to_index,
            max_loras,
            512,
        )

        lora_result = lora_linear(torch.cat(inputs))[0]

664
        expected_results: list[torch.Tensor] = []
665
666
667
        for input_, lora_id in zip(inputs, prompt_mapping):
            lora = lora_dict[lora_id]
            result = linear(input_)[0]
668
            result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
669
670
671
672
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
673
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
674
675
676
677
678
679
680
681
682
683
684

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
            lora_linear.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
685
            input_type=torch.float16,
686
687
688
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
689

690
        punica_wrapper.update_metadata(
691
692
693
694
            lora_mapping,
            id_to_index,
            max_loras,
            512,
695
        )
696
697
698
699
700

        lora_result = lora_linear(torch.cat(inputs))[0]
        expected_result = linear(torch.cat(inputs))[0]

        rtol, atol = TOLERANCES[lora_result.dtype]
701
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
702
703
704


@torch.inference_mode()
705
@pytest.mark.parametrize("num_loras", [1, 2, 4])
706
@pytest.mark.parametrize("repeats", [1, 2, 3])
707
@pytest.mark.parametrize("fully_shard", [True, False])
708
@pytest.mark.parametrize("device", DEVICES)
709
@pytest.mark.parametrize("stage", STAGES)
710
def test_column_parallel_packed(
711
    default_vllm_config, dist_init, num_loras, repeats, fully_shard, device, stage
712
) -> None:
713
714
715
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)

716
    max_loras = 8
717
    torch.set_default_device(device)
718
719
720
721
722
723
    lora_config = LoRAConfig(
        max_loras=max_loras,
        max_lora_rank=8,
        fully_sharded_loras=fully_shard,
        lora_dtype=torch.float16,
    )
724
725
    punica_wrapper = get_punica_wrapper(8192, 256, device, lora_config=lora_config)
    assert check_punica_wrapper(punica_wrapper)
726
727
728

    def create_column_parallel_packed_layer():
        if repeats == 2:
729
730
731
            linear = MergedColumnParallelLinear(
                4096, [4096] * repeats, bias=False, params_dtype=torch.float16
            )
732
            linear.weight.data = torch.rand_like(linear.weight.data)
733
734
735
736
737
            lora_linear = (
                MergedColumnParallelLinearWithLoRA(linear)
                if not fully_shard
                else MergedColumnParallelLinearWithShardedLoRA(linear)
            )
738
        elif repeats == 3:
739
740
741
            linear = QKVParallelLinear(
                4096, 64, 32, bias=False, params_dtype=torch.float16
            )
742
            linear.weight.data = torch.rand_like(linear.weight.data)
743
744
745
746
747
            lora_linear = (
                MergedQKVParallelLinearWithLoRA(linear)
                if not fully_shard
                else MergedQKVParallelLinearWithShardedLoRA(linear)
            )
748
        else:
749
750
751
            linear = QKVParallelLinear(
                4096, 64, 32, bias=False, params_dtype=torch.float16
            )
752
            linear.weight.data = torch.rand_like(linear.weight.data)
753
754
755
756
757
            lora_linear = (
                QKVParallelLinearWithLoRA(linear)
                if not fully_shard
                else QKVParallelLinearWithShardedLoRA(linear)
            )
758
759
760
761
762
763
764

        @dataclass
        class FakeConfig:
            hidden_size = 4096
            num_key_value_heads = 32
            num_attention_heads = 32

765
        n_slices = repeats
766
767
768
769
770
771
772
773
774
        lora_linear.create_lora_weights(
            max_loras, lora_config, model_config=FakeConfig()
        )
        assert (
            lora_linear.n_slices
            == len(lora_linear.lora_a_stacked)
            == len(lora_linear.lora_b_stacked)
            == n_slices
        )
775

776
777
        return linear, lora_linear

778
    for i in range(NUM_RANDOM_SEEDS):
779
780
781
782
783
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)

        linear, lora_linear = create_column_parallel_packed_layer()
784
        assert torch.equal(linear.weight, lora_linear.weight)
785
        lora_linear.set_mapping(punica_wrapper)
786
787
788
789
790
791
792
793
794
795
796
797
        lora_dict, sublora_dict = populate_loras(
            id_to_index,
            layer=lora_linear,
            layer_weights=linear.weight,
            repeats=repeats,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
798
            input_type=torch.float16,
799
800
801
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
802

803
        punica_wrapper.update_metadata(
804
805
806
807
808
809
810
811
            lora_mapping,
            id_to_index,
            max_loras,
            512,
        )

        lora_result = lora_linear(torch.cat(inputs))[0]

812
        expected_results: list[torch.Tensor] = []
813
814
815
816
        for input_, lora_id in zip(inputs, prompt_mapping):
            result = linear(input_)[0]
            subloras = sublora_dict[lora_id]
            for i, sublora in enumerate(subloras):
817
818
819
                result[
                    :, sublora.lora_b.shape[0] * i : sublora.lora_b.shape[0] * (i + 1)
                ] += input_ @ sublora.lora_a.T @ sublora.lora_b.T * sublora.scaling
820
821
822
823
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
824
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
825
826
827
828
829
830
831
832
833

        for slot_idx in range(max_loras):
            lora_linear.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
834
            input_type=torch.float16,
835
836
837
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
838

839
        punica_wrapper.update_metadata(
840
841
842
843
844
845
846
847
848
849
            lora_mapping,
            id_to_index,
            max_loras,
            512,
        )

        lora_result = lora_linear(torch.cat(inputs))[0]
        expected_result = linear(torch.cat(inputs))[0]

        rtol, atol = TOLERANCES[lora_result.dtype]
850
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
851
852


853
@pytest.mark.parametrize("tp_size", [1, 2, 4, 8])
854
@pytest.mark.parametrize(
855
856
    "seed", list(range(VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS))
)
857
def test_vocab_parallel_embedding_indices(tp_size, seed, default_vllm_config):
858
859
860
861
862
863
864
865
866
867
868
    random.seed(seed)
    vocab_size = random.randint(4000, 64000)
    added_vocab_size = random.randint(0, 1024)
    org_vocab_size = vocab_size - added_vocab_size
    last_org_vocab_end_index = 0
    last_added_vocab_end_index = org_vocab_size
    computed_vocab_size = 0
    computed_org_vocab_size = 0
    computed_added_vocab_size = 0
    vocab_size_padded = -1

869
870
871
    all_org_tokens: list[int] = []
    all_added_tokens: list[int] = []
    token_ids: list[int] = []
872
873

    for tp_rank in range(tp_size):
874
875
        with (
            patch(
876
                "vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank",
877
878
879
                return_value=tp_rank,
            ),
            patch(
880
                "vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_world_size",
881
882
883
                return_value=tp_size,
            ),
        ):
884
            vocab_embedding = VocabParallelEmbedding(
885
886
                vocab_size, 1, org_num_embeddings=org_vocab_size
            )
887
888
889
890
        vocab_size_padded = vocab_embedding.num_embeddings_padded
        shard_indices = vocab_embedding.shard_indices
        # Assert that the ranges are contiguous
        assert shard_indices.org_vocab_start_index == last_org_vocab_end_index
891
        assert shard_indices.added_vocab_start_index == last_added_vocab_end_index
892
893
894
895
896
897
898
899

        # Ensure that we are not exceeding the vocab size
        computed_vocab_size += shard_indices.num_elements_padded
        computed_org_vocab_size += shard_indices.num_org_elements
        computed_added_vocab_size += shard_indices.num_added_elements

        # Ensure that the ranges are not overlapping
        all_org_tokens.extend(
900
901
902
903
            range(
                shard_indices.org_vocab_start_index, shard_indices.org_vocab_end_index
            )
        )
904
        all_added_tokens.extend(
905
906
907
908
909
            range(
                shard_indices.added_vocab_start_index,
                shard_indices.added_vocab_end_index,
            )
        )
910
911

        token_ids.extend(
912
913
914
915
916
917
918
919
920
921
922
923
924
925
            range(
                shard_indices.org_vocab_start_index, shard_indices.org_vocab_end_index
            )
        )
        token_ids.extend(
            [-1]
            * (shard_indices.num_org_elements_padded - shard_indices.num_org_elements)
        )
        token_ids.extend(
            range(
                shard_indices.added_vocab_start_index,
                shard_indices.added_vocab_end_index,
            )
        )
926
        token_ids.extend(
927
928
929
930
931
932
            [-1]
            * (
                shard_indices.num_added_elements_padded
                - shard_indices.num_added_elements
            )
        )
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959

        last_org_vocab_end_index = shard_indices.org_vocab_end_index
        last_added_vocab_end_index = shard_indices.added_vocab_end_index

    assert computed_vocab_size == vocab_size_padded
    assert computed_org_vocab_size == org_vocab_size
    assert computed_added_vocab_size == added_vocab_size

    # Ensure that the ranges are not overlapping
    assert len(all_org_tokens) == len(set(all_org_tokens))
    assert len(all_added_tokens) == len(set(all_added_tokens))
    assert not set(all_org_tokens).intersection(set(all_added_tokens))

    token_ids_tensor = torch.tensor(token_ids, dtype=torch.long)
    reindex_mapping = vocab_embedding.get_sharded_to_full_mapping()
    assert reindex_mapping is not None or tp_size == 1
    if reindex_mapping is not None:
        reindexed_token_ids = token_ids_tensor[reindex_mapping]
        expected = torch.tensor(list(range(0, vocab_size)))
        assert reindexed_token_ids[:vocab_size].equal(expected)
        assert torch.all(reindexed_token_ids[vocab_size:] == -1)


def test_get_masked_input_and_mask():
    x = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])

    # base tp 1 case, no padding
960
961
962
963
964
965
966
967
    modified_x, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=0,
        org_vocab_end_index=8,
        added_vocab_start_index=8,
        added_vocab_end_index=12,
        num_org_vocab_padding=0,
    )
968
969
970
    assert torch.equal(x, modified_x)

    # tp 2 case, no padding
971
972
973
974
975
976
977
978
    modified_x_rank_0, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=0,
        org_vocab_end_index=4,
        added_vocab_start_index=8,
        added_vocab_end_index=10,
        num_org_vocab_padding=0,
    )
979
980
981
982
983
984
    modified_x_rank_1, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=4,
        org_vocab_end_index=8,
        added_vocab_start_index=10,
        added_vocab_end_index=12,
985
986
987
988
989
990
991
992
        num_org_vocab_padding=0,
    )
    assert torch.equal(
        modified_x_rank_0, torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 0, 0])
    )
    assert torch.equal(
        modified_x_rank_1, torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 4, 5])
    )
993
994

    # tp 4 case, no padding
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
    modified_x_rank_0, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=0,
        org_vocab_end_index=2,
        added_vocab_start_index=8,
        added_vocab_end_index=9,
        num_org_vocab_padding=0,
    )
    modified_x_rank_1, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=2,
        org_vocab_end_index=4,
        added_vocab_start_index=9,
        added_vocab_end_index=10,
        num_org_vocab_padding=0,
    )
1011
1012
1013
1014
1015
1016
    modified_x_rank_2, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=4,
        org_vocab_end_index=6,
        added_vocab_start_index=10,
        added_vocab_end_index=11,
1017
1018
        num_org_vocab_padding=0,
    )
1019
1020
1021
1022
1023
1024
    modified_x_rank_3, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=6,
        org_vocab_end_index=8,
        added_vocab_start_index=11,
        added_vocab_end_index=12,
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
        num_org_vocab_padding=0,
    )
    assert torch.equal(
        modified_x_rank_0, torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0])
    )
    assert torch.equal(
        modified_x_rank_1, torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0])
    )
    assert torch.equal(
        modified_x_rank_2, torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0])
    )
    assert torch.equal(
        modified_x_rank_3, torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2])
    )
1039
1040

    # base tp 1 case, with padding
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
    modified_x, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=0,
        org_vocab_end_index=8,
        added_vocab_start_index=8,
        added_vocab_end_index=12,
        num_org_vocab_padding=2,
    )
    assert torch.equal(
        modified_x, torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13])
    )
1052
1053

    # tp 2 case, with padding
1054
1055
1056
1057
1058
1059
1060
1061
    modified_x_rank_0, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=0,
        org_vocab_end_index=4,
        added_vocab_start_index=8,
        added_vocab_end_index=10,
        num_org_vocab_padding=2,
    )
1062
1063
1064
1065
1066
1067
    modified_x_rank_1, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=4,
        org_vocab_end_index=8,
        added_vocab_start_index=10,
        added_vocab_end_index=12,
1068
1069
1070
1071
1072
1073
1074
1075
        num_org_vocab_padding=2,
    )
    assert torch.equal(
        modified_x_rank_0, torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 6, 7, 0, 0])
    )
    assert torch.equal(
        modified_x_rank_1, torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 6, 7])
    )
1076
1077

    # tp 4 case, with padding
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
    modified_x_rank_0, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=0,
        org_vocab_end_index=2,
        added_vocab_start_index=8,
        added_vocab_end_index=9,
        num_org_vocab_padding=2,
    )
    modified_x_rank_1, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=2,
        org_vocab_end_index=4,
        added_vocab_start_index=9,
        added_vocab_end_index=10,
        num_org_vocab_padding=2,
    )
1094
1095
1096
1097
1098
1099
    modified_x_rank_2, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=4,
        org_vocab_end_index=6,
        added_vocab_start_index=10,
        added_vocab_end_index=11,
1100
1101
        num_org_vocab_padding=2,
    )
1102
1103
1104
1105
1106
1107
    modified_x_rank_3, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=6,
        org_vocab_end_index=8,
        added_vocab_start_index=11,
        added_vocab_end_index=12,
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
        num_org_vocab_padding=2,
    )
    assert torch.equal(
        modified_x_rank_0, torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0])
    )
    assert torch.equal(
        modified_x_rank_1, torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 4, 0, 0])
    )
    assert torch.equal(
        modified_x_rank_2, torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 4, 0])
    )
    assert torch.equal(
        modified_x_rank_3, torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 4])
    )