test_layers.py 43.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
import random
from copy import deepcopy
from dataclasses import dataclass
7
from typing import Optional
8
from unittest.mock import patch
9

10
import pytest
11
12
13
import torch
import torch.nn.functional as F

14
from vllm.config.lora import LoRAConfig
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from vllm.lora.layers import (
    BaseLayerWithLoRA,
    ColumnParallelLinearWithLoRA,
    ColumnParallelLinearWithShardedLoRA,
    LogitsProcessorWithLoRA,
    LoRAMapping,
    MergedColumnParallelLinearWithLoRA,
    MergedColumnParallelLinearWithShardedLoRA,
    MergedQKVParallelLinearWithLoRA,
    MergedQKVParallelLinearWithShardedLoRA,
    QKVParallelLinearWithLoRA,
    QKVParallelLinearWithShardedLoRA,
    ReplicatedLinearWithLoRA,
    RowParallelLinearWithLoRA,
    RowParallelLinearWithShardedLoRA,
    VocabParallelEmbeddingWithLoRA,
)
32
from vllm.lora.models import LoRALayerWeights, PackedLoRALayerWeights
33
from vllm.lora.punica_wrapper import get_punica_wrapper
34
35
36
37
38
39
40
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    MergedColumnParallelLinear,
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
41
from vllm.model_executor.layers.logits_processor import LogitsProcessor
42
from vllm.model_executor.layers.vocab_parallel_embedding import (
43
44
45
46
    ParallelLMHead,
    VocabParallelEmbedding,
    get_masked_input_and_mask,
)
47
from vllm.model_executor.utils import set_random_seed
48
from vllm.platforms import current_platform
49
50
51
52
53
54
55
56

from .utils import DummyLoRAManager

TOLERANCES = {
    torch.float16: (5e-3, 5e-3),
    torch.float32: (5e-3, 5e-3),
    torch.bfloat16: (3e-2, 2e-2),
}
57
58
59

pytestmark = pytest.mark.skipif(
    not (current_platform.is_cuda_alike() or current_platform.is_cpu()),
60
61
    reason="Backend not supported",
)
62

63
64
65
66
67
DEVICES = (
    [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
    if current_platform.is_cuda_alike()
    else ["cpu"]
)
68

69
# prefill stage(True) or decode stage(False)
70
STAGES = [True, False]
71

72
NUM_RANDOM_SEEDS = 2
73

74
VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS = 2
75
76
77


@pytest.fixture(autouse=True)
78
def clean_cache_reset_device(reset_default_device):
79
    # Release any memory we might be holding on to. CI runs OOMs otherwise.
80
81
    from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT

82
83
84
85
86
    _LORA_B_PTR_DICT.clear()
    _LORA_A_PTR_DICT.clear()

    yield

87

88
89
90
@pytest.fixture(autouse=True)
def skip_cuda_with_stage_false(request):
    """
91
    On cuda-like platforms, we use the same kernels for prefill and decode
92
93
94
95
96
    stage, and 'stage' is generally ignored, so we only need to test once.
    """
    if current_platform.is_cuda_alike():
        try:
            if hasattr(request.node, "callspec") and hasattr(
97
98
                request.node.callspec, "params"
            ):
99
100
101
102
103
104
105
106
                params = request.node.callspec.params
                if "stage" in params and params["stage"] is False:
                    pytest.skip("Skip test when stage=False")
        except Exception:
            pass
    yield


107
108
109
def get_random_id_to_index(
    num_loras: int, num_slots: int, log: bool = True
) -> list[Optional[int]]:
110
111
112
113
114
115
116
117
118
119
120
121
    """Creates a random lora_id_to_index mapping.

    Args:
        num_loras: The number of active loras in the mapping.
        num_slots: The number of slots in the mapping. Must be larger
            than num_loras.
        log: Whether to log the output.
    """

    if num_loras > num_slots:
        raise ValueError(
            f"num_loras is higher than num_slots: {num_loras} > {num_slots}. "
122
123
            "num_loras must be less than or equal to num_slots."
        )
124

125
    slots: list[Optional[int]] = [None] * num_slots
126
127
128
129
130
131
132
133
134
135
136
    random_slot_selections = (torch.randperm(num_slots)[:num_loras]).tolist()
    for lora_id, slot_idx in enumerate(random_slot_selections, start=1):
        slots[slot_idx] = lora_id

    if log:
        print(f"Created lora_id_to_index mapping: {slots}.")

    return slots


def populate_loras(
137
    id_to_index: list[Optional[int]],
138
139
140
141
    layer: BaseLayerWithLoRA,
    layer_weights: torch.Tensor,
    generate_embeddings_tensor: int = 0,
    repeats: int = 1,
142
) -> tuple[dict[int, LoRALayerWeights], dict[int, list[LoRALayerWeights]]]:
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
    """This method populates the lora layers with lora weights.

    Args:
        id_to_index: a list of lora ids. The index of the lora id
            represents which memory slot the lora matrices are
            stored in. A None value indicates a free slot.
        layer: the LoRAlayer to populate.
        layer_weights: the PyTorch tensor containing the layer's
            weights.
        generate_embeddings_tensor: whether to generate an
            embeddings tensor for each LoRA.
        repeats: must only be set for column parallel packed
            layers. Indicates the number of loras to compose
            together to create a single lora layer.
    """

    # Dictionary that maps the lora ID to the
    # corresponding lora weights.
161
    lora_dict: dict[int, LoRALayerWeights] = dict()
162
163

    # Dictionary that maps the lora ID to the
164
    # corresponding subloras.
165
    sublora_dict: dict[int, list[LoRALayerWeights]] = dict()
166
167
168

    for slot_idx, lora_id in enumerate(id_to_index):
        if lora_id is not None:
169
            subloras: list[LoRALayerWeights] = []
170
171
            sublora_len = layer_weights.shape[0] // repeats
            for i in range(repeats):
172
173
174
175
176
177
178
179
                sublora = DummyLoRAManager(layer_weights.device).init_random_lora(
                    module_name=f"fake_{i}",
                    weight=layer_weights,
                    generate_embeddings_tensor=generate_embeddings_tensor,
                )
                sublora.lora_b = sublora.lora_b[
                    (sublora_len * i) : (sublora_len * (i + 1)), :
                ]
180
181
182
                sublora.optimize()
                subloras.append(sublora)

183
            lora = PackedLoRALayerWeights.pack(subloras) if repeats > 1 else subloras[0]
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198

            layer.set_lora(
                slot_idx,
                lora_a=lora.lora_a,
                lora_b=lora.lora_b,
                embeddings_tensor=lora.embeddings_tensor,
            )

            lora_dict[lora_id] = lora
            sublora_dict[lora_id] = subloras

    return lora_dict, sublora_dict


def create_random_inputs(
199
    active_lora_ids: list[int],
200
    num_inputs: int,
201
202
    input_size: tuple[int, ...],
    input_range: tuple[float, float],
203
    input_type: torch.dtype = torch.int,
204
    device: torch.device = "cuda",
205
) -> tuple[list[torch.Tensor], list[int], list[int]]:
206
207
208
209
210
211
212
213
214
215
216
217
218
    """Creates random inputs.

    Args:
        active_lora_ids: lora IDs of active lora weights.
        num_inputs: the number of inputs to create.
        input_size: the size of each individual input.
        input_range: the range of values to include in the input.
            input_range[0] <= possible input values < input_range[1]
        input_type: the type of values in the input.
    """

    low, high = input_range

219
220
221
    inputs: list[torch.Tensor] = []
    index_mapping: list[int] = []
    prompt_mapping: list[int] = []
222

223
224
225
    for _ in range(num_inputs):
        if input_type == torch.int:
            inputs.append(
226
227
228
229
                torch.randint(
                    low=int(low), high=int(high), size=input_size, device=device
                )
            )
230
231
        else:
            inputs.append(
232
233
234
                torch.rand(size=input_size, dtype=input_type, device=device) * high
                + low
            )
235
236
237
238
239
240
241
242

        lora_id = random.choice(active_lora_ids)
        index_mapping += [lora_id] * input_size[0]
        prompt_mapping += [lora_id]

    return inputs, index_mapping, prompt_mapping


243
244
245
246
247
def check_punica_wrapper(punica_wrapper) -> bool:
    if current_platform.is_cuda_alike():
        from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU

        return type(punica_wrapper) is PunicaWrapperGPU
248
249
250
251
    elif current_platform.is_cpu():
        from vllm.lora.punica_wrapper.punica_cpu import PunicaWrapperCPU

        return type(punica_wrapper) is PunicaWrapperCPU
252
253
254
255
    else:
        return False


256
@torch.inference_mode()
257
@pytest.mark.parametrize("num_loras", [1, 2, 4])
258
@pytest.mark.parametrize("device", DEVICES)
259
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
260
261
@pytest.mark.parametrize("stage", STAGES)
def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
262
263
264
    # For multi-GPU testing of Triton kernel, we must explicitly set the CUDA
    # device, see: https://github.com/triton-lang/triton/issues/2925
    # Same below.
265
266
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)
267

268
    torch.set_default_device(device)
269
    max_loras = 8
270
    punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
271
    assert check_punica_wrapper(punica_wrapper)
272
273
274
    lora_config = LoRAConfig(
        max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16
    )
275
276

    def create_random_embedding_layer():
277
        embedding = VocabParallelEmbedding(vocab_size, 256)
278
        embedding.weight.data = torch.rand_like(embedding.weight.data)
279
        embedding.weight.data[vocab_size:, :] = 0
280
281
282
283
284
        lora_embedding = VocabParallelEmbeddingWithLoRA(embedding)
        lora_embedding.create_lora_weights(max_loras, lora_config)

        return embedding, lora_embedding

285
    for i in range(NUM_RANDOM_SEEDS):
286
287
288
289
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
        embedding, lora_embedding = create_random_embedding_layer()
290
        lora_embedding.set_mapping(punica_wrapper)
291
292
293
294
295
296
297
298
299
        lora_dict, _ = populate_loras(
            id_to_index,
            layer=lora_embedding,
            layer_weights=embedding.weight.T,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=num_loras * 3,
300
            input_size=(200,),
301
            input_range=(1, vocab_size),
302
303
304
305
306
307
308
309
310
311
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
        punica_wrapper.update_metadata(
            lora_mapping,
            id_to_index,
            max_loras,
            vocab_size,
            lora_config.lora_extra_vocab_size,
        )
312
313
314

        lora_result = lora_embedding(torch.cat(inputs))

315
        expected_results: list[torch.Tensor] = []
316
317
318
319
320
        for input_, lora_id in zip(inputs, prompt_mapping):
            lora = lora_dict[lora_id]
            result = embedding(input_)
            after_a = F.embedding(
                input_,
321
                lora.lora_a.T,
322
            )
323
            result += after_a @ lora.lora_b.T
324
325
326
327
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
328
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
329
330
331
332
333
334
335
336
337

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
            lora_embedding.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=num_loras * 3,
338
            input_size=(200,),
339
            input_range=(1, vocab_size),
340
341
342
343
344
345
346
347
348
349
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
        punica_wrapper.update_metadata(
            lora_mapping,
            id_to_index,
            max_loras,
            vocab_size,
            lora_config.lora_extra_vocab_size,
        )
350
351
352
353
354

        lora_result = lora_embedding(torch.cat(inputs))
        expected_result = embedding(torch.cat(inputs))

        rtol, atol = TOLERANCES[lora_result.dtype]
355
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
356
357
358


@torch.inference_mode()
359
360
# @pytest.mark.skip(
#     reason="Fails when loras are in any slot other than the first.")
361
@pytest.mark.parametrize("num_loras", [1, 2, 4])
362
@pytest.mark.parametrize("device", DEVICES)
363
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
364
@pytest.mark.parametrize("stage", STAGES)
365
366
367
def test_embeddings_with_new_embeddings(
    dist_init, num_loras, device, vocab_size, stage
) -> None:
368
369
370
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)

371
    torch.set_default_device(device)
372
    max_loras = 8
373
    punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
374
    assert check_punica_wrapper(punica_wrapper)
375
376
377
    lora_config = LoRAConfig(
        max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16
    )
378
379

    def create_random_embedding_layer():
380
        embedding = VocabParallelEmbedding(vocab_size, 256)
381
382
        embedding_data = torch.rand_like(embedding.weight.data)
        embedding.weight.data = embedding_data
383
        embedding.weight.data[vocab_size:, :] = 0
384
        expanded_embedding = VocabParallelEmbedding(
385
            vocab_size + lora_config.lora_extra_vocab_size * max_loras,
386
            256,
387
388
            org_num_embeddings=vocab_size,
        )
389
        expanded_embedding.weight.data[:vocab_size, :] = embedding_data
390
        # We need to deepcopy the embedding as it will be modified
391
        # in place
392
        lora_embedding = VocabParallelEmbeddingWithLoRA(deepcopy(expanded_embedding))
393
394
395
396
        lora_embedding.create_lora_weights(max_loras, lora_config)

        return expanded_embedding, lora_embedding

397
    for i in range(NUM_RANDOM_SEEDS):
398
399
400
401
402
403
404
405
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
        expanded_embedding, lora_embedding = create_random_embedding_layer()
        lora_dict, _ = populate_loras(
            id_to_index,
            layer=lora_embedding,
            layer_weights=torch.zeros(
406
407
                (256, vocab_size + lora_config.lora_extra_vocab_size)
            ),
408
409
410
            generate_embeddings_tensor=256,
        )

411
        lora_embedding.set_mapping(punica_wrapper)
412
413
414
415
416
417
418
419
        # All embeddings tensors have the same shape.
        embeddings_tensors = [
            lora_dict[id].embeddings_tensor for id in sorted(lora_dict.keys())
        ]
        embeddings_tensor_len = embeddings_tensors[0].shape[0]

        # Add empty embeddings_tensors for unoccupied lora slots.
        for _ in range(max_loras - len(embeddings_tensors)):
420
            embeddings_tensors.append(torch.zeros(embeddings_tensors[0].shape))
421
422
423
424

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=num_loras * 3,
425
            input_size=(200,),
426
            input_range=(1, vocab_size),
427
428
429
430
431
432
433
434
435
436
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
        punica_wrapper.update_metadata(
            lora_mapping,
            id_to_index,
            max_loras,
            vocab_size,
            lora_config.lora_extra_vocab_size,
        )
437
438
439
440
        original_inputs = deepcopy(inputs)

        # Force some of the inputs to be in the extended embeddings range
        # to guarantee that their behavior is tested.
441
442
443
        for input_, original_input_, lora_id in zip(
            inputs, original_inputs, prompt_mapping
        ):
444
            embedding_id = lora_id - 1
445
446
            input_[-1] = vocab_size + (embedding_id * embeddings_tensor_len)
            original_input_[-1] = vocab_size
447
            input_[-2] = vocab_size + ((embedding_id + 1) * embeddings_tensor_len - 1)
448
            original_input_[-2] = vocab_size + embeddings_tensor_len - 1
449

450
451
452
        expanded_embedding.weight[
            vocab_size : vocab_size + (embeddings_tensor_len * max_loras)
        ] = torch.cat(embeddings_tensors)
453
454
455

        lora_result = lora_embedding(torch.cat(original_inputs))

456
        expected_results: list[torch.Tensor] = []
457
458
459
        for input_, original_input_, lora_id in zip(
            inputs, original_inputs, prompt_mapping
        ):
460
461
462
463
            lora = lora_dict[lora_id]
            result = expanded_embedding(input_)
            after_a = F.embedding(
                original_input_,
464
                lora.lora_a.T,
465
            )
466
            result += after_a @ lora.lora_b.T
467
468
469
470
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
471
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
472
473
474
475
476
477
478
479
480

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
            lora_embedding.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=num_loras * 3,
481
            input_size=(200,),
482
            input_range=(1, vocab_size),
483
484
            device=device,
        )
485
        original_inputs = deepcopy(inputs)
486
487
488
489
490
491
492
493
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
        punica_wrapper.update_metadata(
            lora_mapping,
            id_to_index,
            max_loras,
            vocab_size,
            lora_config.lora_extra_vocab_size,
        )
494
495
496
497
        lora_result = lora_embedding(torch.cat(original_inputs))
        expected_result = expanded_embedding(torch.cat(inputs))

        rtol, atol = TOLERANCES[lora_result.dtype]
498
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
499
500
501


@torch.inference_mode()
502
@pytest.mark.parametrize("num_loras", [1, 2, 4])
503
@pytest.mark.parametrize("device", DEVICES)
504
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 256512])
505
@pytest.mark.parametrize("stage", STAGES)
506
507
508
def test_lm_head_logits_processor(
    dist_init, num_loras, device, vocab_size, stage
) -> None:
509
510
511
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)

512
    torch.set_default_device(device)
513
    max_loras = 8
514
    punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
515
    assert check_punica_wrapper(punica_wrapper)
516
517
518
    lora_config = LoRAConfig(
        max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16
    )
519

520
    def _pretest():
521
522
523
524
525
526
        linear = ParallelLMHead(
            vocab_size + lora_config.lora_extra_vocab_size,
            1024,
            vocab_size,
            params_dtype=torch.float16,
        )
527
        linear.weight.data = torch.rand_like(linear.weight.data)
528
        linear.weight.data[:, vocab_size:] = 0
529
        logits_processor = LogitsProcessor(
530
531
            vocab_size + lora_config.lora_extra_vocab_size, vocab_size
        )
532
        lora_logits_processor = LogitsProcessorWithLoRA(
533
534
            logits_processor, 1024, linear.weight.dtype, linear.weight.device, None
        )
535
        lora_logits_processor.create_lora_weights(max_loras, lora_config)
536

537
        return linear, logits_processor, lora_logits_processor
538

539
    for i in range(NUM_RANDOM_SEEDS):
540
541
542
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
543
        linear, logits_processor, lora_logits_processor = _pretest()
544
        lora_logits_processor.set_mapping(punica_wrapper)
545
546
547
        # NOTE: all the generated loras share the same embeddings tensor.
        lora_dict, _ = populate_loras(
            id_to_index,
548
            layer=lora_logits_processor,
549
550
551
552
553
554
555
556
557
558
559
            layer_weights=linear.weight,
            generate_embeddings_tensor=1024,
        )
        embeddings_tensor = list(lora_dict.values())[0].embeddings_tensor
        embeddings_tensor_len = embeddings_tensor.shape[0]

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=8 * num_loras,  # * 3,
            input_size=(1, 1024),
            input_range=(0, 1),
560
            input_type=torch.float16,
561
562
563
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
564
        punica_wrapper.update_metadata(
565
566
567
            lora_mapping,
            id_to_index,
            max_loras,
568
            vocab_size,
569
570
            lora_config.lora_extra_vocab_size,
        )
571
        input_ = torch.rand(20, 1024)
572

573
        lora_result = lora_logits_processor._get_logits(
574
575
            hidden_states=torch.cat(inputs), lm_head=linear, embedding_bias=None
        )
576

577
        original_lm_head = deepcopy(linear)
578

579
580
581
582
        linear.weight[
            logits_processor.org_vocab_size : logits_processor.org_vocab_size
            + embeddings_tensor_len
        ] = embeddings_tensor
583

584
        logits_processor.org_vocab_size = vocab_size + lora_config.lora_extra_vocab_size
585
        expected_results: list[torch.Tensor] = []
586
587
        for input_, lora_id in zip(inputs, prompt_mapping):
            lora = lora_dict[lora_id]
588
589
590
591
            result = logits_processor._get_logits(
                hidden_states=input_, lm_head=linear, embedding_bias=None
            )
            result[:, vocab_size + embeddings_tensor_len :] = float("-inf")
592
            result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
593
594
            expected_results.append(result)
        expected_result = torch.cat(expected_results)
595
        logits_processor.org_vocab_size = vocab_size
596
597
598
599

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
600
            lora_logits_processor.reset_lora(slot_idx)
601
602
603
604
605
606

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=8 * num_loras * 3,
            input_size=(1, 1024),
            input_range=(0, 1),
607
            input_type=torch.float16,
608
609
610
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
611
612
613
614
615
616
617
        punica_wrapper.update_metadata(
            lora_mapping,
            id_to_index,
            max_loras,
            vocab_size,
            lora_config.lora_extra_vocab_size,
        )
618
619
620

        lora_result = lora_logits_processor._get_logits(
            hidden_states=torch.cat(inputs),
621
            lm_head=original_lm_head,
622
623
            embedding_bias=None,
        )[:, :vocab_size]
624
625
        expected_result = logits_processor._get_logits(
            hidden_states=torch.cat(inputs),
626
            lm_head=original_lm_head,
627
628
            embedding_bias=None,
        )
629
630

        rtol, atol = TOLERANCES[lora_result.dtype]
631
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
632
633


634
@torch.inference_mode()
635
@pytest.mark.parametrize("num_loras", [1, 2, 4])
636
@pytest.mark.parametrize("device", DEVICES)
637
@pytest.mark.parametrize("stage", STAGES)
638
639
640
641
642
643
def test_linear_replicated(
    dist_init,
    num_loras,
    device,
    stage,
) -> None:
644
645
646
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)

647
    max_loras = 8
648
    torch.set_default_device(device)
649
    punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
650
    assert check_punica_wrapper(punica_wrapper)
651
652
653
654
655
    lora_config = LoRAConfig(
        max_loras=max_loras,
        max_lora_rank=8,
        lora_dtype=torch.float16,
    )
656
657

    def create_random_linear_replicated_layer():
658
        linear = ReplicatedLinear(4096, 4096, bias=False, params_dtype=torch.float16)
659
660
661
662
        linear.weight.data = torch.rand_like(linear.weight.data)
        lora_linear = ReplicatedLinearWithLoRA(linear)

        lora_linear.create_lora_weights(max_loras, lora_config)
663
664
665
666
667
668
        assert (
            lora_linear.n_slices
            == len(lora_linear.lora_a_stacked)
            == len(lora_linear.lora_b_stacked)
            == 1
        )
669
670
        return linear, lora_linear

671
    for i in range(NUM_RANDOM_SEEDS):
672
673
674
675
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
        linear, lora_linear = create_random_linear_replicated_layer()
676
        assert torch.equal(linear.weight, lora_linear.weight)
677
678
679
680
681
682
683
684
685
686
687
688
689
        lora_linear.set_mapping(punica_wrapper)
        lora_dict, _ = populate_loras(
            id_to_index,
            layer=lora_linear,
            layer_weights=linear.weight,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
            input_type=torch.float16,
690
691
692
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
693
694
695
696
697
698
699
700
701
702
        punica_wrapper.update_metadata(
            lora_mapping,
            id_to_index,
            max_loras,
            512,
            lora_config.lora_extra_vocab_size,
        )

        lora_result = lora_linear(torch.cat(inputs))[0]

703
        expected_results: list[torch.Tensor] = []
704
705
706
        for input_, lora_id in zip(inputs, prompt_mapping):
            lora = lora_dict[lora_id]
            result = linear(input_)[0]
707
            result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
708
709
710
711
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
712
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
713
714
715
716
717
718
719
720
721
722
723
724

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
            lora_linear.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
            input_type=torch.float16,
725
726
727
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
728

729
730
731
        punica_wrapper.update_metadata(
            lora_mapping, id_to_index, max_loras, 512, lora_config.lora_extra_vocab_size
        )
732
733
734
735
736

        lora_result = lora_linear(torch.cat(inputs))[0]
        expected_result = linear(torch.cat(inputs))[0]

        rtol, atol = TOLERANCES[lora_result.dtype]
737
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
738
739


740
@torch.inference_mode()
741
@pytest.mark.parametrize("num_loras", [1, 2, 4])
742
@pytest.mark.parametrize("orientation", ["row", "column"])
743
@pytest.mark.parametrize("fully_shard", [True, False])
744
@pytest.mark.parametrize("device", DEVICES)
745
@pytest.mark.parametrize("stage", STAGES)
746
747
748
def test_linear_parallel(
    dist_init, num_loras, orientation, fully_shard, device, stage
) -> None:
749
750
751
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)

752
    max_loras = 8
753
    torch.set_default_device(device)
754
    punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
755
    assert check_punica_wrapper(punica_wrapper)
756
757
758
759
760
761
    lora_config = LoRAConfig(
        max_loras=max_loras,
        max_lora_rank=8,
        fully_sharded_loras=fully_shard,
        lora_dtype=torch.float16,
    )
762
763
764

    def create_random_linear_parallel_layer():
        if orientation == "row":
765
766
767
            linear = RowParallelLinear(
                4096, 4096, bias=False, params_dtype=torch.float16
            )
768
            linear.weight.data = torch.rand_like(linear.weight.data)
769
770
771
772
773
            lora_linear = (
                RowParallelLinearWithLoRA(linear)
                if not fully_shard
                else RowParallelLinearWithShardedLoRA(linear)
            )
774
        else:
775
776
777
            linear = ColumnParallelLinear(
                4096, 4096, bias=False, params_dtype=torch.float16
            )
778
            linear.weight.data = torch.rand_like(linear.weight.data)
779
780
781
782
783
            lora_linear = (
                ColumnParallelLinearWithLoRA(linear)
                if not fully_shard
                else ColumnParallelLinearWithShardedLoRA(linear)
            )
784
        lora_linear.create_lora_weights(max_loras, lora_config)
785
786
787
788
789
790
        assert (
            lora_linear.n_slices
            == len(lora_linear.lora_a_stacked)
            == len(lora_linear.lora_b_stacked)
            == 1
        )
791

792
793
        return linear, lora_linear

794
    for i in range(NUM_RANDOM_SEEDS):
795
796
797
798
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
        linear, lora_linear = create_random_linear_parallel_layer()
799
        assert torch.equal(linear.weight, lora_linear.weight)
800
        lora_linear.set_mapping(punica_wrapper)
801
802
803
804
805
806
807
808
809
810
811
        lora_dict, _ = populate_loras(
            id_to_index,
            layer=lora_linear,
            layer_weights=linear.weight,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
812
            input_type=torch.float16,
813
814
815
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
816
        punica_wrapper.update_metadata(
817
818
819
820
821
822
823
824
825
            lora_mapping,
            id_to_index,
            max_loras,
            512,
            lora_config.lora_extra_vocab_size,
        )

        lora_result = lora_linear(torch.cat(inputs))[0]

826
        expected_results: list[torch.Tensor] = []
827
828
829
        for input_, lora_id in zip(inputs, prompt_mapping):
            lora = lora_dict[lora_id]
            result = linear(input_)[0]
830
            result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
831
832
833
834
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
835
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
836
837
838
839
840
841
842
843
844
845
846

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
            lora_linear.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
847
            input_type=torch.float16,
848
849
850
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
851

852
853
854
        punica_wrapper.update_metadata(
            lora_mapping, id_to_index, max_loras, 512, lora_config.lora_extra_vocab_size
        )
855
856
857
858
859

        lora_result = lora_linear(torch.cat(inputs))[0]
        expected_result = linear(torch.cat(inputs))[0]

        rtol, atol = TOLERANCES[lora_result.dtype]
860
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
861
862
863


@torch.inference_mode()
864
@pytest.mark.parametrize("num_loras", [1, 2, 4])
865
@pytest.mark.parametrize("repeats", [1, 2, 3])
866
@pytest.mark.parametrize("fully_shard", [True, False])
867
@pytest.mark.parametrize("device", DEVICES)
868
@pytest.mark.parametrize("stage", STAGES)
869
870
871
def test_column_parallel_packed(
    dist_init, num_loras, repeats, fully_shard, device, stage
) -> None:
872
873
874
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)

875
    max_loras = 8
876
    torch.set_default_device(device)
877
    punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
878
    assert check_punica_wrapper(punica_wrapper)
879
880
881
882
883
884
    lora_config = LoRAConfig(
        max_loras=max_loras,
        max_lora_rank=8,
        fully_sharded_loras=fully_shard,
        lora_dtype=torch.float16,
    )
885
886
887

    def create_column_parallel_packed_layer():
        if repeats == 2:
888
889
890
            linear = MergedColumnParallelLinear(
                4096, [4096] * repeats, bias=False, params_dtype=torch.float16
            )
891
            linear.weight.data = torch.rand_like(linear.weight.data)
892
893
894
895
896
            lora_linear = (
                MergedColumnParallelLinearWithLoRA(linear)
                if not fully_shard
                else MergedColumnParallelLinearWithShardedLoRA(linear)
            )
897
        elif repeats == 3:
898
899
900
            linear = QKVParallelLinear(
                4096, 64, 32, bias=False, params_dtype=torch.float16
            )
901
            linear.weight.data = torch.rand_like(linear.weight.data)
902
903
904
905
906
            lora_linear = (
                MergedQKVParallelLinearWithLoRA(linear)
                if not fully_shard
                else MergedQKVParallelLinearWithShardedLoRA(linear)
            )
907
        else:
908
909
910
            linear = QKVParallelLinear(
                4096, 64, 32, bias=False, params_dtype=torch.float16
            )
911
            linear.weight.data = torch.rand_like(linear.weight.data)
912
913
914
915
916
            lora_linear = (
                QKVParallelLinearWithLoRA(linear)
                if not fully_shard
                else QKVParallelLinearWithShardedLoRA(linear)
            )
917
918
919
920
921
922
923

        @dataclass
        class FakeConfig:
            hidden_size = 4096
            num_key_value_heads = 32
            num_attention_heads = 32

924
        n_slices = repeats
925
926
927
928
929
930
931
932
933
        lora_linear.create_lora_weights(
            max_loras, lora_config, model_config=FakeConfig()
        )
        assert (
            lora_linear.n_slices
            == len(lora_linear.lora_a_stacked)
            == len(lora_linear.lora_b_stacked)
            == n_slices
        )
934

935
936
        return linear, lora_linear

937
    for i in range(NUM_RANDOM_SEEDS):
938
939
940
941
942
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)

        linear, lora_linear = create_column_parallel_packed_layer()
943
        assert torch.equal(linear.weight, lora_linear.weight)
944
        lora_linear.set_mapping(punica_wrapper)
945
946
947
948
949
950
951
952
953
954
955
956
        lora_dict, sublora_dict = populate_loras(
            id_to_index,
            layer=lora_linear,
            layer_weights=linear.weight,
            repeats=repeats,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
957
            input_type=torch.float16,
958
959
960
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
961

962
        punica_wrapper.update_metadata(
963
964
965
966
967
968
969
970
971
            lora_mapping,
            id_to_index,
            max_loras,
            512,
            lora_config.lora_extra_vocab_size,
        )

        lora_result = lora_linear(torch.cat(inputs))[0]

972
        expected_results: list[torch.Tensor] = []
973
974
975
976
        for input_, lora_id in zip(inputs, prompt_mapping):
            result = linear(input_)[0]
            subloras = sublora_dict[lora_id]
            for i, sublora in enumerate(subloras):
977
978
979
                result[
                    :, sublora.lora_b.shape[0] * i : sublora.lora_b.shape[0] * (i + 1)
                ] += input_ @ sublora.lora_a.T @ sublora.lora_b.T * sublora.scaling
980
981
982
983
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
984
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
985
986
987
988
989
990
991
992
993

        for slot_idx in range(max_loras):
            lora_linear.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
994
            input_type=torch.float16,
995
996
997
            device=device,
        )
        lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
998

999
        punica_wrapper.update_metadata(
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
            lora_mapping,
            id_to_index,
            max_loras,
            512,
            lora_config.lora_extra_vocab_size,
        )

        lora_result = lora_linear(torch.cat(inputs))[0]
        expected_result = linear(torch.cat(inputs))[0]

        rtol, atol = TOLERANCES[lora_result.dtype]
1011
        torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol)
1012
1013


1014
@pytest.mark.parametrize("tp_size", [1, 2, 4, 8])
1015
@pytest.mark.parametrize(
1016
1017
    "seed", list(range(VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS))
)
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
def test_vocab_parallel_embedding_indices(tp_size, seed):
    random.seed(seed)
    vocab_size = random.randint(4000, 64000)
    added_vocab_size = random.randint(0, 1024)
    org_vocab_size = vocab_size - added_vocab_size
    last_org_vocab_end_index = 0
    last_added_vocab_end_index = org_vocab_size
    computed_vocab_size = 0
    computed_org_vocab_size = 0
    computed_added_vocab_size = 0
    vocab_size_padded = -1

1030
1031
1032
    all_org_tokens: list[int] = []
    all_added_tokens: list[int] = []
    token_ids: list[int] = []
1033
1034

    for tp_rank in range(tp_size):
1035
1036
        with (
            patch(
1037
                "vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank",
1038
1039
1040
                return_value=tp_rank,
            ),
            patch(
1041
                "vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_world_size",
1042
1043
1044
                return_value=tp_size,
            ),
        ):
1045
            vocab_embedding = VocabParallelEmbedding(
1046
1047
                vocab_size, 1, org_num_embeddings=org_vocab_size
            )
1048
1049
1050
1051
        vocab_size_padded = vocab_embedding.num_embeddings_padded
        shard_indices = vocab_embedding.shard_indices
        # Assert that the ranges are contiguous
        assert shard_indices.org_vocab_start_index == last_org_vocab_end_index
1052
        assert shard_indices.added_vocab_start_index == last_added_vocab_end_index
1053
1054
1055
1056
1057
1058
1059
1060

        # Ensure that we are not exceeding the vocab size
        computed_vocab_size += shard_indices.num_elements_padded
        computed_org_vocab_size += shard_indices.num_org_elements
        computed_added_vocab_size += shard_indices.num_added_elements

        # Ensure that the ranges are not overlapping
        all_org_tokens.extend(
1061
1062
1063
1064
            range(
                shard_indices.org_vocab_start_index, shard_indices.org_vocab_end_index
            )
        )
1065
        all_added_tokens.extend(
1066
1067
1068
1069
1070
            range(
                shard_indices.added_vocab_start_index,
                shard_indices.added_vocab_end_index,
            )
        )
1071
1072

        token_ids.extend(
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
            range(
                shard_indices.org_vocab_start_index, shard_indices.org_vocab_end_index
            )
        )
        token_ids.extend(
            [-1]
            * (shard_indices.num_org_elements_padded - shard_indices.num_org_elements)
        )
        token_ids.extend(
            range(
                shard_indices.added_vocab_start_index,
                shard_indices.added_vocab_end_index,
            )
        )
1087
        token_ids.extend(
1088
1089
1090
1091
1092
1093
            [-1]
            * (
                shard_indices.num_added_elements_padded
                - shard_indices.num_added_elements
            )
        )
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120

        last_org_vocab_end_index = shard_indices.org_vocab_end_index
        last_added_vocab_end_index = shard_indices.added_vocab_end_index

    assert computed_vocab_size == vocab_size_padded
    assert computed_org_vocab_size == org_vocab_size
    assert computed_added_vocab_size == added_vocab_size

    # Ensure that the ranges are not overlapping
    assert len(all_org_tokens) == len(set(all_org_tokens))
    assert len(all_added_tokens) == len(set(all_added_tokens))
    assert not set(all_org_tokens).intersection(set(all_added_tokens))

    token_ids_tensor = torch.tensor(token_ids, dtype=torch.long)
    reindex_mapping = vocab_embedding.get_sharded_to_full_mapping()
    assert reindex_mapping is not None or tp_size == 1
    if reindex_mapping is not None:
        reindexed_token_ids = token_ids_tensor[reindex_mapping]
        expected = torch.tensor(list(range(0, vocab_size)))
        assert reindexed_token_ids[:vocab_size].equal(expected)
        assert torch.all(reindexed_token_ids[vocab_size:] == -1)


def test_get_masked_input_and_mask():
    x = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])

    # base tp 1 case, no padding
1121
1122
1123
1124
1125
1126
1127
1128
    modified_x, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=0,
        org_vocab_end_index=8,
        added_vocab_start_index=8,
        added_vocab_end_index=12,
        num_org_vocab_padding=0,
    )
1129
1130
1131
    assert torch.equal(x, modified_x)

    # tp 2 case, no padding
1132
1133
1134
1135
1136
1137
1138
1139
    modified_x_rank_0, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=0,
        org_vocab_end_index=4,
        added_vocab_start_index=8,
        added_vocab_end_index=10,
        num_org_vocab_padding=0,
    )
1140
1141
1142
1143
1144
1145
    modified_x_rank_1, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=4,
        org_vocab_end_index=8,
        added_vocab_start_index=10,
        added_vocab_end_index=12,
1146
1147
1148
1149
1150
1151
1152
1153
        num_org_vocab_padding=0,
    )
    assert torch.equal(
        modified_x_rank_0, torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 0, 0])
    )
    assert torch.equal(
        modified_x_rank_1, torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 4, 5])
    )
1154
1155

    # tp 4 case, no padding
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
    modified_x_rank_0, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=0,
        org_vocab_end_index=2,
        added_vocab_start_index=8,
        added_vocab_end_index=9,
        num_org_vocab_padding=0,
    )
    modified_x_rank_1, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=2,
        org_vocab_end_index=4,
        added_vocab_start_index=9,
        added_vocab_end_index=10,
        num_org_vocab_padding=0,
    )
1172
1173
1174
1175
1176
1177
    modified_x_rank_2, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=4,
        org_vocab_end_index=6,
        added_vocab_start_index=10,
        added_vocab_end_index=11,
1178
1179
        num_org_vocab_padding=0,
    )
1180
1181
1182
1183
1184
1185
    modified_x_rank_3, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=6,
        org_vocab_end_index=8,
        added_vocab_start_index=11,
        added_vocab_end_index=12,
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
        num_org_vocab_padding=0,
    )
    assert torch.equal(
        modified_x_rank_0, torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0])
    )
    assert torch.equal(
        modified_x_rank_1, torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0])
    )
    assert torch.equal(
        modified_x_rank_2, torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0])
    )
    assert torch.equal(
        modified_x_rank_3, torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2])
    )
1200
1201

    # base tp 1 case, with padding
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
    modified_x, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=0,
        org_vocab_end_index=8,
        added_vocab_start_index=8,
        added_vocab_end_index=12,
        num_org_vocab_padding=2,
    )
    assert torch.equal(
        modified_x, torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13])
    )
1213
1214

    # tp 2 case, with padding
1215
1216
1217
1218
1219
1220
1221
1222
    modified_x_rank_0, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=0,
        org_vocab_end_index=4,
        added_vocab_start_index=8,
        added_vocab_end_index=10,
        num_org_vocab_padding=2,
    )
1223
1224
1225
1226
1227
1228
    modified_x_rank_1, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=4,
        org_vocab_end_index=8,
        added_vocab_start_index=10,
        added_vocab_end_index=12,
1229
1230
1231
1232
1233
1234
1235
1236
        num_org_vocab_padding=2,
    )
    assert torch.equal(
        modified_x_rank_0, torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 6, 7, 0, 0])
    )
    assert torch.equal(
        modified_x_rank_1, torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 6, 7])
    )
1237
1238

    # tp 4 case, with padding
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
    modified_x_rank_0, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=0,
        org_vocab_end_index=2,
        added_vocab_start_index=8,
        added_vocab_end_index=9,
        num_org_vocab_padding=2,
    )
    modified_x_rank_1, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=2,
        org_vocab_end_index=4,
        added_vocab_start_index=9,
        added_vocab_end_index=10,
        num_org_vocab_padding=2,
    )
1255
1256
1257
1258
1259
1260
    modified_x_rank_2, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=4,
        org_vocab_end_index=6,
        added_vocab_start_index=10,
        added_vocab_end_index=11,
1261
1262
        num_org_vocab_padding=2,
    )
1263
1264
1265
1266
1267
1268
    modified_x_rank_3, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=6,
        org_vocab_end_index=8,
        added_vocab_start_index=11,
        added_vocab_end_index=12,
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
        num_org_vocab_padding=2,
    )
    assert torch.equal(
        modified_x_rank_0, torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0])
    )
    assert torch.equal(
        modified_x_rank_1, torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 4, 0, 0])
    )
    assert torch.equal(
        modified_x_rank_2, torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 4, 0])
    )
    assert torch.equal(
        modified_x_rank_3, torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 4])
    )