test_layers.py 43.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
import random
from copy import deepcopy
from dataclasses import dataclass
7
from typing import Optional
8
from unittest.mock import patch
9

10
import pytest
11
12
13
import torch
import torch.nn.functional as F

14
from vllm.config.lora import LoRAConfig
15

16
17
# yapf conflicts with isort for this block
# yapf: disable
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from vllm.lora.layers import (
    BaseLayerWithLoRA,
    ColumnParallelLinearWithLoRA,
    ColumnParallelLinearWithShardedLoRA,
    LogitsProcessorWithLoRA,
    LoRAMapping,
    MergedColumnParallelLinearWithLoRA,
    MergedColumnParallelLinearWithShardedLoRA,
    MergedQKVParallelLinearWithLoRA,
    MergedQKVParallelLinearWithShardedLoRA,
    QKVParallelLinearWithLoRA,
    QKVParallelLinearWithShardedLoRA,
    ReplicatedLinearWithLoRA,
    RowParallelLinearWithLoRA,
    RowParallelLinearWithShardedLoRA,
    VocabParallelEmbeddingWithLoRA,
)

36
# yapf: enable
37
from vllm.lora.models import LoRALayerWeights, PackedLoRALayerWeights
38
from vllm.lora.punica_wrapper import get_punica_wrapper
39
40
41
42
43
44
45
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    MergedColumnParallelLinear,
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
46
from vllm.model_executor.layers.logits_processor import LogitsProcessor
47
from vllm.model_executor.layers.vocab_parallel_embedding import (
48
49
50
51
    ParallelLMHead,
    VocabParallelEmbedding,
    get_masked_input_and_mask,
)
52
from vllm.model_executor.utils import set_random_seed
53
from vllm.platforms import current_platform
54
55
56
57
58
59
60
61

from .utils import DummyLoRAManager

TOLERANCES = {
    torch.float16: (5e-3, 5e-3),
    torch.float32: (5e-3, 5e-3),
    torch.bfloat16: (3e-2, 2e-2),
}
62
63
64

pytestmark = pytest.mark.skipif(
    not (current_platform.is_cuda_alike() or current_platform.is_cpu()),
65
66
    reason="Backend not supported",
)
67

68
69
70
71
72
DEVICES = (
    [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
    if current_platform.is_cuda_alike()
    else ["cpu"]
)
73

74
# prefill stage(True) or decode stage(False)
75
STAGES = [True, False]
76

77
NUM_RANDOM_SEEDS = 2
78

79
VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS = 2
80
81
82


@pytest.fixture(autouse=True)
83
def clean_cache_reset_device(reset_default_device):
84
    # Release any memory we might be holding on to. CI runs OOMs otherwise.
85
86
    from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT

87
88
89
90
91
    _LORA_B_PTR_DICT.clear()
    _LORA_A_PTR_DICT.clear()

    yield

92

93
94
95
@pytest.fixture(autouse=True)
def skip_cuda_with_stage_false(request):
    """
96
    On cuda-like platforms, we use the same kernels for prefill and decode
97
98
99
100
101
    stage, and 'stage' is generally ignored, so we only need to test once.
    """
    if current_platform.is_cuda_alike():
        try:
            if hasattr(request.node, "callspec") and hasattr(
102
103
                request.node.callspec, "params"
            ):
104
105
106
107
108
109
110
111
                params = request.node.callspec.params
                if "stage" in params and params["stage"] is False:
                    pytest.skip("Skip test when stage=False")
        except Exception:
            pass
    yield


112
113
114
def get_random_id_to_index(
    num_loras: int, num_slots: int, log: bool = True
) -> list[Optional[int]]:
115
116
117
118
119
120
121
122
123
124
125
126
    """Creates a random lora_id_to_index mapping.

    Args:
        num_loras: The number of active loras in the mapping.
        num_slots: The number of slots in the mapping. Must be larger
            than num_loras.
        log: Whether to log the output.
    """

    if num_loras > num_slots:
        raise ValueError(
            f"num_loras is higher than num_slots: {num_loras} > {num_slots}. "
127
128
            "num_loras must be less than or equal to num_slots."
        )
129

130
    slots: list[Optional[int]] = [None] * num_slots
131
132
133
134
135
136
137
138
139
140
141
    random_slot_selections = (torch.randperm(num_slots)[:num_loras]).tolist()
    for lora_id, slot_idx in enumerate(random_slot_selections, start=1):
        slots[slot_idx] = lora_id

    if log:
        print(f"Created lora_id_to_index mapping: {slots}.")

    return slots


def populate_loras(
142
    id_to_index: list[Optional[int]],
143
144
145
146
    layer: BaseLayerWithLoRA,
    layer_weights: torch.Tensor,
    generate_embeddings_tensor: int = 0,
    repeats: int = 1,
147
) -> tuple[dict[int, LoRALayerWeights], dict[int, list[LoRALayerWeights]]]:
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
    """This method populates the lora layers with lora weights.

    Args:
        id_to_index: a list of lora ids. The index of the lora id
            represents which memory slot the lora matrices are
            stored in. A None value indicates a free slot.
        layer: the LoRAlayer to populate.
        layer_weights: the PyTorch tensor containing the layer's
            weights.
        generate_embeddings_tensor: whether to generate an
            embeddings tensor for each LoRA.
        repeats: must only be set for column parallel packed
            layers. Indicates the number of loras to compose
            together to create a single lora layer.
    """

    # Dictionary that maps the lora ID to the
    # corresponding lora weights.
166
    lora_dict: dict[int, LoRALayerWeights] = dict()
167
168

    # Dictionary that maps the lora ID to the
169
    # corresponding subloras.
170
    sublora_dict: dict[int, list[LoRALayerWeights]] = dict()
171
172
173

    for slot_idx, lora_id in enumerate(id_to_index):
        if lora_id is not None:
174
            subloras: list[LoRALayerWeights] = []
175
176
            sublora_len = layer_weights.shape[0] // repeats
            for i in range(repeats):
177
178
179
180
181
182
183
184
                sublora = DummyLoRAManager(layer_weights.device).init_random_lora(
                    module_name=f"fake_{i}",
                    weight=layer_weights,
                    generate_embeddings_tensor=generate_embeddings_tensor,
                )
                sublora.lora_b = sublora.lora_b[
                    (sublora_len * i) : (sublora_len * (i + 1)), :
                ]
185
186
187
                sublora.optimize()
                subloras.append(sublora)

188
            lora = PackedLoRALayerWeights.pack(subloras) if repeats > 1 else subloras[0]
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203

            layer.set_lora(
                slot_idx,
                lora_a=lora.lora_a,
                lora_b=lora.lora_b,
                embeddings_tensor=lora.embeddings_tensor,
            )

            lora_dict[lora_id] = lora
            sublora_dict[lora_id] = subloras

    return lora_dict, sublora_dict


def create_random_inputs(
204
    active_lora_ids: list[int],
205
    num_inputs: int,
206
207
    input_size: tuple[int, ...],
    input_range: tuple[float, float],
208
    input_type: torch.dtype = torch.int,
209
    device: torch.device = "cuda",
210
) -> tuple[list[torch.Tensor], list[int], list[int]]:
211
212
213
214
215
216
217
218
219
220
221
222
223
    """Creates random inputs.

    Args:
        active_lora_ids: lora IDs of active lora weights.
        num_inputs: the number of inputs to create.
        input_size: the size of each individual input.
        input_range: the range of values to include in the input.
            input_range[0] <= possible input values < input_range[1]
        input_type: the type of values in the input.
    """

    low, high = input_range

224
225
226
    inputs: list[torch.Tensor] = []
    index_mapping: list[int] = []
    prompt_mapping: list[int] = []
227

228
229
230
    for _ in range(num_inputs):
        if input_type == torch.int:
            inputs.append(
231
232
233
234
                torch.randint(
                    low=int(low), high=int(high), size=input_size, device=device
                )
            )
235
236
        else:
            inputs.append(
237
238
239
                torch.rand(size=input_size, dtype=input_type, device=device) * high
                + low
            )
240
241
242
243
244
245
246
247

        lora_id = random.choice(active_lora_ids)
        index_mapping += [lora_id] * input_size[0]
        prompt_mapping += [lora_id]

    return inputs, index_mapping, prompt_mapping


248
249
250
251
252
def check_punica_wrapper(punica_wrapper) -> bool:
    if current_platform.is_cuda_alike():
        from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU

        return type(punica_wrapper) is PunicaWrapperGPU
253
254
255
256
    elif current_platform.is_cpu():
        from vllm.lora.punica_wrapper.punica_cpu import PunicaWrapperCPU

        return type(punica_wrapper) is PunicaWrapperCPU
257
258
259
260
    else:
        return False


261
@torch.inference_mode()
262
@pytest.mark.parametrize("num_loras", [1, 2, 4])
263
@pytest.mark.parametrize("device", DEVICES)
264
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
265
266
@pytest.mark.parametrize("stage", STAGES)
def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
267
268
269
    # For multi-GPU testing of Triton kernel, we must explicitly set the CUDA
    # device, see: https://github.com/triton-lang/triton/issues/2925
    # Same below.
270
271
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)
272

273
    torch.set_default_device(device)
274
    max_loras = 8
275
    punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
276
    assert check_punica_wrapper(punica_wrapper)
277
278
279
    lora_config = LoRAConfig(
        max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16
    )
280
281

    def create_random_embedding_layer():
282
        embedding = VocabParallelEmbedding(vocab_size, 256)
283
        embedding.weight.data = torch.rand_like(embedding.weight.data)
284
        embedding.weight.data[vocab_size:, :] = 0
285
286
287
288
289
        lora_embedding = VocabParallelEmbeddingWithLoRA(embedding)
        lora_embedding.create_lora_weights(max_loras, lora_config)

        return embedding, lora_embedding

290
    for i in range(NUM_RANDOM_SEEDS):
291
292
293
294
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
        embedding, lora_embedding = create_random_embedding_layer()
295
        lora_embedding.set_mapping(punica_wrapper)
296
297
298
299
300
301
302
303
304
        lora_dict, _ = populate_loras(
            id_to_index,
            layer=lora_embedding,
            layer_weights=embedding.weight.T,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=num_loras * 3,
305
            input_size=(200,),
306
            input_range=(1, vocab_size),
307
308
309
310
311
312
313
314
315
316
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
        punica_wrapper.update_metadata(
            lora_mapping,
            id_to_index,
            max_loras,
            vocab_size,
            lora_config.lora_extra_vocab_size,
        )
317
318
319

        lora_result = lora_embedding(torch.cat(inputs))

320
        expected_results: list[torch.Tensor] = []
321
322
323
324
325
        for input_, lora_id in zip(inputs, prompt_mapping):
            lora = lora_dict[lora_id]
            result = embedding(input_)
            after_a = F.embedding(
                input_,
326
                lora.lora_a.T,
327
            )
328
            result += after_a @ lora.lora_b.T
329
330
331
332
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
333
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
334
335
336
337
338
339
340
341
342

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
            lora_embedding.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=num_loras * 3,
343
            input_size=(200,),
344
            input_range=(1, vocab_size),
345
346
347
348
349
350
351
352
353
354
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
        punica_wrapper.update_metadata(
            lora_mapping,
            id_to_index,
            max_loras,
            vocab_size,
            lora_config.lora_extra_vocab_size,
        )
355
356
357
358
359

        lora_result = lora_embedding(torch.cat(inputs))
        expected_result = embedding(torch.cat(inputs))

        rtol, atol = TOLERANCES[lora_result.dtype]
360
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
361
362
363


@torch.inference_mode()
364
365
# @pytest.mark.skip(
#     reason="Fails when loras are in any slot other than the first.")
366
@pytest.mark.parametrize("num_loras", [1, 2, 4])
367
@pytest.mark.parametrize("device", DEVICES)
368
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
369
@pytest.mark.parametrize("stage", STAGES)
370
371
372
def test_embeddings_with_new_embeddings(
    dist_init, num_loras, device, vocab_size, stage
) -> None:
373
374
375
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)

376
    torch.set_default_device(device)
377
    max_loras = 8
378
    punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
379
    assert check_punica_wrapper(punica_wrapper)
380
381
382
    lora_config = LoRAConfig(
        max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16
    )
383
384

    def create_random_embedding_layer():
385
        embedding = VocabParallelEmbedding(vocab_size, 256)
386
387
        embedding_data = torch.rand_like(embedding.weight.data)
        embedding.weight.data = embedding_data
388
        embedding.weight.data[vocab_size:, :] = 0
389
        expanded_embedding = VocabParallelEmbedding(
390
            vocab_size + lora_config.lora_extra_vocab_size * max_loras,
391
            256,
392
393
            org_num_embeddings=vocab_size,
        )
394
        expanded_embedding.weight.data[:vocab_size, :] = embedding_data
395
        # We need to deepcopy the embedding as it will be modified
396
        # in place
397
        lora_embedding = VocabParallelEmbeddingWithLoRA(deepcopy(expanded_embedding))
398
399
400
401
        lora_embedding.create_lora_weights(max_loras, lora_config)

        return expanded_embedding, lora_embedding

402
    for i in range(NUM_RANDOM_SEEDS):
403
404
405
406
407
408
409
410
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
        expanded_embedding, lora_embedding = create_random_embedding_layer()
        lora_dict, _ = populate_loras(
            id_to_index,
            layer=lora_embedding,
            layer_weights=torch.zeros(
411
412
                (256, vocab_size + lora_config.lora_extra_vocab_size)
            ),
413
414
415
            generate_embeddings_tensor=256,
        )

416
        lora_embedding.set_mapping(punica_wrapper)
417
418
419
420
421
422
423
424
        # All embeddings tensors have the same shape.
        embeddings_tensors = [
            lora_dict[id].embeddings_tensor for id in sorted(lora_dict.keys())
        ]
        embeddings_tensor_len = embeddings_tensors[0].shape[0]

        # Add empty embeddings_tensors for unoccupied lora slots.
        for _ in range(max_loras - len(embeddings_tensors)):
425
            embeddings_tensors.append(torch.zeros(embeddings_tensors[0].shape))
426
427
428
429

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=num_loras * 3,
430
            input_size=(200,),
431
            input_range=(1, vocab_size),
432
433
434
435
436
437
438
439
440
441
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
        punica_wrapper.update_metadata(
            lora_mapping,
            id_to_index,
            max_loras,
            vocab_size,
            lora_config.lora_extra_vocab_size,
        )
442
443
444
445
        original_inputs = deepcopy(inputs)

        # Force some of the inputs to be in the extended embeddings range
        # to guarantee that their behavior is tested.
446
447
448
        for input_, original_input_, lora_id in zip(
            inputs, original_inputs, prompt_mapping
        ):
449
            embedding_id = lora_id - 1
450
451
            input_[-1] = vocab_size + (embedding_id * embeddings_tensor_len)
            original_input_[-1] = vocab_size
452
            input_[-2] = vocab_size + ((embedding_id + 1) * embeddings_tensor_len - 1)
453
            original_input_[-2] = vocab_size + embeddings_tensor_len - 1
454

455
456
457
        expanded_embedding.weight[
            vocab_size : vocab_size + (embeddings_tensor_len * max_loras)
        ] = torch.cat(embeddings_tensors)
458
459
460

        lora_result = lora_embedding(torch.cat(original_inputs))

461
        expected_results: list[torch.Tensor] = []
462
463
464
        for input_, original_input_, lora_id in zip(
            inputs, original_inputs, prompt_mapping
        ):
465
466
467
468
            lora = lora_dict[lora_id]
            result = expanded_embedding(input_)
            after_a = F.embedding(
                original_input_,
469
                lora.lora_a.T,
470
            )
471
            result += after_a @ lora.lora_b.T
472
473
474
475
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
476
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
477
478
479
480
481
482
483
484
485

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
            lora_embedding.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=num_loras * 3,
486
            input_size=(200,),
487
            input_range=(1, vocab_size),
488
489
            device=device,
        )
490
        original_inputs = deepcopy(inputs)
491
492
493
494
495
496
497
498
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
        punica_wrapper.update_metadata(
            lora_mapping,
            id_to_index,
            max_loras,
            vocab_size,
            lora_config.lora_extra_vocab_size,
        )
499
500
501
502
        lora_result = lora_embedding(torch.cat(original_inputs))
        expected_result = expanded_embedding(torch.cat(inputs))

        rtol, atol = TOLERANCES[lora_result.dtype]
503
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
504
505
506


@torch.inference_mode()
507
@pytest.mark.parametrize("num_loras", [1, 2, 4])
508
@pytest.mark.parametrize("device", DEVICES)
509
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 256512])
510
@pytest.mark.parametrize("stage", STAGES)
511
512
513
def test_lm_head_logits_processor(
    dist_init, num_loras, device, vocab_size, stage
) -> None:
514
515
516
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)

517
    torch.set_default_device(device)
518
    max_loras = 8
519
    punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
520
    assert check_punica_wrapper(punica_wrapper)
521
522
523
    lora_config = LoRAConfig(
        max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16
    )
524

525
    def _pretest():
526
527
528
529
530
531
        linear = ParallelLMHead(
            vocab_size + lora_config.lora_extra_vocab_size,
            1024,
            vocab_size,
            params_dtype=torch.float16,
        )
532
        linear.weight.data = torch.rand_like(linear.weight.data)
533
        linear.weight.data[:, vocab_size:] = 0
534
        logits_processor = LogitsProcessor(
535
536
            vocab_size + lora_config.lora_extra_vocab_size, vocab_size
        )
537
        lora_logits_processor = LogitsProcessorWithLoRA(
538
539
            logits_processor, 1024, linear.weight.dtype, linear.weight.device, None
        )
540
        lora_logits_processor.create_lora_weights(max_loras, lora_config)
541

542
        return linear, logits_processor, lora_logits_processor
543

544
    for i in range(NUM_RANDOM_SEEDS):
545
546
547
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
548
        linear, logits_processor, lora_logits_processor = _pretest()
549
        lora_logits_processor.set_mapping(punica_wrapper)
550
551
552
        # NOTE: all the generated loras share the same embeddings tensor.
        lora_dict, _ = populate_loras(
            id_to_index,
553
            layer=lora_logits_processor,
554
555
556
557
558
559
560
561
562
563
564
            layer_weights=linear.weight,
            generate_embeddings_tensor=1024,
        )
        embeddings_tensor = list(lora_dict.values())[0].embeddings_tensor
        embeddings_tensor_len = embeddings_tensor.shape[0]

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=8 * num_loras,  # * 3,
            input_size=(1, 1024),
            input_range=(0, 1),
565
            input_type=torch.float16,
566
567
568
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
569
        punica_wrapper.update_metadata(
570
571
572
            lora_mapping,
            id_to_index,
            max_loras,
573
            vocab_size,
574
575
            lora_config.lora_extra_vocab_size,
        )
576
        input_ = torch.rand(20, 1024)
577

578
        lora_result = lora_logits_processor._get_logits(
579
580
            hidden_states=torch.cat(inputs), lm_head=linear, embedding_bias=None
        )
581

582
        original_lm_head = deepcopy(linear)
583

584
585
586
587
        linear.weight[
            logits_processor.org_vocab_size : logits_processor.org_vocab_size
            + embeddings_tensor_len
        ] = embeddings_tensor
588

589
        logits_processor.org_vocab_size = vocab_size + lora_config.lora_extra_vocab_size
590
        expected_results: list[torch.Tensor] = []
591
592
        for input_, lora_id in zip(inputs, prompt_mapping):
            lora = lora_dict[lora_id]
593
594
595
596
            result = logits_processor._get_logits(
                hidden_states=input_, lm_head=linear, embedding_bias=None
            )
            result[:, vocab_size + embeddings_tensor_len :] = float("-inf")
597
            result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
598
599
            expected_results.append(result)
        expected_result = torch.cat(expected_results)
600
        logits_processor.org_vocab_size = vocab_size
601
602
603
604

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
605
            lora_logits_processor.reset_lora(slot_idx)
606
607
608
609
610
611

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=8 * num_loras * 3,
            input_size=(1, 1024),
            input_range=(0, 1),
612
            input_type=torch.float16,
613
614
615
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
616
617
618
619
620
621
622
        punica_wrapper.update_metadata(
            lora_mapping,
            id_to_index,
            max_loras,
            vocab_size,
            lora_config.lora_extra_vocab_size,
        )
623
624
625

        lora_result = lora_logits_processor._get_logits(
            hidden_states=torch.cat(inputs),
626
            lm_head=original_lm_head,
627
628
            embedding_bias=None,
        )[:, :vocab_size]
629
630
        expected_result = logits_processor._get_logits(
            hidden_states=torch.cat(inputs),
631
            lm_head=original_lm_head,
632
633
            embedding_bias=None,
        )
634
635

        rtol, atol = TOLERANCES[lora_result.dtype]
636
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
637
638


639
@torch.inference_mode()
640
@pytest.mark.parametrize("num_loras", [1, 2, 4])
641
@pytest.mark.parametrize("device", DEVICES)
642
@pytest.mark.parametrize("stage", STAGES)
643
644
645
646
647
648
def test_linear_replicated(
    dist_init,
    num_loras,
    device,
    stage,
) -> None:
649
650
651
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)

652
    max_loras = 8
653
    torch.set_default_device(device)
654
    punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
655
    assert check_punica_wrapper(punica_wrapper)
656
657
658
659
660
    lora_config = LoRAConfig(
        max_loras=max_loras,
        max_lora_rank=8,
        lora_dtype=torch.float16,
    )
661
662

    def create_random_linear_replicated_layer():
663
        linear = ReplicatedLinear(4096, 4096, bias=False, params_dtype=torch.float16)
664
665
666
667
        linear.weight.data = torch.rand_like(linear.weight.data)
        lora_linear = ReplicatedLinearWithLoRA(linear)

        lora_linear.create_lora_weights(max_loras, lora_config)
668
669
670
671
672
673
        assert (
            lora_linear.n_slices
            == len(lora_linear.lora_a_stacked)
            == len(lora_linear.lora_b_stacked)
            == 1
        )
674
675
        return linear, lora_linear

676
    for i in range(NUM_RANDOM_SEEDS):
677
678
679
680
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
        linear, lora_linear = create_random_linear_replicated_layer()
681
        assert torch.equal(linear.weight, lora_linear.weight)
682
683
684
685
686
687
688
689
690
691
692
693
694
        lora_linear.set_mapping(punica_wrapper)
        lora_dict, _ = populate_loras(
            id_to_index,
            layer=lora_linear,
            layer_weights=linear.weight,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
            input_type=torch.float16,
695
696
697
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
698
699
700
701
702
703
704
705
706
707
        punica_wrapper.update_metadata(
            lora_mapping,
            id_to_index,
            max_loras,
            512,
            lora_config.lora_extra_vocab_size,
        )

        lora_result = lora_linear(torch.cat(inputs))[0]

708
        expected_results: list[torch.Tensor] = []
709
710
711
        for input_, lora_id in zip(inputs, prompt_mapping):
            lora = lora_dict[lora_id]
            result = linear(input_)[0]
712
            result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
713
714
715
716
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
717
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
718
719
720
721
722
723
724
725
726
727
728
729

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
            lora_linear.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
            input_type=torch.float16,
730
731
732
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
733

734
735
736
        punica_wrapper.update_metadata(
            lora_mapping, id_to_index, max_loras, 512, lora_config.lora_extra_vocab_size
        )
737
738
739
740
741

        lora_result = lora_linear(torch.cat(inputs))[0]
        expected_result = linear(torch.cat(inputs))[0]

        rtol, atol = TOLERANCES[lora_result.dtype]
742
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
743
744


745
@torch.inference_mode()
746
@pytest.mark.parametrize("num_loras", [1, 2, 4])
747
@pytest.mark.parametrize("orientation", ["row", "column"])
748
@pytest.mark.parametrize("fully_shard", [True, False])
749
@pytest.mark.parametrize("device", DEVICES)
750
@pytest.mark.parametrize("stage", STAGES)
751
752
753
def test_linear_parallel(
    dist_init, num_loras, orientation, fully_shard, device, stage
) -> None:
754
755
756
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)

757
    max_loras = 8
758
    torch.set_default_device(device)
759
    punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
760
    assert check_punica_wrapper(punica_wrapper)
761
762
763
764
765
766
    lora_config = LoRAConfig(
        max_loras=max_loras,
        max_lora_rank=8,
        fully_sharded_loras=fully_shard,
        lora_dtype=torch.float16,
    )
767
768
769

    def create_random_linear_parallel_layer():
        if orientation == "row":
770
771
772
            linear = RowParallelLinear(
                4096, 4096, bias=False, params_dtype=torch.float16
            )
773
            linear.weight.data = torch.rand_like(linear.weight.data)
774
775
776
777
778
            lora_linear = (
                RowParallelLinearWithLoRA(linear)
                if not fully_shard
                else RowParallelLinearWithShardedLoRA(linear)
            )
779
        else:
780
781
782
            linear = ColumnParallelLinear(
                4096, 4096, bias=False, params_dtype=torch.float16
            )
783
            linear.weight.data = torch.rand_like(linear.weight.data)
784
785
786
787
788
            lora_linear = (
                ColumnParallelLinearWithLoRA(linear)
                if not fully_shard
                else ColumnParallelLinearWithShardedLoRA(linear)
            )
789
        lora_linear.create_lora_weights(max_loras, lora_config)
790
791
792
793
794
795
        assert (
            lora_linear.n_slices
            == len(lora_linear.lora_a_stacked)
            == len(lora_linear.lora_b_stacked)
            == 1
        )
796

797
798
        return linear, lora_linear

799
    for i in range(NUM_RANDOM_SEEDS):
800
801
802
803
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
        linear, lora_linear = create_random_linear_parallel_layer()
804
        assert torch.equal(linear.weight, lora_linear.weight)
805
        lora_linear.set_mapping(punica_wrapper)
806
807
808
809
810
811
812
813
814
815
816
        lora_dict, _ = populate_loras(
            id_to_index,
            layer=lora_linear,
            layer_weights=linear.weight,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
817
            input_type=torch.float16,
818
819
820
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
821
        punica_wrapper.update_metadata(
822
823
824
825
826
827
828
829
830
            lora_mapping,
            id_to_index,
            max_loras,
            512,
            lora_config.lora_extra_vocab_size,
        )

        lora_result = lora_linear(torch.cat(inputs))[0]

831
        expected_results: list[torch.Tensor] = []
832
833
834
        for input_, lora_id in zip(inputs, prompt_mapping):
            lora = lora_dict[lora_id]
            result = linear(input_)[0]
835
            result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
836
837
838
839
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
840
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
841
842
843
844
845
846
847
848
849
850
851

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
            lora_linear.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
852
            input_type=torch.float16,
853
854
855
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
856

857
858
859
        punica_wrapper.update_metadata(
            lora_mapping, id_to_index, max_loras, 512, lora_config.lora_extra_vocab_size
        )
860
861
862
863
864

        lora_result = lora_linear(torch.cat(inputs))[0]
        expected_result = linear(torch.cat(inputs))[0]

        rtol, atol = TOLERANCES[lora_result.dtype]
865
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
866
867
868


@torch.inference_mode()
869
@pytest.mark.parametrize("num_loras", [1, 2, 4])
870
@pytest.mark.parametrize("repeats", [1, 2, 3])
871
@pytest.mark.parametrize("fully_shard", [True, False])
872
@pytest.mark.parametrize("device", DEVICES)
873
@pytest.mark.parametrize("stage", STAGES)
874
875
876
def test_column_parallel_packed(
    dist_init, num_loras, repeats, fully_shard, device, stage
) -> None:
877
878
879
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)

880
    max_loras = 8
881
    torch.set_default_device(device)
882
    punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
883
    assert check_punica_wrapper(punica_wrapper)
884
885
886
887
888
889
    lora_config = LoRAConfig(
        max_loras=max_loras,
        max_lora_rank=8,
        fully_sharded_loras=fully_shard,
        lora_dtype=torch.float16,
    )
890
891
892

    def create_column_parallel_packed_layer():
        if repeats == 2:
893
894
895
            linear = MergedColumnParallelLinear(
                4096, [4096] * repeats, bias=False, params_dtype=torch.float16
            )
896
            linear.weight.data = torch.rand_like(linear.weight.data)
897
898
899
900
901
            lora_linear = (
                MergedColumnParallelLinearWithLoRA(linear)
                if not fully_shard
                else MergedColumnParallelLinearWithShardedLoRA(linear)
            )
902
        elif repeats == 3:
903
904
905
            linear = QKVParallelLinear(
                4096, 64, 32, bias=False, params_dtype=torch.float16
            )
906
            linear.weight.data = torch.rand_like(linear.weight.data)
907
908
909
910
911
            lora_linear = (
                MergedQKVParallelLinearWithLoRA(linear)
                if not fully_shard
                else MergedQKVParallelLinearWithShardedLoRA(linear)
            )
912
        else:
913
914
915
            linear = QKVParallelLinear(
                4096, 64, 32, bias=False, params_dtype=torch.float16
            )
916
            linear.weight.data = torch.rand_like(linear.weight.data)
917
918
919
920
921
            lora_linear = (
                QKVParallelLinearWithLoRA(linear)
                if not fully_shard
                else QKVParallelLinearWithShardedLoRA(linear)
            )
922
923
924
925
926
927
928

        @dataclass
        class FakeConfig:
            hidden_size = 4096
            num_key_value_heads = 32
            num_attention_heads = 32

929
        n_slices = repeats
930
931
932
933
934
935
936
937
938
        lora_linear.create_lora_weights(
            max_loras, lora_config, model_config=FakeConfig()
        )
        assert (
            lora_linear.n_slices
            == len(lora_linear.lora_a_stacked)
            == len(lora_linear.lora_b_stacked)
            == n_slices
        )
939

940
941
        return linear, lora_linear

942
    for i in range(NUM_RANDOM_SEEDS):
943
944
945
946
947
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)

        linear, lora_linear = create_column_parallel_packed_layer()
948
        assert torch.equal(linear.weight, lora_linear.weight)
949
        lora_linear.set_mapping(punica_wrapper)
950
951
952
953
954
955
956
957
958
959
960
961
        lora_dict, sublora_dict = populate_loras(
            id_to_index,
            layer=lora_linear,
            layer_weights=linear.weight,
            repeats=repeats,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
962
            input_type=torch.float16,
963
964
965
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
966

967
        punica_wrapper.update_metadata(
968
969
970
971
972
973
974
975
976
            lora_mapping,
            id_to_index,
            max_loras,
            512,
            lora_config.lora_extra_vocab_size,
        )

        lora_result = lora_linear(torch.cat(inputs))[0]

977
        expected_results: list[torch.Tensor] = []
978
979
980
981
        for input_, lora_id in zip(inputs, prompt_mapping):
            result = linear(input_)[0]
            subloras = sublora_dict[lora_id]
            for i, sublora in enumerate(subloras):
982
983
984
                result[
                    :, sublora.lora_b.shape[0] * i : sublora.lora_b.shape[0] * (i + 1)
                ] += input_ @ sublora.lora_a.T @ sublora.lora_b.T * sublora.scaling
985
986
987
988
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
989
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
990
991
992
993
994
995
996
997
998

        for slot_idx in range(max_loras):
            lora_linear.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
999
            input_type=torch.float16,
1000
1001
1002
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
1003

1004
        punica_wrapper.update_metadata(
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
            lora_mapping,
            id_to_index,
            max_loras,
            512,
            lora_config.lora_extra_vocab_size,
        )

        lora_result = lora_linear(torch.cat(inputs))[0]
        expected_result = linear(torch.cat(inputs))[0]

        rtol, atol = TOLERANCES[lora_result.dtype]
1016
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
1017
1018


1019
@pytest.mark.parametrize("tp_size", [1, 2, 4, 8])
1020
@pytest.mark.parametrize(
1021
1022
    "seed", list(range(VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS))
)
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
def test_vocab_parallel_embedding_indices(tp_size, seed):
    random.seed(seed)
    vocab_size = random.randint(4000, 64000)
    added_vocab_size = random.randint(0, 1024)
    org_vocab_size = vocab_size - added_vocab_size
    last_org_vocab_end_index = 0
    last_added_vocab_end_index = org_vocab_size
    computed_vocab_size = 0
    computed_org_vocab_size = 0
    computed_added_vocab_size = 0
    vocab_size_padded = -1

1035
1036
1037
    all_org_tokens: list[int] = []
    all_added_tokens: list[int] = []
    token_ids: list[int] = []
1038
1039

    for tp_rank in range(tp_size):
1040
1041
        with (
            patch(
1042
                "vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank",
1043
1044
1045
                return_value=tp_rank,
            ),
            patch(
1046
                "vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_world_size",
1047
1048
1049
                return_value=tp_size,
            ),
        ):
1050
            vocab_embedding = VocabParallelEmbedding(
1051
1052
                vocab_size, 1, org_num_embeddings=org_vocab_size
            )
1053
1054
1055
1056
        vocab_size_padded = vocab_embedding.num_embeddings_padded
        shard_indices = vocab_embedding.shard_indices
        # Assert that the ranges are contiguous
        assert shard_indices.org_vocab_start_index == last_org_vocab_end_index
1057
        assert shard_indices.added_vocab_start_index == last_added_vocab_end_index
1058
1059
1060
1061
1062
1063
1064
1065

        # Ensure that we are not exceeding the vocab size
        computed_vocab_size += shard_indices.num_elements_padded
        computed_org_vocab_size += shard_indices.num_org_elements
        computed_added_vocab_size += shard_indices.num_added_elements

        # Ensure that the ranges are not overlapping
        all_org_tokens.extend(
1066
1067
1068
1069
            range(
                shard_indices.org_vocab_start_index, shard_indices.org_vocab_end_index
            )
        )
1070
        all_added_tokens.extend(
1071
1072
1073
1074
1075
            range(
                shard_indices.added_vocab_start_index,
                shard_indices.added_vocab_end_index,
            )
        )
1076
1077

        token_ids.extend(
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
            range(
                shard_indices.org_vocab_start_index, shard_indices.org_vocab_end_index
            )
        )
        token_ids.extend(
            [-1]
            * (shard_indices.num_org_elements_padded - shard_indices.num_org_elements)
        )
        token_ids.extend(
            range(
                shard_indices.added_vocab_start_index,
                shard_indices.added_vocab_end_index,
            )
        )
1092
        token_ids.extend(
1093
1094
1095
1096
1097
1098
            [-1]
            * (
                shard_indices.num_added_elements_padded
                - shard_indices.num_added_elements
            )
        )
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125

        last_org_vocab_end_index = shard_indices.org_vocab_end_index
        last_added_vocab_end_index = shard_indices.added_vocab_end_index

    assert computed_vocab_size == vocab_size_padded
    assert computed_org_vocab_size == org_vocab_size
    assert computed_added_vocab_size == added_vocab_size

    # Ensure that the ranges are not overlapping
    assert len(all_org_tokens) == len(set(all_org_tokens))
    assert len(all_added_tokens) == len(set(all_added_tokens))
    assert not set(all_org_tokens).intersection(set(all_added_tokens))

    token_ids_tensor = torch.tensor(token_ids, dtype=torch.long)
    reindex_mapping = vocab_embedding.get_sharded_to_full_mapping()
    assert reindex_mapping is not None or tp_size == 1
    if reindex_mapping is not None:
        reindexed_token_ids = token_ids_tensor[reindex_mapping]
        expected = torch.tensor(list(range(0, vocab_size)))
        assert reindexed_token_ids[:vocab_size].equal(expected)
        assert torch.all(reindexed_token_ids[vocab_size:] == -1)


def test_get_masked_input_and_mask():
    x = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])

    # base tp 1 case, no padding
1126
1127
1128
1129
1130
1131
1132
1133
    modified_x, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=0,
        org_vocab_end_index=8,
        added_vocab_start_index=8,
        added_vocab_end_index=12,
        num_org_vocab_padding=0,
    )
1134
1135
1136
    assert torch.equal(x, modified_x)

    # tp 2 case, no padding
1137
1138
1139
1140
1141
1142
1143
1144
    modified_x_rank_0, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=0,
        org_vocab_end_index=4,
        added_vocab_start_index=8,
        added_vocab_end_index=10,
        num_org_vocab_padding=0,
    )
1145
1146
1147
1148
1149
1150
    modified_x_rank_1, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=4,
        org_vocab_end_index=8,
        added_vocab_start_index=10,
        added_vocab_end_index=12,
1151
1152
1153
1154
1155
1156
1157
1158
        num_org_vocab_padding=0,
    )
    assert torch.equal(
        modified_x_rank_0, torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 0, 0])
    )
    assert torch.equal(
        modified_x_rank_1, torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 4, 5])
    )
1159
1160

    # tp 4 case, no padding
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
    modified_x_rank_0, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=0,
        org_vocab_end_index=2,
        added_vocab_start_index=8,
        added_vocab_end_index=9,
        num_org_vocab_padding=0,
    )
    modified_x_rank_1, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=2,
        org_vocab_end_index=4,
        added_vocab_start_index=9,
        added_vocab_end_index=10,
        num_org_vocab_padding=0,
    )
1177
1178
1179
1180
1181
1182
    modified_x_rank_2, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=4,
        org_vocab_end_index=6,
        added_vocab_start_index=10,
        added_vocab_end_index=11,
1183
1184
        num_org_vocab_padding=0,
    )
1185
1186
1187
1188
1189
1190
    modified_x_rank_3, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=6,
        org_vocab_end_index=8,
        added_vocab_start_index=11,
        added_vocab_end_index=12,
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
        num_org_vocab_padding=0,
    )
    assert torch.equal(
        modified_x_rank_0, torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0])
    )
    assert torch.equal(
        modified_x_rank_1, torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0])
    )
    assert torch.equal(
        modified_x_rank_2, torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0])
    )
    assert torch.equal(
        modified_x_rank_3, torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2])
    )
1205
1206

    # base tp 1 case, with padding
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
    modified_x, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=0,
        org_vocab_end_index=8,
        added_vocab_start_index=8,
        added_vocab_end_index=12,
        num_org_vocab_padding=2,
    )
    assert torch.equal(
        modified_x, torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13])
    )
1218
1219

    # tp 2 case, with padding
1220
1221
1222
1223
1224
1225
1226
1227
    modified_x_rank_0, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=0,
        org_vocab_end_index=4,
        added_vocab_start_index=8,
        added_vocab_end_index=10,
        num_org_vocab_padding=2,
    )
1228
1229
1230
1231
1232
1233
    modified_x_rank_1, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=4,
        org_vocab_end_index=8,
        added_vocab_start_index=10,
        added_vocab_end_index=12,
1234
1235
1236
1237
1238
1239
1240
1241
        num_org_vocab_padding=2,
    )
    assert torch.equal(
        modified_x_rank_0, torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 6, 7, 0, 0])
    )
    assert torch.equal(
        modified_x_rank_1, torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 6, 7])
    )
1242
1243

    # tp 4 case, with padding
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
    modified_x_rank_0, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=0,
        org_vocab_end_index=2,
        added_vocab_start_index=8,
        added_vocab_end_index=9,
        num_org_vocab_padding=2,
    )
    modified_x_rank_1, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=2,
        org_vocab_end_index=4,
        added_vocab_start_index=9,
        added_vocab_end_index=10,
        num_org_vocab_padding=2,
    )
1260
1261
1262
1263
1264
1265
    modified_x_rank_2, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=4,
        org_vocab_end_index=6,
        added_vocab_start_index=10,
        added_vocab_end_index=11,
1266
1267
        num_org_vocab_padding=2,
    )
1268
1269
1270
1271
1272
1273
    modified_x_rank_3, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=6,
        org_vocab_end_index=8,
        added_vocab_start_index=11,
        added_vocab_end_index=12,
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
        num_org_vocab_padding=2,
    )
    assert torch.equal(
        modified_x_rank_0, torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0])
    )
    assert torch.equal(
        modified_x_rank_1, torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 4, 0, 0])
    )
    assert torch.equal(
        modified_x_rank_2, torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 4, 0])
    )
    assert torch.equal(
        modified_x_rank_3, torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 4])
    )