test_layers.py 50 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
import random
from copy import deepcopy
from dataclasses import dataclass
7
from unittest.mock import patch
8

9
import pytest
10
11
12
import torch
import torch.nn.functional as F

13
from vllm.config.lora import LoRAConfig
14
15
16
17
18
19
from vllm.lora.layers import (
    BaseLayerWithLoRA,
    ColumnParallelLinearWithLoRA,
    ColumnParallelLinearWithShardedLoRA,
    LogitsProcessorWithLoRA,
    LoRAMapping,
20
    MergedColumnParallelLinearVariableSliceWithLoRA,
21
22
23
24
25
26
27
28
29
30
31
    MergedColumnParallelLinearWithLoRA,
    MergedColumnParallelLinearWithShardedLoRA,
    MergedQKVParallelLinearWithLoRA,
    MergedQKVParallelLinearWithShardedLoRA,
    QKVParallelLinearWithLoRA,
    QKVParallelLinearWithShardedLoRA,
    ReplicatedLinearWithLoRA,
    RowParallelLinearWithLoRA,
    RowParallelLinearWithShardedLoRA,
    VocabParallelEmbeddingWithLoRA,
)
32
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
33
from vllm.lora.punica_wrapper import get_punica_wrapper
34
35
36
37
38
39
40
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    MergedColumnParallelLinear,
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
41
from vllm.model_executor.layers.logits_processor import LogitsProcessor
42
from vllm.model_executor.layers.vocab_parallel_embedding import (
43
44
45
46
    ParallelLMHead,
    VocabParallelEmbedding,
    get_masked_input_and_mask,
)
47
from vllm.platforms import current_platform
48
from vllm.utils.torch_utils import set_random_seed
49
50
51
52
53
54
55
56

from .utils import DummyLoRAManager

TOLERANCES = {
    torch.float16: (5e-3, 5e-3),
    torch.float32: (5e-3, 5e-3),
    torch.bfloat16: (3e-2, 2e-2),
}
57
58
59

pytestmark = pytest.mark.skipif(
    not (current_platform.is_cuda_alike() or current_platform.is_cpu()),
60
61
    reason="Backend not supported",
)
62

63
64
65
66
67
DEVICES = (
    [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
    if current_platform.is_cuda_alike()
    else ["cpu"]
)
68

69
# prefill stage(True) or decode stage(False)
70
STAGES = [True, False]
71

72
NUM_RANDOM_SEEDS = 2
73

74
VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS = 2
75
76
77


@pytest.fixture(autouse=True)
78
def clean_cache_reset_device(reset_default_device):
79
    # Release any memory we might be holding on to. CI runs OOMs otherwise.
80
81
    from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT

82
83
84
85
86
    _LORA_B_PTR_DICT.clear()
    _LORA_A_PTR_DICT.clear()

    yield

87

88
89
90
@pytest.fixture(autouse=True)
def skip_cuda_with_stage_false(request):
    """
91
    On cuda-like platforms, we use the same kernels for prefill and decode
92
93
94
95
96
    stage, and 'stage' is generally ignored, so we only need to test once.
    """
    if current_platform.is_cuda_alike():
        try:
            if hasattr(request.node, "callspec") and hasattr(
97
98
                request.node.callspec, "params"
            ):
99
100
101
102
103
104
105
106
                params = request.node.callspec.params
                if "stage" in params and params["stage"] is False:
                    pytest.skip("Skip test when stage=False")
        except Exception:
            pass
    yield


107
108
def get_random_id_to_index(
    num_loras: int, num_slots: int, log: bool = True
109
) -> list[int | None]:
110
111
112
113
114
115
116
117
118
119
120
121
    """Creates a random lora_id_to_index mapping.

    Args:
        num_loras: The number of active loras in the mapping.
        num_slots: The number of slots in the mapping. Must be larger
            than num_loras.
        log: Whether to log the output.
    """

    if num_loras > num_slots:
        raise ValueError(
            f"num_loras is higher than num_slots: {num_loras} > {num_slots}. "
122
123
            "num_loras must be less than or equal to num_slots."
        )
124

125
    slots: list[int | None] = [None] * num_slots
126
127
128
129
130
131
132
133
134
135
136
    random_slot_selections = (torch.randperm(num_slots)[:num_loras]).tolist()
    for lora_id, slot_idx in enumerate(random_slot_selections, start=1):
        slots[slot_idx] = lora_id

    if log:
        print(f"Created lora_id_to_index mapping: {slots}.")

    return slots


def populate_loras(
137
    id_to_index: list[int | None],
138
139
140
    layer: BaseLayerWithLoRA,
    layer_weights: torch.Tensor,
    repeats: int = 1,
141
) -> tuple[dict[int, LoRALayerWeights], dict[int, list[LoRALayerWeights]]]:
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
    """This method populates the lora layers with lora weights.

    Args:
        id_to_index: a list of lora ids. The index of the lora id
            represents which memory slot the lora matrices are
            stored in. A None value indicates a free slot.
        layer: the LoRAlayer to populate.
        layer_weights: the PyTorch tensor containing the layer's
            weights.
        repeats: must only be set for column parallel packed
            layers. Indicates the number of loras to compose
            together to create a single lora layer.
    """

    # Dictionary that maps the lora ID to the
    # corresponding lora weights.
158
    lora_dict: dict[int, LoRALayerWeights] = dict()
159
160

    # Dictionary that maps the lora ID to the
161
    # corresponding subloras.
162
    sublora_dict: dict[int, list[LoRALayerWeights]] = dict()
163
164
165

    for slot_idx, lora_id in enumerate(id_to_index):
        if lora_id is not None:
166
            subloras: list[LoRALayerWeights] = []
167
168
            sublora_len = layer_weights.shape[0] // repeats
            for i in range(repeats):
169
170
171
172
173
174
175
                sublora = DummyLoRAManager(layer_weights.device).init_random_lora(
                    module_name=f"fake_{i}",
                    weight=layer_weights,
                )
                sublora.lora_b = sublora.lora_b[
                    (sublora_len * i) : (sublora_len * (i + 1)), :
                ]
176
177
178
                sublora.optimize()
                subloras.append(sublora)

179
            lora = PackedLoRALayerWeights.pack(subloras) if repeats > 1 else subloras[0]
180
181
182
183
184
185
186
187
188
189
190
191
192
193

            layer.set_lora(
                slot_idx,
                lora_a=lora.lora_a,
                lora_b=lora.lora_b,
            )

            lora_dict[lora_id] = lora
            sublora_dict[lora_id] = subloras

    return lora_dict, sublora_dict


def create_random_inputs(
194
    active_lora_ids: list[int],
195
    num_inputs: int,
196
197
    input_size: tuple[int, ...],
    input_range: tuple[float, float],
198
    input_type: torch.dtype = torch.int,
199
    device: torch.device = "cuda",
200
) -> tuple[list[torch.Tensor], list[int], list[int]]:
201
202
203
204
205
206
207
208
209
210
211
212
213
    """Creates random inputs.

    Args:
        active_lora_ids: lora IDs of active lora weights.
        num_inputs: the number of inputs to create.
        input_size: the size of each individual input.
        input_range: the range of values to include in the input.
            input_range[0] <= possible input values < input_range[1]
        input_type: the type of values in the input.
    """

    low, high = input_range

214
215
216
    inputs: list[torch.Tensor] = []
    index_mapping: list[int] = []
    prompt_mapping: list[int] = []
217

218
219
220
    for _ in range(num_inputs):
        if input_type == torch.int:
            inputs.append(
221
222
223
224
                torch.randint(
                    low=int(low), high=int(high), size=input_size, device=device
                )
            )
225
226
        else:
            inputs.append(
227
228
229
                torch.rand(size=input_size, dtype=input_type, device=device) * high
                + low
            )
230
231
232
233
234
235
236
237

        lora_id = random.choice(active_lora_ids)
        index_mapping += [lora_id] * input_size[0]
        prompt_mapping += [lora_id]

    return inputs, index_mapping, prompt_mapping


238
239
240
241
242
def check_punica_wrapper(punica_wrapper) -> bool:
    if current_platform.is_cuda_alike():
        from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU

        return type(punica_wrapper) is PunicaWrapperGPU
243
244
245
246
    elif current_platform.is_cpu():
        from vllm.lora.punica_wrapper.punica_cpu import PunicaWrapperCPU

        return type(punica_wrapper) is PunicaWrapperCPU
247
248
249
250
    else:
        return False


251
@torch.inference_mode()
252
@pytest.mark.parametrize("num_loras", [1, 2, 4])
253
@pytest.mark.parametrize("device", DEVICES)
254
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
255
@pytest.mark.parametrize("stage", STAGES)
256
257
258
def test_embeddings(
    default_vllm_config, dist_init, num_loras, device, vocab_size, stage
) -> None:
259
260
261
    # For multi-GPU testing of Triton kernel, we must explicitly set the CUDA
    # device, see: https://github.com/triton-lang/triton/issues/2925
    # Same below.
262
263
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)
264

265
    torch.set_default_device(device)
266
    max_loras = 8
267
268
269
    lora_config = LoRAConfig(
        max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16
    )
270
271
    punica_wrapper = get_punica_wrapper(8192, 256, device, lora_config=lora_config)
    assert check_punica_wrapper(punica_wrapper)
272
273

    def create_random_embedding_layer():
274
        embedding = VocabParallelEmbedding(vocab_size, 256)
275
        embedding.weight.data = torch.rand_like(embedding.weight.data)
276
        embedding.weight.data[vocab_size:, :] = 0
277
278
279
280
281
        lora_embedding = VocabParallelEmbeddingWithLoRA(embedding)
        lora_embedding.create_lora_weights(max_loras, lora_config)

        return embedding, lora_embedding

282
    for i in range(NUM_RANDOM_SEEDS):
283
284
285
286
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
        embedding, lora_embedding = create_random_embedding_layer()
287
        lora_embedding.set_mapping(punica_wrapper)
288
289
290
291
292
293
294
295
296
        lora_dict, _ = populate_loras(
            id_to_index,
            layer=lora_embedding,
            layer_weights=embedding.weight.T,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=num_loras * 3,
297
            input_size=(200,),
298
            input_range=(1, vocab_size),
299
300
301
302
303
304
305
306
307
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
        punica_wrapper.update_metadata(
            lora_mapping,
            id_to_index,
            max_loras,
            vocab_size,
        )
308
309
310

        lora_result = lora_embedding(torch.cat(inputs))

311
        expected_results: list[torch.Tensor] = []
312
313
314
315
316
        for input_, lora_id in zip(inputs, prompt_mapping):
            lora = lora_dict[lora_id]
            result = embedding(input_)
            after_a = F.embedding(
                input_,
317
                lora.lora_a.T,
318
            )
319
            result += after_a @ lora.lora_b.T
320
321
322
323
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
324
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
325
326
327
328
329
330
331
332
333

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
            lora_embedding.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=num_loras * 3,
334
            input_size=(200,),
335
            input_range=(1, vocab_size),
336
337
338
339
340
341
342
343
344
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
        punica_wrapper.update_metadata(
            lora_mapping,
            id_to_index,
            max_loras,
            vocab_size,
        )
345
346
347
348
349

        lora_result = lora_embedding(torch.cat(inputs))
        expected_result = embedding(torch.cat(inputs))

        rtol, atol = TOLERANCES[lora_result.dtype]
350
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
351
352
353


@torch.inference_mode()
354
@pytest.mark.parametrize("num_loras", [1, 2, 4])
355
@pytest.mark.parametrize("device", DEVICES)
356
@pytest.mark.parametrize("vocab_size", [64000, 256512, 258048])
357
@pytest.mark.parametrize("stage", STAGES)
358
def test_lm_head_logits_processor(
359
    default_vllm_config, dist_init, num_loras, device, vocab_size, stage
360
) -> None:
361
362
363
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)

364
    torch.set_default_device(device)
365
    max_loras = 8
366
367
368
    lora_config = LoRAConfig(
        max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16
    )
369
370
    punica_wrapper = get_punica_wrapper(8192, 256, device, lora_config=lora_config)
    assert check_punica_wrapper(punica_wrapper)
371

372
    def _pretest():
373
        linear = ParallelLMHead(
374
375
            num_embeddings=vocab_size,
            embedding_dim=1024,
376
377
            params_dtype=torch.float16,
        )
378
        linear.weight.data = torch.rand_like(linear.weight.data)
379
        linear.weight.data[:, vocab_size:] = 0
380
        logits_processor = LogitsProcessor(vocab_size)
381
        lora_logits_processor = LogitsProcessorWithLoRA(
382
383
            logits_processor, 1024, linear.weight.dtype, linear.weight.device, None
        )
384
        lora_logits_processor.create_lora_weights(max_loras, lora_config)
385

386
        return linear, logits_processor, lora_logits_processor
387

388
    for i in range(NUM_RANDOM_SEEDS):
389
390
391
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
392
        linear, logits_processor, lora_logits_processor = _pretest()
393
        lora_logits_processor.set_mapping(punica_wrapper)
394

395
396
        lora_dict, _ = populate_loras(
            id_to_index,
397
            layer=lora_logits_processor,
398
399
400
401
402
403
404
405
            layer_weights=linear.weight,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=8 * num_loras,  # * 3,
            input_size=(1, 1024),
            input_range=(0, 1),
406
            input_type=torch.float16,
407
408
409
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
410
        punica_wrapper.update_metadata(
411
412
413
            lora_mapping,
            id_to_index,
            max_loras,
414
            vocab_size,
415
        )
416
        input_ = torch.rand(20, 1024)
417

418
        lora_result = lora_logits_processor._get_logits(
419
420
            hidden_states=torch.cat(inputs), lm_head=linear, embedding_bias=None
        )
421

422
        original_lm_head = deepcopy(linear)
423

424
        expected_results: list[torch.Tensor] = []
425
426
        for input_, lora_id in zip(inputs, prompt_mapping):
            lora = lora_dict[lora_id]
427
428
429
            result = logits_processor._get_logits(
                hidden_states=input_, lm_head=linear, embedding_bias=None
            )
430

431
            result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
432
433
434
435
436
437
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
438
            lora_logits_processor.reset_lora(slot_idx)
439
440
441
442
443
444

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=8 * num_loras * 3,
            input_size=(1, 1024),
            input_range=(0, 1),
445
            input_type=torch.float16,
446
447
448
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
449
450
451
452
453
454
        punica_wrapper.update_metadata(
            lora_mapping,
            id_to_index,
            max_loras,
            vocab_size,
        )
455
456
457

        lora_result = lora_logits_processor._get_logits(
            hidden_states=torch.cat(inputs),
458
            lm_head=original_lm_head,
459
460
            embedding_bias=None,
        )[:, :vocab_size]
461
462
        expected_result = logits_processor._get_logits(
            hidden_states=torch.cat(inputs),
463
            lm_head=original_lm_head,
464
465
            embedding_bias=None,
        )
466
467

        rtol, atol = TOLERANCES[lora_result.dtype]
468
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
469
470


471
@torch.inference_mode()
472
@pytest.mark.parametrize("vocab_size", [258049, 300000])
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
@pytest.mark.parametrize("device", DEVICES)
def test_lm_head_logits_processor_invalid_vocab_size(
    default_vllm_config, dist_init, vocab_size, device
) -> None:
    """Test that LogitsProcessorWithLoRA raises ValueError for invalid vocab sizes."""
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)

    torch.set_default_device(device)
    max_loras = 8
    lora_config = LoRAConfig(
        max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16
    )

    logits_processor = LogitsProcessor(vocab_size)
    lora_logits_processor = LogitsProcessorWithLoRA(
        logits_processor, 1024, torch.float16, device, None
    )

492
    with pytest.raises(ValueError, match="vocab size must be <= 258048"):
493
494
495
        lora_logits_processor.create_lora_weights(max_loras, lora_config)


496
@torch.inference_mode()
497
@pytest.mark.parametrize("num_loras", [1, 2, 4])
498
@pytest.mark.parametrize("device", DEVICES)
499
@pytest.mark.parametrize("stage", STAGES)
500
def test_linear_replicated(
501
    default_vllm_config,
502
503
504
505
506
    dist_init,
    num_loras,
    device,
    stage,
) -> None:
507
508
509
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)

510
    max_loras = 8
511
    torch.set_default_device(device)
512
513
514
515
516
    lora_config = LoRAConfig(
        max_loras=max_loras,
        max_lora_rank=8,
        lora_dtype=torch.float16,
    )
517
518
    punica_wrapper = get_punica_wrapper(8192, 256, device, lora_config=lora_config)
    assert check_punica_wrapper(punica_wrapper)
519
520

    def create_random_linear_replicated_layer():
521
        linear = ReplicatedLinear(4096, 4096, bias=False, params_dtype=torch.float16)
522
523
524
525
        linear.weight.data = torch.rand_like(linear.weight.data)
        lora_linear = ReplicatedLinearWithLoRA(linear)

        lora_linear.create_lora_weights(max_loras, lora_config)
526
527
528
529
530
531
        assert (
            lora_linear.n_slices
            == len(lora_linear.lora_a_stacked)
            == len(lora_linear.lora_b_stacked)
            == 1
        )
532
533
        return linear, lora_linear

534
    for i in range(NUM_RANDOM_SEEDS):
535
536
537
538
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
        linear, lora_linear = create_random_linear_replicated_layer()
539
        assert torch.equal(linear.weight, lora_linear.weight)
540
541
542
543
544
545
546
547
548
549
550
551
552
        lora_linear.set_mapping(punica_wrapper)
        lora_dict, _ = populate_loras(
            id_to_index,
            layer=lora_linear,
            layer_weights=linear.weight,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
            input_type=torch.float16,
553
554
555
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
556
557
558
559
560
561
562
563
564
        punica_wrapper.update_metadata(
            lora_mapping,
            id_to_index,
            max_loras,
            512,
        )

        lora_result = lora_linear(torch.cat(inputs))[0]

565
        expected_results: list[torch.Tensor] = []
566
567
568
        for input_, lora_id in zip(inputs, prompt_mapping):
            lora = lora_dict[lora_id]
            result = linear(input_)[0]
569
            result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
570
571
572
573
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
574
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
575
576
577
578
579
580
581
582
583
584
585
586

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
            lora_linear.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
            input_type=torch.float16,
587
588
589
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
590

591
        punica_wrapper.update_metadata(
592
593
594
595
            lora_mapping,
            id_to_index,
            max_loras,
            512,
596
        )
597
598
599
600
601

        lora_result = lora_linear(torch.cat(inputs))[0]
        expected_result = linear(torch.cat(inputs))[0]

        rtol, atol = TOLERANCES[lora_result.dtype]
602
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
603
604


605
@torch.inference_mode()
606
@pytest.mark.parametrize("num_loras", [1, 2, 4])
607
@pytest.mark.parametrize("orientation", ["row", "column"])
608
@pytest.mark.parametrize("fully_shard", [True, False])
609
@pytest.mark.parametrize("device", DEVICES)
610
@pytest.mark.parametrize("stage", STAGES)
611
def test_linear_parallel(
612
    default_vllm_config, dist_init, num_loras, orientation, fully_shard, device, stage
613
) -> None:
614
615
616
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)

617
    max_loras = 8
618
    torch.set_default_device(device)
619
620
621
622
623
624
    lora_config = LoRAConfig(
        max_loras=max_loras,
        max_lora_rank=8,
        fully_sharded_loras=fully_shard,
        lora_dtype=torch.float16,
    )
625
626
    punica_wrapper = get_punica_wrapper(8192, 256, device, lora_config=lora_config)
    assert check_punica_wrapper(punica_wrapper)
627
628
629

    def create_random_linear_parallel_layer():
        if orientation == "row":
630
631
632
            linear = RowParallelLinear(
                4096, 4096, bias=False, params_dtype=torch.float16
            )
633
            linear.weight.data = torch.rand_like(linear.weight.data)
634
635
636
637
638
            lora_linear = (
                RowParallelLinearWithLoRA(linear)
                if not fully_shard
                else RowParallelLinearWithShardedLoRA(linear)
            )
639
        else:
640
641
642
            linear = ColumnParallelLinear(
                4096, 4096, bias=False, params_dtype=torch.float16
            )
643
            linear.weight.data = torch.rand_like(linear.weight.data)
644
645
646
647
648
            lora_linear = (
                ColumnParallelLinearWithLoRA(linear)
                if not fully_shard
                else ColumnParallelLinearWithShardedLoRA(linear)
            )
649
        lora_linear.create_lora_weights(max_loras, lora_config)
650
651
652
653
654
655
        assert (
            lora_linear.n_slices
            == len(lora_linear.lora_a_stacked)
            == len(lora_linear.lora_b_stacked)
            == 1
        )
656

657
658
        return linear, lora_linear

659
    for i in range(NUM_RANDOM_SEEDS):
660
661
662
663
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
        linear, lora_linear = create_random_linear_parallel_layer()
664
        assert torch.equal(linear.weight, lora_linear.weight)
665
        lora_linear.set_mapping(punica_wrapper)
666
667
668
669
670
671
672
673
674
675
676
        lora_dict, _ = populate_loras(
            id_to_index,
            layer=lora_linear,
            layer_weights=linear.weight,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
677
            input_type=torch.float16,
678
679
680
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
681
        punica_wrapper.update_metadata(
682
683
684
685
686
687
688
689
            lora_mapping,
            id_to_index,
            max_loras,
            512,
        )

        lora_result = lora_linear(torch.cat(inputs))[0]

690
        expected_results: list[torch.Tensor] = []
691
692
693
        for input_, lora_id in zip(inputs, prompt_mapping):
            lora = lora_dict[lora_id]
            result = linear(input_)[0]
694
            result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
695
696
697
698
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
699
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
700
701
702
703
704
705
706
707
708
709
710

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
            lora_linear.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
711
            input_type=torch.float16,
712
713
714
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
715

716
        punica_wrapper.update_metadata(
717
718
719
720
            lora_mapping,
            id_to_index,
            max_loras,
            512,
721
        )
722
723
724
725
726

        lora_result = lora_linear(torch.cat(inputs))[0]
        expected_result = linear(torch.cat(inputs))[0]

        rtol, atol = TOLERANCES[lora_result.dtype]
727
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
728
729
730


@torch.inference_mode()
731
@pytest.mark.parametrize("num_loras", [1, 2, 4])
732
@pytest.mark.parametrize("repeats", [1, 2, 3])
733
@pytest.mark.parametrize("fully_shard", [True, False])
734
@pytest.mark.parametrize("device", DEVICES)
735
@pytest.mark.parametrize("stage", STAGES)
736
def test_column_parallel_packed(
737
    default_vllm_config, dist_init, num_loras, repeats, fully_shard, device, stage
738
) -> None:
739
740
741
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)

742
    max_loras = 8
743
    torch.set_default_device(device)
744
745
746
747
748
749
    lora_config = LoRAConfig(
        max_loras=max_loras,
        max_lora_rank=8,
        fully_sharded_loras=fully_shard,
        lora_dtype=torch.float16,
    )
750
751
    punica_wrapper = get_punica_wrapper(8192, 256, device, lora_config=lora_config)
    assert check_punica_wrapper(punica_wrapper)
752
753
754

    def create_column_parallel_packed_layer():
        if repeats == 2:
755
756
757
            linear = MergedColumnParallelLinear(
                4096, [4096] * repeats, bias=False, params_dtype=torch.float16
            )
758
            linear.weight.data = torch.rand_like(linear.weight.data)
759
760
761
762
763
            lora_linear = (
                MergedColumnParallelLinearWithLoRA(linear)
                if not fully_shard
                else MergedColumnParallelLinearWithShardedLoRA(linear)
            )
764
        elif repeats == 3:
765
766
767
            linear = QKVParallelLinear(
                4096, 64, 32, bias=False, params_dtype=torch.float16
            )
768
            linear.weight.data = torch.rand_like(linear.weight.data)
769
770
771
772
773
            lora_linear = (
                MergedQKVParallelLinearWithLoRA(linear)
                if not fully_shard
                else MergedQKVParallelLinearWithShardedLoRA(linear)
            )
774
        else:
775
776
777
            linear = QKVParallelLinear(
                4096, 64, 32, bias=False, params_dtype=torch.float16
            )
778
            linear.weight.data = torch.rand_like(linear.weight.data)
779
780
781
782
783
            lora_linear = (
                QKVParallelLinearWithLoRA(linear)
                if not fully_shard
                else QKVParallelLinearWithShardedLoRA(linear)
            )
784
785
786
787
788
789
790

        @dataclass
        class FakeConfig:
            hidden_size = 4096
            num_key_value_heads = 32
            num_attention_heads = 32

791
        n_slices = repeats
792
793
794
795
796
797
798
799
800
        lora_linear.create_lora_weights(
            max_loras, lora_config, model_config=FakeConfig()
        )
        assert (
            lora_linear.n_slices
            == len(lora_linear.lora_a_stacked)
            == len(lora_linear.lora_b_stacked)
            == n_slices
        )
801

802
803
        return linear, lora_linear

804
    for i in range(NUM_RANDOM_SEEDS):
805
806
807
808
809
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)

        linear, lora_linear = create_column_parallel_packed_layer()
810
        assert torch.equal(linear.weight, lora_linear.weight)
811
        lora_linear.set_mapping(punica_wrapper)
812
813
814
815
816
817
818
819
820
821
822
823
        lora_dict, sublora_dict = populate_loras(
            id_to_index,
            layer=lora_linear,
            layer_weights=linear.weight,
            repeats=repeats,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
824
            input_type=torch.float16,
825
826
827
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
828

829
        punica_wrapper.update_metadata(
830
831
832
833
834
835
836
837
            lora_mapping,
            id_to_index,
            max_loras,
            512,
        )

        lora_result = lora_linear(torch.cat(inputs))[0]

838
        expected_results: list[torch.Tensor] = []
839
840
841
842
        for input_, lora_id in zip(inputs, prompt_mapping):
            result = linear(input_)[0]
            subloras = sublora_dict[lora_id]
            for i, sublora in enumerate(subloras):
843
844
845
                result[
                    :, sublora.lora_b.shape[0] * i : sublora.lora_b.shape[0] * (i + 1)
                ] += input_ @ sublora.lora_a.T @ sublora.lora_b.T * sublora.scaling
846
847
848
849
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
850
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
851
852
853
854
855
856
857
858
859

        for slot_idx in range(max_loras):
            lora_linear.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
860
            input_type=torch.float16,
861
862
863
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
864

865
        punica_wrapper.update_metadata(
866
867
868
869
870
871
872
873
874
875
            lora_mapping,
            id_to_index,
            max_loras,
            512,
        )

        lora_result = lora_linear(torch.cat(inputs))[0]
        expected_result = linear(torch.cat(inputs))[0]

        rtol, atol = TOLERANCES[lora_result.dtype]
876
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
877
878


879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4])
@pytest.mark.parametrize("num_slices", [3, 5])
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("stage", STAGES)
def test_merged_column_parallel_variable_slice(
    default_vllm_config, dist_init, num_loras, num_slices, device, stage
) -> None:
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)

    max_loras = 8
    torch.set_default_device(device)
    lora_config = LoRAConfig(
        max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16
    )
    punica_wrapper = get_punica_wrapper(8192, 256, device, lora_config=lora_config)

    # Set number of output slices
    output_sizes = [1024 + i * 256 for i in range(num_slices)]
    total_output = sum(output_sizes)

    def create_layer():
        # Create linear layer
        linear = MergedColumnParallelLinear(
            4096, output_sizes, bias=False, params_dtype=torch.float16
        )
        linear.weight.data = torch.rand_like(linear.weight.data)

        # Create linear layer with LoRA adapter
        lora_linear = MergedColumnParallelLinearVariableSliceWithLoRA(linear)
        lora_linear.create_lora_weights(max_loras, lora_config)
        return linear, lora_linear

    for i in range(NUM_RANDOM_SEEDS):
        set_random_seed(i)
        id_to_index = get_random_id_to_index(num_loras, max_loras)
        linear, lora_linear = create_layer()
        lora_linear.set_mapping(punica_wrapper)

        # Populate LoRA weights
        lora_dict, sublora_dict = {}, {}
        for slot_idx, lora_id in enumerate(id_to_index):
            if lora_id is not None:
                # Create random LoRA weights
                lora_a = torch.rand(8, 4096, dtype=torch.float16, device=device)
                lora_b = torch.rand(total_output, 8, dtype=torch.float16, device=device)
                lora_linear.set_lora(slot_idx, lora_a, lora_b)
                lora_dict[lora_id] = (lora_a, lora_b)

                # Split lora_b for expected computation
                sublora_dict[lora_id] = torch.split(lora_b, output_sizes, dim=0)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
            input_type=torch.float16,
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
        punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, 512)

        # Compute LoRA result
        lora_result = lora_linear(torch.cat(inputs))[0]

        # Compute expected result
        expected_results = []
        for input_, lora_id in zip(inputs, prompt_mapping):
            result = linear(input_)[0]
            lora_a, _ = lora_dict[lora_id]
            offset = 0
            # Compute expected result for each sublora
            for lora_b_slice in sublora_dict[lora_id]:
                sz = lora_b_slice.shape[0]
                result[:, offset : offset + sz] += input_ @ lora_a.T @ lora_b_slice.T
                offset += sz
            expected_results.append(result)

        # Check that the LoRA result is close to the expected result
        rtol, atol = TOLERANCES[lora_result.dtype]
        torch.testing.assert_close(
            lora_result, torch.cat(expected_results), rtol=rtol, atol=atol
        )

        # Reset LoRA weights and check results with zero LoRA weights
        for slot_idx in range(max_loras):
            lora_linear.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
            input_type=torch.float16,
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
        punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, 512)

        # After resetting LoRA weights,
        # lora_linear should behave like the base linear layer
        lora_result = lora_linear(torch.cat(inputs))[0]
        expected_result = linear(torch.cat(inputs))[0]

        rtol, atol = TOLERANCES[lora_result.dtype]
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)


989
@pytest.mark.parametrize("tp_size", [1, 2, 4, 8])
990
@pytest.mark.parametrize(
991
992
    "seed", list(range(VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS))
)
993
def test_vocab_parallel_embedding_indices(tp_size, seed, default_vllm_config):
994
995
996
997
998
999
1000
1001
1002
1003
1004
    random.seed(seed)
    vocab_size = random.randint(4000, 64000)
    added_vocab_size = random.randint(0, 1024)
    org_vocab_size = vocab_size - added_vocab_size
    last_org_vocab_end_index = 0
    last_added_vocab_end_index = org_vocab_size
    computed_vocab_size = 0
    computed_org_vocab_size = 0
    computed_added_vocab_size = 0
    vocab_size_padded = -1

1005
1006
1007
    all_org_tokens: list[int] = []
    all_added_tokens: list[int] = []
    token_ids: list[int] = []
1008
1009

    for tp_rank in range(tp_size):
1010
1011
        with (
            patch(
1012
                "vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank",
1013
1014
1015
                return_value=tp_rank,
            ),
            patch(
1016
                "vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_world_size",
1017
1018
1019
                return_value=tp_size,
            ),
        ):
1020
            vocab_embedding = VocabParallelEmbedding(
1021
1022
                vocab_size, 1, org_num_embeddings=org_vocab_size
            )
1023
1024
1025
1026
        vocab_size_padded = vocab_embedding.num_embeddings_padded
        shard_indices = vocab_embedding.shard_indices
        # Assert that the ranges are contiguous
        assert shard_indices.org_vocab_start_index == last_org_vocab_end_index
1027
        assert shard_indices.added_vocab_start_index == last_added_vocab_end_index
1028
1029
1030
1031
1032
1033
1034
1035

        # Ensure that we are not exceeding the vocab size
        computed_vocab_size += shard_indices.num_elements_padded
        computed_org_vocab_size += shard_indices.num_org_elements
        computed_added_vocab_size += shard_indices.num_added_elements

        # Ensure that the ranges are not overlapping
        all_org_tokens.extend(
1036
1037
1038
1039
            range(
                shard_indices.org_vocab_start_index, shard_indices.org_vocab_end_index
            )
        )
1040
        all_added_tokens.extend(
1041
1042
1043
1044
1045
            range(
                shard_indices.added_vocab_start_index,
                shard_indices.added_vocab_end_index,
            )
        )
1046
1047

        token_ids.extend(
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
            range(
                shard_indices.org_vocab_start_index, shard_indices.org_vocab_end_index
            )
        )
        token_ids.extend(
            [-1]
            * (shard_indices.num_org_elements_padded - shard_indices.num_org_elements)
        )
        token_ids.extend(
            range(
                shard_indices.added_vocab_start_index,
                shard_indices.added_vocab_end_index,
            )
        )
1062
        token_ids.extend(
1063
1064
1065
1066
1067
1068
            [-1]
            * (
                shard_indices.num_added_elements_padded
                - shard_indices.num_added_elements
            )
        )
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095

        last_org_vocab_end_index = shard_indices.org_vocab_end_index
        last_added_vocab_end_index = shard_indices.added_vocab_end_index

    assert computed_vocab_size == vocab_size_padded
    assert computed_org_vocab_size == org_vocab_size
    assert computed_added_vocab_size == added_vocab_size

    # Ensure that the ranges are not overlapping
    assert len(all_org_tokens) == len(set(all_org_tokens))
    assert len(all_added_tokens) == len(set(all_added_tokens))
    assert not set(all_org_tokens).intersection(set(all_added_tokens))

    token_ids_tensor = torch.tensor(token_ids, dtype=torch.long)
    reindex_mapping = vocab_embedding.get_sharded_to_full_mapping()
    assert reindex_mapping is not None or tp_size == 1
    if reindex_mapping is not None:
        reindexed_token_ids = token_ids_tensor[reindex_mapping]
        expected = torch.tensor(list(range(0, vocab_size)))
        assert reindexed_token_ids[:vocab_size].equal(expected)
        assert torch.all(reindexed_token_ids[vocab_size:] == -1)


def test_get_masked_input_and_mask():
    x = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])

    # base tp 1 case, no padding
1096
1097
1098
1099
1100
1101
1102
1103
    modified_x, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=0,
        org_vocab_end_index=8,
        added_vocab_start_index=8,
        added_vocab_end_index=12,
        num_org_vocab_padding=0,
    )
1104
1105
1106
    assert torch.equal(x, modified_x)

    # tp 2 case, no padding
1107
1108
1109
1110
1111
1112
1113
1114
    modified_x_rank_0, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=0,
        org_vocab_end_index=4,
        added_vocab_start_index=8,
        added_vocab_end_index=10,
        num_org_vocab_padding=0,
    )
1115
1116
1117
1118
1119
1120
    modified_x_rank_1, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=4,
        org_vocab_end_index=8,
        added_vocab_start_index=10,
        added_vocab_end_index=12,
1121
1122
1123
1124
1125
1126
1127
1128
        num_org_vocab_padding=0,
    )
    assert torch.equal(
        modified_x_rank_0, torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 0, 0])
    )
    assert torch.equal(
        modified_x_rank_1, torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 4, 5])
    )
1129
1130

    # tp 4 case, no padding
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
    modified_x_rank_0, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=0,
        org_vocab_end_index=2,
        added_vocab_start_index=8,
        added_vocab_end_index=9,
        num_org_vocab_padding=0,
    )
    modified_x_rank_1, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=2,
        org_vocab_end_index=4,
        added_vocab_start_index=9,
        added_vocab_end_index=10,
        num_org_vocab_padding=0,
    )
1147
1148
1149
1150
1151
1152
    modified_x_rank_2, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=4,
        org_vocab_end_index=6,
        added_vocab_start_index=10,
        added_vocab_end_index=11,
1153
1154
        num_org_vocab_padding=0,
    )
1155
1156
1157
1158
1159
1160
    modified_x_rank_3, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=6,
        org_vocab_end_index=8,
        added_vocab_start_index=11,
        added_vocab_end_index=12,
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
        num_org_vocab_padding=0,
    )
    assert torch.equal(
        modified_x_rank_0, torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0])
    )
    assert torch.equal(
        modified_x_rank_1, torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0])
    )
    assert torch.equal(
        modified_x_rank_2, torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0])
    )
    assert torch.equal(
        modified_x_rank_3, torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2])
    )
1175
1176

    # base tp 1 case, with padding
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
    modified_x, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=0,
        org_vocab_end_index=8,
        added_vocab_start_index=8,
        added_vocab_end_index=12,
        num_org_vocab_padding=2,
    )
    assert torch.equal(
        modified_x, torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13])
    )
1188
1189

    # tp 2 case, with padding
1190
1191
1192
1193
1194
1195
1196
1197
    modified_x_rank_0, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=0,
        org_vocab_end_index=4,
        added_vocab_start_index=8,
        added_vocab_end_index=10,
        num_org_vocab_padding=2,
    )
1198
1199
1200
1201
1202
1203
    modified_x_rank_1, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=4,
        org_vocab_end_index=8,
        added_vocab_start_index=10,
        added_vocab_end_index=12,
1204
1205
1206
1207
1208
1209
1210
1211
        num_org_vocab_padding=2,
    )
    assert torch.equal(
        modified_x_rank_0, torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 6, 7, 0, 0])
    )
    assert torch.equal(
        modified_x_rank_1, torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 6, 7])
    )
1212
1213

    # tp 4 case, with padding
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
    modified_x_rank_0, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=0,
        org_vocab_end_index=2,
        added_vocab_start_index=8,
        added_vocab_end_index=9,
        num_org_vocab_padding=2,
    )
    modified_x_rank_1, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=2,
        org_vocab_end_index=4,
        added_vocab_start_index=9,
        added_vocab_end_index=10,
        num_org_vocab_padding=2,
    )
1230
1231
1232
1233
1234
1235
    modified_x_rank_2, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=4,
        org_vocab_end_index=6,
        added_vocab_start_index=10,
        added_vocab_end_index=11,
1236
1237
        num_org_vocab_padding=2,
    )
1238
1239
1240
1241
1242
1243
    modified_x_rank_3, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=6,
        org_vocab_end_index=8,
        added_vocab_start_index=11,
        added_vocab_end_index=12,
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
        num_org_vocab_padding=2,
    )
    assert torch.equal(
        modified_x_rank_0, torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0])
    )
    assert torch.equal(
        modified_x_rank_1, torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 4, 0, 0])
    )
    assert torch.equal(
        modified_x_rank_2, torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 4, 0])
    )
    assert torch.equal(
        modified_x_rank_3, torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 4])
    )
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443


def test_variable_slice_lora_class_selection(default_vllm_config, dist_init):
    """Test that MergedColumnParallelLinearVariableSliceWithLoRA is selected
    only for nemotron-h style models (checkpoint has single weight but layer
    has 3+ output slices).

    This verifies that from_layer selects
    MergedColumnParallelLinearVariableSliceWithLoRA
    before ColumnParallelLinearWithLoRA for layers with 3+ output sizes, since
    ColumnParallelLinearWithLoRA's slice_lora_b assumes exactly 2 slices.
    """
    from vllm.lora.utils import from_layer

    lora_config = LoRAConfig(max_loras=8, max_lora_rank=8, lora_dtype=torch.float16)

    # Case 1: MergedColumnParallelLinear with 3+ output sizes and
    # packed_modules_list with 1 item (nemotron-h style)
    # -> MergedColumnParallelLinearVariableSliceWithLoRA should be selected
    layer_3_slices = MergedColumnParallelLinear(
        4096, [1024, 1280, 1536], bias=False, params_dtype=torch.float16
    )
    packed_modules_single = ["mlp"]

    assert MergedColumnParallelLinearVariableSliceWithLoRA.can_replace_layer(
        source_layer=layer_3_slices,
        lora_config=lora_config,
        packed_modules_list=packed_modules_single,
    ), "MergedColumnParallelLinearVariableSliceWithLoRA should handle 3+ slices"

    # ColumnParallelLinearWithLoRA should NOT match 3+ slices
    # (its slice_lora_b assumes exactly 2 slices)
    assert not ColumnParallelLinearWithLoRA.can_replace_layer(
        source_layer=layer_3_slices,
        lora_config=lora_config,
        packed_modules_list=packed_modules_single,
    ), (
        "ColumnParallelLinearWithLoRA should NOT handle 3+ slices "
        "(slice_lora_b assumes 2 slices)"
    )

    # Verify from_layer selects the correct class (Variable, not base)
    selected_layer = from_layer(
        layer_3_slices,
        max_loras=8,
        lora_config=lora_config,
        packed_modules_list=packed_modules_single,
    )
    assert isinstance(
        selected_layer, MergedColumnParallelLinearVariableSliceWithLoRA
    ), (
        f"from_layer should select MergedColumnParallelLinearVariableSliceWithLoRA "
        f"for 3+ slices, got {type(selected_layer).__name__}"
    )

    # Case 2: MergedColumnParallelLinear with 2 output sizes and
    # packed_modules_list with 1 item (standard gate_up style)
    # -> ColumnParallelLinearWithLoRA should be selected
    # -> MergedColumnParallelLinearVariableSliceWithLoRA should NOT match
    layer_2_slices = MergedColumnParallelLinear(
        4096, [2048, 2048], bias=False, params_dtype=torch.float16
    )

    assert ColumnParallelLinearWithLoRA.can_replace_layer(
        source_layer=layer_2_slices,
        lora_config=lora_config,
        packed_modules_list=packed_modules_single,
    ), "ColumnParallelLinearWithLoRA should handle 2 slices"

    assert not MergedColumnParallelLinearVariableSliceWithLoRA.can_replace_layer(
        source_layer=layer_2_slices,
        lora_config=lora_config,
        packed_modules_list=packed_modules_single,
    ), "MergedColumnParallelLinearVariableSliceWithLoRA should NOT handle 2 slices"

    # Verify from_layer selects ColumnParallelLinearWithLoRA for 2 slices
    selected_layer_2 = from_layer(
        layer_2_slices,
        max_loras=8,
        lora_config=lora_config,
        packed_modules_list=packed_modules_single,
    )
    assert isinstance(selected_layer_2, ColumnParallelLinearWithLoRA), (
        f"from_layer should select ColumnParallelLinearWithLoRA "
        f"for 2 slices, got {type(selected_layer_2).__name__}"
    )
    # But NOT the Variable subclass
    assert not isinstance(
        selected_layer_2, MergedColumnParallelLinearVariableSliceWithLoRA
    ), (
        "from_layer should NOT select "
        "MergedColumnParallelLinearVariableSliceWithLoRA for 2 slices"
    )

    # Case 3: MergedColumnParallelLinear with 3+ items in packed_modules_list
    # -> MergedColumnParallelLinearVariableSliceWithLoRA should be selected
    packed_modules_three = ["gate_proj", "up_proj", "down_proj"]

    assert MergedColumnParallelLinearVariableSliceWithLoRA.can_replace_layer(
        source_layer=layer_3_slices,
        lora_config=lora_config,
        packed_modules_list=packed_modules_three,
    ), "MergedColumnParallelLinearVariableSliceWithLoRA should handle 3+ packed modules"

    # Case 4: MergedColumnParallelLinear with 2 items in packed_modules_list
    # -> MergedColumnParallelLinearWithLoRA should handle this (not Variable)
    packed_modules_two = ["gate_proj", "up_proj"]

    assert not MergedColumnParallelLinearVariableSliceWithLoRA.can_replace_layer(
        source_layer=layer_2_slices,
        lora_config=lora_config,
        packed_modules_list=packed_modules_two,
    ), (
        "MergedColumnParallelLinearVariableSliceWithLoRA"
        " should NOT handle 2 packed modules"
    )

    assert MergedColumnParallelLinearWithLoRA.can_replace_layer(
        source_layer=layer_2_slices,
        lora_config=lora_config,
        packed_modules_list=packed_modules_two,
    ), "MergedColumnParallelLinearWithLoRA should handle 2 packed modules"

    # Verify from_layer selects MergedColumnParallelLinearWithLoRA for 2 packed modules
    selected_layer_merged = from_layer(
        layer_2_slices,
        max_loras=8,
        lora_config=lora_config,
        packed_modules_list=packed_modules_two,
    )
    assert isinstance(selected_layer_merged, MergedColumnParallelLinearWithLoRA), (
        f"from_layer should select MergedColumnParallelLinearWithLoRA "
        f"for 2 packed modules, got {type(selected_layer_merged).__name__}"
    )

    # Case 5: Plain ColumnParallelLinear (not merged) - common in many models
    # -> ColumnParallelLinearWithLoRA should be selected
    plain_column_parallel = ColumnParallelLinear(
        4096, 4096, bias=False, params_dtype=torch.float16
    )

    assert ColumnParallelLinearWithLoRA.can_replace_layer(
        source_layer=plain_column_parallel,
        lora_config=lora_config,
        packed_modules_list=packed_modules_single,
    ), "ColumnParallelLinearWithLoRA should handle plain ColumnParallelLinear"

    assert not MergedColumnParallelLinearVariableSliceWithLoRA.can_replace_layer(
        source_layer=plain_column_parallel,
        lora_config=lora_config,
        packed_modules_list=packed_modules_single,
    ), (
        "MergedColumnParallelLinearVariableSliceWithLoRA "
        "should NOT handle plain ColumnParallelLinear"
    )

    # Verify from_layer selects ColumnParallelLinearWithLoRA for plain layer
    selected_plain = from_layer(
        plain_column_parallel,
        max_loras=8,
        lora_config=lora_config,
        packed_modules_list=packed_modules_single,
    )
    assert isinstance(selected_plain, ColumnParallelLinearWithLoRA), (
        f"from_layer should select ColumnParallelLinearWithLoRA "
        f"for plain ColumnParallelLinear, got {type(selected_plain).__name__}"
    )

    # Case 6: MergedColumnParallelLinear with exactly 2 output sizes
    # and empty packed_modules_list
    # -> ColumnParallelLinearWithLoRA should NOT match (packed_modules_list != 1)
    # -> MergedColumnParallelLinearVariableSliceWithLoRA should NOT match (< 3 slices)
    assert not ColumnParallelLinearWithLoRA.can_replace_layer(
        source_layer=layer_2_slices,
        lora_config=lora_config,
        packed_modules_list=[],
    ), "ColumnParallelLinearWithLoRA should NOT handle empty packed_modules_list"

    assert not MergedColumnParallelLinearVariableSliceWithLoRA.can_replace_layer(
        source_layer=layer_2_slices,
        lora_config=lora_config,
        packed_modules_list=[],
    ), (
        "MergedColumnParallelLinearVariableSliceWithLoRA "
        "should NOT handle 2 slices even with empty packed_modules_list"
    )