test_layers.py 49.6 KB
Newer Older
1
2
3
import random
from copy import deepcopy
from dataclasses import dataclass
4
from typing import Dict, List, Optional, Tuple
5
from unittest.mock import patch
6

7
import pytest
8
9
10
11
import torch
import torch.nn.functional as F

from vllm.config import LoRAConfig
12
13
14
from vllm.lora.fully_sharded_layers import (
    ColumnParallelLinearWithShardedLoRA,
    MergedColumnParallelLinearWithShardedLoRA,
15
16
    MergedQKVParallelLinearWithShardedLora, QKVParallelLinearWithShardedLora,
    RowParallelLinearWithShardedLoRA)
17
18
# yapf conflicts with isort for this block
# yapf: disable
19
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
20
                              LinearScalingRotaryEmbeddingWithLora,
21
22
                              LogitsProcessorWithLoRA, LoRAMapping,
                              MergedColumnParallelLinearWithLoRA,
23
                              MergedQKVParallelLinearWithLora,
24
                              QKVParallelLinearWithLora,
25
                              ReplicatedLinearWithLoRA,
26
27
                              RowParallelLinearWithLoRA,
                              VocabParallelEmbeddingWithLoRA)
28
# yapf: enable
29
from vllm.lora.models import (LongContextLoRAContext, LoRALayerWeights,
30
31
                              PackedLoRALayerWeights)
from vllm.lora.punica import PunicaWrapper
32
33
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               MergedColumnParallelLinear,
34
                                               QKVParallelLinear,
35
                                               ReplicatedLinear,
36
37
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
38
from vllm.model_executor.layers.rotary_embedding import get_rope
39
from vllm.model_executor.layers.vocab_parallel_embedding import (
40
    ParallelLMHead, VocabParallelEmbedding, get_masked_input_and_mask)
41
from vllm.model_executor.utils import set_random_seed
42
from vllm.platforms import current_platform
43
44
45
46
47
48
49
50

from .utils import DummyLoRAManager

TOLERANCES = {
    torch.float16: (5e-3, 5e-3),
    torch.float32: (5e-3, 5e-3),
    torch.bfloat16: (3e-2, 2e-2),
}
51
52
53
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
54

55
56
57
# We will launch different triton kernels between the prefill and decode
# stages, so we need to verify this. prefill stage(True) or decode stage(False)
STAGES = [True, False]
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115


def get_random_id_to_index(num_loras: int,
                           num_slots: int,
                           log: bool = True) -> List[Optional[int]]:
    """Creates a random lora_id_to_index mapping.

    Args:
        num_loras: The number of active loras in the mapping.
        num_slots: The number of slots in the mapping. Must be larger
            than num_loras.
        log: Whether to log the output.
    """

    if num_loras > num_slots:
        raise ValueError(
            f"num_loras is higher than num_slots: {num_loras} > {num_slots}. "
            "num_loras must be less than or equal to num_slots.")

    slots: List[Optional[int]] = [None] * num_slots
    random_slot_selections = (torch.randperm(num_slots)[:num_loras]).tolist()
    for lora_id, slot_idx in enumerate(random_slot_selections, start=1):
        slots[slot_idx] = lora_id

    if log:
        print(f"Created lora_id_to_index mapping: {slots}.")

    return slots


def populate_loras(
    id_to_index: List[Optional[int]],
    layer: BaseLayerWithLoRA,
    layer_weights: torch.Tensor,
    generate_embeddings_tensor: int = 0,
    repeats: int = 1,
) -> Tuple[Dict[int, LoRALayerWeights], Dict[int, List[LoRALayerWeights]]]:
    """This method populates the lora layers with lora weights.

    Args:
        id_to_index: a list of lora ids. The index of the lora id
            represents which memory slot the lora matrices are
            stored in. A None value indicates a free slot.
        layer: the LoRAlayer to populate.
        layer_weights: the PyTorch tensor containing the layer's
            weights.
        generate_embeddings_tensor: whether to generate an
            embeddings tensor for each LoRA.
        repeats: must only be set for column parallel packed
            layers. Indicates the number of loras to compose
            together to create a single lora layer.
    """

    # Dictionary that maps the lora ID to the
    # corresponding lora weights.
    lora_dict: Dict[int, LoRALayerWeights] = dict()

    # Dictionary that maps the lora ID to the
116
    # corresponding subloras.
117
118
119
120
    sublora_dict: Dict[int, List[LoRALayerWeights]] = dict()

    for slot_idx, lora_id in enumerate(id_to_index):
        if lora_id is not None:
121
            subloras: List[LoRALayerWeights] = []
122
123
            sublora_len = layer_weights.shape[0] // repeats
            for i in range(repeats):
124
125
126
127
128
129
                sublora = DummyLoRAManager(
                    layer_weights.device).init_random_lora(
                        module_name=f"fake_{i}",
                        weight=layer_weights,
                        generate_embeddings_tensor=generate_embeddings_tensor,
                    )
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
                sublora.lora_b = sublora.lora_b[:, (sublora_len *
                                                    i):(sublora_len * (i + 1))]
                sublora.optimize()
                subloras.append(sublora)

            lora = PackedLoRALayerWeights.pack(
                subloras) if repeats > 1 else subloras[0]

            layer.set_lora(
                slot_idx,
                lora_a=lora.lora_a,
                lora_b=lora.lora_b,
                embeddings_tensor=lora.embeddings_tensor,
            )

            lora_dict[lora_id] = lora
            sublora_dict[lora_id] = subloras

    return lora_dict, sublora_dict


def create_random_inputs(
    active_lora_ids: List[int],
    num_inputs: int,
    input_size: Tuple[int, ...],
    input_range: Tuple[float, float],
    input_type: torch.dtype = torch.int,
157
    device: torch.device = "cuda"
158
159
160
161
162
163
164
165
166
167
168
169
170
171
) -> Tuple[List[torch.Tensor], List[int], List[int]]:
    """Creates random inputs.

    Args:
        active_lora_ids: lora IDs of active lora weights.
        num_inputs: the number of inputs to create.
        input_size: the size of each individual input.
        input_range: the range of values to include in the input.
            input_range[0] <= possible input values < input_range[1]
        input_type: the type of values in the input.
    """

    low, high = input_range

172
173
174
175
    inputs: List[torch.Tensor] = []
    index_mapping: List[int] = []
    prompt_mapping: List[int] = []

176
177
178
    for _ in range(num_inputs):
        if input_type == torch.int:
            inputs.append(
179
180
181
182
                torch.randint(low=int(low),
                              high=int(high),
                              size=input_size,
                              device=device))
183
184
        else:
            inputs.append(
185
186
                torch.rand(size=input_size, dtype=input_type, device=device) *
                high + low)
187
188
189
190
191
192
193
194
195
196

        lora_id = random.choice(active_lora_ids)
        index_mapping += [lora_id] * input_size[0]
        prompt_mapping += [lora_id]

    return inputs, index_mapping, prompt_mapping


@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
197
@pytest.mark.parametrize("device", CUDA_DEVICES)
198
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
199
200
@pytest.mark.parametrize("stage", STAGES)
def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
201
202
203
204
    # For multi-GPU testing of Triton kernel, we must explicitly set the CUDA
    # device, see: https://github.com/triton-lang/triton/issues/2925
    # Same below.
    torch.cuda.set_device(device)
205

206
    torch.set_default_device(device)
207
    max_loras = 8
208
    punica_wrapper = PunicaWrapper(8192, 256, device)
209
210
211
212
213
    lora_config = LoRAConfig(max_loras=max_loras,
                             max_lora_rank=8,
                             lora_dtype=torch.float16)

    def create_random_embedding_layer():
214
        embedding = VocabParallelEmbedding(vocab_size, 256)
215
        embedding.weight.data = torch.rand_like(embedding.weight.data)
216
        embedding.weight.data[vocab_size:, :] = 0
217
218
219
220
221
222
223
224
225
226
        lora_embedding = VocabParallelEmbeddingWithLoRA(embedding)
        lora_embedding.create_lora_weights(max_loras, lora_config)

        return embedding, lora_embedding

    for i in range(10):
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
        embedding, lora_embedding = create_random_embedding_layer()
227
        lora_embedding.set_mapping(punica_wrapper)
228
229
230
231
232
233
234
235
236
237
        lora_dict, _ = populate_loras(
            id_to_index,
            layer=lora_embedding,
            layer_weights=embedding.weight.T,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=num_loras * 3,
            input_size=(200, ),
238
            input_range=(1, vocab_size),
239
            device=device)
240
241
242
243
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
        punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
244
245
                                       vocab_size,
                                       lora_config.lora_extra_vocab_size)
246
247
248

        lora_result = lora_embedding(torch.cat(inputs))

249
        expected_results: List[torch.Tensor] = []
250
251
252
253
254
255
256
257
258
259
260
261
        for input_, lora_id in zip(inputs, prompt_mapping):
            lora = lora_dict[lora_id]
            result = embedding(input_)
            after_a = F.embedding(
                input_,
                lora.lora_a,
            )
            result += (after_a @ lora.lora_b)
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
262
263
264
265
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
266
267
268
269
270
271
272
273
274
275

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
            lora_embedding.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=num_loras * 3,
            input_size=(200, ),
276
            input_range=(1, vocab_size),
277
            device=device)
278
279
280
281
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
        punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
282
283
                                       vocab_size,
                                       lora_config.lora_extra_vocab_size)
284
285
286
287
288

        lora_result = lora_embedding(torch.cat(inputs))
        expected_result = embedding(torch.cat(inputs))

        rtol, atol = TOLERANCES[lora_result.dtype]
289
290
291
292
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
293
294
295


@torch.inference_mode()
296
297
# @pytest.mark.skip(
#     reason="Fails when loras are in any slot other than the first.")
298
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
299
@pytest.mark.parametrize("device", CUDA_DEVICES)
300
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
301
@pytest.mark.parametrize("stage", STAGES)
302
def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
303
                                        vocab_size, stage) -> None:
304

305
    torch.cuda.set_device(device)
306
    torch.set_default_device(device)
307
    max_loras = 8
308
    punica_wrapper = PunicaWrapper(8192, 256, device)
309
310
311
312
313
    lora_config = LoRAConfig(max_loras=max_loras,
                             max_lora_rank=8,
                             lora_dtype=torch.float16)

    def create_random_embedding_layer():
314
        embedding = VocabParallelEmbedding(vocab_size, 256)
315
316
        embedding_data = torch.rand_like(embedding.weight.data)
        embedding.weight.data = embedding_data
317
        embedding.weight.data[vocab_size:, :] = 0
318
        expanded_embedding = VocabParallelEmbedding(
319
            vocab_size + lora_config.lora_extra_vocab_size * max_loras,
320
            256,
321
322
            org_num_embeddings=vocab_size)
        expanded_embedding.weight.data[:vocab_size, :] = embedding_data
323
        # We need to deepcopy the embedding as it will be modified
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
        # in place
        lora_embedding = VocabParallelEmbeddingWithLoRA(
            deepcopy(expanded_embedding))
        lora_embedding.create_lora_weights(max_loras, lora_config)

        return expanded_embedding, lora_embedding

    for i in range(10):
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
        expanded_embedding, lora_embedding = create_random_embedding_layer()
        lora_dict, _ = populate_loras(
            id_to_index,
            layer=lora_embedding,
            layer_weights=torch.zeros(
340
                (256, vocab_size + lora_config.lora_extra_vocab_size)),
341
342
343
            generate_embeddings_tensor=256,
        )

344
        lora_embedding.set_mapping(punica_wrapper)
345
346
347
348
349
350
351
352
        # All embeddings tensors have the same shape.
        embeddings_tensors = [
            lora_dict[id].embeddings_tensor for id in sorted(lora_dict.keys())
        ]
        embeddings_tensor_len = embeddings_tensors[0].shape[0]

        # Add empty embeddings_tensors for unoccupied lora slots.
        for _ in range(max_loras - len(embeddings_tensors)):
353
            embeddings_tensors.append(torch.zeros(embeddings_tensors[0].shape))
354
355
356
357
358

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=num_loras * 3,
            input_size=(200, ),
359
            input_range=(1, vocab_size),
360
            device=device)
361
362
363
364
365
366
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
        punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
                                       vocab_size,
                                       lora_config.lora_extra_vocab_size)
367
368
369
370
371
372
373
        original_inputs = deepcopy(inputs)

        # Force some of the inputs to be in the extended embeddings range
        # to guarantee that their behavior is tested.
        for input_, original_input_, lora_id in zip(inputs, original_inputs,
                                                    prompt_mapping):
            embedding_id = lora_id - 1
374
375
376
377
378
            input_[-1] = vocab_size + (embedding_id * embeddings_tensor_len)
            original_input_[-1] = vocab_size
            input_[-2] = vocab_size + (
                (embedding_id + 1) * embeddings_tensor_len - 1)
            original_input_[-2] = vocab_size + embeddings_tensor_len - 1
379

380
        expanded_embedding.weight[vocab_size:vocab_size +
381
382
383
384
385
                                  (embeddings_tensor_len *
                                   max_loras)] = torch.cat(embeddings_tensors)

        lora_result = lora_embedding(torch.cat(original_inputs))

386
        expected_results: List[torch.Tensor] = []
387
388
389
390
391
392
393
394
395
396
397
398
399
        for input_, original_input_, lora_id in zip(inputs, original_inputs,
                                                    prompt_mapping):
            lora = lora_dict[lora_id]
            result = expanded_embedding(input_)
            after_a = F.embedding(
                original_input_,
                lora.lora_a,
            )
            result += (after_a @ lora.lora_b)
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
400
401
402
403
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
404
405
406
407
408
409
410
411
412
413

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
            lora_embedding.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=num_loras * 3,
            input_size=(200, ),
414
            input_range=(1, vocab_size),
415
            device=device)
416
        original_inputs = deepcopy(inputs)
417
418
419
420
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
        punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
421
422
                                       vocab_size,
                                       lora_config.lora_extra_vocab_size)
423
424
425
426
        lora_result = lora_embedding(torch.cat(original_inputs))
        expected_result = expanded_embedding(torch.cat(inputs))

        rtol, atol = TOLERANCES[lora_result.dtype]
427
428
429
430
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
431
432
433
434


@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
435
@pytest.mark.parametrize("device", CUDA_DEVICES)
436
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 256512])
437
438
439
@pytest.mark.parametrize("stage", STAGES)
def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
                                  stage) -> None:
440

441
    torch.cuda.set_device(device)
442
    torch.set_default_device(device)
443
    max_loras = 8
444
    punica_wrapper = PunicaWrapper(8192, 256, device)
445
446
447
448
    lora_config = LoRAConfig(max_loras=max_loras,
                             max_lora_rank=8,
                             lora_dtype=torch.float16)

449
    def _pretest():
450
        linear = ParallelLMHead(vocab_size + lora_config.lora_extra_vocab_size,
451
452
453
                                1024,
                                vocab_size,
                                params_dtype=torch.float16)
454
        linear.weight.data = torch.rand_like(linear.weight.data)
455
        linear.weight.data[:, vocab_size:] = 0
456
        logits_processor = LogitsProcessor(
457
            vocab_size + lora_config.lora_extra_vocab_size, vocab_size)
458
        lora_logits_processor = LogitsProcessorWithLoRA(
459
460
            logits_processor, 1024, linear.weight.dtype, linear.weight.device,
            None)
461
        lora_logits_processor.create_lora_weights(max_loras, lora_config)
462

463
        return linear, logits_processor, lora_logits_processor
464
465
466
467
468

    for i in range(10):
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
469
        linear, logits_processor, lora_logits_processor = _pretest()
470
        lora_logits_processor.set_mapping(punica_wrapper)
471
472
473
        # NOTE: all the generated loras share the same embeddings tensor.
        lora_dict, _ = populate_loras(
            id_to_index,
474
            layer=lora_logits_processor,
475
476
477
478
479
480
481
482
483
484
485
            layer_weights=linear.weight,
            generate_embeddings_tensor=1024,
        )
        embeddings_tensor = list(lora_dict.values())[0].embeddings_tensor
        embeddings_tensor_len = embeddings_tensor.shape[0]

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=8 * num_loras,  # * 3,
            input_size=(1, 1024),
            input_range=(0, 1),
486
            input_type=torch.float16,
487
            device=device)
488
489
490
491
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
        punica_wrapper.update_metadata(
492
493
494
            lora_mapping,
            id_to_index,
            max_loras,
495
            vocab_size,
496
497
            lora_config.lora_extra_vocab_size,
        )
498
        input_ = torch.rand(20, 1024)
499

500
501
        lora_result = lora_logits_processor._get_logits(
            hidden_states=torch.cat(inputs),
502
            lm_head=linear,
503
            embedding_bias=None)
504

505
        original_lm_head = deepcopy(linear)
506

507
508
        linear.weight[logits_processor.
                      org_vocab_size:logits_processor.org_vocab_size +
509
510
                      embeddings_tensor_len] = embeddings_tensor

511
        logits_processor.org_vocab_size = (vocab_size +
512
                                           lora_config.lora_extra_vocab_size)
513
        expected_results: List[torch.Tensor] = []
514
515
        for input_, lora_id in zip(inputs, prompt_mapping):
            lora = lora_dict[lora_id]
516
            result = logits_processor._get_logits(hidden_states=input_,
517
                                                  lm_head=linear,
518
                                                  embedding_bias=None)
519
            result[:, vocab_size + embeddings_tensor_len:] = float("-inf")
520
521
522
            result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
            expected_results.append(result)
        expected_result = torch.cat(expected_results)
523
        logits_processor.org_vocab_size = vocab_size
524
525
526
527

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
528
            lora_logits_processor.reset_lora(slot_idx)
529
530
531
532
533
534

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=8 * num_loras * 3,
            input_size=(1, 1024),
            input_range=(0, 1),
535
            input_type=torch.float16,
536
            device=device)
537
538
539
540
541
542
543
544
545
546
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
        punica_wrapper.update_metadata(
            lora_mapping,
            id_to_index,
            max_loras,
            vocab_size,
            lora_config.lora_extra_vocab_size,
        )
547
548
549

        lora_result = lora_logits_processor._get_logits(
            hidden_states=torch.cat(inputs),
550
            lm_head=original_lm_head,
551
            embedding_bias=None)[:, :vocab_size]
552
553
        expected_result = logits_processor._get_logits(
            hidden_states=torch.cat(inputs),
554
            lm_head=original_lm_head,
555
            embedding_bias=None)
556
557

        rtol, atol = TOLERANCES[lora_result.dtype]
558
559
560
561
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
562
563


564
565
566
567
568
569
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("stage", STAGES)
def test_linear_replicated(dist_init, num_loras, device, stage) -> None:

570
    torch.cuda.set_device(device)
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
    torch.set_default_device(device)
    punica_wrapper = PunicaWrapper(8192, 256, device)
    max_loras = 8
    lora_config = LoRAConfig(max_loras=max_loras,
                             max_lora_rank=8,
                             lora_dtype=torch.float16)

    def create_random_linear_replicated_layer():

        linear = ReplicatedLinear(4096,
                                  4096,
                                  bias=False,
                                  params_dtype=torch.float16)
        linear.weight.data = torch.rand_like(linear.weight.data)
        lora_linear = ReplicatedLinearWithLoRA(linear)

        lora_linear.create_lora_weights(max_loras, lora_config)

        return linear, lora_linear

    for i in range(10):
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
        linear, lora_linear = create_random_linear_replicated_layer()
        lora_linear.set_mapping(punica_wrapper)
        lora_dict, _ = populate_loras(
            id_to_index,
            layer=lora_linear,
            layer_weights=linear.weight,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
            input_type=torch.float16,
609
            device=device)
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
        punica_wrapper.update_metadata(
            lora_mapping,
            id_to_index,
            max_loras,
            512,
            lora_config.lora_extra_vocab_size,
        )

        lora_result = lora_linear(torch.cat(inputs))[0]

        expected_results: List[torch.Tensor] = []
        for input_, lora_id in zip(inputs, prompt_mapping):
            lora = lora_dict[lora_id]
            result = linear(input_)[0]
            result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
632
633
634
635
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
636
637
638
639
640
641
642
643
644
645
646
647

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
            lora_linear.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
            input_type=torch.float16,
648
            device=device)
649
650
651
652
653
654
655
656
657
658
659
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)

        punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
                                       512, lora_config.lora_extra_vocab_size)

        lora_result = lora_linear(torch.cat(inputs))[0]
        expected_result = linear(torch.cat(inputs))[0]

        rtol, atol = TOLERANCES[lora_result.dtype]
660
661
662
663
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
664
665


666
667
668
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("orientation", ["row", "column"])
669
@pytest.mark.parametrize("fully_shard", [True, False])
670
@pytest.mark.parametrize("device", CUDA_DEVICES)
671
@pytest.mark.parametrize("stage", STAGES)
672
def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
673
                         device, stage) -> None:
674

675
    torch.cuda.set_device(device)
676
    torch.set_default_device(device)
677
    punica_wrapper = PunicaWrapper(8192, 256, device)
678
679
680
    max_loras = 8
    lora_config = LoRAConfig(max_loras=max_loras,
                             max_lora_rank=8,
681
                             fully_sharded_loras=fully_shard,
682
683
684
685
                             lora_dtype=torch.float16)

    def create_random_linear_parallel_layer():
        if orientation == "row":
686
687
688
689
            linear = RowParallelLinear(4096,
                                       4096,
                                       bias=False,
                                       params_dtype=torch.float16)
690
            linear.weight.data = torch.rand_like(linear.weight.data)
691
692
            lora_linear = (RowParallelLinearWithLoRA(linear) if not fully_shard
                           else RowParallelLinearWithShardedLoRA(linear))
693
        else:
694
695
696
697
            linear = ColumnParallelLinear(4096,
                                          4096,
                                          bias=False,
                                          params_dtype=torch.float16)
698
            linear.weight.data = torch.rand_like(linear.weight.data)
699
700
701
            lora_linear = (ColumnParallelLinearWithLoRA(linear)
                           if not fully_shard else
                           ColumnParallelLinearWithShardedLoRA(linear))
702
703
704
705
706
707
708
709
710
        lora_linear.create_lora_weights(max_loras, lora_config)

        return linear, lora_linear

    for i in range(10):
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
        linear, lora_linear = create_random_linear_parallel_layer()
711
        lora_linear.set_mapping(punica_wrapper)
712
713
714
715
716
717
718
719
720
721
722
        lora_dict, _ = populate_loras(
            id_to_index,
            layer=lora_linear,
            layer_weights=linear.weight,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
723
            input_type=torch.float16,
724
            device=device)
725
726
727
728
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
        punica_wrapper.update_metadata(
729
730
731
732
733
734
735
736
737
            lora_mapping,
            id_to_index,
            max_loras,
            512,
            lora_config.lora_extra_vocab_size,
        )

        lora_result = lora_linear(torch.cat(inputs))[0]

738
        expected_results: List[torch.Tensor] = []
739
740
741
742
743
744
745
746
        for input_, lora_id in zip(inputs, prompt_mapping):
            lora = lora_dict[lora_id]
            result = linear(input_)[0]
            result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
747
748
749
750
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
751
752
753
754
755
756
757
758
759
760
761

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
            lora_linear.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
762
            input_type=torch.float16,
763
            device=device)
764
765
766
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
767

768
        punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
769
770
771
772
773
774
                                       512, lora_config.lora_extra_vocab_size)

        lora_result = lora_linear(torch.cat(inputs))[0]
        expected_result = linear(torch.cat(inputs))[0]

        rtol, atol = TOLERANCES[lora_result.dtype]
775
776
777
778
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
779
780
781
782


@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
783
@pytest.mark.parametrize("repeats", [1, 2, 3])
784
@pytest.mark.parametrize("fully_shard", [True, False])
785
@pytest.mark.parametrize("device", CUDA_DEVICES)
786
@pytest.mark.parametrize("stage", STAGES)
787
def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
788
                                device, stage) -> None:
789

790
    torch.cuda.set_device(device)
791
    torch.set_default_device(device)
792
    punica_wrapper = PunicaWrapper(8192, 256, device)
793
794
795
    max_loras = 8
    lora_config = LoRAConfig(max_loras=max_loras,
                             max_lora_rank=8,
796
                             fully_sharded_loras=fully_shard,
797
798
799
800
801
                             lora_dtype=torch.float16)

    def create_column_parallel_packed_layer():
        if repeats == 2:
            linear = MergedColumnParallelLinear(4096, [4096] * repeats,
802
803
                                                bias=False,
                                                params_dtype=torch.float16)
804
            linear.weight.data = torch.rand_like(linear.weight.data)
805
806
807
            lora_linear = (MergedColumnParallelLinearWithLoRA(linear)
                           if not fully_shard else
                           MergedColumnParallelLinearWithShardedLoRA(linear))
808
        elif repeats == 3:
809
810
811
812
813
            linear = QKVParallelLinear(4096,
                                       64,
                                       32,
                                       bias=False,
                                       params_dtype=torch.float16)
814
            linear.weight.data = torch.rand_like(linear.weight.data)
815
816
817
            lora_linear = (MergedQKVParallelLinearWithLora(linear)
                           if not fully_shard else
                           MergedQKVParallelLinearWithShardedLora(linear))
818
        else:
819
820
821
822
823
            linear = QKVParallelLinear(4096,
                                       64,
                                       32,
                                       bias=False,
                                       params_dtype=torch.float16)
824
            linear.weight.data = torch.rand_like(linear.weight.data)
825
826
827
            lora_linear = QKVParallelLinearWithLora(
                linear
            ) if not fully_shard else QKVParallelLinearWithShardedLora(linear)
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846

        @dataclass
        class FakeConfig:
            hidden_size = 4096
            num_key_value_heads = 32
            num_attention_heads = 32

        lora_linear.create_lora_weights(max_loras,
                                        lora_config,
                                        model_config=FakeConfig())

        return linear, lora_linear

    for i in range(10):
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)

        linear, lora_linear = create_column_parallel_packed_layer()
847
        lora_linear.set_mapping(punica_wrapper)
848
849
850
851
852
853
854
855
856
857
858
859
        lora_dict, sublora_dict = populate_loras(
            id_to_index,
            layer=lora_linear,
            layer_weights=linear.weight,
            repeats=repeats,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
860
            input_type=torch.float16,
861
            device=device)
862
863
864
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
865

866
        punica_wrapper.update_metadata(
867
868
869
870
871
872
873
874
875
            lora_mapping,
            id_to_index,
            max_loras,
            512,
            lora_config.lora_extra_vocab_size,
        )

        lora_result = lora_linear(torch.cat(inputs))[0]

876
        expected_results: List[torch.Tensor] = []
877
878
879
880
        for input_, lora_id in zip(inputs, prompt_mapping):
            result = linear(input_)[0]
            subloras = sublora_dict[lora_id]
            for i, sublora in enumerate(subloras):
881
882
883
                result[:, sublora.lora_b.shape[1] * i:sublora.lora_b.shape[1] *
                       (i + 1)] += (input_ @ sublora.lora_a @ sublora.lora_b *
                                    sublora.scaling)
884
885
886
887
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
888
889
890
891
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
892
893
894
895
896
897
898
899
900

        for slot_idx in range(max_loras):
            lora_linear.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
901
            input_type=torch.float16,
902
            device=device)
903
904
905
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
906

907
        punica_wrapper.update_metadata(
908
909
910
911
912
913
            lora_mapping,
            id_to_index,
            max_loras,
            512,
            lora_config.lora_extra_vocab_size,
        )
914
        # lora_linear.set_mapping(*mapping_info)
915
916
917
918
919

        lora_result = lora_linear(torch.cat(inputs))[0]
        expected_result = linear(torch.cat(inputs))[0]

        rtol, atol = TOLERANCES[lora_result.dtype]
920
921
922
923
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941


@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 8])
@pytest.mark.parametrize("device", ["cuda"])
@pytest.mark.parametrize("scaling_factors", [(1.0, ), (4.0, ), (4.0, 8.0),
                                             (6.0, 1.0)])
@pytest.mark.parametrize("max_position", [11, 4096, 32768])
@pytest.mark.parametrize("is_neox_style", [True, False])
@pytest.mark.parametrize("rotary_dim", [None, 32])
@pytest.mark.parametrize("head_size", [32, 108])
@pytest.mark.parametrize("seq_len", [11, 1024])
def test_rotary_embedding_long_context(dist_init, num_loras, device,
                                       scaling_factors, max_position,
                                       is_neox_style, rotary_dim, head_size,
                                       seq_len) -> None:
    dtype = torch.float16
    seed = 0
942
    current_platform.seed_everything(seed)
943
    torch.set_default_device(device)
944
    punica_wrapper = PunicaWrapper(8192, 256, device)
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
    max_loras = 8
    lora_config = LoRAConfig(max_loras=max_loras,
                             max_lora_rank=8,
                             long_lora_scaling_factors=scaling_factors,
                             lora_dtype=dtype)

    if rotary_dim is None:
        rotary_dim = head_size
    base = 10000
    batch_size = 5 * num_loras
    num_heads = 7

    # Verify lora is equivalent to linear scaling rotary embedding.
    rope = get_rope(
        head_size,
        rotary_dim,
        max_position,
        base,
        is_neox_style,
    )
    lora_rope = LinearScalingRotaryEmbeddingWithLora(rope)
966
    lora_rope.set_mapping(punica_wrapper)
967
968
969
    lora_rope.create_lora_weights(max_loras, lora_config)
    linear_rope = get_rope(head_size, rotary_dim, max_position, base,
                           is_neox_style, {
970
                               "rope_type": "linear",
971
972
973
974
975
976
977
978
979
980
                               "factor": scaling_factors
                           })
    linear_rope = linear_rope.to(dtype=dtype)
    id_to_index = get_random_id_to_index(num_loras, max_loras)
    _, index_mapping, prompt_mapping = create_random_inputs(
        active_lora_ids=[0],
        num_inputs=batch_size,
        input_size=(1, max_position),
        input_range=(0, lora_config.lora_extra_vocab_size),
        input_type=torch.float16,
981
        device=device)
982

983
984
985
986
987
988
989
990
991
992
993
994
995
996
    lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
    long_lora_context = LongContextLoRAContext(list(scaling_factors),
                                               rotary_dim)

    next_expected_offset = 0
    # Make sure the offset is correct.
    scaling_factor_to_offset = lora_rope.scaling_factor_to_offset
    for scaling_factor, offset in scaling_factor_to_offset.items():
        assert offset == next_expected_offset
        next_expected_offset += scaling_factor * max_position

    for i in range(len(scaling_factors)):
        long_lora_context.offsets_by_lora_id[i] = scaling_factor_to_offset.get(
            scaling_factors[i], 0)
997
    punica_wrapper.update_metadata(
998
999
1000
1001
1002
1003
1004
        lora_mapping,
        id_to_index,
        max_loras,
        512,
        lora_config.lora_extra_vocab_size,
        long_lora_context=long_lora_context,
    )
1005
    # lora_rope.set_mapping(*mapping_info)
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017

    positions = torch.randint(0, max_position, (batch_size, seq_len))
    query = torch.randn(batch_size,
                        seq_len,
                        num_heads * head_size,
                        dtype=dtype)
    key = torch.randn_like(query)
    ref_q, ref_k = linear_rope(positions, query, key)
    actual_q, actual_k = lora_rope(positions, query, key)

    torch.allclose(ref_q, actual_q)
    torch.allclose(ref_k, actual_k)
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033


@pytest.mark.parametrize("tp_size", [1, 2, 4, 8])
@pytest.mark.parametrize("seed", list(range(256)))
def test_vocab_parallel_embedding_indices(tp_size, seed):
    random.seed(seed)
    vocab_size = random.randint(4000, 64000)
    added_vocab_size = random.randint(0, 1024)
    org_vocab_size = vocab_size - added_vocab_size
    last_org_vocab_end_index = 0
    last_added_vocab_end_index = org_vocab_size
    computed_vocab_size = 0
    computed_org_vocab_size = 0
    computed_added_vocab_size = 0
    vocab_size_padded = -1

1034
1035
1036
    all_org_tokens: List[int] = []
    all_added_tokens: List[int] = []
    token_ids: List[int] = []
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230

    for tp_rank in range(tp_size):
        with patch(
                "vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank",
                return_value=tp_rank
        ), patch(
                "vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_world_size",
                return_value=tp_size):
            vocab_embedding = VocabParallelEmbedding(
                vocab_size, 1, org_num_embeddings=org_vocab_size)
        vocab_size_padded = vocab_embedding.num_embeddings_padded
        shard_indices = vocab_embedding.shard_indices
        # Assert that the ranges are contiguous
        assert shard_indices.org_vocab_start_index == last_org_vocab_end_index
        assert (shard_indices.added_vocab_start_index ==
                last_added_vocab_end_index)

        # Ensure that we are not exceeding the vocab size
        computed_vocab_size += shard_indices.num_elements_padded
        computed_org_vocab_size += shard_indices.num_org_elements
        computed_added_vocab_size += shard_indices.num_added_elements

        # Ensure that the ranges are not overlapping
        all_org_tokens.extend(
            range(shard_indices.org_vocab_start_index,
                  shard_indices.org_vocab_end_index))
        all_added_tokens.extend(
            range(shard_indices.added_vocab_start_index,
                  shard_indices.added_vocab_end_index))

        token_ids.extend(
            range(shard_indices.org_vocab_start_index,
                  shard_indices.org_vocab_end_index))
        token_ids.extend([-1] * (shard_indices.num_org_elements_padded -
                                 shard_indices.num_org_elements))
        token_ids.extend(
            range(shard_indices.added_vocab_start_index,
                  shard_indices.added_vocab_end_index))
        token_ids.extend([-1] * (shard_indices.num_added_elements_padded -
                                 shard_indices.num_added_elements))

        last_org_vocab_end_index = shard_indices.org_vocab_end_index
        last_added_vocab_end_index = shard_indices.added_vocab_end_index

    assert computed_vocab_size == vocab_size_padded
    assert computed_org_vocab_size == org_vocab_size
    assert computed_added_vocab_size == added_vocab_size

    # Ensure that the ranges are not overlapping
    assert len(all_org_tokens) == len(set(all_org_tokens))
    assert len(all_added_tokens) == len(set(all_added_tokens))
    assert not set(all_org_tokens).intersection(set(all_added_tokens))

    token_ids_tensor = torch.tensor(token_ids, dtype=torch.long)
    reindex_mapping = vocab_embedding.get_sharded_to_full_mapping()
    assert reindex_mapping is not None or tp_size == 1
    if reindex_mapping is not None:
        reindexed_token_ids = token_ids_tensor[reindex_mapping]
        expected = torch.tensor(list(range(0, vocab_size)))
        assert reindexed_token_ids[:vocab_size].equal(expected)
        assert torch.all(reindexed_token_ids[vocab_size:] == -1)


def test_get_masked_input_and_mask():
    x = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])

    # base tp 1 case, no padding
    modified_x, _ = get_masked_input_and_mask(x,
                                              org_vocab_start_index=0,
                                              org_vocab_end_index=8,
                                              added_vocab_start_index=8,
                                              added_vocab_end_index=12,
                                              num_org_vocab_padding=0)
    assert torch.equal(x, modified_x)

    # tp 2 case, no padding
    modified_x_rank_0, _ = get_masked_input_and_mask(x,
                                                     org_vocab_start_index=0,
                                                     org_vocab_end_index=4,
                                                     added_vocab_start_index=8,
                                                     added_vocab_end_index=10,
                                                     num_org_vocab_padding=0)
    modified_x_rank_1, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=4,
        org_vocab_end_index=8,
        added_vocab_start_index=10,
        added_vocab_end_index=12,
        num_org_vocab_padding=0)
    assert torch.equal(modified_x_rank_0,
                       torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 0, 0]))
    assert torch.equal(modified_x_rank_1,
                       torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 4, 5]))

    # tp 4 case, no padding
    modified_x_rank_0, _ = get_masked_input_and_mask(x,
                                                     org_vocab_start_index=0,
                                                     org_vocab_end_index=2,
                                                     added_vocab_start_index=8,
                                                     added_vocab_end_index=9,
                                                     num_org_vocab_padding=0)
    modified_x_rank_1, _ = get_masked_input_and_mask(x,
                                                     org_vocab_start_index=2,
                                                     org_vocab_end_index=4,
                                                     added_vocab_start_index=9,
                                                     added_vocab_end_index=10,
                                                     num_org_vocab_padding=0)
    modified_x_rank_2, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=4,
        org_vocab_end_index=6,
        added_vocab_start_index=10,
        added_vocab_end_index=11,
        num_org_vocab_padding=0)
    modified_x_rank_3, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=6,
        org_vocab_end_index=8,
        added_vocab_start_index=11,
        added_vocab_end_index=12,
        num_org_vocab_padding=0)
    assert torch.equal(modified_x_rank_0,
                       torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0]))
    assert torch.equal(modified_x_rank_1,
                       torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0]))
    assert torch.equal(modified_x_rank_2,
                       torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0]))
    assert torch.equal(modified_x_rank_3,
                       torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2]))

    # base tp 1 case, with padding
    modified_x, _ = get_masked_input_and_mask(x,
                                              org_vocab_start_index=0,
                                              org_vocab_end_index=8,
                                              added_vocab_start_index=8,
                                              added_vocab_end_index=12,
                                              num_org_vocab_padding=2)
    assert torch.equal(modified_x,
                       torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13]))

    # tp 2 case, with padding
    modified_x_rank_0, _ = get_masked_input_and_mask(x,
                                                     org_vocab_start_index=0,
                                                     org_vocab_end_index=4,
                                                     added_vocab_start_index=8,
                                                     added_vocab_end_index=10,
                                                     num_org_vocab_padding=2)
    modified_x_rank_1, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=4,
        org_vocab_end_index=8,
        added_vocab_start_index=10,
        added_vocab_end_index=12,
        num_org_vocab_padding=2)
    assert torch.equal(modified_x_rank_0,
                       torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 6, 7, 0, 0]))
    assert torch.equal(modified_x_rank_1,
                       torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 6, 7]))

    # tp 4 case, with padding
    modified_x_rank_0, _ = get_masked_input_and_mask(x,
                                                     org_vocab_start_index=0,
                                                     org_vocab_end_index=2,
                                                     added_vocab_start_index=8,
                                                     added_vocab_end_index=9,
                                                     num_org_vocab_padding=2)
    modified_x_rank_1, _ = get_masked_input_and_mask(x,
                                                     org_vocab_start_index=2,
                                                     org_vocab_end_index=4,
                                                     added_vocab_start_index=9,
                                                     added_vocab_end_index=10,
                                                     num_org_vocab_padding=2)
    modified_x_rank_2, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=4,
        org_vocab_end_index=6,
        added_vocab_start_index=10,
        added_vocab_end_index=11,
        num_org_vocab_padding=2)
    modified_x_rank_3, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=6,
        org_vocab_end_index=8,
        added_vocab_start_index=11,
        added_vocab_end_index=12,
        num_org_vocab_padding=2)
    assert torch.equal(modified_x_rank_0,
                       torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0]))
    assert torch.equal(modified_x_rank_1,
                       torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 4, 0, 0]))
    assert torch.equal(modified_x_rank_2,
                       torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 4, 0]))
    assert torch.equal(modified_x_rank_3,
                       torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 4]))