test_layers.py 52.2 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
import random
from copy import deepcopy
from dataclasses import dataclass
6
from typing import Dict, List, Optional, Tuple
7
from unittest.mock import patch
8

9
import pytest
10
11
12
13
import torch
import torch.nn.functional as F

from vllm.config import LoRAConfig
14
15
16
from vllm.lora.fully_sharded_layers import (
    ColumnParallelLinearWithShardedLoRA,
    MergedColumnParallelLinearWithShardedLoRA,
17
18
    MergedQKVParallelLinearWithShardedLora, QKVParallelLinearWithShardedLora,
    RowParallelLinearWithShardedLoRA)
19
20
# yapf conflicts with isort for this block
# yapf: disable
21
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
22
                              LinearScalingRotaryEmbeddingWithLora,
23
24
                              LogitsProcessorWithLoRA, LoRAMapping,
                              MergedColumnParallelLinearWithLoRA,
25
                              MergedQKVParallelLinearWithLora,
26
                              QKVParallelLinearWithLora,
27
                              ReplicatedLinearWithLoRA,
28
29
                              RowParallelLinearWithLoRA,
                              VocabParallelEmbeddingWithLoRA)
30
# yapf: enable
31
from vllm.lora.models import (LongContextLoRAContext, LoRALayerWeights,
32
                              PackedLoRALayerWeights)
33
from vllm.lora.punica_wrapper import get_punica_wrapper
34
35
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               MergedColumnParallelLinear,
36
                                               QKVParallelLinear,
37
                                               ReplicatedLinear,
38
39
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
40
from vllm.model_executor.layers.rotary_embedding import get_rope
41
from vllm.model_executor.layers.vocab_parallel_embedding import (
42
    ParallelLMHead, VocabParallelEmbedding, get_masked_input_and_mask)
43
from vllm.model_executor.utils import set_random_seed
44
from vllm.platforms import current_platform
45
46
47
48
49
50
51
52

from .utils import DummyLoRAManager

TOLERANCES = {
    torch.float16: (5e-3, 5e-3),
    torch.float32: (5e-3, 5e-3),
    torch.bfloat16: (3e-2, 2e-2),
}
53
54
55
56
57
58

pytestmark = pytest.mark.skipif(
    not (current_platform.is_cuda_alike() or current_platform.is_cpu()),
    reason="Backend not supported")

DEVICES = ([
59
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
60
] if current_platform.is_cuda_alike() else ["cpu"])
61

62
#For GPU, we will launch different triton kernels between the prefill and decode
63
64
# stages, so we need to verify this. prefill stage(True) or decode stage(False)
STAGES = [True, False]
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122


def get_random_id_to_index(num_loras: int,
                           num_slots: int,
                           log: bool = True) -> List[Optional[int]]:
    """Creates a random lora_id_to_index mapping.

    Args:
        num_loras: The number of active loras in the mapping.
        num_slots: The number of slots in the mapping. Must be larger
            than num_loras.
        log: Whether to log the output.
    """

    if num_loras > num_slots:
        raise ValueError(
            f"num_loras is higher than num_slots: {num_loras} > {num_slots}. "
            "num_loras must be less than or equal to num_slots.")

    slots: List[Optional[int]] = [None] * num_slots
    random_slot_selections = (torch.randperm(num_slots)[:num_loras]).tolist()
    for lora_id, slot_idx in enumerate(random_slot_selections, start=1):
        slots[slot_idx] = lora_id

    if log:
        print(f"Created lora_id_to_index mapping: {slots}.")

    return slots


def populate_loras(
    id_to_index: List[Optional[int]],
    layer: BaseLayerWithLoRA,
    layer_weights: torch.Tensor,
    generate_embeddings_tensor: int = 0,
    repeats: int = 1,
) -> Tuple[Dict[int, LoRALayerWeights], Dict[int, List[LoRALayerWeights]]]:
    """This method populates the lora layers with lora weights.

    Args:
        id_to_index: a list of lora ids. The index of the lora id
            represents which memory slot the lora matrices are
            stored in. A None value indicates a free slot.
        layer: the LoRAlayer to populate.
        layer_weights: the PyTorch tensor containing the layer's
            weights.
        generate_embeddings_tensor: whether to generate an
            embeddings tensor for each LoRA.
        repeats: must only be set for column parallel packed
            layers. Indicates the number of loras to compose
            together to create a single lora layer.
    """

    # Dictionary that maps the lora ID to the
    # corresponding lora weights.
    lora_dict: Dict[int, LoRALayerWeights] = dict()

    # Dictionary that maps the lora ID to the
123
    # corresponding subloras.
124
125
126
127
    sublora_dict: Dict[int, List[LoRALayerWeights]] = dict()

    for slot_idx, lora_id in enumerate(id_to_index):
        if lora_id is not None:
128
            subloras: List[LoRALayerWeights] = []
129
130
            sublora_len = layer_weights.shape[0] // repeats
            for i in range(repeats):
131
132
133
134
135
136
                sublora = DummyLoRAManager(
                    layer_weights.device).init_random_lora(
                        module_name=f"fake_{i}",
                        weight=layer_weights,
                        generate_embeddings_tensor=generate_embeddings_tensor,
                    )
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
                sublora.lora_b = sublora.lora_b[:, (sublora_len *
                                                    i):(sublora_len * (i + 1))]
                sublora.optimize()
                subloras.append(sublora)

            lora = PackedLoRALayerWeights.pack(
                subloras) if repeats > 1 else subloras[0]

            layer.set_lora(
                slot_idx,
                lora_a=lora.lora_a,
                lora_b=lora.lora_b,
                embeddings_tensor=lora.embeddings_tensor,
            )

            lora_dict[lora_id] = lora
            sublora_dict[lora_id] = subloras

    return lora_dict, sublora_dict


def create_random_inputs(
    active_lora_ids: List[int],
    num_inputs: int,
    input_size: Tuple[int, ...],
    input_range: Tuple[float, float],
    input_type: torch.dtype = torch.int,
164
    device: torch.device = "cuda"
165
166
167
168
169
170
171
172
173
174
175
176
177
178
) -> Tuple[List[torch.Tensor], List[int], List[int]]:
    """Creates random inputs.

    Args:
        active_lora_ids: lora IDs of active lora weights.
        num_inputs: the number of inputs to create.
        input_size: the size of each individual input.
        input_range: the range of values to include in the input.
            input_range[0] <= possible input values < input_range[1]
        input_type: the type of values in the input.
    """

    low, high = input_range

179
180
181
182
    inputs: List[torch.Tensor] = []
    index_mapping: List[int] = []
    prompt_mapping: List[int] = []

183
184
185
    for _ in range(num_inputs):
        if input_type == torch.int:
            inputs.append(
186
187
188
189
                torch.randint(low=int(low),
                              high=int(high),
                              size=input_size,
                              device=device))
190
191
        else:
            inputs.append(
192
193
                torch.rand(size=input_size, dtype=input_type, device=device) *
                high + low)
194
195
196
197
198
199
200
201

        lora_id = random.choice(active_lora_ids)
        index_mapping += [lora_id] * input_size[0]
        prompt_mapping += [lora_id]

    return inputs, index_mapping, prompt_mapping


202
203
204
205
206
def check_punica_wrapper(punica_wrapper) -> bool:
    if current_platform.is_cuda_alike():
        from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU

        return type(punica_wrapper) is PunicaWrapperGPU
207
208
209
210
    elif current_platform.is_cpu():
        from vllm.lora.punica_wrapper.punica_cpu import PunicaWrapperCPU

        return type(punica_wrapper) is PunicaWrapperCPU
211
212
213
214
    else:
        return False


215
216
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
217
@pytest.mark.parametrize("device", DEVICES)
218
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
219
220
@pytest.mark.parametrize("stage", STAGES)
def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
221
222
223
    # For multi-GPU testing of Triton kernel, we must explicitly set the CUDA
    # device, see: https://github.com/triton-lang/triton/issues/2925
    # Same below.
224
225
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)
226

227
    torch.set_default_device(device)
228
    max_loras = 8
229
230
    punica_wrapper = get_punica_wrapper(8192, 256, device)
    assert check_punica_wrapper(punica_wrapper)
231
232
233
234
235
    lora_config = LoRAConfig(max_loras=max_loras,
                             max_lora_rank=8,
                             lora_dtype=torch.float16)

    def create_random_embedding_layer():
236
        embedding = VocabParallelEmbedding(vocab_size, 256)
237
        embedding.weight.data = torch.rand_like(embedding.weight.data)
238
        embedding.weight.data[vocab_size:, :] = 0
239
240
241
242
243
244
245
246
247
248
        lora_embedding = VocabParallelEmbeddingWithLoRA(embedding)
        lora_embedding.create_lora_weights(max_loras, lora_config)

        return embedding, lora_embedding

    for i in range(10):
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
        embedding, lora_embedding = create_random_embedding_layer()
249
        lora_embedding.set_mapping(punica_wrapper)
250
251
252
253
254
255
256
257
258
259
        lora_dict, _ = populate_loras(
            id_to_index,
            layer=lora_embedding,
            layer_weights=embedding.weight.T,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=num_loras * 3,
            input_size=(200, ),
260
            input_range=(1, vocab_size),
261
            device=device)
262
263
264
265
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
        punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
266
267
                                       vocab_size,
                                       lora_config.lora_extra_vocab_size)
268
269
270

        lora_result = lora_embedding(torch.cat(inputs))

271
        expected_results: List[torch.Tensor] = []
272
273
274
275
276
277
278
279
280
281
282
283
        for input_, lora_id in zip(inputs, prompt_mapping):
            lora = lora_dict[lora_id]
            result = embedding(input_)
            after_a = F.embedding(
                input_,
                lora.lora_a,
            )
            result += (after_a @ lora.lora_b)
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
284
285
286
287
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
288
289
290
291
292
293
294
295
296
297

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
            lora_embedding.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=num_loras * 3,
            input_size=(200, ),
298
            input_range=(1, vocab_size),
299
            device=device)
300
301
302
303
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
        punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
304
305
                                       vocab_size,
                                       lora_config.lora_extra_vocab_size)
306
307
308
309
310

        lora_result = lora_embedding(torch.cat(inputs))
        expected_result = embedding(torch.cat(inputs))

        rtol, atol = TOLERANCES[lora_result.dtype]
311
312
313
314
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
315
316
317


@torch.inference_mode()
318
319
# @pytest.mark.skip(
#     reason="Fails when loras are in any slot other than the first.")
320
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
321
@pytest.mark.parametrize("device", DEVICES)
322
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
323
@pytest.mark.parametrize("stage", STAGES)
324
def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
325
                                        vocab_size, stage) -> None:
326

327
328
329
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)

330
    torch.set_default_device(device)
331
    max_loras = 8
332
333
    punica_wrapper = get_punica_wrapper(8192, 256, device)
    assert check_punica_wrapper(punica_wrapper)
334
335
336
337
338
    lora_config = LoRAConfig(max_loras=max_loras,
                             max_lora_rank=8,
                             lora_dtype=torch.float16)

    def create_random_embedding_layer():
339
        embedding = VocabParallelEmbedding(vocab_size, 256)
340
341
        embedding_data = torch.rand_like(embedding.weight.data)
        embedding.weight.data = embedding_data
342
        embedding.weight.data[vocab_size:, :] = 0
343
        expanded_embedding = VocabParallelEmbedding(
344
            vocab_size + lora_config.lora_extra_vocab_size * max_loras,
345
            256,
346
347
            org_num_embeddings=vocab_size)
        expanded_embedding.weight.data[:vocab_size, :] = embedding_data
348
        # We need to deepcopy the embedding as it will be modified
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
        # in place
        lora_embedding = VocabParallelEmbeddingWithLoRA(
            deepcopy(expanded_embedding))
        lora_embedding.create_lora_weights(max_loras, lora_config)

        return expanded_embedding, lora_embedding

    for i in range(10):
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
        expanded_embedding, lora_embedding = create_random_embedding_layer()
        lora_dict, _ = populate_loras(
            id_to_index,
            layer=lora_embedding,
            layer_weights=torch.zeros(
365
                (256, vocab_size + lora_config.lora_extra_vocab_size)),
366
367
368
            generate_embeddings_tensor=256,
        )

369
        lora_embedding.set_mapping(punica_wrapper)
370
371
372
373
374
375
376
377
        # All embeddings tensors have the same shape.
        embeddings_tensors = [
            lora_dict[id].embeddings_tensor for id in sorted(lora_dict.keys())
        ]
        embeddings_tensor_len = embeddings_tensors[0].shape[0]

        # Add empty embeddings_tensors for unoccupied lora slots.
        for _ in range(max_loras - len(embeddings_tensors)):
378
            embeddings_tensors.append(torch.zeros(embeddings_tensors[0].shape))
379
380
381
382
383

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=num_loras * 3,
            input_size=(200, ),
384
            input_range=(1, vocab_size),
385
            device=device)
386
387
388
389
390
391
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
        punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
                                       vocab_size,
                                       lora_config.lora_extra_vocab_size)
392
393
394
395
396
397
398
        original_inputs = deepcopy(inputs)

        # Force some of the inputs to be in the extended embeddings range
        # to guarantee that their behavior is tested.
        for input_, original_input_, lora_id in zip(inputs, original_inputs,
                                                    prompt_mapping):
            embedding_id = lora_id - 1
399
400
401
402
403
            input_[-1] = vocab_size + (embedding_id * embeddings_tensor_len)
            original_input_[-1] = vocab_size
            input_[-2] = vocab_size + (
                (embedding_id + 1) * embeddings_tensor_len - 1)
            original_input_[-2] = vocab_size + embeddings_tensor_len - 1
404

405
        expanded_embedding.weight[vocab_size:vocab_size +
406
407
408
409
410
                                  (embeddings_tensor_len *
                                   max_loras)] = torch.cat(embeddings_tensors)

        lora_result = lora_embedding(torch.cat(original_inputs))

411
        expected_results: List[torch.Tensor] = []
412
413
414
415
416
417
418
419
420
421
422
423
424
        for input_, original_input_, lora_id in zip(inputs, original_inputs,
                                                    prompt_mapping):
            lora = lora_dict[lora_id]
            result = expanded_embedding(input_)
            after_a = F.embedding(
                original_input_,
                lora.lora_a,
            )
            result += (after_a @ lora.lora_b)
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
425
426
427
428
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
429
430
431
432
433
434
435
436
437
438

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
            lora_embedding.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=num_loras * 3,
            input_size=(200, ),
439
            input_range=(1, vocab_size),
440
            device=device)
441
        original_inputs = deepcopy(inputs)
442
443
444
445
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
        punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
446
447
                                       vocab_size,
                                       lora_config.lora_extra_vocab_size)
448
449
450
451
        lora_result = lora_embedding(torch.cat(original_inputs))
        expected_result = expanded_embedding(torch.cat(inputs))

        rtol, atol = TOLERANCES[lora_result.dtype]
452
453
454
455
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
456
457
458
459


@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
460
@pytest.mark.parametrize("device", DEVICES)
461
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 256512])
462
463
464
@pytest.mark.parametrize("stage", STAGES)
def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
                                  stage) -> None:
465

466
467
468
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)

469
    torch.set_default_device(device)
470
    max_loras = 8
471
472
    punica_wrapper = get_punica_wrapper(8192, 256, device)
    assert check_punica_wrapper(punica_wrapper)
473
474
475
476
    lora_config = LoRAConfig(max_loras=max_loras,
                             max_lora_rank=8,
                             lora_dtype=torch.float16)

477
    def _pretest():
478
        linear = ParallelLMHead(vocab_size + lora_config.lora_extra_vocab_size,
479
480
481
                                1024,
                                vocab_size,
                                params_dtype=torch.float16)
482
        linear.weight.data = torch.rand_like(linear.weight.data)
483
        linear.weight.data[:, vocab_size:] = 0
484
        logits_processor = LogitsProcessor(
485
            vocab_size + lora_config.lora_extra_vocab_size, vocab_size)
486
        lora_logits_processor = LogitsProcessorWithLoRA(
487
488
            logits_processor, 1024, linear.weight.dtype, linear.weight.device,
            None)
489
        lora_logits_processor.create_lora_weights(max_loras, lora_config)
490

491
        return linear, logits_processor, lora_logits_processor
492
493
494
495
496

    for i in range(10):
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
497
        linear, logits_processor, lora_logits_processor = _pretest()
498
        lora_logits_processor.set_mapping(punica_wrapper)
499
500
501
        # NOTE: all the generated loras share the same embeddings tensor.
        lora_dict, _ = populate_loras(
            id_to_index,
502
            layer=lora_logits_processor,
503
504
505
506
507
508
509
510
511
512
513
            layer_weights=linear.weight,
            generate_embeddings_tensor=1024,
        )
        embeddings_tensor = list(lora_dict.values())[0].embeddings_tensor
        embeddings_tensor_len = embeddings_tensor.shape[0]

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=8 * num_loras,  # * 3,
            input_size=(1, 1024),
            input_range=(0, 1),
514
            input_type=torch.float16,
515
            device=device)
516
517
518
519
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
        punica_wrapper.update_metadata(
520
521
522
            lora_mapping,
            id_to_index,
            max_loras,
523
            vocab_size,
524
525
            lora_config.lora_extra_vocab_size,
        )
526
        input_ = torch.rand(20, 1024)
527

528
529
        lora_result = lora_logits_processor._get_logits(
            hidden_states=torch.cat(inputs),
530
            lm_head=linear,
531
            embedding_bias=None)
532

533
        original_lm_head = deepcopy(linear)
534

535
536
        linear.weight[logits_processor.
                      org_vocab_size:logits_processor.org_vocab_size +
537
538
                      embeddings_tensor_len] = embeddings_tensor

539
        logits_processor.org_vocab_size = (vocab_size +
540
                                           lora_config.lora_extra_vocab_size)
541
        expected_results: List[torch.Tensor] = []
542
543
        for input_, lora_id in zip(inputs, prompt_mapping):
            lora = lora_dict[lora_id]
544
            result = logits_processor._get_logits(hidden_states=input_,
545
                                                  lm_head=linear,
546
                                                  embedding_bias=None)
547
            result[:, vocab_size + embeddings_tensor_len:] = float("-inf")
548
549
550
            result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
            expected_results.append(result)
        expected_result = torch.cat(expected_results)
551
        logits_processor.org_vocab_size = vocab_size
552
553
554
555

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
556
            lora_logits_processor.reset_lora(slot_idx)
557
558
559
560
561
562

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=8 * num_loras * 3,
            input_size=(1, 1024),
            input_range=(0, 1),
563
            input_type=torch.float16,
564
            device=device)
565
566
567
568
569
570
571
572
573
574
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
        punica_wrapper.update_metadata(
            lora_mapping,
            id_to_index,
            max_loras,
            vocab_size,
            lora_config.lora_extra_vocab_size,
        )
575
576
577

        lora_result = lora_logits_processor._get_logits(
            hidden_states=torch.cat(inputs),
578
            lm_head=original_lm_head,
579
            embedding_bias=None)[:, :vocab_size]
580
581
        expected_result = logits_processor._get_logits(
            hidden_states=torch.cat(inputs),
582
            lm_head=original_lm_head,
583
            embedding_bias=None)
584
585

        rtol, atol = TOLERANCES[lora_result.dtype]
586
587
588
589
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
590
591


592
593
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
594
@pytest.mark.parametrize("device", DEVICES)
595
@pytest.mark.parametrize("stage", STAGES)
596
597
598
@pytest.mark.parametrize("bias_enabled", [True, False])
def test_linear_replicated(dist_init, num_loras, device, stage,
                           bias_enabled) -> None:
599

600
601
602
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)

603
    torch.set_default_device(device)
604
605
    punica_wrapper = get_punica_wrapper(8192, 256, device)
    assert check_punica_wrapper(punica_wrapper)
606
607
608
    max_loras = 8
    lora_config = LoRAConfig(max_loras=max_loras,
                             max_lora_rank=8,
609
610
                             lora_dtype=torch.float16,
                             bias_enabled=bias_enabled)
611
612
613
614
615
616
617
618
619
620
621

    def create_random_linear_replicated_layer():

        linear = ReplicatedLinear(4096,
                                  4096,
                                  bias=False,
                                  params_dtype=torch.float16)
        linear.weight.data = torch.rand_like(linear.weight.data)
        lora_linear = ReplicatedLinearWithLoRA(linear)

        lora_linear.create_lora_weights(max_loras, lora_config)
622
623
624
625
626
627
        assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len(
            lora_linear.lora_b_stacked) == 1)
        if bias_enabled:
            assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices
        else:
            assert lora_linear.lora_bias_stacked is None
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
        return linear, lora_linear

    for i in range(10):
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
        linear, lora_linear = create_random_linear_replicated_layer()
        lora_linear.set_mapping(punica_wrapper)
        lora_dict, _ = populate_loras(
            id_to_index,
            layer=lora_linear,
            layer_weights=linear.weight,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
            input_type=torch.float16,
648
            device=device)
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
        punica_wrapper.update_metadata(
            lora_mapping,
            id_to_index,
            max_loras,
            512,
            lora_config.lora_extra_vocab_size,
        )

        lora_result = lora_linear(torch.cat(inputs))[0]

        expected_results: List[torch.Tensor] = []
        for input_, lora_id in zip(inputs, prompt_mapping):
            lora = lora_dict[lora_id]
            result = linear(input_)[0]
            result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
671
672
673
674
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
675
676
677
678
679
680
681
682
683
684
685
686

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
            lora_linear.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
            input_type=torch.float16,
687
            device=device)
688
689
690
691
692
693
694
695
696
697
698
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)

        punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
                                       512, lora_config.lora_extra_vocab_size)

        lora_result = lora_linear(torch.cat(inputs))[0]
        expected_result = linear(torch.cat(inputs))[0]

        rtol, atol = TOLERANCES[lora_result.dtype]
699
700
701
702
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
703
704


705
706
707
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("orientation", ["row", "column"])
708
@pytest.mark.parametrize("fully_shard", [True, False])
709
@pytest.mark.parametrize("device", DEVICES)
710
@pytest.mark.parametrize("stage", STAGES)
711
@pytest.mark.parametrize("bias_enabled", [True, False])
712
def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
713
                         device, stage, bias_enabled) -> None:
714

715
716
717
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)

718
    torch.set_default_device(device)
719
720
    punica_wrapper = get_punica_wrapper(8192, 256, device)
    assert check_punica_wrapper(punica_wrapper)
721
722
723
    max_loras = 8
    lora_config = LoRAConfig(max_loras=max_loras,
                             max_lora_rank=8,
724
                             fully_sharded_loras=fully_shard,
725
726
                             lora_dtype=torch.float16,
                             bias_enabled=bias_enabled)
727
728
729

    def create_random_linear_parallel_layer():
        if orientation == "row":
730
731
732
733
            linear = RowParallelLinear(4096,
                                       4096,
                                       bias=False,
                                       params_dtype=torch.float16)
734
            linear.weight.data = torch.rand_like(linear.weight.data)
735
736
            lora_linear = (RowParallelLinearWithLoRA(linear) if not fully_shard
                           else RowParallelLinearWithShardedLoRA(linear))
737
        else:
738
739
740
741
            linear = ColumnParallelLinear(4096,
                                          4096,
                                          bias=False,
                                          params_dtype=torch.float16)
742
            linear.weight.data = torch.rand_like(linear.weight.data)
743
744
745
            lora_linear = (ColumnParallelLinearWithLoRA(linear)
                           if not fully_shard else
                           ColumnParallelLinearWithShardedLoRA(linear))
746
        lora_linear.create_lora_weights(max_loras, lora_config)
747
748
749
750
751
752
        assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len(
            lora_linear.lora_b_stacked) == 1)
        if bias_enabled:
            assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices
        else:
            assert lora_linear.lora_bias_stacked is None
753
754
755
756
757
758
759
        return linear, lora_linear

    for i in range(10):
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
        linear, lora_linear = create_random_linear_parallel_layer()
760
        lora_linear.set_mapping(punica_wrapper)
761
762
763
764
765
766
767
768
769
770
771
        lora_dict, _ = populate_loras(
            id_to_index,
            layer=lora_linear,
            layer_weights=linear.weight,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
772
            input_type=torch.float16,
773
            device=device)
774
775
776
777
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
        punica_wrapper.update_metadata(
778
779
780
781
782
783
784
785
786
            lora_mapping,
            id_to_index,
            max_loras,
            512,
            lora_config.lora_extra_vocab_size,
        )

        lora_result = lora_linear(torch.cat(inputs))[0]

787
        expected_results: List[torch.Tensor] = []
788
789
790
791
792
793
794
795
        for input_, lora_id in zip(inputs, prompt_mapping):
            lora = lora_dict[lora_id]
            result = linear(input_)[0]
            result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
796
797
798
799
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
800
801
802
803
804
805
806
807
808
809
810

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
            lora_linear.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
811
            input_type=torch.float16,
812
            device=device)
813
814
815
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
816

817
        punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
818
819
820
821
822
823
                                       512, lora_config.lora_extra_vocab_size)

        lora_result = lora_linear(torch.cat(inputs))[0]
        expected_result = linear(torch.cat(inputs))[0]

        rtol, atol = TOLERANCES[lora_result.dtype]
824
825
826
827
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
828
829
830
831


@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
832
@pytest.mark.parametrize("repeats", [1, 2, 3])
833
@pytest.mark.parametrize("fully_shard", [True, False])
834
@pytest.mark.parametrize("device", DEVICES)
835
@pytest.mark.parametrize("stage", STAGES)
836
@pytest.mark.parametrize("bias_enabled", [True, False])
837
def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
838
                                device, stage, bias_enabled) -> None:
839

840
841
842
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)

843
    torch.set_default_device(device)
844
845
    punica_wrapper = get_punica_wrapper(8192, 256, device)
    assert check_punica_wrapper(punica_wrapper)
846
847
848
    max_loras = 8
    lora_config = LoRAConfig(max_loras=max_loras,
                             max_lora_rank=8,
849
                             fully_sharded_loras=fully_shard,
850
851
                             lora_dtype=torch.float16,
                             bias_enabled=bias_enabled)
852
853
854
855

    def create_column_parallel_packed_layer():
        if repeats == 2:
            linear = MergedColumnParallelLinear(4096, [4096] * repeats,
856
857
                                                bias=False,
                                                params_dtype=torch.float16)
858
            linear.weight.data = torch.rand_like(linear.weight.data)
859
860
861
            lora_linear = (MergedColumnParallelLinearWithLoRA(linear)
                           if not fully_shard else
                           MergedColumnParallelLinearWithShardedLoRA(linear))
862
        elif repeats == 3:
863
864
865
866
867
            linear = QKVParallelLinear(4096,
                                       64,
                                       32,
                                       bias=False,
                                       params_dtype=torch.float16)
868
            linear.weight.data = torch.rand_like(linear.weight.data)
869
870
871
            lora_linear = (MergedQKVParallelLinearWithLora(linear)
                           if not fully_shard else
                           MergedQKVParallelLinearWithShardedLora(linear))
872
        else:
873
874
875
876
877
            linear = QKVParallelLinear(4096,
                                       64,
                                       32,
                                       bias=False,
                                       params_dtype=torch.float16)
878
            linear.weight.data = torch.rand_like(linear.weight.data)
879
880
881
            lora_linear = QKVParallelLinearWithLora(
                linear
            ) if not fully_shard else QKVParallelLinearWithShardedLora(linear)
882
883
884
885
886
887
888

        @dataclass
        class FakeConfig:
            hidden_size = 4096
            num_key_value_heads = 32
            num_attention_heads = 32

889
        n_slices = repeats
890
891
892
        lora_linear.create_lora_weights(max_loras,
                                        lora_config,
                                        model_config=FakeConfig())
893
894
895
896
897
898
        assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len(
            lora_linear.lora_b_stacked) == n_slices)
        if bias_enabled:
            assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices
        else:
            assert lora_linear.lora_bias_stacked is None
899
900
901
902
903
904
905
906
        return linear, lora_linear

    for i in range(10):
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)

        linear, lora_linear = create_column_parallel_packed_layer()
907
        lora_linear.set_mapping(punica_wrapper)
908
909
910
911
912
913
914
915
916
917
918
919
        lora_dict, sublora_dict = populate_loras(
            id_to_index,
            layer=lora_linear,
            layer_weights=linear.weight,
            repeats=repeats,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
920
            input_type=torch.float16,
921
            device=device)
922
923
924
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
925

926
        punica_wrapper.update_metadata(
927
928
929
930
931
932
933
934
935
            lora_mapping,
            id_to_index,
            max_loras,
            512,
            lora_config.lora_extra_vocab_size,
        )

        lora_result = lora_linear(torch.cat(inputs))[0]

936
        expected_results: List[torch.Tensor] = []
937
938
939
940
        for input_, lora_id in zip(inputs, prompt_mapping):
            result = linear(input_)[0]
            subloras = sublora_dict[lora_id]
            for i, sublora in enumerate(subloras):
941
942
943
                result[:, sublora.lora_b.shape[1] * i:sublora.lora_b.shape[1] *
                       (i + 1)] += (input_ @ sublora.lora_a @ sublora.lora_b *
                                    sublora.scaling)
944
945
946
947
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
948
949
950
951
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
952
953
954
955
956
957
958
959
960

        for slot_idx in range(max_loras):
            lora_linear.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
961
            input_type=torch.float16,
962
            device=device)
963
964
965
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
966

967
        punica_wrapper.update_metadata(
968
969
970
971
972
973
974
975
976
977
978
            lora_mapping,
            id_to_index,
            max_loras,
            512,
            lora_config.lora_extra_vocab_size,
        )

        lora_result = lora_linear(torch.cat(inputs))[0]
        expected_result = linear(torch.cat(inputs))[0]

        rtol, atol = TOLERANCES[lora_result.dtype]
979
980
981
982
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
983
984
985
986
987
988
989
990
991
992
993
994


@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 8])
@pytest.mark.parametrize("device", ["cuda"])
@pytest.mark.parametrize("scaling_factors", [(1.0, ), (4.0, ), (4.0, 8.0),
                                             (6.0, 1.0)])
@pytest.mark.parametrize("max_position", [11, 4096, 32768])
@pytest.mark.parametrize("is_neox_style", [True, False])
@pytest.mark.parametrize("rotary_dim", [None, 32])
@pytest.mark.parametrize("head_size", [32, 108])
@pytest.mark.parametrize("seq_len", [11, 1024])
995
996
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
                    reason="Only CUDA backends are supported")
997
998
999
1000
1001
1002
def test_rotary_embedding_long_context(dist_init, num_loras, device,
                                       scaling_factors, max_position,
                                       is_neox_style, rotary_dim, head_size,
                                       seq_len) -> None:
    dtype = torch.float16
    seed = 0
1003
    current_platform.seed_everything(seed)
1004
    torch.set_default_device(device)
1005
1006
    punica_wrapper = get_punica_wrapper(8192, 256, device)
    assert check_punica_wrapper(punica_wrapper)
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
    max_loras = 8
    lora_config = LoRAConfig(max_loras=max_loras,
                             max_lora_rank=8,
                             long_lora_scaling_factors=scaling_factors,
                             lora_dtype=dtype)

    if rotary_dim is None:
        rotary_dim = head_size
    base = 10000
    batch_size = 5 * num_loras
    num_heads = 7

    # Verify lora is equivalent to linear scaling rotary embedding.
    rope = get_rope(
        head_size,
        rotary_dim,
        max_position,
        base,
        is_neox_style,
    )
    lora_rope = LinearScalingRotaryEmbeddingWithLora(rope)
1028
    lora_rope.set_mapping(punica_wrapper)
1029
1030
1031
    lora_rope.create_lora_weights(max_loras, lora_config)
    linear_rope = get_rope(head_size, rotary_dim, max_position, base,
                           is_neox_style, {
1032
                               "rope_type": "linear",
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
                               "factor": scaling_factors
                           })
    linear_rope = linear_rope.to(dtype=dtype)
    id_to_index = get_random_id_to_index(num_loras, max_loras)
    _, index_mapping, prompt_mapping = create_random_inputs(
        active_lora_ids=[0],
        num_inputs=batch_size,
        input_size=(1, max_position),
        input_range=(0, lora_config.lora_extra_vocab_size),
        input_type=torch.float16,
1043
        device=device)
1044

1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
    lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
    long_lora_context = LongContextLoRAContext(list(scaling_factors),
                                               rotary_dim)

    next_expected_offset = 0
    # Make sure the offset is correct.
    scaling_factor_to_offset = lora_rope.scaling_factor_to_offset
    for scaling_factor, offset in scaling_factor_to_offset.items():
        assert offset == next_expected_offset
        next_expected_offset += scaling_factor * max_position

    for i in range(len(scaling_factors)):
        long_lora_context.offsets_by_lora_id[i] = scaling_factor_to_offset.get(
            scaling_factors[i], 0)
1059
    punica_wrapper.update_metadata(
1060
1061
1062
1063
1064
1065
1066
        lora_mapping,
        id_to_index,
        max_loras,
        512,
        lora_config.lora_extra_vocab_size,
        long_lora_context=long_lora_context,
    )
1067
    # lora_rope.set_mapping(*mapping_info)
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079

    positions = torch.randint(0, max_position, (batch_size, seq_len))
    query = torch.randn(batch_size,
                        seq_len,
                        num_heads * head_size,
                        dtype=dtype)
    key = torch.randn_like(query)
    ref_q, ref_k = linear_rope(positions, query, key)
    actual_q, actual_k = lora_rope(positions, query, key)

    torch.allclose(ref_q, actual_q)
    torch.allclose(ref_k, actual_k)
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095


@pytest.mark.parametrize("tp_size", [1, 2, 4, 8])
@pytest.mark.parametrize("seed", list(range(256)))
def test_vocab_parallel_embedding_indices(tp_size, seed):
    random.seed(seed)
    vocab_size = random.randint(4000, 64000)
    added_vocab_size = random.randint(0, 1024)
    org_vocab_size = vocab_size - added_vocab_size
    last_org_vocab_end_index = 0
    last_added_vocab_end_index = org_vocab_size
    computed_vocab_size = 0
    computed_org_vocab_size = 0
    computed_added_vocab_size = 0
    vocab_size_padded = -1

1096
1097
1098
    all_org_tokens: List[int] = []
    all_added_tokens: List[int] = []
    token_ids: List[int] = []
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292

    for tp_rank in range(tp_size):
        with patch(
                "vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank",
                return_value=tp_rank
        ), patch(
                "vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_world_size",
                return_value=tp_size):
            vocab_embedding = VocabParallelEmbedding(
                vocab_size, 1, org_num_embeddings=org_vocab_size)
        vocab_size_padded = vocab_embedding.num_embeddings_padded
        shard_indices = vocab_embedding.shard_indices
        # Assert that the ranges are contiguous
        assert shard_indices.org_vocab_start_index == last_org_vocab_end_index
        assert (shard_indices.added_vocab_start_index ==
                last_added_vocab_end_index)

        # Ensure that we are not exceeding the vocab size
        computed_vocab_size += shard_indices.num_elements_padded
        computed_org_vocab_size += shard_indices.num_org_elements
        computed_added_vocab_size += shard_indices.num_added_elements

        # Ensure that the ranges are not overlapping
        all_org_tokens.extend(
            range(shard_indices.org_vocab_start_index,
                  shard_indices.org_vocab_end_index))
        all_added_tokens.extend(
            range(shard_indices.added_vocab_start_index,
                  shard_indices.added_vocab_end_index))

        token_ids.extend(
            range(shard_indices.org_vocab_start_index,
                  shard_indices.org_vocab_end_index))
        token_ids.extend([-1] * (shard_indices.num_org_elements_padded -
                                 shard_indices.num_org_elements))
        token_ids.extend(
            range(shard_indices.added_vocab_start_index,
                  shard_indices.added_vocab_end_index))
        token_ids.extend([-1] * (shard_indices.num_added_elements_padded -
                                 shard_indices.num_added_elements))

        last_org_vocab_end_index = shard_indices.org_vocab_end_index
        last_added_vocab_end_index = shard_indices.added_vocab_end_index

    assert computed_vocab_size == vocab_size_padded
    assert computed_org_vocab_size == org_vocab_size
    assert computed_added_vocab_size == added_vocab_size

    # Ensure that the ranges are not overlapping
    assert len(all_org_tokens) == len(set(all_org_tokens))
    assert len(all_added_tokens) == len(set(all_added_tokens))
    assert not set(all_org_tokens).intersection(set(all_added_tokens))

    token_ids_tensor = torch.tensor(token_ids, dtype=torch.long)
    reindex_mapping = vocab_embedding.get_sharded_to_full_mapping()
    assert reindex_mapping is not None or tp_size == 1
    if reindex_mapping is not None:
        reindexed_token_ids = token_ids_tensor[reindex_mapping]
        expected = torch.tensor(list(range(0, vocab_size)))
        assert reindexed_token_ids[:vocab_size].equal(expected)
        assert torch.all(reindexed_token_ids[vocab_size:] == -1)


def test_get_masked_input_and_mask():
    x = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])

    # base tp 1 case, no padding
    modified_x, _ = get_masked_input_and_mask(x,
                                              org_vocab_start_index=0,
                                              org_vocab_end_index=8,
                                              added_vocab_start_index=8,
                                              added_vocab_end_index=12,
                                              num_org_vocab_padding=0)
    assert torch.equal(x, modified_x)

    # tp 2 case, no padding
    modified_x_rank_0, _ = get_masked_input_and_mask(x,
                                                     org_vocab_start_index=0,
                                                     org_vocab_end_index=4,
                                                     added_vocab_start_index=8,
                                                     added_vocab_end_index=10,
                                                     num_org_vocab_padding=0)
    modified_x_rank_1, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=4,
        org_vocab_end_index=8,
        added_vocab_start_index=10,
        added_vocab_end_index=12,
        num_org_vocab_padding=0)
    assert torch.equal(modified_x_rank_0,
                       torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 0, 0]))
    assert torch.equal(modified_x_rank_1,
                       torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 4, 5]))

    # tp 4 case, no padding
    modified_x_rank_0, _ = get_masked_input_and_mask(x,
                                                     org_vocab_start_index=0,
                                                     org_vocab_end_index=2,
                                                     added_vocab_start_index=8,
                                                     added_vocab_end_index=9,
                                                     num_org_vocab_padding=0)
    modified_x_rank_1, _ = get_masked_input_and_mask(x,
                                                     org_vocab_start_index=2,
                                                     org_vocab_end_index=4,
                                                     added_vocab_start_index=9,
                                                     added_vocab_end_index=10,
                                                     num_org_vocab_padding=0)
    modified_x_rank_2, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=4,
        org_vocab_end_index=6,
        added_vocab_start_index=10,
        added_vocab_end_index=11,
        num_org_vocab_padding=0)
    modified_x_rank_3, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=6,
        org_vocab_end_index=8,
        added_vocab_start_index=11,
        added_vocab_end_index=12,
        num_org_vocab_padding=0)
    assert torch.equal(modified_x_rank_0,
                       torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0]))
    assert torch.equal(modified_x_rank_1,
                       torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0]))
    assert torch.equal(modified_x_rank_2,
                       torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0]))
    assert torch.equal(modified_x_rank_3,
                       torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2]))

    # base tp 1 case, with padding
    modified_x, _ = get_masked_input_and_mask(x,
                                              org_vocab_start_index=0,
                                              org_vocab_end_index=8,
                                              added_vocab_start_index=8,
                                              added_vocab_end_index=12,
                                              num_org_vocab_padding=2)
    assert torch.equal(modified_x,
                       torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13]))

    # tp 2 case, with padding
    modified_x_rank_0, _ = get_masked_input_and_mask(x,
                                                     org_vocab_start_index=0,
                                                     org_vocab_end_index=4,
                                                     added_vocab_start_index=8,
                                                     added_vocab_end_index=10,
                                                     num_org_vocab_padding=2)
    modified_x_rank_1, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=4,
        org_vocab_end_index=8,
        added_vocab_start_index=10,
        added_vocab_end_index=12,
        num_org_vocab_padding=2)
    assert torch.equal(modified_x_rank_0,
                       torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 6, 7, 0, 0]))
    assert torch.equal(modified_x_rank_1,
                       torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 6, 7]))

    # tp 4 case, with padding
    modified_x_rank_0, _ = get_masked_input_and_mask(x,
                                                     org_vocab_start_index=0,
                                                     org_vocab_end_index=2,
                                                     added_vocab_start_index=8,
                                                     added_vocab_end_index=9,
                                                     num_org_vocab_padding=2)
    modified_x_rank_1, _ = get_masked_input_and_mask(x,
                                                     org_vocab_start_index=2,
                                                     org_vocab_end_index=4,
                                                     added_vocab_start_index=9,
                                                     added_vocab_end_index=10,
                                                     num_org_vocab_padding=2)
    modified_x_rank_2, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=4,
        org_vocab_end_index=6,
        added_vocab_start_index=10,
        added_vocab_end_index=11,
        num_org_vocab_padding=2)
    modified_x_rank_3, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=6,
        org_vocab_end_index=8,
        added_vocab_start_index=11,
        added_vocab_end_index=12,
        num_org_vocab_padding=2)
    assert torch.equal(modified_x_rank_0,
                       torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0]))
    assert torch.equal(modified_x_rank_1,
                       torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 4, 0, 0]))
    assert torch.equal(modified_x_rank_2,
                       torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 4, 0]))
    assert torch.equal(modified_x_rank_3,
                       torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 4]))