"vllm/vscode:/vscode.git/clone" did not exist on "af6e19f50f1d5d0c3801948c3ab17b2af231c259"
test_layers.py 43.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
import random
from copy import deepcopy
from dataclasses import dataclass
7
from unittest.mock import patch
8

9
import pytest
10
11
12
import torch
import torch.nn.functional as F

13
from vllm.config.lora import LoRAConfig
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from vllm.lora.layers import (
    BaseLayerWithLoRA,
    ColumnParallelLinearWithLoRA,
    ColumnParallelLinearWithShardedLoRA,
    LogitsProcessorWithLoRA,
    LoRAMapping,
    MergedColumnParallelLinearWithLoRA,
    MergedColumnParallelLinearWithShardedLoRA,
    MergedQKVParallelLinearWithLoRA,
    MergedQKVParallelLinearWithShardedLoRA,
    QKVParallelLinearWithLoRA,
    QKVParallelLinearWithShardedLoRA,
    ReplicatedLinearWithLoRA,
    RowParallelLinearWithLoRA,
    RowParallelLinearWithShardedLoRA,
    VocabParallelEmbeddingWithLoRA,
)
31
from vllm.lora.models import LoRALayerWeights, PackedLoRALayerWeights
32
from vllm.lora.punica_wrapper import get_punica_wrapper
33
34
35
36
37
38
39
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    MergedColumnParallelLinear,
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
40
from vllm.model_executor.layers.logits_processor import LogitsProcessor
41
from vllm.model_executor.layers.vocab_parallel_embedding import (
42
43
44
45
    ParallelLMHead,
    VocabParallelEmbedding,
    get_masked_input_and_mask,
)
46
from vllm.model_executor.utils import set_random_seed
47
from vllm.platforms import current_platform
48
49
50
51
52
53
54
55

from .utils import DummyLoRAManager

TOLERANCES = {
    torch.float16: (5e-3, 5e-3),
    torch.float32: (5e-3, 5e-3),
    torch.bfloat16: (3e-2, 2e-2),
}
56
57
58

pytestmark = pytest.mark.skipif(
    not (current_platform.is_cuda_alike() or current_platform.is_cpu()),
59
60
    reason="Backend not supported",
)
61

62
63
64
65
66
DEVICES = (
    [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
    if current_platform.is_cuda_alike()
    else ["cpu"]
)
67

68
# prefill stage(True) or decode stage(False)
69
STAGES = [True, False]
70

71
NUM_RANDOM_SEEDS = 2
72

73
VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS = 2
74
75
76


@pytest.fixture(autouse=True)
77
def clean_cache_reset_device(reset_default_device):
78
    # Release any memory we might be holding on to. CI runs OOMs otherwise.
79
80
    from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT

81
82
83
84
85
    _LORA_B_PTR_DICT.clear()
    _LORA_A_PTR_DICT.clear()

    yield

86

87
88
89
@pytest.fixture(autouse=True)
def skip_cuda_with_stage_false(request):
    """
90
    On cuda-like platforms, we use the same kernels for prefill and decode
91
92
93
94
95
    stage, and 'stage' is generally ignored, so we only need to test once.
    """
    if current_platform.is_cuda_alike():
        try:
            if hasattr(request.node, "callspec") and hasattr(
96
97
                request.node.callspec, "params"
            ):
98
99
100
101
102
103
104
105
                params = request.node.callspec.params
                if "stage" in params and params["stage"] is False:
                    pytest.skip("Skip test when stage=False")
        except Exception:
            pass
    yield


106
107
def get_random_id_to_index(
    num_loras: int, num_slots: int, log: bool = True
108
) -> list[int | None]:
109
110
111
112
113
114
115
116
117
118
119
120
    """Creates a random lora_id_to_index mapping.

    Args:
        num_loras: The number of active loras in the mapping.
        num_slots: The number of slots in the mapping. Must be larger
            than num_loras.
        log: Whether to log the output.
    """

    if num_loras > num_slots:
        raise ValueError(
            f"num_loras is higher than num_slots: {num_loras} > {num_slots}. "
121
122
            "num_loras must be less than or equal to num_slots."
        )
123

124
    slots: list[int | None] = [None] * num_slots
125
126
127
128
129
130
131
132
133
134
135
    random_slot_selections = (torch.randperm(num_slots)[:num_loras]).tolist()
    for lora_id, slot_idx in enumerate(random_slot_selections, start=1):
        slots[slot_idx] = lora_id

    if log:
        print(f"Created lora_id_to_index mapping: {slots}.")

    return slots


def populate_loras(
136
    id_to_index: list[int | None],
137
138
139
140
    layer: BaseLayerWithLoRA,
    layer_weights: torch.Tensor,
    generate_embeddings_tensor: int = 0,
    repeats: int = 1,
141
) -> tuple[dict[int, LoRALayerWeights], dict[int, list[LoRALayerWeights]]]:
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
    """This method populates the lora layers with lora weights.

    Args:
        id_to_index: a list of lora ids. The index of the lora id
            represents which memory slot the lora matrices are
            stored in. A None value indicates a free slot.
        layer: the LoRAlayer to populate.
        layer_weights: the PyTorch tensor containing the layer's
            weights.
        generate_embeddings_tensor: whether to generate an
            embeddings tensor for each LoRA.
        repeats: must only be set for column parallel packed
            layers. Indicates the number of loras to compose
            together to create a single lora layer.
    """

    # Dictionary that maps the lora ID to the
    # corresponding lora weights.
160
    lora_dict: dict[int, LoRALayerWeights] = dict()
161
162

    # Dictionary that maps the lora ID to the
163
    # corresponding subloras.
164
    sublora_dict: dict[int, list[LoRALayerWeights]] = dict()
165
166
167

    for slot_idx, lora_id in enumerate(id_to_index):
        if lora_id is not None:
168
            subloras: list[LoRALayerWeights] = []
169
170
            sublora_len = layer_weights.shape[0] // repeats
            for i in range(repeats):
171
172
173
174
175
176
177
178
                sublora = DummyLoRAManager(layer_weights.device).init_random_lora(
                    module_name=f"fake_{i}",
                    weight=layer_weights,
                    generate_embeddings_tensor=generate_embeddings_tensor,
                )
                sublora.lora_b = sublora.lora_b[
                    (sublora_len * i) : (sublora_len * (i + 1)), :
                ]
179
180
181
                sublora.optimize()
                subloras.append(sublora)

182
            lora = PackedLoRALayerWeights.pack(subloras) if repeats > 1 else subloras[0]
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197

            layer.set_lora(
                slot_idx,
                lora_a=lora.lora_a,
                lora_b=lora.lora_b,
                embeddings_tensor=lora.embeddings_tensor,
            )

            lora_dict[lora_id] = lora
            sublora_dict[lora_id] = subloras

    return lora_dict, sublora_dict


def create_random_inputs(
198
    active_lora_ids: list[int],
199
    num_inputs: int,
200
201
    input_size: tuple[int, ...],
    input_range: tuple[float, float],
202
    input_type: torch.dtype = torch.int,
203
    device: torch.device = "cuda",
204
) -> tuple[list[torch.Tensor], list[int], list[int]]:
205
206
207
208
209
210
211
212
213
214
215
216
217
    """Creates random inputs.

    Args:
        active_lora_ids: lora IDs of active lora weights.
        num_inputs: the number of inputs to create.
        input_size: the size of each individual input.
        input_range: the range of values to include in the input.
            input_range[0] <= possible input values < input_range[1]
        input_type: the type of values in the input.
    """

    low, high = input_range

218
219
220
    inputs: list[torch.Tensor] = []
    index_mapping: list[int] = []
    prompt_mapping: list[int] = []
221

222
223
224
    for _ in range(num_inputs):
        if input_type == torch.int:
            inputs.append(
225
226
227
228
                torch.randint(
                    low=int(low), high=int(high), size=input_size, device=device
                )
            )
229
230
        else:
            inputs.append(
231
232
233
                torch.rand(size=input_size, dtype=input_type, device=device) * high
                + low
            )
234
235
236
237
238
239
240
241

        lora_id = random.choice(active_lora_ids)
        index_mapping += [lora_id] * input_size[0]
        prompt_mapping += [lora_id]

    return inputs, index_mapping, prompt_mapping


242
243
244
245
246
def check_punica_wrapper(punica_wrapper) -> bool:
    if current_platform.is_cuda_alike():
        from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU

        return type(punica_wrapper) is PunicaWrapperGPU
247
248
249
250
    elif current_platform.is_cpu():
        from vllm.lora.punica_wrapper.punica_cpu import PunicaWrapperCPU

        return type(punica_wrapper) is PunicaWrapperCPU
251
252
253
254
    else:
        return False


255
@torch.inference_mode()
256
@pytest.mark.parametrize("num_loras", [1, 2, 4])
257
@pytest.mark.parametrize("device", DEVICES)
258
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
259
260
@pytest.mark.parametrize("stage", STAGES)
def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
261
262
263
    # For multi-GPU testing of Triton kernel, we must explicitly set the CUDA
    # device, see: https://github.com/triton-lang/triton/issues/2925
    # Same below.
264
265
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)
266

267
    torch.set_default_device(device)
268
    max_loras = 8
269
    punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
270
    assert check_punica_wrapper(punica_wrapper)
271
272
273
    lora_config = LoRAConfig(
        max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16
    )
274
275

    def create_random_embedding_layer():
276
        embedding = VocabParallelEmbedding(vocab_size, 256)
277
        embedding.weight.data = torch.rand_like(embedding.weight.data)
278
        embedding.weight.data[vocab_size:, :] = 0
279
280
281
282
283
        lora_embedding = VocabParallelEmbeddingWithLoRA(embedding)
        lora_embedding.create_lora_weights(max_loras, lora_config)

        return embedding, lora_embedding

284
    for i in range(NUM_RANDOM_SEEDS):
285
286
287
288
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
        embedding, lora_embedding = create_random_embedding_layer()
289
        lora_embedding.set_mapping(punica_wrapper)
290
291
292
293
294
295
296
297
298
        lora_dict, _ = populate_loras(
            id_to_index,
            layer=lora_embedding,
            layer_weights=embedding.weight.T,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=num_loras * 3,
299
            input_size=(200,),
300
            input_range=(1, vocab_size),
301
302
303
304
305
306
307
308
309
310
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
        punica_wrapper.update_metadata(
            lora_mapping,
            id_to_index,
            max_loras,
            vocab_size,
            lora_config.lora_extra_vocab_size,
        )
311
312
313

        lora_result = lora_embedding(torch.cat(inputs))

314
        expected_results: list[torch.Tensor] = []
315
316
317
318
319
        for input_, lora_id in zip(inputs, prompt_mapping):
            lora = lora_dict[lora_id]
            result = embedding(input_)
            after_a = F.embedding(
                input_,
320
                lora.lora_a.T,
321
            )
322
            result += after_a @ lora.lora_b.T
323
324
325
326
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
327
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
328
329
330
331
332
333
334
335
336

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
            lora_embedding.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=num_loras * 3,
337
            input_size=(200,),
338
            input_range=(1, vocab_size),
339
340
341
342
343
344
345
346
347
348
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
        punica_wrapper.update_metadata(
            lora_mapping,
            id_to_index,
            max_loras,
            vocab_size,
            lora_config.lora_extra_vocab_size,
        )
349
350
351
352
353

        lora_result = lora_embedding(torch.cat(inputs))
        expected_result = embedding(torch.cat(inputs))

        rtol, atol = TOLERANCES[lora_result.dtype]
354
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
355
356
357


@torch.inference_mode()
358
359
# @pytest.mark.skip(
#     reason="Fails when loras are in any slot other than the first.")
360
@pytest.mark.parametrize("num_loras", [1, 2, 4])
361
@pytest.mark.parametrize("device", DEVICES)
362
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
363
@pytest.mark.parametrize("stage", STAGES)
364
365
366
def test_embeddings_with_new_embeddings(
    dist_init, num_loras, device, vocab_size, stage
) -> None:
367
368
369
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)

370
    torch.set_default_device(device)
371
    max_loras = 8
372
    punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
373
    assert check_punica_wrapper(punica_wrapper)
374
375
376
    lora_config = LoRAConfig(
        max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16
    )
377
378

    def create_random_embedding_layer():
379
        embedding = VocabParallelEmbedding(vocab_size, 256)
380
381
        embedding_data = torch.rand_like(embedding.weight.data)
        embedding.weight.data = embedding_data
382
        embedding.weight.data[vocab_size:, :] = 0
383
        expanded_embedding = VocabParallelEmbedding(
384
            vocab_size + lora_config.lora_extra_vocab_size * max_loras,
385
            256,
386
387
            org_num_embeddings=vocab_size,
        )
388
        expanded_embedding.weight.data[:vocab_size, :] = embedding_data
389
        # We need to deepcopy the embedding as it will be modified
390
        # in place
391
        lora_embedding = VocabParallelEmbeddingWithLoRA(deepcopy(expanded_embedding))
392
393
394
395
        lora_embedding.create_lora_weights(max_loras, lora_config)

        return expanded_embedding, lora_embedding

396
    for i in range(NUM_RANDOM_SEEDS):
397
398
399
400
401
402
403
404
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
        expanded_embedding, lora_embedding = create_random_embedding_layer()
        lora_dict, _ = populate_loras(
            id_to_index,
            layer=lora_embedding,
            layer_weights=torch.zeros(
405
406
                (256, vocab_size + lora_config.lora_extra_vocab_size)
            ),
407
408
409
            generate_embeddings_tensor=256,
        )

410
        lora_embedding.set_mapping(punica_wrapper)
411
412
413
414
415
416
417
418
        # All embeddings tensors have the same shape.
        embeddings_tensors = [
            lora_dict[id].embeddings_tensor for id in sorted(lora_dict.keys())
        ]
        embeddings_tensor_len = embeddings_tensors[0].shape[0]

        # Add empty embeddings_tensors for unoccupied lora slots.
        for _ in range(max_loras - len(embeddings_tensors)):
419
            embeddings_tensors.append(torch.zeros(embeddings_tensors[0].shape))
420
421
422
423

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=num_loras * 3,
424
            input_size=(200,),
425
            input_range=(1, vocab_size),
426
427
428
429
430
431
432
433
434
435
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
        punica_wrapper.update_metadata(
            lora_mapping,
            id_to_index,
            max_loras,
            vocab_size,
            lora_config.lora_extra_vocab_size,
        )
436
437
438
439
        original_inputs = deepcopy(inputs)

        # Force some of the inputs to be in the extended embeddings range
        # to guarantee that their behavior is tested.
440
441
442
        for input_, original_input_, lora_id in zip(
            inputs, original_inputs, prompt_mapping
        ):
443
            embedding_id = lora_id - 1
444
445
            input_[-1] = vocab_size + (embedding_id * embeddings_tensor_len)
            original_input_[-1] = vocab_size
446
            input_[-2] = vocab_size + ((embedding_id + 1) * embeddings_tensor_len - 1)
447
            original_input_[-2] = vocab_size + embeddings_tensor_len - 1
448

449
450
451
        expanded_embedding.weight[
            vocab_size : vocab_size + (embeddings_tensor_len * max_loras)
        ] = torch.cat(embeddings_tensors)
452
453
454

        lora_result = lora_embedding(torch.cat(original_inputs))

455
        expected_results: list[torch.Tensor] = []
456
457
458
        for input_, original_input_, lora_id in zip(
            inputs, original_inputs, prompt_mapping
        ):
459
460
461
462
            lora = lora_dict[lora_id]
            result = expanded_embedding(input_)
            after_a = F.embedding(
                original_input_,
463
                lora.lora_a.T,
464
            )
465
            result += after_a @ lora.lora_b.T
466
467
468
469
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
470
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
471
472
473
474
475
476
477
478
479

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
            lora_embedding.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=num_loras * 3,
480
            input_size=(200,),
481
            input_range=(1, vocab_size),
482
483
            device=device,
        )
484
        original_inputs = deepcopy(inputs)
485
486
487
488
489
490
491
492
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
        punica_wrapper.update_metadata(
            lora_mapping,
            id_to_index,
            max_loras,
            vocab_size,
            lora_config.lora_extra_vocab_size,
        )
493
494
495
496
        lora_result = lora_embedding(torch.cat(original_inputs))
        expected_result = expanded_embedding(torch.cat(inputs))

        rtol, atol = TOLERANCES[lora_result.dtype]
497
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
498
499
500


@torch.inference_mode()
501
@pytest.mark.parametrize("num_loras", [1, 2, 4])
502
@pytest.mark.parametrize("device", DEVICES)
503
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 256512])
504
@pytest.mark.parametrize("stage", STAGES)
505
506
507
def test_lm_head_logits_processor(
    dist_init, num_loras, device, vocab_size, stage
) -> None:
508
509
510
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)

511
    torch.set_default_device(device)
512
    max_loras = 8
513
    punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
514
    assert check_punica_wrapper(punica_wrapper)
515
516
517
    lora_config = LoRAConfig(
        max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16
    )
518

519
    def _pretest():
520
521
522
523
524
525
        linear = ParallelLMHead(
            vocab_size + lora_config.lora_extra_vocab_size,
            1024,
            vocab_size,
            params_dtype=torch.float16,
        )
526
        linear.weight.data = torch.rand_like(linear.weight.data)
527
        linear.weight.data[:, vocab_size:] = 0
528
        logits_processor = LogitsProcessor(
529
530
            vocab_size + lora_config.lora_extra_vocab_size, vocab_size
        )
531
        lora_logits_processor = LogitsProcessorWithLoRA(
532
533
            logits_processor, 1024, linear.weight.dtype, linear.weight.device, None
        )
534
        lora_logits_processor.create_lora_weights(max_loras, lora_config)
535

536
        return linear, logits_processor, lora_logits_processor
537

538
    for i in range(NUM_RANDOM_SEEDS):
539
540
541
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
542
        linear, logits_processor, lora_logits_processor = _pretest()
543
        lora_logits_processor.set_mapping(punica_wrapper)
544
545
546
        # NOTE: all the generated loras share the same embeddings tensor.
        lora_dict, _ = populate_loras(
            id_to_index,
547
            layer=lora_logits_processor,
548
549
550
551
552
553
554
555
556
557
558
            layer_weights=linear.weight,
            generate_embeddings_tensor=1024,
        )
        embeddings_tensor = list(lora_dict.values())[0].embeddings_tensor
        embeddings_tensor_len = embeddings_tensor.shape[0]

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=8 * num_loras,  # * 3,
            input_size=(1, 1024),
            input_range=(0, 1),
559
            input_type=torch.float16,
560
561
562
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
563
        punica_wrapper.update_metadata(
564
565
566
            lora_mapping,
            id_to_index,
            max_loras,
567
            vocab_size,
568
569
            lora_config.lora_extra_vocab_size,
        )
570
        input_ = torch.rand(20, 1024)
571

572
        lora_result = lora_logits_processor._get_logits(
573
574
            hidden_states=torch.cat(inputs), lm_head=linear, embedding_bias=None
        )
575

576
        original_lm_head = deepcopy(linear)
577

578
579
580
581
        linear.weight[
            logits_processor.org_vocab_size : logits_processor.org_vocab_size
            + embeddings_tensor_len
        ] = embeddings_tensor
582

583
        logits_processor.org_vocab_size = vocab_size + lora_config.lora_extra_vocab_size
584
        expected_results: list[torch.Tensor] = []
585
586
        for input_, lora_id in zip(inputs, prompt_mapping):
            lora = lora_dict[lora_id]
587
588
589
590
            result = logits_processor._get_logits(
                hidden_states=input_, lm_head=linear, embedding_bias=None
            )
            result[:, vocab_size + embeddings_tensor_len :] = float("-inf")
591
            result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
592
593
            expected_results.append(result)
        expected_result = torch.cat(expected_results)
594
        logits_processor.org_vocab_size = vocab_size
595
596
597
598

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
599
            lora_logits_processor.reset_lora(slot_idx)
600
601
602
603
604
605

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=8 * num_loras * 3,
            input_size=(1, 1024),
            input_range=(0, 1),
606
            input_type=torch.float16,
607
608
609
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
610
611
612
613
614
615
616
        punica_wrapper.update_metadata(
            lora_mapping,
            id_to_index,
            max_loras,
            vocab_size,
            lora_config.lora_extra_vocab_size,
        )
617
618
619

        lora_result = lora_logits_processor._get_logits(
            hidden_states=torch.cat(inputs),
620
            lm_head=original_lm_head,
621
622
            embedding_bias=None,
        )[:, :vocab_size]
623
624
        expected_result = logits_processor._get_logits(
            hidden_states=torch.cat(inputs),
625
            lm_head=original_lm_head,
626
627
            embedding_bias=None,
        )
628
629

        rtol, atol = TOLERANCES[lora_result.dtype]
630
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
631
632


633
@torch.inference_mode()
634
@pytest.mark.parametrize("num_loras", [1, 2, 4])
635
@pytest.mark.parametrize("device", DEVICES)
636
@pytest.mark.parametrize("stage", STAGES)
637
638
639
640
641
642
def test_linear_replicated(
    dist_init,
    num_loras,
    device,
    stage,
) -> None:
643
644
645
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)

646
    max_loras = 8
647
    torch.set_default_device(device)
648
    punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
649
    assert check_punica_wrapper(punica_wrapper)
650
651
652
653
654
    lora_config = LoRAConfig(
        max_loras=max_loras,
        max_lora_rank=8,
        lora_dtype=torch.float16,
    )
655
656

    def create_random_linear_replicated_layer():
657
        linear = ReplicatedLinear(4096, 4096, bias=False, params_dtype=torch.float16)
658
659
660
661
        linear.weight.data = torch.rand_like(linear.weight.data)
        lora_linear = ReplicatedLinearWithLoRA(linear)

        lora_linear.create_lora_weights(max_loras, lora_config)
662
663
664
665
666
667
        assert (
            lora_linear.n_slices
            == len(lora_linear.lora_a_stacked)
            == len(lora_linear.lora_b_stacked)
            == 1
        )
668
669
        return linear, lora_linear

670
    for i in range(NUM_RANDOM_SEEDS):
671
672
673
674
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
        linear, lora_linear = create_random_linear_replicated_layer()
675
        assert torch.equal(linear.weight, lora_linear.weight)
676
677
678
679
680
681
682
683
684
685
686
687
688
        lora_linear.set_mapping(punica_wrapper)
        lora_dict, _ = populate_loras(
            id_to_index,
            layer=lora_linear,
            layer_weights=linear.weight,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
            input_type=torch.float16,
689
690
691
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
692
693
694
695
696
697
698
699
700
701
        punica_wrapper.update_metadata(
            lora_mapping,
            id_to_index,
            max_loras,
            512,
            lora_config.lora_extra_vocab_size,
        )

        lora_result = lora_linear(torch.cat(inputs))[0]

702
        expected_results: list[torch.Tensor] = []
703
704
705
        for input_, lora_id in zip(inputs, prompt_mapping):
            lora = lora_dict[lora_id]
            result = linear(input_)[0]
706
            result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
707
708
709
710
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
711
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
712
713
714
715
716
717
718
719
720
721
722
723

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
            lora_linear.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
            input_type=torch.float16,
724
725
726
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
727

728
729
730
        punica_wrapper.update_metadata(
            lora_mapping, id_to_index, max_loras, 512, lora_config.lora_extra_vocab_size
        )
731
732
733
734
735

        lora_result = lora_linear(torch.cat(inputs))[0]
        expected_result = linear(torch.cat(inputs))[0]

        rtol, atol = TOLERANCES[lora_result.dtype]
736
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
737
738


739
@torch.inference_mode()
740
@pytest.mark.parametrize("num_loras", [1, 2, 4])
741
@pytest.mark.parametrize("orientation", ["row", "column"])
742
@pytest.mark.parametrize("fully_shard", [True, False])
743
@pytest.mark.parametrize("device", DEVICES)
744
@pytest.mark.parametrize("stage", STAGES)
745
746
747
def test_linear_parallel(
    dist_init, num_loras, orientation, fully_shard, device, stage
) -> None:
748
749
750
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)

751
    max_loras = 8
752
    torch.set_default_device(device)
753
    punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
754
    assert check_punica_wrapper(punica_wrapper)
755
756
757
758
759
760
    lora_config = LoRAConfig(
        max_loras=max_loras,
        max_lora_rank=8,
        fully_sharded_loras=fully_shard,
        lora_dtype=torch.float16,
    )
761
762
763

    def create_random_linear_parallel_layer():
        if orientation == "row":
764
765
766
            linear = RowParallelLinear(
                4096, 4096, bias=False, params_dtype=torch.float16
            )
767
            linear.weight.data = torch.rand_like(linear.weight.data)
768
769
770
771
772
            lora_linear = (
                RowParallelLinearWithLoRA(linear)
                if not fully_shard
                else RowParallelLinearWithShardedLoRA(linear)
            )
773
        else:
774
775
776
            linear = ColumnParallelLinear(
                4096, 4096, bias=False, params_dtype=torch.float16
            )
777
            linear.weight.data = torch.rand_like(linear.weight.data)
778
779
780
781
782
            lora_linear = (
                ColumnParallelLinearWithLoRA(linear)
                if not fully_shard
                else ColumnParallelLinearWithShardedLoRA(linear)
            )
783
        lora_linear.create_lora_weights(max_loras, lora_config)
784
785
786
787
788
789
        assert (
            lora_linear.n_slices
            == len(lora_linear.lora_a_stacked)
            == len(lora_linear.lora_b_stacked)
            == 1
        )
790

791
792
        return linear, lora_linear

793
    for i in range(NUM_RANDOM_SEEDS):
794
795
796
797
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
        linear, lora_linear = create_random_linear_parallel_layer()
798
        assert torch.equal(linear.weight, lora_linear.weight)
799
        lora_linear.set_mapping(punica_wrapper)
800
801
802
803
804
805
806
807
808
809
810
        lora_dict, _ = populate_loras(
            id_to_index,
            layer=lora_linear,
            layer_weights=linear.weight,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
811
            input_type=torch.float16,
812
813
814
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
815
        punica_wrapper.update_metadata(
816
817
818
819
820
821
822
823
824
            lora_mapping,
            id_to_index,
            max_loras,
            512,
            lora_config.lora_extra_vocab_size,
        )

        lora_result = lora_linear(torch.cat(inputs))[0]

825
        expected_results: list[torch.Tensor] = []
826
827
828
        for input_, lora_id in zip(inputs, prompt_mapping):
            lora = lora_dict[lora_id]
            result = linear(input_)[0]
829
            result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
830
831
832
833
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
834
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
835
836
837
838
839
840
841
842
843
844
845

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
            lora_linear.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
846
            input_type=torch.float16,
847
848
849
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
850

851
852
853
        punica_wrapper.update_metadata(
            lora_mapping, id_to_index, max_loras, 512, lora_config.lora_extra_vocab_size
        )
854
855
856
857
858

        lora_result = lora_linear(torch.cat(inputs))[0]
        expected_result = linear(torch.cat(inputs))[0]

        rtol, atol = TOLERANCES[lora_result.dtype]
859
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
860
861
862


@torch.inference_mode()
863
@pytest.mark.parametrize("num_loras", [1, 2, 4])
864
@pytest.mark.parametrize("repeats", [1, 2, 3])
865
@pytest.mark.parametrize("fully_shard", [True, False])
866
@pytest.mark.parametrize("device", DEVICES)
867
@pytest.mark.parametrize("stage", STAGES)
868
869
870
def test_column_parallel_packed(
    dist_init, num_loras, repeats, fully_shard, device, stage
) -> None:
871
872
873
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)

874
    max_loras = 8
875
    torch.set_default_device(device)
876
    punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
877
    assert check_punica_wrapper(punica_wrapper)
878
879
880
881
882
883
    lora_config = LoRAConfig(
        max_loras=max_loras,
        max_lora_rank=8,
        fully_sharded_loras=fully_shard,
        lora_dtype=torch.float16,
    )
884
885
886

    def create_column_parallel_packed_layer():
        if repeats == 2:
887
888
889
            linear = MergedColumnParallelLinear(
                4096, [4096] * repeats, bias=False, params_dtype=torch.float16
            )
890
            linear.weight.data = torch.rand_like(linear.weight.data)
891
892
893
894
895
            lora_linear = (
                MergedColumnParallelLinearWithLoRA(linear)
                if not fully_shard
                else MergedColumnParallelLinearWithShardedLoRA(linear)
            )
896
        elif repeats == 3:
897
898
899
            linear = QKVParallelLinear(
                4096, 64, 32, bias=False, params_dtype=torch.float16
            )
900
            linear.weight.data = torch.rand_like(linear.weight.data)
901
902
903
904
905
            lora_linear = (
                MergedQKVParallelLinearWithLoRA(linear)
                if not fully_shard
                else MergedQKVParallelLinearWithShardedLoRA(linear)
            )
906
        else:
907
908
909
            linear = QKVParallelLinear(
                4096, 64, 32, bias=False, params_dtype=torch.float16
            )
910
            linear.weight.data = torch.rand_like(linear.weight.data)
911
912
913
914
915
            lora_linear = (
                QKVParallelLinearWithLoRA(linear)
                if not fully_shard
                else QKVParallelLinearWithShardedLoRA(linear)
            )
916
917
918
919
920
921
922

        @dataclass
        class FakeConfig:
            hidden_size = 4096
            num_key_value_heads = 32
            num_attention_heads = 32

923
        n_slices = repeats
924
925
926
927
928
929
930
931
932
        lora_linear.create_lora_weights(
            max_loras, lora_config, model_config=FakeConfig()
        )
        assert (
            lora_linear.n_slices
            == len(lora_linear.lora_a_stacked)
            == len(lora_linear.lora_b_stacked)
            == n_slices
        )
933

934
935
        return linear, lora_linear

936
    for i in range(NUM_RANDOM_SEEDS):
937
938
939
940
941
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)

        linear, lora_linear = create_column_parallel_packed_layer()
942
        assert torch.equal(linear.weight, lora_linear.weight)
943
        lora_linear.set_mapping(punica_wrapper)
944
945
946
947
948
949
950
951
952
953
954
955
        lora_dict, sublora_dict = populate_loras(
            id_to_index,
            layer=lora_linear,
            layer_weights=linear.weight,
            repeats=repeats,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
956
            input_type=torch.float16,
957
958
959
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
960

961
        punica_wrapper.update_metadata(
962
963
964
965
966
967
968
969
970
            lora_mapping,
            id_to_index,
            max_loras,
            512,
            lora_config.lora_extra_vocab_size,
        )

        lora_result = lora_linear(torch.cat(inputs))[0]

971
        expected_results: list[torch.Tensor] = []
972
973
974
975
        for input_, lora_id in zip(inputs, prompt_mapping):
            result = linear(input_)[0]
            subloras = sublora_dict[lora_id]
            for i, sublora in enumerate(subloras):
976
977
978
                result[
                    :, sublora.lora_b.shape[0] * i : sublora.lora_b.shape[0] * (i + 1)
                ] += input_ @ sublora.lora_a.T @ sublora.lora_b.T * sublora.scaling
979
980
981
982
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
983
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
984
985
986
987
988
989
990
991
992

        for slot_idx in range(max_loras):
            lora_linear.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
993
            input_type=torch.float16,
994
995
996
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
997

998
        punica_wrapper.update_metadata(
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
            lora_mapping,
            id_to_index,
            max_loras,
            512,
            lora_config.lora_extra_vocab_size,
        )

        lora_result = lora_linear(torch.cat(inputs))[0]
        expected_result = linear(torch.cat(inputs))[0]

        rtol, atol = TOLERANCES[lora_result.dtype]
1010
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
1011
1012


1013
@pytest.mark.parametrize("tp_size", [1, 2, 4, 8])
1014
@pytest.mark.parametrize(
1015
1016
    "seed", list(range(VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS))
)
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
def test_vocab_parallel_embedding_indices(tp_size, seed):
    random.seed(seed)
    vocab_size = random.randint(4000, 64000)
    added_vocab_size = random.randint(0, 1024)
    org_vocab_size = vocab_size - added_vocab_size
    last_org_vocab_end_index = 0
    last_added_vocab_end_index = org_vocab_size
    computed_vocab_size = 0
    computed_org_vocab_size = 0
    computed_added_vocab_size = 0
    vocab_size_padded = -1

1029
1030
1031
    all_org_tokens: list[int] = []
    all_added_tokens: list[int] = []
    token_ids: list[int] = []
1032
1033

    for tp_rank in range(tp_size):
1034
1035
        with (
            patch(
1036
                "vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank",
1037
1038
1039
                return_value=tp_rank,
            ),
            patch(
1040
                "vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_world_size",
1041
1042
1043
                return_value=tp_size,
            ),
        ):
1044
            vocab_embedding = VocabParallelEmbedding(
1045
1046
                vocab_size, 1, org_num_embeddings=org_vocab_size
            )
1047
1048
1049
1050
        vocab_size_padded = vocab_embedding.num_embeddings_padded
        shard_indices = vocab_embedding.shard_indices
        # Assert that the ranges are contiguous
        assert shard_indices.org_vocab_start_index == last_org_vocab_end_index
1051
        assert shard_indices.added_vocab_start_index == last_added_vocab_end_index
1052
1053
1054
1055
1056
1057
1058
1059

        # Ensure that we are not exceeding the vocab size
        computed_vocab_size += shard_indices.num_elements_padded
        computed_org_vocab_size += shard_indices.num_org_elements
        computed_added_vocab_size += shard_indices.num_added_elements

        # Ensure that the ranges are not overlapping
        all_org_tokens.extend(
1060
1061
1062
1063
            range(
                shard_indices.org_vocab_start_index, shard_indices.org_vocab_end_index
            )
        )
1064
        all_added_tokens.extend(
1065
1066
1067
1068
1069
            range(
                shard_indices.added_vocab_start_index,
                shard_indices.added_vocab_end_index,
            )
        )
1070
1071

        token_ids.extend(
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
            range(
                shard_indices.org_vocab_start_index, shard_indices.org_vocab_end_index
            )
        )
        token_ids.extend(
            [-1]
            * (shard_indices.num_org_elements_padded - shard_indices.num_org_elements)
        )
        token_ids.extend(
            range(
                shard_indices.added_vocab_start_index,
                shard_indices.added_vocab_end_index,
            )
        )
1086
        token_ids.extend(
1087
1088
1089
1090
1091
1092
            [-1]
            * (
                shard_indices.num_added_elements_padded
                - shard_indices.num_added_elements
            )
        )
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119

        last_org_vocab_end_index = shard_indices.org_vocab_end_index
        last_added_vocab_end_index = shard_indices.added_vocab_end_index

    assert computed_vocab_size == vocab_size_padded
    assert computed_org_vocab_size == org_vocab_size
    assert computed_added_vocab_size == added_vocab_size

    # Ensure that the ranges are not overlapping
    assert len(all_org_tokens) == len(set(all_org_tokens))
    assert len(all_added_tokens) == len(set(all_added_tokens))
    assert not set(all_org_tokens).intersection(set(all_added_tokens))

    token_ids_tensor = torch.tensor(token_ids, dtype=torch.long)
    reindex_mapping = vocab_embedding.get_sharded_to_full_mapping()
    assert reindex_mapping is not None or tp_size == 1
    if reindex_mapping is not None:
        reindexed_token_ids = token_ids_tensor[reindex_mapping]
        expected = torch.tensor(list(range(0, vocab_size)))
        assert reindexed_token_ids[:vocab_size].equal(expected)
        assert torch.all(reindexed_token_ids[vocab_size:] == -1)


def test_get_masked_input_and_mask():
    x = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])

    # base tp 1 case, no padding
1120
1121
1122
1123
1124
1125
1126
1127
    modified_x, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=0,
        org_vocab_end_index=8,
        added_vocab_start_index=8,
        added_vocab_end_index=12,
        num_org_vocab_padding=0,
    )
1128
1129
1130
    assert torch.equal(x, modified_x)

    # tp 2 case, no padding
1131
1132
1133
1134
1135
1136
1137
1138
    modified_x_rank_0, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=0,
        org_vocab_end_index=4,
        added_vocab_start_index=8,
        added_vocab_end_index=10,
        num_org_vocab_padding=0,
    )
1139
1140
1141
1142
1143
1144
    modified_x_rank_1, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=4,
        org_vocab_end_index=8,
        added_vocab_start_index=10,
        added_vocab_end_index=12,
1145
1146
1147
1148
1149
1150
1151
1152
        num_org_vocab_padding=0,
    )
    assert torch.equal(
        modified_x_rank_0, torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 0, 0])
    )
    assert torch.equal(
        modified_x_rank_1, torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 4, 5])
    )
1153
1154

    # tp 4 case, no padding
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
    modified_x_rank_0, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=0,
        org_vocab_end_index=2,
        added_vocab_start_index=8,
        added_vocab_end_index=9,
        num_org_vocab_padding=0,
    )
    modified_x_rank_1, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=2,
        org_vocab_end_index=4,
        added_vocab_start_index=9,
        added_vocab_end_index=10,
        num_org_vocab_padding=0,
    )
1171
1172
1173
1174
1175
1176
    modified_x_rank_2, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=4,
        org_vocab_end_index=6,
        added_vocab_start_index=10,
        added_vocab_end_index=11,
1177
1178
        num_org_vocab_padding=0,
    )
1179
1180
1181
1182
1183
1184
    modified_x_rank_3, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=6,
        org_vocab_end_index=8,
        added_vocab_start_index=11,
        added_vocab_end_index=12,
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
        num_org_vocab_padding=0,
    )
    assert torch.equal(
        modified_x_rank_0, torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0])
    )
    assert torch.equal(
        modified_x_rank_1, torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0])
    )
    assert torch.equal(
        modified_x_rank_2, torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0])
    )
    assert torch.equal(
        modified_x_rank_3, torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2])
    )
1199
1200

    # base tp 1 case, with padding
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
    modified_x, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=0,
        org_vocab_end_index=8,
        added_vocab_start_index=8,
        added_vocab_end_index=12,
        num_org_vocab_padding=2,
    )
    assert torch.equal(
        modified_x, torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13])
    )
1212
1213

    # tp 2 case, with padding
1214
1215
1216
1217
1218
1219
1220
1221
    modified_x_rank_0, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=0,
        org_vocab_end_index=4,
        added_vocab_start_index=8,
        added_vocab_end_index=10,
        num_org_vocab_padding=2,
    )
1222
1223
1224
1225
1226
1227
    modified_x_rank_1, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=4,
        org_vocab_end_index=8,
        added_vocab_start_index=10,
        added_vocab_end_index=12,
1228
1229
1230
1231
1232
1233
1234
1235
        num_org_vocab_padding=2,
    )
    assert torch.equal(
        modified_x_rank_0, torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 6, 7, 0, 0])
    )
    assert torch.equal(
        modified_x_rank_1, torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 6, 7])
    )
1236
1237

    # tp 4 case, with padding
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
    modified_x_rank_0, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=0,
        org_vocab_end_index=2,
        added_vocab_start_index=8,
        added_vocab_end_index=9,
        num_org_vocab_padding=2,
    )
    modified_x_rank_1, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=2,
        org_vocab_end_index=4,
        added_vocab_start_index=9,
        added_vocab_end_index=10,
        num_org_vocab_padding=2,
    )
1254
1255
1256
1257
1258
1259
    modified_x_rank_2, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=4,
        org_vocab_end_index=6,
        added_vocab_start_index=10,
        added_vocab_end_index=11,
1260
1261
        num_org_vocab_padding=2,
    )
1262
1263
1264
1265
1266
1267
    modified_x_rank_3, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=6,
        org_vocab_end_index=8,
        added_vocab_start_index=11,
        added_vocab_end_index=12,
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
        num_org_vocab_padding=2,
    )
    assert torch.equal(
        modified_x_rank_0, torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0])
    )
    assert torch.equal(
        modified_x_rank_1, torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 4, 0, 0])
    )
    assert torch.equal(
        modified_x_rank_2, torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 4, 0])
    )
    assert torch.equal(
        modified_x_rank_3, torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 4])
    )