test_layers.py 61 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
import random
from copy import deepcopy
from dataclasses import dataclass
7
from unittest.mock import patch
8

9
import pytest
10
11
12
import torch
import torch.nn.functional as F

13
from vllm.config.lora import LoRAConfig
14
15
16
17
18
19
from vllm.lora.layers import (
    BaseLayerWithLoRA,
    ColumnParallelLinearWithLoRA,
    ColumnParallelLinearWithShardedLoRA,
    LogitsProcessorWithLoRA,
    LoRAMapping,
20
    MergedColumnParallelLinearVariableSliceWithLoRA,
21
22
23
24
25
26
27
28
29
30
31
    MergedColumnParallelLinearWithLoRA,
    MergedColumnParallelLinearWithShardedLoRA,
    MergedQKVParallelLinearWithLoRA,
    MergedQKVParallelLinearWithShardedLoRA,
    QKVParallelLinearWithLoRA,
    QKVParallelLinearWithShardedLoRA,
    ReplicatedLinearWithLoRA,
    RowParallelLinearWithLoRA,
    RowParallelLinearWithShardedLoRA,
    VocabParallelEmbeddingWithLoRA,
)
32
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
33
from vllm.lora.punica_wrapper import get_punica_wrapper
34
35
36
37
38
39
40
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    MergedColumnParallelLinear,
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
41
from vllm.model_executor.layers.logits_processor import LogitsProcessor
42
from vllm.model_executor.layers.vocab_parallel_embedding import (
43
44
45
46
    ParallelLMHead,
    VocabParallelEmbedding,
    get_masked_input_and_mask,
)
47
from vllm.model_executor.models.deepseek_v2 import DeepSeekV2FusedQkvAProjLinear
48
from vllm.platforms import current_platform
49
from vllm.utils.torch_utils import set_random_seed
50
51
52
53
54
55
56
57

from .utils import DummyLoRAManager

TOLERANCES = {
    torch.float16: (5e-3, 5e-3),
    torch.float32: (5e-3, 5e-3),
    torch.bfloat16: (3e-2, 2e-2),
}
58
59
60

pytestmark = pytest.mark.skipif(
    not (current_platform.is_cuda_alike() or current_platform.is_cpu()),
61
62
    reason="Backend not supported",
)
63

64
DEVICE_TYPE = current_platform.device_type
65
DEVICES = (
66
67
68
69
    [
        f"{DEVICE_TYPE}:{i}"
        for i in range(1 if torch.accelerator.device_count() == 1 else 2)
    ]
70
71
72
    if current_platform.is_cuda_alike()
    else ["cpu"]
)
73

74
# prefill stage(True) or decode stage(False)
75
STAGES = [True, False]
76

77
NUM_RANDOM_SEEDS = 2
78

79
VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS = 2
80
81
82


@pytest.fixture(autouse=True)
83
def clean_cache_reset_device(reset_default_device):
84
    # Release any memory we might be holding on to. CI runs OOMs otherwise.
85
86
    from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT

87
88
89
90
91
    _LORA_B_PTR_DICT.clear()
    _LORA_A_PTR_DICT.clear()

    yield

92

93
94
95
@pytest.fixture(autouse=True)
def skip_cuda_with_stage_false(request):
    """
96
    On cuda-like platforms, we use the same kernels for prefill and decode
97
98
99
100
101
    stage, and 'stage' is generally ignored, so we only need to test once.
    """
    if current_platform.is_cuda_alike():
        try:
            if hasattr(request.node, "callspec") and hasattr(
102
103
                request.node.callspec, "params"
            ):
104
105
106
107
108
109
110
111
                params = request.node.callspec.params
                if "stage" in params and params["stage"] is False:
                    pytest.skip("Skip test when stage=False")
        except Exception:
            pass
    yield


112
113
def get_random_id_to_index(
    num_loras: int, num_slots: int, log: bool = True
114
) -> list[int | None]:
115
116
117
118
119
120
121
122
123
124
125
126
    """Creates a random lora_id_to_index mapping.

    Args:
        num_loras: The number of active loras in the mapping.
        num_slots: The number of slots in the mapping. Must be larger
            than num_loras.
        log: Whether to log the output.
    """

    if num_loras > num_slots:
        raise ValueError(
            f"num_loras is higher than num_slots: {num_loras} > {num_slots}. "
127
128
            "num_loras must be less than or equal to num_slots."
        )
129

130
    slots: list[int | None] = [None] * num_slots
131
132
133
134
135
136
137
138
139
140
141
    random_slot_selections = (torch.randperm(num_slots)[:num_loras]).tolist()
    for lora_id, slot_idx in enumerate(random_slot_selections, start=1):
        slots[slot_idx] = lora_id

    if log:
        print(f"Created lora_id_to_index mapping: {slots}.")

    return slots


def populate_loras(
142
    id_to_index: list[int | None],
143
144
145
    layer: BaseLayerWithLoRA,
    layer_weights: torch.Tensor,
    repeats: int = 1,
146
) -> tuple[dict[int, LoRALayerWeights], dict[int, list[LoRALayerWeights]]]:
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
    """This method populates the lora layers with lora weights.

    Args:
        id_to_index: a list of lora ids. The index of the lora id
            represents which memory slot the lora matrices are
            stored in. A None value indicates a free slot.
        layer: the LoRAlayer to populate.
        layer_weights: the PyTorch tensor containing the layer's
            weights.
        repeats: must only be set for column parallel packed
            layers. Indicates the number of loras to compose
            together to create a single lora layer.
    """

    # Dictionary that maps the lora ID to the
    # corresponding lora weights.
163
    lora_dict: dict[int, LoRALayerWeights] = dict()
164
165

    # Dictionary that maps the lora ID to the
166
    # corresponding subloras.
167
    sublora_dict: dict[int, list[LoRALayerWeights]] = dict()
168
169
170

    for slot_idx, lora_id in enumerate(id_to_index):
        if lora_id is not None:
171
            subloras: list[LoRALayerWeights] = []
172
173
            sublora_len = layer_weights.shape[0] // repeats
            for i in range(repeats):
174
175
176
177
178
179
180
                sublora = DummyLoRAManager(layer_weights.device).init_random_lora(
                    module_name=f"fake_{i}",
                    weight=layer_weights,
                )
                sublora.lora_b = sublora.lora_b[
                    (sublora_len * i) : (sublora_len * (i + 1)), :
                ]
181
182
183
                sublora.optimize()
                subloras.append(sublora)

184
            lora = PackedLoRALayerWeights.pack(subloras) if repeats > 1 else subloras[0]
185
186
187
188
189
190
191
192
193
194
195
196
197
198

            layer.set_lora(
                slot_idx,
                lora_a=lora.lora_a,
                lora_b=lora.lora_b,
            )

            lora_dict[lora_id] = lora
            sublora_dict[lora_id] = subloras

    return lora_dict, sublora_dict


def create_random_inputs(
199
    active_lora_ids: list[int],
200
    num_inputs: int,
201
202
    input_size: tuple[int, ...],
    input_range: tuple[float, float],
203
    input_type: torch.dtype = torch.int,
204
    device: torch.device = DEVICE_TYPE,
205
) -> tuple[list[torch.Tensor], list[int], list[int]]:
206
207
208
209
210
211
212
213
214
215
216
217
218
    """Creates random inputs.

    Args:
        active_lora_ids: lora IDs of active lora weights.
        num_inputs: the number of inputs to create.
        input_size: the size of each individual input.
        input_range: the range of values to include in the input.
            input_range[0] <= possible input values < input_range[1]
        input_type: the type of values in the input.
    """

    low, high = input_range

219
220
221
    inputs: list[torch.Tensor] = []
    index_mapping: list[int] = []
    prompt_mapping: list[int] = []
222

223
224
225
    for _ in range(num_inputs):
        if input_type == torch.int:
            inputs.append(
226
227
228
229
                torch.randint(
                    low=int(low), high=int(high), size=input_size, device=device
                )
            )
230
231
        else:
            inputs.append(
232
233
234
                torch.rand(size=input_size, dtype=input_type, device=device) * high
                + low
            )
235
236
237
238
239
240
241
242

        lora_id = random.choice(active_lora_ids)
        index_mapping += [lora_id] * input_size[0]
        prompt_mapping += [lora_id]

    return inputs, index_mapping, prompt_mapping


243
244
245
246
247
def check_punica_wrapper(punica_wrapper) -> bool:
    if current_platform.is_cuda_alike():
        from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU

        return type(punica_wrapper) is PunicaWrapperGPU
248
249
250
251
    elif current_platform.is_cpu():
        from vllm.lora.punica_wrapper.punica_cpu import PunicaWrapperCPU

        return type(punica_wrapper) is PunicaWrapperCPU
252
253
254
255
    else:
        return False


256
@torch.inference_mode()
257
@pytest.mark.parametrize("num_loras", [1, 2, 4])
258
@pytest.mark.parametrize("device", DEVICES)
259
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
260
@pytest.mark.parametrize("stage", STAGES)
261
262
263
def test_embeddings(
    default_vllm_config, dist_init, num_loras, device, vocab_size, stage
) -> None:
264
265
266
    # For multi-GPU testing of Triton kernel, we must explicitly set the CUDA
    # device, see: https://github.com/triton-lang/triton/issues/2925
    # Same below.
267
    if current_platform.is_cuda_alike():
268
        torch.accelerator.set_device_index(device)
269

270
    torch.set_default_device(device)
271
    max_loras = 8
272
273
274
    lora_config = LoRAConfig(
        max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16
    )
275
276
    punica_wrapper = get_punica_wrapper(8192, 256, device, lora_config=lora_config)
    assert check_punica_wrapper(punica_wrapper)
277
278

    def create_random_embedding_layer():
279
        embedding = VocabParallelEmbedding(vocab_size, 256)
280
        embedding.weight.data = torch.rand_like(embedding.weight.data)
281
        embedding.weight.data[vocab_size:, :] = 0
282
283
284
285
286
        lora_embedding = VocabParallelEmbeddingWithLoRA(embedding)
        lora_embedding.create_lora_weights(max_loras, lora_config)

        return embedding, lora_embedding

287
    for i in range(NUM_RANDOM_SEEDS):
288
289
290
291
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
        embedding, lora_embedding = create_random_embedding_layer()
292
        lora_embedding.set_mapping(punica_wrapper)
293
294
295
296
297
298
299
300
301
        lora_dict, _ = populate_loras(
            id_to_index,
            layer=lora_embedding,
            layer_weights=embedding.weight.T,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=num_loras * 3,
302
            input_size=(200,),
303
            input_range=(1, vocab_size),
304
305
306
307
308
309
310
311
312
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
        punica_wrapper.update_metadata(
            lora_mapping,
            id_to_index,
            max_loras,
            vocab_size,
        )
313
314
315

        lora_result = lora_embedding(torch.cat(inputs))

316
        expected_results: list[torch.Tensor] = []
317
318
319
320
321
        for input_, lora_id in zip(inputs, prompt_mapping):
            lora = lora_dict[lora_id]
            result = embedding(input_)
            after_a = F.embedding(
                input_,
322
                lora.lora_a.T,
323
            )
324
            result += after_a @ lora.lora_b.T
325
326
327
328
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
329
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
330
331
332
333
334
335
336
337
338

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
            lora_embedding.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=num_loras * 3,
339
            input_size=(200,),
340
            input_range=(1, vocab_size),
341
342
343
344
345
346
347
348
349
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
        punica_wrapper.update_metadata(
            lora_mapping,
            id_to_index,
            max_loras,
            vocab_size,
        )
350
351
352
353
354

        lora_result = lora_embedding(torch.cat(inputs))
        expected_result = embedding(torch.cat(inputs))

        rtol, atol = TOLERANCES[lora_result.dtype]
355
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
356
357
358


@torch.inference_mode()
359
@pytest.mark.parametrize("num_loras", [1, 2, 4])
360
@pytest.mark.parametrize("device", DEVICES)
361
@pytest.mark.parametrize("vocab_size", [64000, 256512, 258048])
362
@pytest.mark.parametrize("stage", STAGES)
363
def test_lm_head_logits_processor(
364
    default_vllm_config, dist_init, num_loras, device, vocab_size, stage
365
) -> None:
366
    if current_platform.is_cuda_alike():
367
        torch.accelerator.set_device_index(device)
368

369
    torch.set_default_device(device)
370
    max_loras = 8
371
372
373
    lora_config = LoRAConfig(
        max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16
    )
374
375
    punica_wrapper = get_punica_wrapper(8192, 256, device, lora_config=lora_config)
    assert check_punica_wrapper(punica_wrapper)
376

377
    def _pretest():
378
        linear = ParallelLMHead(
379
380
            num_embeddings=vocab_size,
            embedding_dim=1024,
381
382
            params_dtype=torch.float16,
        )
383
        linear.weight.data = torch.rand_like(linear.weight.data)
384
        linear.weight.data[:, vocab_size:] = 0
385
        logits_processor = LogitsProcessor(vocab_size)
386
        lora_logits_processor = LogitsProcessorWithLoRA(
387
388
            logits_processor, 1024, linear.weight.dtype, linear.weight.device, None
        )
389
        lora_logits_processor.create_lora_weights(max_loras, lora_config)
390

391
        return linear, logits_processor, lora_logits_processor
392

393
    for i in range(NUM_RANDOM_SEEDS):
394
395
396
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
397
        linear, logits_processor, lora_logits_processor = _pretest()
398
        lora_logits_processor.set_mapping(punica_wrapper)
399

400
401
        lora_dict, _ = populate_loras(
            id_to_index,
402
            layer=lora_logits_processor,
403
404
405
406
407
408
409
410
            layer_weights=linear.weight,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=8 * num_loras,  # * 3,
            input_size=(1, 1024),
            input_range=(0, 1),
411
            input_type=torch.float16,
412
413
414
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
415
        punica_wrapper.update_metadata(
416
417
418
            lora_mapping,
            id_to_index,
            max_loras,
419
            vocab_size,
420
        )
421
        input_ = torch.rand(20, 1024)
422

423
        lora_result = lora_logits_processor._get_logits(
424
425
            hidden_states=torch.cat(inputs), lm_head=linear, embedding_bias=None
        )
426

427
        original_lm_head = deepcopy(linear)
428

429
        expected_results: list[torch.Tensor] = []
430
431
        for input_, lora_id in zip(inputs, prompt_mapping):
            lora = lora_dict[lora_id]
432
433
434
            result = logits_processor._get_logits(
                hidden_states=input_, lm_head=linear, embedding_bias=None
            )
435

436
            result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
437
438
439
440
441
442
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
443
            lora_logits_processor.reset_lora(slot_idx)
444
445
446
447
448
449

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=8 * num_loras * 3,
            input_size=(1, 1024),
            input_range=(0, 1),
450
            input_type=torch.float16,
451
452
453
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
454
455
456
457
458
459
        punica_wrapper.update_metadata(
            lora_mapping,
            id_to_index,
            max_loras,
            vocab_size,
        )
460
461
462

        lora_result = lora_logits_processor._get_logits(
            hidden_states=torch.cat(inputs),
463
            lm_head=original_lm_head,
464
465
            embedding_bias=None,
        )[:, :vocab_size]
466
467
        expected_result = logits_processor._get_logits(
            hidden_states=torch.cat(inputs),
468
            lm_head=original_lm_head,
469
470
            embedding_bias=None,
        )
471
472

        rtol, atol = TOLERANCES[lora_result.dtype]
473
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
474
475


476
@torch.inference_mode()
477
@pytest.mark.parametrize("vocab_size", [258049, 300000])
478
479
480
481
482
483
@pytest.mark.parametrize("device", DEVICES)
def test_lm_head_logits_processor_invalid_vocab_size(
    default_vllm_config, dist_init, vocab_size, device
) -> None:
    """Test that LogitsProcessorWithLoRA raises ValueError for invalid vocab sizes."""
    if current_platform.is_cuda_alike():
484
        torch.accelerator.set_device_index(device)
485
486
487
488
489
490
491
492
493
494
495
496

    torch.set_default_device(device)
    max_loras = 8
    lora_config = LoRAConfig(
        max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16
    )

    logits_processor = LogitsProcessor(vocab_size)
    lora_logits_processor = LogitsProcessorWithLoRA(
        logits_processor, 1024, torch.float16, device, None
    )

497
    with pytest.raises(ValueError, match="vocab size must be <= 258048"):
498
499
500
        lora_logits_processor.create_lora_weights(max_loras, lora_config)


501
@torch.inference_mode()
502
@pytest.mark.parametrize("num_loras", [1, 2, 4])
503
@pytest.mark.parametrize("device", DEVICES)
504
@pytest.mark.parametrize("stage", STAGES)
505
def test_linear_replicated(
506
    default_vllm_config,
507
508
509
510
511
    dist_init,
    num_loras,
    device,
    stage,
) -> None:
512
    if current_platform.is_cuda_alike():
513
        torch.accelerator.set_device_index(device)
514

515
    max_loras = 8
516
    torch.set_default_device(device)
517
518
519
520
521
    lora_config = LoRAConfig(
        max_loras=max_loras,
        max_lora_rank=8,
        lora_dtype=torch.float16,
    )
522
523
    punica_wrapper = get_punica_wrapper(8192, 256, device, lora_config=lora_config)
    assert check_punica_wrapper(punica_wrapper)
524

525
526
527
528
    def create_random_linear_replicated_layer(idx: int = 0):
        linear = ReplicatedLinear(
            4096, 4096, bias=False, params_dtype=torch.float16, prefix=f"layer_{idx}"
        )
529
530
531
532
        linear.weight.data = torch.rand_like(linear.weight.data)
        lora_linear = ReplicatedLinearWithLoRA(linear)

        lora_linear.create_lora_weights(max_loras, lora_config)
533
534
535
536
537
538
        assert (
            lora_linear.n_slices
            == len(lora_linear.lora_a_stacked)
            == len(lora_linear.lora_b_stacked)
            == 1
        )
539
540
        return linear, lora_linear

541
    for i in range(NUM_RANDOM_SEEDS):
542
543
544
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
545
        linear, lora_linear = create_random_linear_replicated_layer(i)
546
        assert torch.equal(linear.weight, lora_linear.weight)
547
548
549
550
551
552
553
554
555
556
557
558
559
        lora_linear.set_mapping(punica_wrapper)
        lora_dict, _ = populate_loras(
            id_to_index,
            layer=lora_linear,
            layer_weights=linear.weight,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
            input_type=torch.float16,
560
561
562
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
563
564
565
566
567
568
569
570
571
        punica_wrapper.update_metadata(
            lora_mapping,
            id_to_index,
            max_loras,
            512,
        )

        lora_result = lora_linear(torch.cat(inputs))[0]

572
        expected_results: list[torch.Tensor] = []
573
574
575
        for input_, lora_id in zip(inputs, prompt_mapping):
            lora = lora_dict[lora_id]
            result = linear(input_)[0]
576
            result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
577
578
579
580
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
581
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
582
583
584
585
586
587
588
589
590
591
592
593

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
            lora_linear.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
            input_type=torch.float16,
594
595
596
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
597

598
        punica_wrapper.update_metadata(
599
600
601
602
            lora_mapping,
            id_to_index,
            max_loras,
            512,
603
        )
604
605
606
607
608

        lora_result = lora_linear(torch.cat(inputs))[0]
        expected_result = linear(torch.cat(inputs))[0]

        rtol, atol = TOLERANCES[lora_result.dtype]
609
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
610
611


612
@torch.inference_mode()
613
@pytest.mark.parametrize("num_loras", [1, 2, 4])
614
@pytest.mark.parametrize("orientation", ["row", "column"])
615
@pytest.mark.parametrize("fully_shard", [True, False])
616
@pytest.mark.parametrize("device", DEVICES)
617
@pytest.mark.parametrize("stage", STAGES)
618
def test_linear_parallel(
619
    default_vllm_config, dist_init, num_loras, orientation, fully_shard, device, stage
620
) -> None:
621
    if current_platform.is_cuda_alike():
622
        torch.accelerator.set_device_index(device)
623

624
    max_loras = 8
625
    torch.set_default_device(device)
626
627
628
629
630
631
    lora_config = LoRAConfig(
        max_loras=max_loras,
        max_lora_rank=8,
        fully_sharded_loras=fully_shard,
        lora_dtype=torch.float16,
    )
632
633
    punica_wrapper = get_punica_wrapper(8192, 256, device, lora_config=lora_config)
    assert check_punica_wrapper(punica_wrapper)
634

635
    def create_random_linear_parallel_layer(idx: int = 0):
636
        if orientation == "row":
637
            linear = RowParallelLinear(
638
639
640
641
642
                4096,
                4096,
                bias=False,
                params_dtype=torch.float16,
                prefix=f"layer_{idx}",
643
            )
644
            linear.weight.data = torch.rand_like(linear.weight.data)
645
646
647
648
649
            lora_linear = (
                RowParallelLinearWithLoRA(linear)
                if not fully_shard
                else RowParallelLinearWithShardedLoRA(linear)
            )
650
        else:
651
            linear = ColumnParallelLinear(
652
653
654
655
656
                4096,
                4096,
                bias=False,
                params_dtype=torch.float16,
                prefix=f"layer_{idx}",
657
            )
658
            linear.weight.data = torch.rand_like(linear.weight.data)
659
660
661
662
663
            lora_linear = (
                ColumnParallelLinearWithLoRA(linear)
                if not fully_shard
                else ColumnParallelLinearWithShardedLoRA(linear)
            )
664
        lora_linear.create_lora_weights(max_loras, lora_config)
665
666
667
668
669
670
        assert (
            lora_linear.n_slices
            == len(lora_linear.lora_a_stacked)
            == len(lora_linear.lora_b_stacked)
            == 1
        )
671

672
673
        return linear, lora_linear

674
    for i in range(NUM_RANDOM_SEEDS):
675
676
677
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
678
        linear, lora_linear = create_random_linear_parallel_layer(i)
679
        assert torch.equal(linear.weight, lora_linear.weight)
680
        lora_linear.set_mapping(punica_wrapper)
681
682
683
684
685
686
687
688
689
690
691
        lora_dict, _ = populate_loras(
            id_to_index,
            layer=lora_linear,
            layer_weights=linear.weight,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
692
            input_type=torch.float16,
693
694
695
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
696
        punica_wrapper.update_metadata(
697
698
699
700
701
702
703
704
            lora_mapping,
            id_to_index,
            max_loras,
            512,
        )

        lora_result = lora_linear(torch.cat(inputs))[0]

705
        expected_results: list[torch.Tensor] = []
706
707
708
        for input_, lora_id in zip(inputs, prompt_mapping):
            lora = lora_dict[lora_id]
            result = linear(input_)[0]
709
            result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
710
711
712
713
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
714
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
715
716
717
718
719
720
721
722
723
724
725

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
            lora_linear.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
726
            input_type=torch.float16,
727
728
729
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
730

731
        punica_wrapper.update_metadata(
732
733
734
735
            lora_mapping,
            id_to_index,
            max_loras,
            512,
736
        )
737
738
739
740
741

        lora_result = lora_linear(torch.cat(inputs))[0]
        expected_result = linear(torch.cat(inputs))[0]

        rtol, atol = TOLERANCES[lora_result.dtype]
742
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
743
744
745


@torch.inference_mode()
746
@pytest.mark.parametrize("num_loras", [1, 2, 4])
747
@pytest.mark.parametrize("repeats", [1, 2, 3])
748
@pytest.mark.parametrize("fully_shard", [True, False])
749
@pytest.mark.parametrize("device", DEVICES)
750
@pytest.mark.parametrize("stage", STAGES)
751
def test_column_parallel_packed(
752
    default_vllm_config, dist_init, num_loras, repeats, fully_shard, device, stage
753
) -> None:
754
    if current_platform.is_cuda_alike():
755
        torch.accelerator.set_device_index(device)
756

757
    max_loras = 8
758
    torch.set_default_device(device)
759
760
761
762
763
764
    lora_config = LoRAConfig(
        max_loras=max_loras,
        max_lora_rank=8,
        fully_sharded_loras=fully_shard,
        lora_dtype=torch.float16,
    )
765
766
    punica_wrapper = get_punica_wrapper(8192, 256, device, lora_config=lora_config)
    assert check_punica_wrapper(punica_wrapper)
767

768
    def create_column_parallel_packed_layer(idx: int = 0):
769
        if repeats == 2:
770
            linear = MergedColumnParallelLinear(
771
772
773
774
775
                4096,
                [4096] * repeats,
                bias=False,
                params_dtype=torch.float16,
                prefix=f"layer_{idx}",
776
            )
777
            linear.weight.data = torch.rand_like(linear.weight.data)
778
779
780
781
782
            lora_linear = (
                MergedColumnParallelLinearWithLoRA(linear)
                if not fully_shard
                else MergedColumnParallelLinearWithShardedLoRA(linear)
            )
783
        elif repeats == 3:
784
            linear = QKVParallelLinear(
785
786
787
788
789
790
                4096,
                64,
                32,
                bias=False,
                params_dtype=torch.float16,
                prefix=f"layer_{idx}",
791
            )
792
            linear.weight.data = torch.rand_like(linear.weight.data)
793
794
795
796
797
            lora_linear = (
                MergedQKVParallelLinearWithLoRA(linear)
                if not fully_shard
                else MergedQKVParallelLinearWithShardedLoRA(linear)
            )
798
        else:
799
            linear = QKVParallelLinear(
800
801
802
803
804
805
                4096,
                64,
                32,
                bias=False,
                params_dtype=torch.float16,
                prefix=f"layer_{idx}",
806
            )
807
            linear.weight.data = torch.rand_like(linear.weight.data)
808
809
810
811
812
            lora_linear = (
                QKVParallelLinearWithLoRA(linear)
                if not fully_shard
                else QKVParallelLinearWithShardedLoRA(linear)
            )
813
814
815
816
817
818
819

        @dataclass
        class FakeConfig:
            hidden_size = 4096
            num_key_value_heads = 32
            num_attention_heads = 32

820
        n_slices = repeats
821
822
823
824
825
826
827
828
829
        lora_linear.create_lora_weights(
            max_loras, lora_config, model_config=FakeConfig()
        )
        assert (
            lora_linear.n_slices
            == len(lora_linear.lora_a_stacked)
            == len(lora_linear.lora_b_stacked)
            == n_slices
        )
830

831
832
        return linear, lora_linear

833
    for i in range(NUM_RANDOM_SEEDS):
834
835
836
837
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)

838
        linear, lora_linear = create_column_parallel_packed_layer(i)
839
        assert torch.equal(linear.weight, lora_linear.weight)
840
        lora_linear.set_mapping(punica_wrapper)
841
842
843
844
845
846
847
848
849
850
851
852
        lora_dict, sublora_dict = populate_loras(
            id_to_index,
            layer=lora_linear,
            layer_weights=linear.weight,
            repeats=repeats,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
853
            input_type=torch.float16,
854
855
856
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
857

858
        punica_wrapper.update_metadata(
859
860
861
862
863
864
865
866
            lora_mapping,
            id_to_index,
            max_loras,
            512,
        )

        lora_result = lora_linear(torch.cat(inputs))[0]

867
        expected_results: list[torch.Tensor] = []
868
869
870
871
        for input_, lora_id in zip(inputs, prompt_mapping):
            result = linear(input_)[0]
            subloras = sublora_dict[lora_id]
            for i, sublora in enumerate(subloras):
872
873
874
                result[
                    :, sublora.lora_b.shape[0] * i : sublora.lora_b.shape[0] * (i + 1)
                ] += input_ @ sublora.lora_a.T @ sublora.lora_b.T * sublora.scaling
875
876
877
878
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
879
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
880
881
882
883
884
885
886
887
888

        for slot_idx in range(max_loras):
            lora_linear.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
889
            input_type=torch.float16,
890
891
892
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
893

894
        punica_wrapper.update_metadata(
895
896
897
898
899
900
901
902
903
904
            lora_mapping,
            id_to_index,
            max_loras,
            512,
        )

        lora_result = lora_linear(torch.cat(inputs))[0]
        expected_result = linear(torch.cat(inputs))[0]

        rtol, atol = TOLERANCES[lora_result.dtype]
905
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
906
907


908
909
910
911
912
913
914
915
916
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4])
@pytest.mark.parametrize("num_slices", [3, 5])
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("stage", STAGES)
def test_merged_column_parallel_variable_slice(
    default_vllm_config, dist_init, num_loras, num_slices, device, stage
) -> None:
    if current_platform.is_cuda_alike():
917
        torch.accelerator.set_device_index(device)
918
919
920
921
922
923
924
925
926
927
928
929

    max_loras = 8
    torch.set_default_device(device)
    lora_config = LoRAConfig(
        max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16
    )
    punica_wrapper = get_punica_wrapper(8192, 256, device, lora_config=lora_config)

    # Set number of output slices
    output_sizes = [1024 + i * 256 for i in range(num_slices)]
    total_output = sum(output_sizes)

930
    def create_layer(idx: int = 0):
931
932
        # Create linear layer
        linear = MergedColumnParallelLinear(
933
934
935
936
937
            4096,
            output_sizes,
            bias=False,
            params_dtype=torch.float16,
            prefix=f"layer_{idx}",
938
939
940
941
942
943
944
945
946
947
948
        )
        linear.weight.data = torch.rand_like(linear.weight.data)

        # Create linear layer with LoRA adapter
        lora_linear = MergedColumnParallelLinearVariableSliceWithLoRA(linear)
        lora_linear.create_lora_weights(max_loras, lora_config)
        return linear, lora_linear

    for i in range(NUM_RANDOM_SEEDS):
        set_random_seed(i)
        id_to_index = get_random_id_to_index(num_loras, max_loras)
949
        linear, lora_linear = create_layer(i)
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
        lora_linear.set_mapping(punica_wrapper)

        # Populate LoRA weights
        lora_dict, sublora_dict = {}, {}
        for slot_idx, lora_id in enumerate(id_to_index):
            if lora_id is not None:
                # Create random LoRA weights
                lora_a = torch.rand(8, 4096, dtype=torch.float16, device=device)
                lora_b = torch.rand(total_output, 8, dtype=torch.float16, device=device)
                lora_linear.set_lora(slot_idx, lora_a, lora_b)
                lora_dict[lora_id] = (lora_a, lora_b)

                # Split lora_b for expected computation
                sublora_dict[lora_id] = torch.split(lora_b, output_sizes, dim=0)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
            input_type=torch.float16,
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
        punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, 512)

        # Compute LoRA result
        lora_result = lora_linear(torch.cat(inputs))[0]

        # Compute expected result
        expected_results = []
        for input_, lora_id in zip(inputs, prompt_mapping):
            result = linear(input_)[0]
            lora_a, _ = lora_dict[lora_id]
            offset = 0
            # Compute expected result for each sublora
            for lora_b_slice in sublora_dict[lora_id]:
                sz = lora_b_slice.shape[0]
                result[:, offset : offset + sz] += input_ @ lora_a.T @ lora_b_slice.T
                offset += sz
            expected_results.append(result)

        # Check that the LoRA result is close to the expected result
        rtol, atol = TOLERANCES[lora_result.dtype]
        torch.testing.assert_close(
            lora_result, torch.cat(expected_results), rtol=rtol, atol=atol
        )

        # Reset LoRA weights and check results with zero LoRA weights
        for slot_idx in range(max_loras):
            lora_linear.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
            input_type=torch.float16,
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
        punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, 512)

        # After resetting LoRA weights,
        # lora_linear should behave like the base linear layer
        lora_result = lora_linear(torch.cat(inputs))[0]
        expected_result = linear(torch.cat(inputs))[0]

        rtol, atol = TOLERANCES[lora_result.dtype]
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)


1022
@pytest.mark.parametrize("tp_size", [1, 2, 4, 8])
1023
@pytest.mark.parametrize(
1024
1025
    "seed", list(range(VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS))
)
1026
def test_vocab_parallel_embedding_indices(tp_size, seed, default_vllm_config):
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
    random.seed(seed)
    vocab_size = random.randint(4000, 64000)
    added_vocab_size = random.randint(0, 1024)
    org_vocab_size = vocab_size - added_vocab_size
    last_org_vocab_end_index = 0
    last_added_vocab_end_index = org_vocab_size
    computed_vocab_size = 0
    computed_org_vocab_size = 0
    computed_added_vocab_size = 0
    vocab_size_padded = -1

1038
1039
1040
    all_org_tokens: list[int] = []
    all_added_tokens: list[int] = []
    token_ids: list[int] = []
1041
1042

    for tp_rank in range(tp_size):
1043
1044
        with (
            patch(
1045
                "vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank",
1046
1047
1048
                return_value=tp_rank,
            ),
            patch(
1049
                "vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_world_size",
1050
1051
1052
                return_value=tp_size,
            ),
        ):
1053
            vocab_embedding = VocabParallelEmbedding(
1054
1055
                vocab_size, 1, org_num_embeddings=org_vocab_size
            )
1056
1057
1058
1059
        vocab_size_padded = vocab_embedding.num_embeddings_padded
        shard_indices = vocab_embedding.shard_indices
        # Assert that the ranges are contiguous
        assert shard_indices.org_vocab_start_index == last_org_vocab_end_index
1060
        assert shard_indices.added_vocab_start_index == last_added_vocab_end_index
1061
1062
1063
1064
1065
1066
1067
1068

        # Ensure that we are not exceeding the vocab size
        computed_vocab_size += shard_indices.num_elements_padded
        computed_org_vocab_size += shard_indices.num_org_elements
        computed_added_vocab_size += shard_indices.num_added_elements

        # Ensure that the ranges are not overlapping
        all_org_tokens.extend(
1069
1070
1071
1072
            range(
                shard_indices.org_vocab_start_index, shard_indices.org_vocab_end_index
            )
        )
1073
        all_added_tokens.extend(
1074
1075
1076
1077
1078
            range(
                shard_indices.added_vocab_start_index,
                shard_indices.added_vocab_end_index,
            )
        )
1079
1080

        token_ids.extend(
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
            range(
                shard_indices.org_vocab_start_index, shard_indices.org_vocab_end_index
            )
        )
        token_ids.extend(
            [-1]
            * (shard_indices.num_org_elements_padded - shard_indices.num_org_elements)
        )
        token_ids.extend(
            range(
                shard_indices.added_vocab_start_index,
                shard_indices.added_vocab_end_index,
            )
        )
1095
        token_ids.extend(
1096
1097
1098
1099
1100
1101
            [-1]
            * (
                shard_indices.num_added_elements_padded
                - shard_indices.num_added_elements
            )
        )
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128

        last_org_vocab_end_index = shard_indices.org_vocab_end_index
        last_added_vocab_end_index = shard_indices.added_vocab_end_index

    assert computed_vocab_size == vocab_size_padded
    assert computed_org_vocab_size == org_vocab_size
    assert computed_added_vocab_size == added_vocab_size

    # Ensure that the ranges are not overlapping
    assert len(all_org_tokens) == len(set(all_org_tokens))
    assert len(all_added_tokens) == len(set(all_added_tokens))
    assert not set(all_org_tokens).intersection(set(all_added_tokens))

    token_ids_tensor = torch.tensor(token_ids, dtype=torch.long)
    reindex_mapping = vocab_embedding.get_sharded_to_full_mapping()
    assert reindex_mapping is not None or tp_size == 1
    if reindex_mapping is not None:
        reindexed_token_ids = token_ids_tensor[reindex_mapping]
        expected = torch.tensor(list(range(0, vocab_size)))
        assert reindexed_token_ids[:vocab_size].equal(expected)
        assert torch.all(reindexed_token_ids[vocab_size:] == -1)


def test_get_masked_input_and_mask():
    x = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])

    # base tp 1 case, no padding
1129
1130
1131
1132
1133
1134
1135
1136
    modified_x, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=0,
        org_vocab_end_index=8,
        added_vocab_start_index=8,
        added_vocab_end_index=12,
        num_org_vocab_padding=0,
    )
1137
1138
1139
    assert torch.equal(x, modified_x)

    # tp 2 case, no padding
1140
1141
1142
1143
1144
1145
1146
1147
    modified_x_rank_0, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=0,
        org_vocab_end_index=4,
        added_vocab_start_index=8,
        added_vocab_end_index=10,
        num_org_vocab_padding=0,
    )
1148
1149
1150
1151
1152
1153
    modified_x_rank_1, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=4,
        org_vocab_end_index=8,
        added_vocab_start_index=10,
        added_vocab_end_index=12,
1154
1155
1156
1157
1158
1159
1160
1161
        num_org_vocab_padding=0,
    )
    assert torch.equal(
        modified_x_rank_0, torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 0, 0])
    )
    assert torch.equal(
        modified_x_rank_1, torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 4, 5])
    )
1162
1163

    # tp 4 case, no padding
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
    modified_x_rank_0, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=0,
        org_vocab_end_index=2,
        added_vocab_start_index=8,
        added_vocab_end_index=9,
        num_org_vocab_padding=0,
    )
    modified_x_rank_1, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=2,
        org_vocab_end_index=4,
        added_vocab_start_index=9,
        added_vocab_end_index=10,
        num_org_vocab_padding=0,
    )
1180
1181
1182
1183
1184
1185
    modified_x_rank_2, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=4,
        org_vocab_end_index=6,
        added_vocab_start_index=10,
        added_vocab_end_index=11,
1186
1187
        num_org_vocab_padding=0,
    )
1188
1189
1190
1191
1192
1193
    modified_x_rank_3, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=6,
        org_vocab_end_index=8,
        added_vocab_start_index=11,
        added_vocab_end_index=12,
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
        num_org_vocab_padding=0,
    )
    assert torch.equal(
        modified_x_rank_0, torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0])
    )
    assert torch.equal(
        modified_x_rank_1, torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0])
    )
    assert torch.equal(
        modified_x_rank_2, torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0])
    )
    assert torch.equal(
        modified_x_rank_3, torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2])
    )
1208
1209

    # base tp 1 case, with padding
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
    modified_x, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=0,
        org_vocab_end_index=8,
        added_vocab_start_index=8,
        added_vocab_end_index=12,
        num_org_vocab_padding=2,
    )
    assert torch.equal(
        modified_x, torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13])
    )
1221
1222

    # tp 2 case, with padding
1223
1224
1225
1226
1227
1228
1229
1230
    modified_x_rank_0, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=0,
        org_vocab_end_index=4,
        added_vocab_start_index=8,
        added_vocab_end_index=10,
        num_org_vocab_padding=2,
    )
1231
1232
1233
1234
1235
1236
    modified_x_rank_1, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=4,
        org_vocab_end_index=8,
        added_vocab_start_index=10,
        added_vocab_end_index=12,
1237
1238
1239
1240
1241
1242
1243
1244
        num_org_vocab_padding=2,
    )
    assert torch.equal(
        modified_x_rank_0, torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 6, 7, 0, 0])
    )
    assert torch.equal(
        modified_x_rank_1, torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 6, 7])
    )
1245
1246

    # tp 4 case, with padding
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
    modified_x_rank_0, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=0,
        org_vocab_end_index=2,
        added_vocab_start_index=8,
        added_vocab_end_index=9,
        num_org_vocab_padding=2,
    )
    modified_x_rank_1, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=2,
        org_vocab_end_index=4,
        added_vocab_start_index=9,
        added_vocab_end_index=10,
        num_org_vocab_padding=2,
    )
1263
1264
1265
1266
1267
1268
    modified_x_rank_2, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=4,
        org_vocab_end_index=6,
        added_vocab_start_index=10,
        added_vocab_end_index=11,
1269
1270
        num_org_vocab_padding=2,
    )
1271
1272
1273
1274
1275
1276
    modified_x_rank_3, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=6,
        org_vocab_end_index=8,
        added_vocab_start_index=11,
        added_vocab_end_index=12,
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
        num_org_vocab_padding=2,
    )
    assert torch.equal(
        modified_x_rank_0, torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0])
    )
    assert torch.equal(
        modified_x_rank_1, torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 4, 0, 0])
    )
    assert torch.equal(
        modified_x_rank_2, torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 4, 0])
    )
    assert torch.equal(
        modified_x_rank_3, torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 4])
    )
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425


def test_variable_slice_lora_class_selection(default_vllm_config, dist_init):
    """Test that MergedColumnParallelLinearVariableSliceWithLoRA is selected
    only for nemotron-h style models (checkpoint has single weight but layer
    has 3+ output slices).

    This verifies that from_layer selects
    MergedColumnParallelLinearVariableSliceWithLoRA
    before ColumnParallelLinearWithLoRA for layers with 3+ output sizes, since
    ColumnParallelLinearWithLoRA's slice_lora_b assumes exactly 2 slices.
    """
    from vllm.lora.utils import from_layer

    lora_config = LoRAConfig(max_loras=8, max_lora_rank=8, lora_dtype=torch.float16)

    # Case 1: MergedColumnParallelLinear with 3+ output sizes and
    # packed_modules_list with 1 item (nemotron-h style)
    # -> MergedColumnParallelLinearVariableSliceWithLoRA should be selected
    layer_3_slices = MergedColumnParallelLinear(
        4096, [1024, 1280, 1536], bias=False, params_dtype=torch.float16
    )
    packed_modules_single = ["mlp"]

    assert MergedColumnParallelLinearVariableSliceWithLoRA.can_replace_layer(
        source_layer=layer_3_slices,
        lora_config=lora_config,
        packed_modules_list=packed_modules_single,
    ), "MergedColumnParallelLinearVariableSliceWithLoRA should handle 3+ slices"

    # ColumnParallelLinearWithLoRA should NOT match 3+ slices
    # (its slice_lora_b assumes exactly 2 slices)
    assert not ColumnParallelLinearWithLoRA.can_replace_layer(
        source_layer=layer_3_slices,
        lora_config=lora_config,
        packed_modules_list=packed_modules_single,
    ), (
        "ColumnParallelLinearWithLoRA should NOT handle 3+ slices "
        "(slice_lora_b assumes 2 slices)"
    )

    # Verify from_layer selects the correct class (Variable, not base)
    selected_layer = from_layer(
        layer_3_slices,
        max_loras=8,
        lora_config=lora_config,
        packed_modules_list=packed_modules_single,
    )
    assert isinstance(
        selected_layer, MergedColumnParallelLinearVariableSliceWithLoRA
    ), (
        f"from_layer should select MergedColumnParallelLinearVariableSliceWithLoRA "
        f"for 3+ slices, got {type(selected_layer).__name__}"
    )

    # Case 2: MergedColumnParallelLinear with 2 output sizes and
    # packed_modules_list with 1 item (standard gate_up style)
    # -> ColumnParallelLinearWithLoRA should be selected
    # -> MergedColumnParallelLinearVariableSliceWithLoRA should NOT match
    layer_2_slices = MergedColumnParallelLinear(
        4096, [2048, 2048], bias=False, params_dtype=torch.float16
    )

    assert ColumnParallelLinearWithLoRA.can_replace_layer(
        source_layer=layer_2_slices,
        lora_config=lora_config,
        packed_modules_list=packed_modules_single,
    ), "ColumnParallelLinearWithLoRA should handle 2 slices"

    assert not MergedColumnParallelLinearVariableSliceWithLoRA.can_replace_layer(
        source_layer=layer_2_slices,
        lora_config=lora_config,
        packed_modules_list=packed_modules_single,
    ), "MergedColumnParallelLinearVariableSliceWithLoRA should NOT handle 2 slices"

    # Verify from_layer selects ColumnParallelLinearWithLoRA for 2 slices
    selected_layer_2 = from_layer(
        layer_2_slices,
        max_loras=8,
        lora_config=lora_config,
        packed_modules_list=packed_modules_single,
    )
    assert isinstance(selected_layer_2, ColumnParallelLinearWithLoRA), (
        f"from_layer should select ColumnParallelLinearWithLoRA "
        f"for 2 slices, got {type(selected_layer_2).__name__}"
    )
    # But NOT the Variable subclass
    assert not isinstance(
        selected_layer_2, MergedColumnParallelLinearVariableSliceWithLoRA
    ), (
        "from_layer should NOT select "
        "MergedColumnParallelLinearVariableSliceWithLoRA for 2 slices"
    )

    # Case 3: MergedColumnParallelLinear with 3+ items in packed_modules_list
    # -> MergedColumnParallelLinearVariableSliceWithLoRA should be selected
    packed_modules_three = ["gate_proj", "up_proj", "down_proj"]

    assert MergedColumnParallelLinearVariableSliceWithLoRA.can_replace_layer(
        source_layer=layer_3_slices,
        lora_config=lora_config,
        packed_modules_list=packed_modules_three,
    ), "MergedColumnParallelLinearVariableSliceWithLoRA should handle 3+ packed modules"

    # Case 4: MergedColumnParallelLinear with 2 items in packed_modules_list
    # -> MergedColumnParallelLinearWithLoRA should handle this (not Variable)
    packed_modules_two = ["gate_proj", "up_proj"]

    assert not MergedColumnParallelLinearVariableSliceWithLoRA.can_replace_layer(
        source_layer=layer_2_slices,
        lora_config=lora_config,
        packed_modules_list=packed_modules_two,
    ), (
        "MergedColumnParallelLinearVariableSliceWithLoRA"
        " should NOT handle 2 packed modules"
    )

    assert MergedColumnParallelLinearWithLoRA.can_replace_layer(
        source_layer=layer_2_slices,
        lora_config=lora_config,
        packed_modules_list=packed_modules_two,
    ), "MergedColumnParallelLinearWithLoRA should handle 2 packed modules"

    # Verify from_layer selects MergedColumnParallelLinearWithLoRA for 2 packed modules
    selected_layer_merged = from_layer(
        layer_2_slices,
        max_loras=8,
        lora_config=lora_config,
        packed_modules_list=packed_modules_two,
    )
    assert isinstance(selected_layer_merged, MergedColumnParallelLinearWithLoRA), (
        f"from_layer should select MergedColumnParallelLinearWithLoRA "
        f"for 2 packed modules, got {type(selected_layer_merged).__name__}"
    )

1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
    fully_sharded_tp_lora_config = LoRAConfig(
        max_loras=8,
        max_lora_rank=16,
        lora_dtype=torch.float16,
        fully_sharded_loras=True,
    )
    fully_sharded_tp_layer = MergedColumnParallelLinear(
        4096, [2048, 2048], bias=False, params_dtype=torch.float16
    )
    fully_sharded_tp_layer.tp_size = 2

    assert not MergedColumnParallelLinearWithLoRA.can_replace_layer(
        source_layer=fully_sharded_tp_layer,
        lora_config=fully_sharded_tp_lora_config,
        packed_modules_list=packed_modules_two,
    ), "Generic merged wrapper should reject fully sharded TP layers"

    assert MergedColumnParallelLinearWithShardedLoRA.can_replace_layer(
        source_layer=fully_sharded_tp_layer,
        lora_config=fully_sharded_tp_lora_config,
        packed_modules_list=packed_modules_two,
    ), "Sharded merged wrapper should remain eligible for fully sharded TP layers"

    selected_fully_sharded_tp_layer = from_layer(
        fully_sharded_tp_layer,
        max_loras=8,
        lora_config=fully_sharded_tp_lora_config,
        packed_modules_list=packed_modules_two,
    )
    assert isinstance(
        selected_fully_sharded_tp_layer,
        MergedColumnParallelLinearWithShardedLoRA,
    ), (
        "from_layer should select MergedColumnParallelLinearWithShardedLoRA "
        "for fully sharded TP merged layers, got "
        f"{type(selected_fully_sharded_tp_layer).__name__}"
    )

    # Case 5: DeepSeek's fused_qkv_a_proj should reuse the generic merged
    # wrapper while preserving its custom base forward path.
    deepseek_fused_layer = DeepSeekV2FusedQkvAProjLinear(
        4096, [2048, 2048], prefix="model.layers.0.self_attn.fused_qkv_a_proj"
    )
    selected_deepseek_layer = from_layer(
        deepseek_fused_layer,
        max_loras=8,
        lora_config=lora_config,
        packed_modules_list=packed_modules_two,
    )
    assert isinstance(selected_deepseek_layer, MergedColumnParallelLinearWithLoRA), (
        "from_layer should select MergedColumnParallelLinearWithLoRA "
        f"for DeepSeek fused_qkv_a_proj, got {type(selected_deepseek_layer).__name__}"
    )

    fully_sharded_lora_config = LoRAConfig(
        max_loras=8,
        max_lora_rank=16,
        lora_dtype=torch.float16,
        fully_sharded_loras=True,
    )
    selected_fully_sharded_deepseek_layer = from_layer(
        deepseek_fused_layer,
        max_loras=8,
        lora_config=fully_sharded_lora_config,
        packed_modules_list=packed_modules_two,
    )
    assert isinstance(
        selected_fully_sharded_deepseek_layer,
        MergedColumnParallelLinearWithLoRA,
    ), (
        "from_layer should keep using MergedColumnParallelLinearWithLoRA "
        "for fused_qkv_a_proj when the base layer is effectively unsharded, got "
        f"{type(selected_fully_sharded_deepseek_layer).__name__}"
    )

    # Case 6: Generic subclass of MergedColumnParallelLinear with 2 packed
    # modules should still use the generic merged wrapper.
    class CustomMergedColumnParallelLinear(MergedColumnParallelLinear):
        pass

    custom_merged_layer = CustomMergedColumnParallelLinear(
        4096, [2048, 2048], bias=False, params_dtype=torch.float16
    )
    assert MergedColumnParallelLinearWithLoRA.can_replace_layer(
        source_layer=custom_merged_layer,
        lora_config=lora_config,
        packed_modules_list=packed_modules_two,
    ), "MergedColumnParallelLinearWithLoRA should handle subclasses"

    selected_custom_layer = from_layer(
        custom_merged_layer,
        max_loras=8,
        lora_config=lora_config,
        packed_modules_list=packed_modules_two,
    )
    assert isinstance(selected_custom_layer, MergedColumnParallelLinearWithLoRA), (
        f"from_layer should select MergedColumnParallelLinearWithLoRA "
        f"for subclassed merged layers, got {type(selected_custom_layer).__name__}"
    )

    # Case 7: Plain ColumnParallelLinear (not merged) - common in many models
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
    # -> ColumnParallelLinearWithLoRA should be selected
    plain_column_parallel = ColumnParallelLinear(
        4096, 4096, bias=False, params_dtype=torch.float16
    )

    assert ColumnParallelLinearWithLoRA.can_replace_layer(
        source_layer=plain_column_parallel,
        lora_config=lora_config,
        packed_modules_list=packed_modules_single,
    ), "ColumnParallelLinearWithLoRA should handle plain ColumnParallelLinear"

    assert not MergedColumnParallelLinearVariableSliceWithLoRA.can_replace_layer(
        source_layer=plain_column_parallel,
        lora_config=lora_config,
        packed_modules_list=packed_modules_single,
    ), (
        "MergedColumnParallelLinearVariableSliceWithLoRA "
        "should NOT handle plain ColumnParallelLinear"
    )

    # Verify from_layer selects ColumnParallelLinearWithLoRA for plain layer
    selected_plain = from_layer(
        plain_column_parallel,
        max_loras=8,
        lora_config=lora_config,
        packed_modules_list=packed_modules_single,
    )
    assert isinstance(selected_plain, ColumnParallelLinearWithLoRA), (
        f"from_layer should select ColumnParallelLinearWithLoRA "
        f"for plain ColumnParallelLinear, got {type(selected_plain).__name__}"
    )

1559
    # Case 8: MergedColumnParallelLinear with exactly 2 output sizes
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
    # and empty packed_modules_list
    # -> ColumnParallelLinearWithLoRA should NOT match (packed_modules_list != 1)
    # -> MergedColumnParallelLinearVariableSliceWithLoRA should NOT match (< 3 slices)
    assert not ColumnParallelLinearWithLoRA.can_replace_layer(
        source_layer=layer_2_slices,
        lora_config=lora_config,
        packed_modules_list=[],
    ), "ColumnParallelLinearWithLoRA should NOT handle empty packed_modules_list"

    assert not MergedColumnParallelLinearVariableSliceWithLoRA.can_replace_layer(
        source_layer=layer_2_slices,
        lora_config=lora_config,
        packed_modules_list=[],
    ), (
        "MergedColumnParallelLinearVariableSliceWithLoRA "
        "should NOT handle 2 slices even with empty packed_modules_list"
    )
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743


@pytest.mark.parametrize(
    "wrapper_cls",
    [ColumnParallelLinearWithLoRA, ColumnParallelLinearWithShardedLoRA],
)
def test_get_and_maybe_dequant_weights_accepts_lora_wrappers(dist_init, wrapper_cls):
    from vllm.model_executor.layers.quantization.utils.quant_utils import (
        get_and_maybe_dequant_weights,
    )

    linear = ColumnParallelLinear(4096, 4096, bias=False, params_dtype=torch.float16)
    lora_linear = wrapper_cls(linear)

    # Should work with LoRA wrappers and return [out, in] weights.
    dequant_weight = get_and_maybe_dequant_weights(lora_linear, out_dtype=torch.float16)
    assert dequant_weight.shape == linear.weight.shape


@torch.inference_mode()
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("stage", STAGES)
@pytest.mark.parametrize("fully_sharded", [False, True])
def test_deepseek_fused_qkv_a_proj_lora_preserves_base_forward(
    default_vllm_config, dist_init, device, stage, fully_sharded
):
    if current_platform.is_cuda_alike():
        torch.accelerator.set_device_index(device)

    torch.set_default_device(device)
    dtype = torch.float16 if current_platform.is_cuda_alike() else torch.float32
    max_loras = 8
    lora_config = LoRAConfig(
        max_loras=max_loras,
        max_lora_rank=8,
        lora_dtype=dtype,
        fully_sharded_loras=fully_sharded,
    )
    punica_wrapper = get_punica_wrapper(8192, 256, device, lora_config=lora_config)
    assert check_punica_wrapper(punica_wrapper)

    class OffsetDeepSeekFusedQkvAProjLinear(DeepSeekV2FusedQkvAProjLinear):
        def forward(self, input_):
            output, output_bias = super().forward(input_)
            return output + 1, output_bias

    layer = OffsetDeepSeekFusedQkvAProjLinear(
        32, [16, 16], prefix="model.layers.0.self_attn.fused_qkv_a_proj"
    )
    layer.weight.data = torch.rand_like(layer.weight.data, dtype=dtype)

    lora_layer = MergedColumnParallelLinearWithLoRA(layer)
    lora_layer.create_lora_weights(max_loras, lora_config)
    lora_layer.set_mapping(punica_wrapper)

    id_to_index = get_random_id_to_index(1, max_loras, log=False)
    active_slot = next(i for i, lora_id in enumerate(id_to_index) if lora_id == 1)
    lora_a = [
        torch.rand(8, 32, dtype=dtype, device=device),
        torch.rand(8, 32, dtype=dtype, device=device),
    ]
    lora_b = [
        torch.rand(16, 8, dtype=dtype, device=device),
        torch.rand(16, 8, dtype=dtype, device=device),
    ]
    lora_layer.set_lora(active_slot, lora_a=lora_a, lora_b=lora_b)

    inputs, index_mapping, prompt_mapping = create_random_inputs(
        active_lora_ids=[1],
        num_inputs=4,
        input_size=(1, 32),
        input_range=(0, 1),
        input_type=dtype,
        device=device,
    )
    lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
    punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, 512)

    lora_result = lora_layer(torch.cat(inputs))[0]

    expected_results = []
    for input_ in inputs:
        result = layer(input_)[0]
        result[:, :16] += input_ @ lora_a[0].T @ lora_b[0].T
        result[:, 16:] += input_ @ lora_a[1].T @ lora_b[1].T
        expected_results.append(result)

    rtol, atol = TOLERANCES[lora_result.dtype]
    torch.testing.assert_close(
        lora_result, torch.cat(expected_results), rtol=rtol, atol=atol
    )

    merged_layer = OffsetDeepSeekFusedQkvAProjLinear(
        32, [16, 16], prefix="model.layers.0.self_attn.fused_qkv_a_proj"
    )
    merged_layer.weight.data = layer.weight.data.clone()
    merged_layer.weight.data[:16].add_(lora_b[0] @ lora_a[0])
    merged_layer.weight.data[16:].add_(lora_b[1] @ lora_a[1])
    merged_result = merged_layer(torch.cat(inputs))[0]

    torch.testing.assert_close(lora_result, merged_result, rtol=rtol, atol=atol)


@torch.inference_mode()
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("stage", STAGES)
def test_replicated_lora_preserves_base_forward_for_subclasses(
    default_vllm_config, dist_init, device, stage
):
    if current_platform.is_cuda_alike():
        torch.accelerator.set_device_index(device)

    torch.set_default_device(device)
    dtype = torch.float16 if current_platform.is_cuda_alike() else torch.float32
    max_loras = 8
    lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, lora_dtype=dtype)
    punica_wrapper = get_punica_wrapper(8192, 256, device, lora_config=lora_config)
    assert check_punica_wrapper(punica_wrapper)

    class OffsetReplicatedLinear(ReplicatedLinear):
        def forward(self, input_):
            output, output_bias = super().forward(input_)
            return output + 1, output_bias

    layer = OffsetReplicatedLinear(32, 16, bias=False, params_dtype=dtype)
    layer.weight.data = torch.rand_like(layer.weight.data, dtype=dtype)

    lora_layer = ReplicatedLinearWithLoRA(layer)
    lora_layer.create_lora_weights(max_loras, lora_config)
    lora_layer.set_mapping(punica_wrapper)

    id_to_index = get_random_id_to_index(1, max_loras, log=False)
    active_slot = next(i for i, lora_id in enumerate(id_to_index) if lora_id == 1)
    lora_a = torch.rand(8, 32, dtype=dtype, device=device)
    lora_b = torch.rand(16, 8, dtype=dtype, device=device)
    lora_layer.set_lora(active_slot, lora_a=lora_a, lora_b=lora_b)

    inputs, index_mapping, prompt_mapping = create_random_inputs(
        active_lora_ids=[1],
        num_inputs=4,
        input_size=(1, 32),
        input_range=(0, 1),
        input_type=dtype,
        device=device,
    )
    lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
    punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, 512)

    lora_result = lora_layer(torch.cat(inputs))[0]

    expected_results = []
    for input_ in inputs:
        result = layer(input_)[0]
        result += input_ @ lora_a.T @ lora_b.T
        expected_results.append(result)

    rtol, atol = TOLERANCES[lora_result.dtype]
    torch.testing.assert_close(
        lora_result, torch.cat(expected_results), rtol=rtol, atol=atol
    )

    merged_layer = OffsetReplicatedLinear(32, 16, bias=False, params_dtype=dtype)
    merged_layer.weight.data = layer.weight.data.clone()
    merged_layer.weight.data.add_(lora_b @ lora_a)
    merged_result = merged_layer(torch.cat(inputs))[0]

    torch.testing.assert_close(lora_result, merged_result, rtol=rtol, atol=atol)