test_layers.py 49.6 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
import random
from copy import deepcopy
from dataclasses import dataclass
6
from typing import Optional
7
from unittest.mock import patch
8

9
import pytest
10
11
12
13
import torch
import torch.nn.functional as F

from vllm.config import LoRAConfig
14
15
16
from vllm.lora.fully_sharded_layers import (
    ColumnParallelLinearWithShardedLoRA,
    MergedColumnParallelLinearWithShardedLoRA,
17
    MergedQKVParallelLinearWithShardedLoRA, QKVParallelLinearWithShardedLoRA,
18
    RowParallelLinearWithShardedLoRA)
19
20
# yapf conflicts with isort for this block
# yapf: disable
21
22
23
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
                              LogitsProcessorWithLoRA, LoRAMapping,
                              MergedColumnParallelLinearWithLoRA,
24
25
                              MergedQKVParallelLinearWithLoRA,
                              QKVParallelLinearWithLoRA,
26
                              ReplicatedLinearWithLoRA,
27
28
                              RowParallelLinearWithLoRA,
                              VocabParallelEmbeddingWithLoRA)
29
# yapf: enable
30
from vllm.lora.models import LoRALayerWeights, PackedLoRALayerWeights
31
from vllm.lora.punica_wrapper import get_punica_wrapper
32
33
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               MergedColumnParallelLinear,
34
                                               QKVParallelLinear,
35
                                               ReplicatedLinear,
36
37
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
38
from vllm.model_executor.layers.vocab_parallel_embedding import (
39
    ParallelLMHead, VocabParallelEmbedding, get_masked_input_and_mask)
40
from vllm.model_executor.utils import set_random_seed
41
from vllm.platforms import current_platform
42
43
44
45
46
47
48
49

from .utils import DummyLoRAManager

TOLERANCES = {
    torch.float16: (5e-3, 5e-3),
    torch.float32: (5e-3, 5e-3),
    torch.bfloat16: (3e-2, 2e-2),
}
50
51
52
53
54
55

pytestmark = pytest.mark.skipif(
    not (current_platform.is_cuda_alike() or current_platform.is_cpu()),
    reason="Backend not supported")

DEVICES = ([
56
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
57
] if current_platform.is_cuda_alike() else ["cpu"])
58

59
# prefill stage(True) or decode stage(False)
60
STAGES = [True, False]
61

62
NUM_RANDOM_SEEDS = 6
63

64
65
66
67
VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS = 128


@pytest.fixture(autouse=True)
68
def clean_cache():
69
70
71
72
73
74
75
76
    # Release any memory we might be holding on to. CI runs OOMs otherwise.
    from vllm.lora.ops.triton_ops.utils import (_LORA_A_PTR_DICT,
                                                _LORA_B_PTR_DICT)
    _LORA_B_PTR_DICT.clear()
    _LORA_A_PTR_DICT.clear()

    yield

77

78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
@pytest.fixture(autouse=True)
def skip_cuda_with_stage_false(request):
    """
    On cuda-like platforms, we use the same kernels for prefill and decode 
    stage, and 'stage' is generally ignored, so we only need to test once.
    """
    if current_platform.is_cuda_alike():
        try:
            if hasattr(request.node, "callspec") and hasattr(
                    request.node.callspec, "params"):
                params = request.node.callspec.params
                if "stage" in params and params["stage"] is False:
                    pytest.skip("Skip test when stage=False")
        except Exception:
            pass
    yield


96
97
def get_random_id_to_index(num_loras: int,
                           num_slots: int,
98
                           log: bool = True) -> list[Optional[int]]:
99
100
101
102
103
104
105
106
107
108
109
110
111
112
    """Creates a random lora_id_to_index mapping.

    Args:
        num_loras: The number of active loras in the mapping.
        num_slots: The number of slots in the mapping. Must be larger
            than num_loras.
        log: Whether to log the output.
    """

    if num_loras > num_slots:
        raise ValueError(
            f"num_loras is higher than num_slots: {num_loras} > {num_slots}. "
            "num_loras must be less than or equal to num_slots.")

113
    slots: list[Optional[int]] = [None] * num_slots
114
115
116
117
118
119
120
121
122
123
124
    random_slot_selections = (torch.randperm(num_slots)[:num_loras]).tolist()
    for lora_id, slot_idx in enumerate(random_slot_selections, start=1):
        slots[slot_idx] = lora_id

    if log:
        print(f"Created lora_id_to_index mapping: {slots}.")

    return slots


def populate_loras(
125
    id_to_index: list[Optional[int]],
126
127
128
129
    layer: BaseLayerWithLoRA,
    layer_weights: torch.Tensor,
    generate_embeddings_tensor: int = 0,
    repeats: int = 1,
130
) -> tuple[dict[int, LoRALayerWeights], dict[int, list[LoRALayerWeights]]]:
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
    """This method populates the lora layers with lora weights.

    Args:
        id_to_index: a list of lora ids. The index of the lora id
            represents which memory slot the lora matrices are
            stored in. A None value indicates a free slot.
        layer: the LoRAlayer to populate.
        layer_weights: the PyTorch tensor containing the layer's
            weights.
        generate_embeddings_tensor: whether to generate an
            embeddings tensor for each LoRA.
        repeats: must only be set for column parallel packed
            layers. Indicates the number of loras to compose
            together to create a single lora layer.
    """

    # Dictionary that maps the lora ID to the
    # corresponding lora weights.
149
    lora_dict: dict[int, LoRALayerWeights] = dict()
150
151

    # Dictionary that maps the lora ID to the
152
    # corresponding subloras.
153
    sublora_dict: dict[int, list[LoRALayerWeights]] = dict()
154
155
156

    for slot_idx, lora_id in enumerate(id_to_index):
        if lora_id is not None:
157
            subloras: list[LoRALayerWeights] = []
158
159
            sublora_len = layer_weights.shape[0] // repeats
            for i in range(repeats):
160
161
162
163
164
165
                sublora = DummyLoRAManager(
                    layer_weights.device).init_random_lora(
                        module_name=f"fake_{i}",
                        weight=layer_weights,
                        generate_embeddings_tensor=generate_embeddings_tensor,
                    )
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
                sublora.lora_b = sublora.lora_b[:, (sublora_len *
                                                    i):(sublora_len * (i + 1))]
                sublora.optimize()
                subloras.append(sublora)

            lora = PackedLoRALayerWeights.pack(
                subloras) if repeats > 1 else subloras[0]

            layer.set_lora(
                slot_idx,
                lora_a=lora.lora_a,
                lora_b=lora.lora_b,
                embeddings_tensor=lora.embeddings_tensor,
            )

            lora_dict[lora_id] = lora
            sublora_dict[lora_id] = subloras

    return lora_dict, sublora_dict


def create_random_inputs(
188
    active_lora_ids: list[int],
189
    num_inputs: int,
190
191
    input_size: tuple[int, ...],
    input_range: tuple[float, float],
192
    input_type: torch.dtype = torch.int,
193
    device: torch.device = "cuda"
194
) -> tuple[list[torch.Tensor], list[int], list[int]]:
195
196
197
198
199
200
201
202
203
204
205
206
207
    """Creates random inputs.

    Args:
        active_lora_ids: lora IDs of active lora weights.
        num_inputs: the number of inputs to create.
        input_size: the size of each individual input.
        input_range: the range of values to include in the input.
            input_range[0] <= possible input values < input_range[1]
        input_type: the type of values in the input.
    """

    low, high = input_range

208
209
210
    inputs: list[torch.Tensor] = []
    index_mapping: list[int] = []
    prompt_mapping: list[int] = []
211

212
213
214
    for _ in range(num_inputs):
        if input_type == torch.int:
            inputs.append(
215
216
217
218
                torch.randint(low=int(low),
                              high=int(high),
                              size=input_size,
                              device=device))
219
220
        else:
            inputs.append(
221
222
                torch.rand(size=input_size, dtype=input_type, device=device) *
                high + low)
223
224
225
226
227
228
229
230

        lora_id = random.choice(active_lora_ids)
        index_mapping += [lora_id] * input_size[0]
        prompt_mapping += [lora_id]

    return inputs, index_mapping, prompt_mapping


231
232
233
234
235
def check_punica_wrapper(punica_wrapper) -> bool:
    if current_platform.is_cuda_alike():
        from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU

        return type(punica_wrapper) is PunicaWrapperGPU
236
237
238
239
    elif current_platform.is_cpu():
        from vllm.lora.punica_wrapper.punica_cpu import PunicaWrapperCPU

        return type(punica_wrapper) is PunicaWrapperCPU
240
241
242
243
    else:
        return False


244
245
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
246
@pytest.mark.parametrize("device", DEVICES)
247
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
248
249
@pytest.mark.parametrize("stage", STAGES)
def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
250
251
252
    # For multi-GPU testing of Triton kernel, we must explicitly set the CUDA
    # device, see: https://github.com/triton-lang/triton/issues/2925
    # Same below.
253
254
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)
255

256
    torch.set_default_device(device)
257
    max_loras = 8
258
    punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
259
    assert check_punica_wrapper(punica_wrapper)
260
261
262
263
264
    lora_config = LoRAConfig(max_loras=max_loras,
                             max_lora_rank=8,
                             lora_dtype=torch.float16)

    def create_random_embedding_layer():
265
        embedding = VocabParallelEmbedding(vocab_size, 256)
266
        embedding.weight.data = torch.rand_like(embedding.weight.data)
267
        embedding.weight.data[vocab_size:, :] = 0
268
269
270
271
272
        lora_embedding = VocabParallelEmbeddingWithLoRA(embedding)
        lora_embedding.create_lora_weights(max_loras, lora_config)

        return embedding, lora_embedding

273
    for i in range(NUM_RANDOM_SEEDS):
274
275
276
277
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
        embedding, lora_embedding = create_random_embedding_layer()
278
        lora_embedding.set_mapping(punica_wrapper)
279
280
281
282
283
284
285
286
287
288
        lora_dict, _ = populate_loras(
            id_to_index,
            layer=lora_embedding,
            layer_weights=embedding.weight.T,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=num_loras * 3,
            input_size=(200, ),
289
            input_range=(1, vocab_size),
290
            device=device)
291
292
293
294
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
        punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
295
296
                                       vocab_size,
                                       lora_config.lora_extra_vocab_size)
297
298
299

        lora_result = lora_embedding(torch.cat(inputs))

300
        expected_results: list[torch.Tensor] = []
301
302
303
304
305
306
307
308
309
310
311
312
        for input_, lora_id in zip(inputs, prompt_mapping):
            lora = lora_dict[lora_id]
            result = embedding(input_)
            after_a = F.embedding(
                input_,
                lora.lora_a,
            )
            result += (after_a @ lora.lora_b)
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
313
314
315
316
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
317
318
319
320
321
322
323
324
325
326

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
            lora_embedding.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=num_loras * 3,
            input_size=(200, ),
327
            input_range=(1, vocab_size),
328
            device=device)
329
330
331
332
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
        punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
333
334
                                       vocab_size,
                                       lora_config.lora_extra_vocab_size)
335
336
337
338
339

        lora_result = lora_embedding(torch.cat(inputs))
        expected_result = embedding(torch.cat(inputs))

        rtol, atol = TOLERANCES[lora_result.dtype]
340
341
342
343
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
344
345
346


@torch.inference_mode()
347
348
# @pytest.mark.skip(
#     reason="Fails when loras are in any slot other than the first.")
349
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
350
@pytest.mark.parametrize("device", DEVICES)
351
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
352
@pytest.mark.parametrize("stage", STAGES)
353
def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
354
                                        vocab_size, stage) -> None:
355

356
357
358
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)

359
    torch.set_default_device(device)
360
    max_loras = 8
361
    punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
362
    assert check_punica_wrapper(punica_wrapper)
363
364
365
366
367
    lora_config = LoRAConfig(max_loras=max_loras,
                             max_lora_rank=8,
                             lora_dtype=torch.float16)

    def create_random_embedding_layer():
368
        embedding = VocabParallelEmbedding(vocab_size, 256)
369
370
        embedding_data = torch.rand_like(embedding.weight.data)
        embedding.weight.data = embedding_data
371
        embedding.weight.data[vocab_size:, :] = 0
372
        expanded_embedding = VocabParallelEmbedding(
373
            vocab_size + lora_config.lora_extra_vocab_size * max_loras,
374
            256,
375
376
            org_num_embeddings=vocab_size)
        expanded_embedding.weight.data[:vocab_size, :] = embedding_data
377
        # We need to deepcopy the embedding as it will be modified
378
379
380
381
382
383
384
        # in place
        lora_embedding = VocabParallelEmbeddingWithLoRA(
            deepcopy(expanded_embedding))
        lora_embedding.create_lora_weights(max_loras, lora_config)

        return expanded_embedding, lora_embedding

385
    for i in range(NUM_RANDOM_SEEDS):
386
387
388
389
390
391
392
393
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
        expanded_embedding, lora_embedding = create_random_embedding_layer()
        lora_dict, _ = populate_loras(
            id_to_index,
            layer=lora_embedding,
            layer_weights=torch.zeros(
394
                (256, vocab_size + lora_config.lora_extra_vocab_size)),
395
396
397
            generate_embeddings_tensor=256,
        )

398
        lora_embedding.set_mapping(punica_wrapper)
399
400
401
402
403
404
405
406
        # All embeddings tensors have the same shape.
        embeddings_tensors = [
            lora_dict[id].embeddings_tensor for id in sorted(lora_dict.keys())
        ]
        embeddings_tensor_len = embeddings_tensors[0].shape[0]

        # Add empty embeddings_tensors for unoccupied lora slots.
        for _ in range(max_loras - len(embeddings_tensors)):
407
            embeddings_tensors.append(torch.zeros(embeddings_tensors[0].shape))
408
409
410
411
412

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=num_loras * 3,
            input_size=(200, ),
413
            input_range=(1, vocab_size),
414
            device=device)
415
416
417
418
419
420
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
        punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
                                       vocab_size,
                                       lora_config.lora_extra_vocab_size)
421
422
423
424
425
426
427
        original_inputs = deepcopy(inputs)

        # Force some of the inputs to be in the extended embeddings range
        # to guarantee that their behavior is tested.
        for input_, original_input_, lora_id in zip(inputs, original_inputs,
                                                    prompt_mapping):
            embedding_id = lora_id - 1
428
429
430
431
432
            input_[-1] = vocab_size + (embedding_id * embeddings_tensor_len)
            original_input_[-1] = vocab_size
            input_[-2] = vocab_size + (
                (embedding_id + 1) * embeddings_tensor_len - 1)
            original_input_[-2] = vocab_size + embeddings_tensor_len - 1
433

434
        expanded_embedding.weight[vocab_size:vocab_size +
435
436
437
438
439
                                  (embeddings_tensor_len *
                                   max_loras)] = torch.cat(embeddings_tensors)

        lora_result = lora_embedding(torch.cat(original_inputs))

440
        expected_results: list[torch.Tensor] = []
441
442
443
444
445
446
447
448
449
450
451
452
453
        for input_, original_input_, lora_id in zip(inputs, original_inputs,
                                                    prompt_mapping):
            lora = lora_dict[lora_id]
            result = expanded_embedding(input_)
            after_a = F.embedding(
                original_input_,
                lora.lora_a,
            )
            result += (after_a @ lora.lora_b)
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
454
455
456
457
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
458
459
460
461
462
463
464
465
466
467

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
            lora_embedding.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=num_loras * 3,
            input_size=(200, ),
468
            input_range=(1, vocab_size),
469
            device=device)
470
        original_inputs = deepcopy(inputs)
471
472
473
474
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
        punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
475
476
                                       vocab_size,
                                       lora_config.lora_extra_vocab_size)
477
478
479
480
        lora_result = lora_embedding(torch.cat(original_inputs))
        expected_result = expanded_embedding(torch.cat(inputs))

        rtol, atol = TOLERANCES[lora_result.dtype]
481
482
483
484
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
485
486
487
488


@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
489
@pytest.mark.parametrize("device", DEVICES)
490
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 256512])
491
492
493
@pytest.mark.parametrize("stage", STAGES)
def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
                                  stage) -> None:
494

495
496
497
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)

498
    torch.set_default_device(device)
499
    max_loras = 8
500
    punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
501
    assert check_punica_wrapper(punica_wrapper)
502
503
504
505
    lora_config = LoRAConfig(max_loras=max_loras,
                             max_lora_rank=8,
                             lora_dtype=torch.float16)

506
    def _pretest():
507
        linear = ParallelLMHead(vocab_size + lora_config.lora_extra_vocab_size,
508
509
510
                                1024,
                                vocab_size,
                                params_dtype=torch.float16)
511
        linear.weight.data = torch.rand_like(linear.weight.data)
512
        linear.weight.data[:, vocab_size:] = 0
513
        logits_processor = LogitsProcessor(
514
            vocab_size + lora_config.lora_extra_vocab_size, vocab_size)
515
        lora_logits_processor = LogitsProcessorWithLoRA(
516
517
            logits_processor, 1024, linear.weight.dtype, linear.weight.device,
            None)
518
        lora_logits_processor.create_lora_weights(max_loras, lora_config)
519

520
        return linear, logits_processor, lora_logits_processor
521

522
    for i in range(NUM_RANDOM_SEEDS):
523
524
525
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
526
        linear, logits_processor, lora_logits_processor = _pretest()
527
        lora_logits_processor.set_mapping(punica_wrapper)
528
529
530
        # NOTE: all the generated loras share the same embeddings tensor.
        lora_dict, _ = populate_loras(
            id_to_index,
531
            layer=lora_logits_processor,
532
533
534
535
536
537
538
539
540
541
542
            layer_weights=linear.weight,
            generate_embeddings_tensor=1024,
        )
        embeddings_tensor = list(lora_dict.values())[0].embeddings_tensor
        embeddings_tensor_len = embeddings_tensor.shape[0]

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=8 * num_loras,  # * 3,
            input_size=(1, 1024),
            input_range=(0, 1),
543
            input_type=torch.float16,
544
            device=device)
545
546
547
548
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
        punica_wrapper.update_metadata(
549
550
551
            lora_mapping,
            id_to_index,
            max_loras,
552
            vocab_size,
553
554
            lora_config.lora_extra_vocab_size,
        )
555
        input_ = torch.rand(20, 1024)
556

557
558
        lora_result = lora_logits_processor._get_logits(
            hidden_states=torch.cat(inputs),
559
            lm_head=linear,
560
            embedding_bias=None)
561

562
        original_lm_head = deepcopy(linear)
563

564
565
        linear.weight[logits_processor.
                      org_vocab_size:logits_processor.org_vocab_size +
566
567
                      embeddings_tensor_len] = embeddings_tensor

568
        logits_processor.org_vocab_size = (vocab_size +
569
                                           lora_config.lora_extra_vocab_size)
570
        expected_results: list[torch.Tensor] = []
571
572
        for input_, lora_id in zip(inputs, prompt_mapping):
            lora = lora_dict[lora_id]
573
            result = logits_processor._get_logits(hidden_states=input_,
574
                                                  lm_head=linear,
575
                                                  embedding_bias=None)
576
            result[:, vocab_size + embeddings_tensor_len:] = float("-inf")
577
578
579
            result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
            expected_results.append(result)
        expected_result = torch.cat(expected_results)
580
        logits_processor.org_vocab_size = vocab_size
581
582
583
584

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
585
            lora_logits_processor.reset_lora(slot_idx)
586
587
588
589
590
591

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=8 * num_loras * 3,
            input_size=(1, 1024),
            input_range=(0, 1),
592
            input_type=torch.float16,
593
            device=device)
594
595
596
597
598
599
600
601
602
603
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
        punica_wrapper.update_metadata(
            lora_mapping,
            id_to_index,
            max_loras,
            vocab_size,
            lora_config.lora_extra_vocab_size,
        )
604
605
606

        lora_result = lora_logits_processor._get_logits(
            hidden_states=torch.cat(inputs),
607
            lm_head=original_lm_head,
608
            embedding_bias=None)[:, :vocab_size]
609
610
        expected_result = logits_processor._get_logits(
            hidden_states=torch.cat(inputs),
611
            lm_head=original_lm_head,
612
            embedding_bias=None)
613
614

        rtol, atol = TOLERANCES[lora_result.dtype]
615
616
617
618
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
619
620


621
622
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
623
@pytest.mark.parametrize("device", DEVICES)
624
@pytest.mark.parametrize("stage", STAGES)
625
626
627
@pytest.mark.parametrize("bias_enabled", [True, False])
def test_linear_replicated(dist_init, num_loras, device, stage,
                           bias_enabled) -> None:
628

629
630
631
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)

632
    max_loras = 8
633
    torch.set_default_device(device)
634
    punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
635
    assert check_punica_wrapper(punica_wrapper)
636
637
    lora_config = LoRAConfig(max_loras=max_loras,
                             max_lora_rank=8,
638
639
                             lora_dtype=torch.float16,
                             bias_enabled=bias_enabled)
640
641
642
643
644
645
646
647
648
649
650

    def create_random_linear_replicated_layer():

        linear = ReplicatedLinear(4096,
                                  4096,
                                  bias=False,
                                  params_dtype=torch.float16)
        linear.weight.data = torch.rand_like(linear.weight.data)
        lora_linear = ReplicatedLinearWithLoRA(linear)

        lora_linear.create_lora_weights(max_loras, lora_config)
651
652
653
654
655
656
        assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len(
            lora_linear.lora_b_stacked) == 1)
        if bias_enabled:
            assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices
        else:
            assert lora_linear.lora_bias_stacked is None
657
658
        return linear, lora_linear

659
    for i in range(NUM_RANDOM_SEEDS):
660
661
662
663
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
        linear, lora_linear = create_random_linear_replicated_layer()
664
        assert torch.equal(linear.weight, lora_linear.weight)
665
666
667
668
669
670
671
672
673
674
675
676
677
        lora_linear.set_mapping(punica_wrapper)
        lora_dict, _ = populate_loras(
            id_to_index,
            layer=lora_linear,
            layer_weights=linear.weight,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
            input_type=torch.float16,
678
            device=device)
679
680
681
682
683
684
685
686
687
688
689
690
691
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
        punica_wrapper.update_metadata(
            lora_mapping,
            id_to_index,
            max_loras,
            512,
            lora_config.lora_extra_vocab_size,
        )

        lora_result = lora_linear(torch.cat(inputs))[0]

692
        expected_results: list[torch.Tensor] = []
693
694
695
696
697
698
699
700
        for input_, lora_id in zip(inputs, prompt_mapping):
            lora = lora_dict[lora_id]
            result = linear(input_)[0]
            result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
701
702
703
704
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
705
706
707
708
709
710
711
712
713
714
715
716

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
            lora_linear.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
            input_type=torch.float16,
717
            device=device)
718
719
720
721
722
723
724
725
726
727
728
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)

        punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
                                       512, lora_config.lora_extra_vocab_size)

        lora_result = lora_linear(torch.cat(inputs))[0]
        expected_result = linear(torch.cat(inputs))[0]

        rtol, atol = TOLERANCES[lora_result.dtype]
729
730
731
732
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
733
734


735
736
737
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("orientation", ["row", "column"])
738
@pytest.mark.parametrize("fully_shard", [True, False])
739
@pytest.mark.parametrize("device", DEVICES)
740
@pytest.mark.parametrize("stage", STAGES)
741
@pytest.mark.parametrize("bias_enabled", [True, False])
742
def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
743
                         device, stage, bias_enabled) -> None:
744

745
746
747
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)

748
    max_loras = 8
749
    torch.set_default_device(device)
750
    punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
751
    assert check_punica_wrapper(punica_wrapper)
752
753
    lora_config = LoRAConfig(max_loras=max_loras,
                             max_lora_rank=8,
754
                             fully_sharded_loras=fully_shard,
755
756
                             lora_dtype=torch.float16,
                             bias_enabled=bias_enabled)
757
758
759

    def create_random_linear_parallel_layer():
        if orientation == "row":
760
761
762
763
            linear = RowParallelLinear(4096,
                                       4096,
                                       bias=False,
                                       params_dtype=torch.float16)
764
            linear.weight.data = torch.rand_like(linear.weight.data)
765
766
            lora_linear = (RowParallelLinearWithLoRA(linear) if not fully_shard
                           else RowParallelLinearWithShardedLoRA(linear))
767
        else:
768
769
770
771
            linear = ColumnParallelLinear(4096,
                                          4096,
                                          bias=False,
                                          params_dtype=torch.float16)
772
            linear.weight.data = torch.rand_like(linear.weight.data)
773
774
775
            lora_linear = (ColumnParallelLinearWithLoRA(linear)
                           if not fully_shard else
                           ColumnParallelLinearWithShardedLoRA(linear))
776
        lora_linear.create_lora_weights(max_loras, lora_config)
777
778
779
780
781
782
        assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len(
            lora_linear.lora_b_stacked) == 1)
        if bias_enabled:
            assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices
        else:
            assert lora_linear.lora_bias_stacked is None
783
784
        return linear, lora_linear

785
    for i in range(NUM_RANDOM_SEEDS):
786
787
788
789
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
        linear, lora_linear = create_random_linear_parallel_layer()
790
        assert torch.equal(linear.weight, lora_linear.weight)
791
        lora_linear.set_mapping(punica_wrapper)
792
793
794
795
796
797
798
799
800
801
802
        lora_dict, _ = populate_loras(
            id_to_index,
            layer=lora_linear,
            layer_weights=linear.weight,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
803
            input_type=torch.float16,
804
            device=device)
805
806
807
808
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
        punica_wrapper.update_metadata(
809
810
811
812
813
814
815
816
817
            lora_mapping,
            id_to_index,
            max_loras,
            512,
            lora_config.lora_extra_vocab_size,
        )

        lora_result = lora_linear(torch.cat(inputs))[0]

818
        expected_results: list[torch.Tensor] = []
819
820
821
822
823
824
825
826
        for input_, lora_id in zip(inputs, prompt_mapping):
            lora = lora_dict[lora_id]
            result = linear(input_)[0]
            result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
827
828
829
830
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
831
832
833
834
835
836
837
838
839
840
841

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
            lora_linear.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
842
            input_type=torch.float16,
843
            device=device)
844
845
846
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
847

848
        punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
849
850
851
852
853
854
                                       512, lora_config.lora_extra_vocab_size)

        lora_result = lora_linear(torch.cat(inputs))[0]
        expected_result = linear(torch.cat(inputs))[0]

        rtol, atol = TOLERANCES[lora_result.dtype]
855
856
857
858
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
859
860
861
862


@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
863
@pytest.mark.parametrize("repeats", [1, 2, 3])
864
@pytest.mark.parametrize("fully_shard", [True, False])
865
@pytest.mark.parametrize("device", DEVICES)
866
@pytest.mark.parametrize("stage", STAGES)
867
@pytest.mark.parametrize("bias_enabled", [True, False])
868
def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
869
                                device, stage, bias_enabled) -> None:
870

871
872
873
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)

874
    max_loras = 8
875
    torch.set_default_device(device)
876
    punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
877
    assert check_punica_wrapper(punica_wrapper)
878
879
    lora_config = LoRAConfig(max_loras=max_loras,
                             max_lora_rank=8,
880
                             fully_sharded_loras=fully_shard,
881
882
                             lora_dtype=torch.float16,
                             bias_enabled=bias_enabled)
883
884
885
886

    def create_column_parallel_packed_layer():
        if repeats == 2:
            linear = MergedColumnParallelLinear(4096, [4096] * repeats,
887
888
                                                bias=False,
                                                params_dtype=torch.float16)
889
            linear.weight.data = torch.rand_like(linear.weight.data)
890
891
892
            lora_linear = (MergedColumnParallelLinearWithLoRA(linear)
                           if not fully_shard else
                           MergedColumnParallelLinearWithShardedLoRA(linear))
893
        elif repeats == 3:
894
895
896
897
898
            linear = QKVParallelLinear(4096,
                                       64,
                                       32,
                                       bias=False,
                                       params_dtype=torch.float16)
899
            linear.weight.data = torch.rand_like(linear.weight.data)
900
            lora_linear = (MergedQKVParallelLinearWithLoRA(linear)
901
                           if not fully_shard else
902
                           MergedQKVParallelLinearWithShardedLoRA(linear))
903
        else:
904
905
906
907
908
            linear = QKVParallelLinear(4096,
                                       64,
                                       32,
                                       bias=False,
                                       params_dtype=torch.float16)
909
            linear.weight.data = torch.rand_like(linear.weight.data)
910
            lora_linear = QKVParallelLinearWithLoRA(
911
                linear
912
            ) if not fully_shard else QKVParallelLinearWithShardedLoRA(linear)
913
914
915
916
917
918
919

        @dataclass
        class FakeConfig:
            hidden_size = 4096
            num_key_value_heads = 32
            num_attention_heads = 32

920
        n_slices = repeats
921
922
923
        lora_linear.create_lora_weights(max_loras,
                                        lora_config,
                                        model_config=FakeConfig())
924
925
926
927
928
929
        assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len(
            lora_linear.lora_b_stacked) == n_slices)
        if bias_enabled:
            assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices
        else:
            assert lora_linear.lora_bias_stacked is None
930
931
        return linear, lora_linear

932
    for i in range(NUM_RANDOM_SEEDS):
933
934
935
936
937
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)

        linear, lora_linear = create_column_parallel_packed_layer()
938
        assert torch.equal(linear.weight, lora_linear.weight)
939
        lora_linear.set_mapping(punica_wrapper)
940
941
942
943
944
945
946
947
948
949
950
951
        lora_dict, sublora_dict = populate_loras(
            id_to_index,
            layer=lora_linear,
            layer_weights=linear.weight,
            repeats=repeats,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
952
            input_type=torch.float16,
953
            device=device)
954
955
956
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
957

958
        punica_wrapper.update_metadata(
959
960
961
962
963
964
965
966
967
            lora_mapping,
            id_to_index,
            max_loras,
            512,
            lora_config.lora_extra_vocab_size,
        )

        lora_result = lora_linear(torch.cat(inputs))[0]

968
        expected_results: list[torch.Tensor] = []
969
970
971
972
        for input_, lora_id in zip(inputs, prompt_mapping):
            result = linear(input_)[0]
            subloras = sublora_dict[lora_id]
            for i, sublora in enumerate(subloras):
973
974
975
                result[:, sublora.lora_b.shape[1] * i:sublora.lora_b.shape[1] *
                       (i + 1)] += (input_ @ sublora.lora_a @ sublora.lora_b *
                                    sublora.scaling)
976
977
978
979
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
980
981
982
983
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
984
985
986
987
988
989
990
991
992

        for slot_idx in range(max_loras):
            lora_linear.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
993
            input_type=torch.float16,
994
            device=device)
995
996
997
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
998

999
        punica_wrapper.update_metadata(
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
            lora_mapping,
            id_to_index,
            max_loras,
            512,
            lora_config.lora_extra_vocab_size,
        )

        lora_result = lora_linear(torch.cat(inputs))[0]
        expected_result = linear(torch.cat(inputs))[0]

        rtol, atol = TOLERANCES[lora_result.dtype]
1011
1012
1013
1014
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
1015
1016


1017
@pytest.mark.parametrize("tp_size", [1, 2, 4, 8])
1018
1019
@pytest.mark.parametrize(
    "seed", list(range(VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS)))
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
def test_vocab_parallel_embedding_indices(tp_size, seed):
    random.seed(seed)
    vocab_size = random.randint(4000, 64000)
    added_vocab_size = random.randint(0, 1024)
    org_vocab_size = vocab_size - added_vocab_size
    last_org_vocab_end_index = 0
    last_added_vocab_end_index = org_vocab_size
    computed_vocab_size = 0
    computed_org_vocab_size = 0
    computed_added_vocab_size = 0
    vocab_size_padded = -1

1032
1033
1034
    all_org_tokens: list[int] = []
    all_added_tokens: list[int] = []
    token_ids: list[int] = []
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228

    for tp_rank in range(tp_size):
        with patch(
                "vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank",
                return_value=tp_rank
        ), patch(
                "vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_world_size",
                return_value=tp_size):
            vocab_embedding = VocabParallelEmbedding(
                vocab_size, 1, org_num_embeddings=org_vocab_size)
        vocab_size_padded = vocab_embedding.num_embeddings_padded
        shard_indices = vocab_embedding.shard_indices
        # Assert that the ranges are contiguous
        assert shard_indices.org_vocab_start_index == last_org_vocab_end_index
        assert (shard_indices.added_vocab_start_index ==
                last_added_vocab_end_index)

        # Ensure that we are not exceeding the vocab size
        computed_vocab_size += shard_indices.num_elements_padded
        computed_org_vocab_size += shard_indices.num_org_elements
        computed_added_vocab_size += shard_indices.num_added_elements

        # Ensure that the ranges are not overlapping
        all_org_tokens.extend(
            range(shard_indices.org_vocab_start_index,
                  shard_indices.org_vocab_end_index))
        all_added_tokens.extend(
            range(shard_indices.added_vocab_start_index,
                  shard_indices.added_vocab_end_index))

        token_ids.extend(
            range(shard_indices.org_vocab_start_index,
                  shard_indices.org_vocab_end_index))
        token_ids.extend([-1] * (shard_indices.num_org_elements_padded -
                                 shard_indices.num_org_elements))
        token_ids.extend(
            range(shard_indices.added_vocab_start_index,
                  shard_indices.added_vocab_end_index))
        token_ids.extend([-1] * (shard_indices.num_added_elements_padded -
                                 shard_indices.num_added_elements))

        last_org_vocab_end_index = shard_indices.org_vocab_end_index
        last_added_vocab_end_index = shard_indices.added_vocab_end_index

    assert computed_vocab_size == vocab_size_padded
    assert computed_org_vocab_size == org_vocab_size
    assert computed_added_vocab_size == added_vocab_size

    # Ensure that the ranges are not overlapping
    assert len(all_org_tokens) == len(set(all_org_tokens))
    assert len(all_added_tokens) == len(set(all_added_tokens))
    assert not set(all_org_tokens).intersection(set(all_added_tokens))

    token_ids_tensor = torch.tensor(token_ids, dtype=torch.long)
    reindex_mapping = vocab_embedding.get_sharded_to_full_mapping()
    assert reindex_mapping is not None or tp_size == 1
    if reindex_mapping is not None:
        reindexed_token_ids = token_ids_tensor[reindex_mapping]
        expected = torch.tensor(list(range(0, vocab_size)))
        assert reindexed_token_ids[:vocab_size].equal(expected)
        assert torch.all(reindexed_token_ids[vocab_size:] == -1)


def test_get_masked_input_and_mask():
    x = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])

    # base tp 1 case, no padding
    modified_x, _ = get_masked_input_and_mask(x,
                                              org_vocab_start_index=0,
                                              org_vocab_end_index=8,
                                              added_vocab_start_index=8,
                                              added_vocab_end_index=12,
                                              num_org_vocab_padding=0)
    assert torch.equal(x, modified_x)

    # tp 2 case, no padding
    modified_x_rank_0, _ = get_masked_input_and_mask(x,
                                                     org_vocab_start_index=0,
                                                     org_vocab_end_index=4,
                                                     added_vocab_start_index=8,
                                                     added_vocab_end_index=10,
                                                     num_org_vocab_padding=0)
    modified_x_rank_1, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=4,
        org_vocab_end_index=8,
        added_vocab_start_index=10,
        added_vocab_end_index=12,
        num_org_vocab_padding=0)
    assert torch.equal(modified_x_rank_0,
                       torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 0, 0]))
    assert torch.equal(modified_x_rank_1,
                       torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 4, 5]))

    # tp 4 case, no padding
    modified_x_rank_0, _ = get_masked_input_and_mask(x,
                                                     org_vocab_start_index=0,
                                                     org_vocab_end_index=2,
                                                     added_vocab_start_index=8,
                                                     added_vocab_end_index=9,
                                                     num_org_vocab_padding=0)
    modified_x_rank_1, _ = get_masked_input_and_mask(x,
                                                     org_vocab_start_index=2,
                                                     org_vocab_end_index=4,
                                                     added_vocab_start_index=9,
                                                     added_vocab_end_index=10,
                                                     num_org_vocab_padding=0)
    modified_x_rank_2, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=4,
        org_vocab_end_index=6,
        added_vocab_start_index=10,
        added_vocab_end_index=11,
        num_org_vocab_padding=0)
    modified_x_rank_3, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=6,
        org_vocab_end_index=8,
        added_vocab_start_index=11,
        added_vocab_end_index=12,
        num_org_vocab_padding=0)
    assert torch.equal(modified_x_rank_0,
                       torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0]))
    assert torch.equal(modified_x_rank_1,
                       torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0]))
    assert torch.equal(modified_x_rank_2,
                       torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0]))
    assert torch.equal(modified_x_rank_3,
                       torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2]))

    # base tp 1 case, with padding
    modified_x, _ = get_masked_input_and_mask(x,
                                              org_vocab_start_index=0,
                                              org_vocab_end_index=8,
                                              added_vocab_start_index=8,
                                              added_vocab_end_index=12,
                                              num_org_vocab_padding=2)
    assert torch.equal(modified_x,
                       torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13]))

    # tp 2 case, with padding
    modified_x_rank_0, _ = get_masked_input_and_mask(x,
                                                     org_vocab_start_index=0,
                                                     org_vocab_end_index=4,
                                                     added_vocab_start_index=8,
                                                     added_vocab_end_index=10,
                                                     num_org_vocab_padding=2)
    modified_x_rank_1, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=4,
        org_vocab_end_index=8,
        added_vocab_start_index=10,
        added_vocab_end_index=12,
        num_org_vocab_padding=2)
    assert torch.equal(modified_x_rank_0,
                       torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 6, 7, 0, 0]))
    assert torch.equal(modified_x_rank_1,
                       torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 6, 7]))

    # tp 4 case, with padding
    modified_x_rank_0, _ = get_masked_input_and_mask(x,
                                                     org_vocab_start_index=0,
                                                     org_vocab_end_index=2,
                                                     added_vocab_start_index=8,
                                                     added_vocab_end_index=9,
                                                     num_org_vocab_padding=2)
    modified_x_rank_1, _ = get_masked_input_and_mask(x,
                                                     org_vocab_start_index=2,
                                                     org_vocab_end_index=4,
                                                     added_vocab_start_index=9,
                                                     added_vocab_end_index=10,
                                                     num_org_vocab_padding=2)
    modified_x_rank_2, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=4,
        org_vocab_end_index=6,
        added_vocab_start_index=10,
        added_vocab_end_index=11,
        num_org_vocab_padding=2)
    modified_x_rank_3, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=6,
        org_vocab_end_index=8,
        added_vocab_start_index=11,
        added_vocab_end_index=12,
        num_org_vocab_padding=2)
    assert torch.equal(modified_x_rank_0,
                       torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0]))
    assert torch.equal(modified_x_rank_1,
                       torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 4, 0, 0]))
    assert torch.equal(modified_x_rank_2,
                       torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 4, 0]))
    assert torch.equal(modified_x_rank_3,
                       torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 4]))