"tests/models/multimodal/processing/test_llava_next.py" did not exist on "8f37be38ebfe0295a4925837c501c87149997a4d"
test_layers.py 37 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
import random
from copy import deepcopy
from dataclasses import dataclass
7
from unittest.mock import patch
8

9
import pytest
10
11
12
import torch
import torch.nn.functional as F

13
from vllm.config.lora import LoRAConfig
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from vllm.lora.layers import (
    BaseLayerWithLoRA,
    ColumnParallelLinearWithLoRA,
    ColumnParallelLinearWithShardedLoRA,
    LogitsProcessorWithLoRA,
    LoRAMapping,
    MergedColumnParallelLinearWithLoRA,
    MergedColumnParallelLinearWithShardedLoRA,
    MergedQKVParallelLinearWithLoRA,
    MergedQKVParallelLinearWithShardedLoRA,
    QKVParallelLinearWithLoRA,
    QKVParallelLinearWithShardedLoRA,
    ReplicatedLinearWithLoRA,
    RowParallelLinearWithLoRA,
    RowParallelLinearWithShardedLoRA,
    VocabParallelEmbeddingWithLoRA,
)
31
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
32
from vllm.lora.punica_wrapper import get_punica_wrapper
33
34
35
36
37
38
39
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    MergedColumnParallelLinear,
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
40
from vllm.model_executor.layers.logits_processor import LogitsProcessor
41
from vllm.model_executor.layers.vocab_parallel_embedding import (
42
43
44
45
    ParallelLMHead,
    VocabParallelEmbedding,
    get_masked_input_and_mask,
)
46
from vllm.platforms import current_platform
47
from vllm.utils.torch_utils import set_random_seed
48
49
50
51
52
53
54
55

from .utils import DummyLoRAManager

TOLERANCES = {
    torch.float16: (5e-3, 5e-3),
    torch.float32: (5e-3, 5e-3),
    torch.bfloat16: (3e-2, 2e-2),
}
56
57
58

pytestmark = pytest.mark.skipif(
    not (current_platform.is_cuda_alike() or current_platform.is_cpu()),
59
60
    reason="Backend not supported",
)
61

62
63
64
65
66
DEVICES = (
    [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
    if current_platform.is_cuda_alike()
    else ["cpu"]
)
67

68
# prefill stage(True) or decode stage(False)
69
STAGES = [True, False]
70

71
NUM_RANDOM_SEEDS = 2
72

73
VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS = 2
74
75
76


@pytest.fixture(autouse=True)
77
def clean_cache_reset_device(reset_default_device):
78
    # Release any memory we might be holding on to. CI runs OOMs otherwise.
79
80
    from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT

81
82
83
84
85
    _LORA_B_PTR_DICT.clear()
    _LORA_A_PTR_DICT.clear()

    yield

86

87
88
89
@pytest.fixture(autouse=True)
def skip_cuda_with_stage_false(request):
    """
90
    On cuda-like platforms, we use the same kernels for prefill and decode
91
92
93
94
95
    stage, and 'stage' is generally ignored, so we only need to test once.
    """
    if current_platform.is_cuda_alike():
        try:
            if hasattr(request.node, "callspec") and hasattr(
96
97
                request.node.callspec, "params"
            ):
98
99
100
101
102
103
104
105
                params = request.node.callspec.params
                if "stage" in params and params["stage"] is False:
                    pytest.skip("Skip test when stage=False")
        except Exception:
            pass
    yield


106
107
def get_random_id_to_index(
    num_loras: int, num_slots: int, log: bool = True
108
) -> list[int | None]:
109
110
111
112
113
114
115
116
117
118
119
120
    """Creates a random lora_id_to_index mapping.

    Args:
        num_loras: The number of active loras in the mapping.
        num_slots: The number of slots in the mapping. Must be larger
            than num_loras.
        log: Whether to log the output.
    """

    if num_loras > num_slots:
        raise ValueError(
            f"num_loras is higher than num_slots: {num_loras} > {num_slots}. "
121
122
            "num_loras must be less than or equal to num_slots."
        )
123

124
    slots: list[int | None] = [None] * num_slots
125
126
127
128
129
130
131
132
133
134
135
    random_slot_selections = (torch.randperm(num_slots)[:num_loras]).tolist()
    for lora_id, slot_idx in enumerate(random_slot_selections, start=1):
        slots[slot_idx] = lora_id

    if log:
        print(f"Created lora_id_to_index mapping: {slots}.")

    return slots


def populate_loras(
136
    id_to_index: list[int | None],
137
138
139
    layer: BaseLayerWithLoRA,
    layer_weights: torch.Tensor,
    repeats: int = 1,
140
) -> tuple[dict[int, LoRALayerWeights], dict[int, list[LoRALayerWeights]]]:
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
    """This method populates the lora layers with lora weights.

    Args:
        id_to_index: a list of lora ids. The index of the lora id
            represents which memory slot the lora matrices are
            stored in. A None value indicates a free slot.
        layer: the LoRAlayer to populate.
        layer_weights: the PyTorch tensor containing the layer's
            weights.
        repeats: must only be set for column parallel packed
            layers. Indicates the number of loras to compose
            together to create a single lora layer.
    """

    # Dictionary that maps the lora ID to the
    # corresponding lora weights.
157
    lora_dict: dict[int, LoRALayerWeights] = dict()
158
159

    # Dictionary that maps the lora ID to the
160
    # corresponding subloras.
161
    sublora_dict: dict[int, list[LoRALayerWeights]] = dict()
162
163
164

    for slot_idx, lora_id in enumerate(id_to_index):
        if lora_id is not None:
165
            subloras: list[LoRALayerWeights] = []
166
167
            sublora_len = layer_weights.shape[0] // repeats
            for i in range(repeats):
168
169
170
171
172
173
174
                sublora = DummyLoRAManager(layer_weights.device).init_random_lora(
                    module_name=f"fake_{i}",
                    weight=layer_weights,
                )
                sublora.lora_b = sublora.lora_b[
                    (sublora_len * i) : (sublora_len * (i + 1)), :
                ]
175
176
177
                sublora.optimize()
                subloras.append(sublora)

178
            lora = PackedLoRALayerWeights.pack(subloras) if repeats > 1 else subloras[0]
179
180
181
182
183
184
185
186
187
188
189
190
191
192

            layer.set_lora(
                slot_idx,
                lora_a=lora.lora_a,
                lora_b=lora.lora_b,
            )

            lora_dict[lora_id] = lora
            sublora_dict[lora_id] = subloras

    return lora_dict, sublora_dict


def create_random_inputs(
193
    active_lora_ids: list[int],
194
    num_inputs: int,
195
196
    input_size: tuple[int, ...],
    input_range: tuple[float, float],
197
    input_type: torch.dtype = torch.int,
198
    device: torch.device = "cuda",
199
) -> tuple[list[torch.Tensor], list[int], list[int]]:
200
201
202
203
204
205
206
207
208
209
210
211
212
    """Creates random inputs.

    Args:
        active_lora_ids: lora IDs of active lora weights.
        num_inputs: the number of inputs to create.
        input_size: the size of each individual input.
        input_range: the range of values to include in the input.
            input_range[0] <= possible input values < input_range[1]
        input_type: the type of values in the input.
    """

    low, high = input_range

213
214
215
    inputs: list[torch.Tensor] = []
    index_mapping: list[int] = []
    prompt_mapping: list[int] = []
216

217
218
219
    for _ in range(num_inputs):
        if input_type == torch.int:
            inputs.append(
220
221
222
223
                torch.randint(
                    low=int(low), high=int(high), size=input_size, device=device
                )
            )
224
225
        else:
            inputs.append(
226
227
228
                torch.rand(size=input_size, dtype=input_type, device=device) * high
                + low
            )
229
230
231
232
233
234
235
236

        lora_id = random.choice(active_lora_ids)
        index_mapping += [lora_id] * input_size[0]
        prompt_mapping += [lora_id]

    return inputs, index_mapping, prompt_mapping


237
238
239
240
241
def check_punica_wrapper(punica_wrapper) -> bool:
    if current_platform.is_cuda_alike():
        from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU

        return type(punica_wrapper) is PunicaWrapperGPU
242
243
244
245
    elif current_platform.is_cpu():
        from vllm.lora.punica_wrapper.punica_cpu import PunicaWrapperCPU

        return type(punica_wrapper) is PunicaWrapperCPU
246
247
248
249
    else:
        return False


250
@torch.inference_mode()
251
@pytest.mark.parametrize("num_loras", [1, 2, 4])
252
@pytest.mark.parametrize("device", DEVICES)
253
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
254
255
@pytest.mark.parametrize("stage", STAGES)
def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
256
257
258
    # For multi-GPU testing of Triton kernel, we must explicitly set the CUDA
    # device, see: https://github.com/triton-lang/triton/issues/2925
    # Same below.
259
260
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)
261

262
    torch.set_default_device(device)
263
    max_loras = 8
264
265
266
    lora_config = LoRAConfig(
        max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16
    )
267
268
    punica_wrapper = get_punica_wrapper(8192, 256, device, lora_config=lora_config)
    assert check_punica_wrapper(punica_wrapper)
269
270

    def create_random_embedding_layer():
271
        embedding = VocabParallelEmbedding(vocab_size, 256)
272
        embedding.weight.data = torch.rand_like(embedding.weight.data)
273
        embedding.weight.data[vocab_size:, :] = 0
274
275
276
277
278
        lora_embedding = VocabParallelEmbeddingWithLoRA(embedding)
        lora_embedding.create_lora_weights(max_loras, lora_config)

        return embedding, lora_embedding

279
    for i in range(NUM_RANDOM_SEEDS):
280
281
282
283
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
        embedding, lora_embedding = create_random_embedding_layer()
284
        lora_embedding.set_mapping(punica_wrapper)
285
286
287
288
289
290
291
292
293
        lora_dict, _ = populate_loras(
            id_to_index,
            layer=lora_embedding,
            layer_weights=embedding.weight.T,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=num_loras * 3,
294
            input_size=(200,),
295
            input_range=(1, vocab_size),
296
297
298
299
300
301
302
303
304
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
        punica_wrapper.update_metadata(
            lora_mapping,
            id_to_index,
            max_loras,
            vocab_size,
        )
305
306
307

        lora_result = lora_embedding(torch.cat(inputs))

308
        expected_results: list[torch.Tensor] = []
309
310
311
312
313
        for input_, lora_id in zip(inputs, prompt_mapping):
            lora = lora_dict[lora_id]
            result = embedding(input_)
            after_a = F.embedding(
                input_,
314
                lora.lora_a.T,
315
            )
316
            result += after_a @ lora.lora_b.T
317
318
319
320
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
321
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
322
323
324
325
326
327
328
329
330

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
            lora_embedding.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=num_loras * 3,
331
            input_size=(200,),
332
            input_range=(1, vocab_size),
333
334
335
336
337
338
339
340
341
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
        punica_wrapper.update_metadata(
            lora_mapping,
            id_to_index,
            max_loras,
            vocab_size,
        )
342
343
344
345
346

        lora_result = lora_embedding(torch.cat(inputs))
        expected_result = embedding(torch.cat(inputs))

        rtol, atol = TOLERANCES[lora_result.dtype]
347
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
348
349
350


@torch.inference_mode()
351
@pytest.mark.parametrize("num_loras", [1, 2, 4])
352
@pytest.mark.parametrize("device", DEVICES)
353
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 256512])
354
@pytest.mark.parametrize("stage", STAGES)
355
356
357
def test_lm_head_logits_processor(
    dist_init, num_loras, device, vocab_size, stage
) -> None:
358
359
360
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)

361
    torch.set_default_device(device)
362
    max_loras = 8
363
364
365
    lora_config = LoRAConfig(
        max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16
    )
366
367
    punica_wrapper = get_punica_wrapper(8192, 256, device, lora_config=lora_config)
    assert check_punica_wrapper(punica_wrapper)
368

369
    def _pretest():
370
        linear = ParallelLMHead(
371
372
            num_embeddings=vocab_size,
            embedding_dim=1024,
373
374
            params_dtype=torch.float16,
        )
375
        linear.weight.data = torch.rand_like(linear.weight.data)
376
        linear.weight.data[:, vocab_size:] = 0
377
        logits_processor = LogitsProcessor(vocab_size)
378
        lora_logits_processor = LogitsProcessorWithLoRA(
379
380
            logits_processor, 1024, linear.weight.dtype, linear.weight.device, None
        )
381
        lora_logits_processor.create_lora_weights(max_loras, lora_config)
382

383
        return linear, logits_processor, lora_logits_processor
384

385
    for i in range(NUM_RANDOM_SEEDS):
386
387
388
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
389
        linear, logits_processor, lora_logits_processor = _pretest()
390
        lora_logits_processor.set_mapping(punica_wrapper)
391

392
393
        lora_dict, _ = populate_loras(
            id_to_index,
394
            layer=lora_logits_processor,
395
396
397
398
399
400
401
402
            layer_weights=linear.weight,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=8 * num_loras,  # * 3,
            input_size=(1, 1024),
            input_range=(0, 1),
403
            input_type=torch.float16,
404
405
406
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
407
        punica_wrapper.update_metadata(
408
409
410
            lora_mapping,
            id_to_index,
            max_loras,
411
            vocab_size,
412
        )
413
        input_ = torch.rand(20, 1024)
414

415
        lora_result = lora_logits_processor._get_logits(
416
417
            hidden_states=torch.cat(inputs), lm_head=linear, embedding_bias=None
        )
418

419
        original_lm_head = deepcopy(linear)
420

421
        expected_results: list[torch.Tensor] = []
422
423
        for input_, lora_id in zip(inputs, prompt_mapping):
            lora = lora_dict[lora_id]
424
425
426
            result = logits_processor._get_logits(
                hidden_states=input_, lm_head=linear, embedding_bias=None
            )
427

428
            result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
429
430
431
432
433
434
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
435
            lora_logits_processor.reset_lora(slot_idx)
436
437
438
439
440
441

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=8 * num_loras * 3,
            input_size=(1, 1024),
            input_range=(0, 1),
442
            input_type=torch.float16,
443
444
445
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
446
447
448
449
450
451
        punica_wrapper.update_metadata(
            lora_mapping,
            id_to_index,
            max_loras,
            vocab_size,
        )
452
453
454

        lora_result = lora_logits_processor._get_logits(
            hidden_states=torch.cat(inputs),
455
            lm_head=original_lm_head,
456
457
            embedding_bias=None,
        )[:, :vocab_size]
458
459
        expected_result = logits_processor._get_logits(
            hidden_states=torch.cat(inputs),
460
            lm_head=original_lm_head,
461
462
            embedding_bias=None,
        )
463
464

        rtol, atol = TOLERANCES[lora_result.dtype]
465
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
466
467


468
@torch.inference_mode()
469
@pytest.mark.parametrize("num_loras", [1, 2, 4])
470
@pytest.mark.parametrize("device", DEVICES)
471
@pytest.mark.parametrize("stage", STAGES)
472
473
474
475
476
477
def test_linear_replicated(
    dist_init,
    num_loras,
    device,
    stage,
) -> None:
478
479
480
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)

481
    max_loras = 8
482
    torch.set_default_device(device)
483
484
485
486
487
    lora_config = LoRAConfig(
        max_loras=max_loras,
        max_lora_rank=8,
        lora_dtype=torch.float16,
    )
488
489
    punica_wrapper = get_punica_wrapper(8192, 256, device, lora_config=lora_config)
    assert check_punica_wrapper(punica_wrapper)
490
491

    def create_random_linear_replicated_layer():
492
        linear = ReplicatedLinear(4096, 4096, bias=False, params_dtype=torch.float16)
493
494
495
496
        linear.weight.data = torch.rand_like(linear.weight.data)
        lora_linear = ReplicatedLinearWithLoRA(linear)

        lora_linear.create_lora_weights(max_loras, lora_config)
497
498
499
500
501
502
        assert (
            lora_linear.n_slices
            == len(lora_linear.lora_a_stacked)
            == len(lora_linear.lora_b_stacked)
            == 1
        )
503
504
        return linear, lora_linear

505
    for i in range(NUM_RANDOM_SEEDS):
506
507
508
509
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
        linear, lora_linear = create_random_linear_replicated_layer()
510
        assert torch.equal(linear.weight, lora_linear.weight)
511
512
513
514
515
516
517
518
519
520
521
522
523
        lora_linear.set_mapping(punica_wrapper)
        lora_dict, _ = populate_loras(
            id_to_index,
            layer=lora_linear,
            layer_weights=linear.weight,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
            input_type=torch.float16,
524
525
526
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
527
528
529
530
531
532
533
534
535
        punica_wrapper.update_metadata(
            lora_mapping,
            id_to_index,
            max_loras,
            512,
        )

        lora_result = lora_linear(torch.cat(inputs))[0]

536
        expected_results: list[torch.Tensor] = []
537
538
539
        for input_, lora_id in zip(inputs, prompt_mapping):
            lora = lora_dict[lora_id]
            result = linear(input_)[0]
540
            result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
541
542
543
544
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
545
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
546
547
548
549
550
551
552
553
554
555
556
557

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
            lora_linear.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
            input_type=torch.float16,
558
559
560
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
561

562
        punica_wrapper.update_metadata(
563
564
565
566
            lora_mapping,
            id_to_index,
            max_loras,
            512,
567
        )
568
569
570
571
572

        lora_result = lora_linear(torch.cat(inputs))[0]
        expected_result = linear(torch.cat(inputs))[0]

        rtol, atol = TOLERANCES[lora_result.dtype]
573
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
574
575


576
@torch.inference_mode()
577
@pytest.mark.parametrize("num_loras", [1, 2, 4])
578
@pytest.mark.parametrize("orientation", ["row", "column"])
579
@pytest.mark.parametrize("fully_shard", [True, False])
580
@pytest.mark.parametrize("device", DEVICES)
581
@pytest.mark.parametrize("stage", STAGES)
582
583
584
def test_linear_parallel(
    dist_init, num_loras, orientation, fully_shard, device, stage
) -> None:
585
586
587
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)

588
    max_loras = 8
589
    torch.set_default_device(device)
590
591
592
593
594
595
    lora_config = LoRAConfig(
        max_loras=max_loras,
        max_lora_rank=8,
        fully_sharded_loras=fully_shard,
        lora_dtype=torch.float16,
    )
596
597
    punica_wrapper = get_punica_wrapper(8192, 256, device, lora_config=lora_config)
    assert check_punica_wrapper(punica_wrapper)
598
599
600

    def create_random_linear_parallel_layer():
        if orientation == "row":
601
602
603
            linear = RowParallelLinear(
                4096, 4096, bias=False, params_dtype=torch.float16
            )
604
            linear.weight.data = torch.rand_like(linear.weight.data)
605
606
607
608
609
            lora_linear = (
                RowParallelLinearWithLoRA(linear)
                if not fully_shard
                else RowParallelLinearWithShardedLoRA(linear)
            )
610
        else:
611
612
613
            linear = ColumnParallelLinear(
                4096, 4096, bias=False, params_dtype=torch.float16
            )
614
            linear.weight.data = torch.rand_like(linear.weight.data)
615
616
617
618
619
            lora_linear = (
                ColumnParallelLinearWithLoRA(linear)
                if not fully_shard
                else ColumnParallelLinearWithShardedLoRA(linear)
            )
620
        lora_linear.create_lora_weights(max_loras, lora_config)
621
622
623
624
625
626
        assert (
            lora_linear.n_slices
            == len(lora_linear.lora_a_stacked)
            == len(lora_linear.lora_b_stacked)
            == 1
        )
627

628
629
        return linear, lora_linear

630
    for i in range(NUM_RANDOM_SEEDS):
631
632
633
634
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
        linear, lora_linear = create_random_linear_parallel_layer()
635
        assert torch.equal(linear.weight, lora_linear.weight)
636
        lora_linear.set_mapping(punica_wrapper)
637
638
639
640
641
642
643
644
645
646
647
        lora_dict, _ = populate_loras(
            id_to_index,
            layer=lora_linear,
            layer_weights=linear.weight,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
648
            input_type=torch.float16,
649
650
651
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
652
        punica_wrapper.update_metadata(
653
654
655
656
657
658
659
660
            lora_mapping,
            id_to_index,
            max_loras,
            512,
        )

        lora_result = lora_linear(torch.cat(inputs))[0]

661
        expected_results: list[torch.Tensor] = []
662
663
664
        for input_, lora_id in zip(inputs, prompt_mapping):
            lora = lora_dict[lora_id]
            result = linear(input_)[0]
665
            result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
666
667
668
669
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
670
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
671
672
673
674
675
676
677
678
679
680
681

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
            lora_linear.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
682
            input_type=torch.float16,
683
684
685
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
686

687
        punica_wrapper.update_metadata(
688
689
690
691
            lora_mapping,
            id_to_index,
            max_loras,
            512,
692
        )
693
694
695
696
697

        lora_result = lora_linear(torch.cat(inputs))[0]
        expected_result = linear(torch.cat(inputs))[0]

        rtol, atol = TOLERANCES[lora_result.dtype]
698
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
699
700
701


@torch.inference_mode()
702
@pytest.mark.parametrize("num_loras", [1, 2, 4])
703
@pytest.mark.parametrize("repeats", [1, 2, 3])
704
@pytest.mark.parametrize("fully_shard", [True, False])
705
@pytest.mark.parametrize("device", DEVICES)
706
@pytest.mark.parametrize("stage", STAGES)
707
708
709
def test_column_parallel_packed(
    dist_init, num_loras, repeats, fully_shard, device, stage
) -> None:
710
711
712
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)

713
    max_loras = 8
714
    torch.set_default_device(device)
715
716
717
718
719
720
    lora_config = LoRAConfig(
        max_loras=max_loras,
        max_lora_rank=8,
        fully_sharded_loras=fully_shard,
        lora_dtype=torch.float16,
    )
721
722
    punica_wrapper = get_punica_wrapper(8192, 256, device, lora_config=lora_config)
    assert check_punica_wrapper(punica_wrapper)
723
724
725

    def create_column_parallel_packed_layer():
        if repeats == 2:
726
727
728
            linear = MergedColumnParallelLinear(
                4096, [4096] * repeats, bias=False, params_dtype=torch.float16
            )
729
            linear.weight.data = torch.rand_like(linear.weight.data)
730
731
732
733
734
            lora_linear = (
                MergedColumnParallelLinearWithLoRA(linear)
                if not fully_shard
                else MergedColumnParallelLinearWithShardedLoRA(linear)
            )
735
        elif repeats == 3:
736
737
738
            linear = QKVParallelLinear(
                4096, 64, 32, bias=False, params_dtype=torch.float16
            )
739
            linear.weight.data = torch.rand_like(linear.weight.data)
740
741
742
743
744
            lora_linear = (
                MergedQKVParallelLinearWithLoRA(linear)
                if not fully_shard
                else MergedQKVParallelLinearWithShardedLoRA(linear)
            )
745
        else:
746
747
748
            linear = QKVParallelLinear(
                4096, 64, 32, bias=False, params_dtype=torch.float16
            )
749
            linear.weight.data = torch.rand_like(linear.weight.data)
750
751
752
753
754
            lora_linear = (
                QKVParallelLinearWithLoRA(linear)
                if not fully_shard
                else QKVParallelLinearWithShardedLoRA(linear)
            )
755
756
757
758
759
760
761

        @dataclass
        class FakeConfig:
            hidden_size = 4096
            num_key_value_heads = 32
            num_attention_heads = 32

762
        n_slices = repeats
763
764
765
766
767
768
769
770
771
        lora_linear.create_lora_weights(
            max_loras, lora_config, model_config=FakeConfig()
        )
        assert (
            lora_linear.n_slices
            == len(lora_linear.lora_a_stacked)
            == len(lora_linear.lora_b_stacked)
            == n_slices
        )
772

773
774
        return linear, lora_linear

775
    for i in range(NUM_RANDOM_SEEDS):
776
777
778
779
780
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)

        linear, lora_linear = create_column_parallel_packed_layer()
781
        assert torch.equal(linear.weight, lora_linear.weight)
782
        lora_linear.set_mapping(punica_wrapper)
783
784
785
786
787
788
789
790
791
792
793
794
        lora_dict, sublora_dict = populate_loras(
            id_to_index,
            layer=lora_linear,
            layer_weights=linear.weight,
            repeats=repeats,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
795
            input_type=torch.float16,
796
797
798
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
799

800
        punica_wrapper.update_metadata(
801
802
803
804
805
806
807
808
            lora_mapping,
            id_to_index,
            max_loras,
            512,
        )

        lora_result = lora_linear(torch.cat(inputs))[0]

809
        expected_results: list[torch.Tensor] = []
810
811
812
813
        for input_, lora_id in zip(inputs, prompt_mapping):
            result = linear(input_)[0]
            subloras = sublora_dict[lora_id]
            for i, sublora in enumerate(subloras):
814
815
816
                result[
                    :, sublora.lora_b.shape[0] * i : sublora.lora_b.shape[0] * (i + 1)
                ] += input_ @ sublora.lora_a.T @ sublora.lora_b.T * sublora.scaling
817
818
819
820
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
821
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
822
823
824
825
826
827
828
829
830

        for slot_idx in range(max_loras):
            lora_linear.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
831
            input_type=torch.float16,
832
833
834
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
835

836
        punica_wrapper.update_metadata(
837
838
839
840
841
842
843
844
845
846
            lora_mapping,
            id_to_index,
            max_loras,
            512,
        )

        lora_result = lora_linear(torch.cat(inputs))[0]
        expected_result = linear(torch.cat(inputs))[0]

        rtol, atol = TOLERANCES[lora_result.dtype]
847
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
848
849


850
@pytest.mark.parametrize("tp_size", [1, 2, 4, 8])
851
@pytest.mark.parametrize(
852
853
    "seed", list(range(VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS))
)
854
855
856
857
858
859
860
861
862
863
864
865
def test_vocab_parallel_embedding_indices(tp_size, seed):
    random.seed(seed)
    vocab_size = random.randint(4000, 64000)
    added_vocab_size = random.randint(0, 1024)
    org_vocab_size = vocab_size - added_vocab_size
    last_org_vocab_end_index = 0
    last_added_vocab_end_index = org_vocab_size
    computed_vocab_size = 0
    computed_org_vocab_size = 0
    computed_added_vocab_size = 0
    vocab_size_padded = -1

866
867
868
    all_org_tokens: list[int] = []
    all_added_tokens: list[int] = []
    token_ids: list[int] = []
869
870

    for tp_rank in range(tp_size):
871
872
        with (
            patch(
873
                "vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank",
874
875
876
                return_value=tp_rank,
            ),
            patch(
877
                "vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_world_size",
878
879
880
                return_value=tp_size,
            ),
        ):
881
            vocab_embedding = VocabParallelEmbedding(
882
883
                vocab_size, 1, org_num_embeddings=org_vocab_size
            )
884
885
886
887
        vocab_size_padded = vocab_embedding.num_embeddings_padded
        shard_indices = vocab_embedding.shard_indices
        # Assert that the ranges are contiguous
        assert shard_indices.org_vocab_start_index == last_org_vocab_end_index
888
        assert shard_indices.added_vocab_start_index == last_added_vocab_end_index
889
890
891
892
893
894
895
896

        # Ensure that we are not exceeding the vocab size
        computed_vocab_size += shard_indices.num_elements_padded
        computed_org_vocab_size += shard_indices.num_org_elements
        computed_added_vocab_size += shard_indices.num_added_elements

        # Ensure that the ranges are not overlapping
        all_org_tokens.extend(
897
898
899
900
            range(
                shard_indices.org_vocab_start_index, shard_indices.org_vocab_end_index
            )
        )
901
        all_added_tokens.extend(
902
903
904
905
906
            range(
                shard_indices.added_vocab_start_index,
                shard_indices.added_vocab_end_index,
            )
        )
907
908

        token_ids.extend(
909
910
911
912
913
914
915
916
917
918
919
920
921
922
            range(
                shard_indices.org_vocab_start_index, shard_indices.org_vocab_end_index
            )
        )
        token_ids.extend(
            [-1]
            * (shard_indices.num_org_elements_padded - shard_indices.num_org_elements)
        )
        token_ids.extend(
            range(
                shard_indices.added_vocab_start_index,
                shard_indices.added_vocab_end_index,
            )
        )
923
        token_ids.extend(
924
925
926
927
928
929
            [-1]
            * (
                shard_indices.num_added_elements_padded
                - shard_indices.num_added_elements
            )
        )
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956

        last_org_vocab_end_index = shard_indices.org_vocab_end_index
        last_added_vocab_end_index = shard_indices.added_vocab_end_index

    assert computed_vocab_size == vocab_size_padded
    assert computed_org_vocab_size == org_vocab_size
    assert computed_added_vocab_size == added_vocab_size

    # Ensure that the ranges are not overlapping
    assert len(all_org_tokens) == len(set(all_org_tokens))
    assert len(all_added_tokens) == len(set(all_added_tokens))
    assert not set(all_org_tokens).intersection(set(all_added_tokens))

    token_ids_tensor = torch.tensor(token_ids, dtype=torch.long)
    reindex_mapping = vocab_embedding.get_sharded_to_full_mapping()
    assert reindex_mapping is not None or tp_size == 1
    if reindex_mapping is not None:
        reindexed_token_ids = token_ids_tensor[reindex_mapping]
        expected = torch.tensor(list(range(0, vocab_size)))
        assert reindexed_token_ids[:vocab_size].equal(expected)
        assert torch.all(reindexed_token_ids[vocab_size:] == -1)


def test_get_masked_input_and_mask():
    x = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])

    # base tp 1 case, no padding
957
958
959
960
961
962
963
964
    modified_x, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=0,
        org_vocab_end_index=8,
        added_vocab_start_index=8,
        added_vocab_end_index=12,
        num_org_vocab_padding=0,
    )
965
966
967
    assert torch.equal(x, modified_x)

    # tp 2 case, no padding
968
969
970
971
972
973
974
975
    modified_x_rank_0, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=0,
        org_vocab_end_index=4,
        added_vocab_start_index=8,
        added_vocab_end_index=10,
        num_org_vocab_padding=0,
    )
976
977
978
979
980
981
    modified_x_rank_1, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=4,
        org_vocab_end_index=8,
        added_vocab_start_index=10,
        added_vocab_end_index=12,
982
983
984
985
986
987
988
989
        num_org_vocab_padding=0,
    )
    assert torch.equal(
        modified_x_rank_0, torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 0, 0])
    )
    assert torch.equal(
        modified_x_rank_1, torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 4, 5])
    )
990
991

    # tp 4 case, no padding
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
    modified_x_rank_0, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=0,
        org_vocab_end_index=2,
        added_vocab_start_index=8,
        added_vocab_end_index=9,
        num_org_vocab_padding=0,
    )
    modified_x_rank_1, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=2,
        org_vocab_end_index=4,
        added_vocab_start_index=9,
        added_vocab_end_index=10,
        num_org_vocab_padding=0,
    )
1008
1009
1010
1011
1012
1013
    modified_x_rank_2, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=4,
        org_vocab_end_index=6,
        added_vocab_start_index=10,
        added_vocab_end_index=11,
1014
1015
        num_org_vocab_padding=0,
    )
1016
1017
1018
1019
1020
1021
    modified_x_rank_3, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=6,
        org_vocab_end_index=8,
        added_vocab_start_index=11,
        added_vocab_end_index=12,
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
        num_org_vocab_padding=0,
    )
    assert torch.equal(
        modified_x_rank_0, torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0])
    )
    assert torch.equal(
        modified_x_rank_1, torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0])
    )
    assert torch.equal(
        modified_x_rank_2, torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0])
    )
    assert torch.equal(
        modified_x_rank_3, torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2])
    )
1036
1037

    # base tp 1 case, with padding
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
    modified_x, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=0,
        org_vocab_end_index=8,
        added_vocab_start_index=8,
        added_vocab_end_index=12,
        num_org_vocab_padding=2,
    )
    assert torch.equal(
        modified_x, torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13])
    )
1049
1050

    # tp 2 case, with padding
1051
1052
1053
1054
1055
1056
1057
1058
    modified_x_rank_0, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=0,
        org_vocab_end_index=4,
        added_vocab_start_index=8,
        added_vocab_end_index=10,
        num_org_vocab_padding=2,
    )
1059
1060
1061
1062
1063
1064
    modified_x_rank_1, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=4,
        org_vocab_end_index=8,
        added_vocab_start_index=10,
        added_vocab_end_index=12,
1065
1066
1067
1068
1069
1070
1071
1072
        num_org_vocab_padding=2,
    )
    assert torch.equal(
        modified_x_rank_0, torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 6, 7, 0, 0])
    )
    assert torch.equal(
        modified_x_rank_1, torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 6, 7])
    )
1073
1074

    # tp 4 case, with padding
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
    modified_x_rank_0, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=0,
        org_vocab_end_index=2,
        added_vocab_start_index=8,
        added_vocab_end_index=9,
        num_org_vocab_padding=2,
    )
    modified_x_rank_1, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=2,
        org_vocab_end_index=4,
        added_vocab_start_index=9,
        added_vocab_end_index=10,
        num_org_vocab_padding=2,
    )
1091
1092
1093
1094
1095
1096
    modified_x_rank_2, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=4,
        org_vocab_end_index=6,
        added_vocab_start_index=10,
        added_vocab_end_index=11,
1097
1098
        num_org_vocab_padding=2,
    )
1099
1100
1101
1102
1103
1104
    modified_x_rank_3, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=6,
        org_vocab_end_index=8,
        added_vocab_start_index=11,
        added_vocab_end_index=12,
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
        num_org_vocab_padding=2,
    )
    assert torch.equal(
        modified_x_rank_0, torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0])
    )
    assert torch.equal(
        modified_x_rank_1, torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 4, 0, 0])
    )
    assert torch.equal(
        modified_x_rank_2, torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 4, 0])
    )
    assert torch.equal(
        modified_x_rank_3, torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 4])
    )