test_layers.py 53.6 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
import random
from copy import deepcopy
from dataclasses import dataclass
6
from typing import Optional
7
from unittest.mock import patch
8

9
import pytest
10
11
12
13
import torch
import torch.nn.functional as F

from vllm.config import LoRAConfig
14
15
16
from vllm.lora.fully_sharded_layers import (
    ColumnParallelLinearWithShardedLoRA,
    MergedColumnParallelLinearWithShardedLoRA,
17
    MergedQKVParallelLinearWithShardedLoRA, QKVParallelLinearWithShardedLoRA,
18
    RowParallelLinearWithShardedLoRA)
19
20
# yapf conflicts with isort for this block
# yapf: disable
21
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
22
                              LinearScalingRotaryEmbeddingWithLoRA,
23
24
                              LogitsProcessorWithLoRA, LoRAMapping,
                              MergedColumnParallelLinearWithLoRA,
25
26
                              MergedQKVParallelLinearWithLoRA,
                              QKVParallelLinearWithLoRA,
27
                              ReplicatedLinearWithLoRA,
28
29
                              RowParallelLinearWithLoRA,
                              VocabParallelEmbeddingWithLoRA)
30
# yapf: enable
31
from vllm.lora.models import (LongContextLoRAContext, LoRALayerWeights,
32
                              PackedLoRALayerWeights)
33
from vllm.lora.punica_wrapper import get_punica_wrapper
34
35
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               MergedColumnParallelLinear,
36
                                               QKVParallelLinear,
37
                                               ReplicatedLinear,
38
39
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
40
from vllm.model_executor.layers.rotary_embedding import get_rope
41
from vllm.model_executor.layers.vocab_parallel_embedding import (
42
    ParallelLMHead, VocabParallelEmbedding, get_masked_input_and_mask)
43
from vllm.model_executor.utils import set_random_seed
44
from vllm.platforms import current_platform
45
46
47
48
49
50
51
52

from .utils import DummyLoRAManager

TOLERANCES = {
    torch.float16: (5e-3, 5e-3),
    torch.float32: (5e-3, 5e-3),
    torch.bfloat16: (3e-2, 2e-2),
}
53
54
55
56
57
58

pytestmark = pytest.mark.skipif(
    not (current_platform.is_cuda_alike() or current_platform.is_cpu()),
    reason="Backend not supported")

DEVICES = ([
59
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
60
] if current_platform.is_cuda_alike() else ["cpu"])
61

62
#For GPU, we will launch different triton kernels between the prefill and decode
63
64
# stages, so we need to verify this. prefill stage(True) or decode stage(False)
STAGES = [True, False]
65

66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# With the inclusion of V1 tests (look at the run_with_both_engines_lora),
# the tests in this file run twice, once with the V0 engine and then with
# the V1 engine.
# The NUM_RANDOM_SEEDS value was set to 10 before. It is cut to half
# with the inclusion of V1 tests to maintain the CI test times.
NUM_RANDOM_SEEDS = 5
# The VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS value was set to
# 256 before. It is cut to half with the inclusion of V1 tests to maintain
# the CI test times.
VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS = 128


@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
    # Simple autouse wrapper to run both engines for each test
    # This can be promoted up to conftest.py to run for every
    # test in a package

    # Release any memory we might be holding on to. CI runs OOMs otherwise.
    from vllm.lora.ops.triton_ops.utils import (_LORA_A_PTR_DICT,
                                                _LORA_B_PTR_DICT)
    _LORA_B_PTR_DICT.clear()
    _LORA_A_PTR_DICT.clear()

    yield

92
93
94

def get_random_id_to_index(num_loras: int,
                           num_slots: int,
95
                           log: bool = True) -> list[Optional[int]]:
96
97
98
99
100
101
102
103
104
105
106
107
108
109
    """Creates a random lora_id_to_index mapping.

    Args:
        num_loras: The number of active loras in the mapping.
        num_slots: The number of slots in the mapping. Must be larger
            than num_loras.
        log: Whether to log the output.
    """

    if num_loras > num_slots:
        raise ValueError(
            f"num_loras is higher than num_slots: {num_loras} > {num_slots}. "
            "num_loras must be less than or equal to num_slots.")

110
    slots: list[Optional[int]] = [None] * num_slots
111
112
113
114
115
116
117
118
119
120
121
    random_slot_selections = (torch.randperm(num_slots)[:num_loras]).tolist()
    for lora_id, slot_idx in enumerate(random_slot_selections, start=1):
        slots[slot_idx] = lora_id

    if log:
        print(f"Created lora_id_to_index mapping: {slots}.")

    return slots


def populate_loras(
122
    id_to_index: list[Optional[int]],
123
124
125
126
    layer: BaseLayerWithLoRA,
    layer_weights: torch.Tensor,
    generate_embeddings_tensor: int = 0,
    repeats: int = 1,
127
) -> tuple[dict[int, LoRALayerWeights], dict[int, list[LoRALayerWeights]]]:
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
    """This method populates the lora layers with lora weights.

    Args:
        id_to_index: a list of lora ids. The index of the lora id
            represents which memory slot the lora matrices are
            stored in. A None value indicates a free slot.
        layer: the LoRAlayer to populate.
        layer_weights: the PyTorch tensor containing the layer's
            weights.
        generate_embeddings_tensor: whether to generate an
            embeddings tensor for each LoRA.
        repeats: must only be set for column parallel packed
            layers. Indicates the number of loras to compose
            together to create a single lora layer.
    """

    # Dictionary that maps the lora ID to the
    # corresponding lora weights.
146
    lora_dict: dict[int, LoRALayerWeights] = dict()
147
148

    # Dictionary that maps the lora ID to the
149
    # corresponding subloras.
150
    sublora_dict: dict[int, list[LoRALayerWeights]] = dict()
151
152
153

    for slot_idx, lora_id in enumerate(id_to_index):
        if lora_id is not None:
154
            subloras: list[LoRALayerWeights] = []
155
156
            sublora_len = layer_weights.shape[0] // repeats
            for i in range(repeats):
157
158
159
160
161
162
                sublora = DummyLoRAManager(
                    layer_weights.device).init_random_lora(
                        module_name=f"fake_{i}",
                        weight=layer_weights,
                        generate_embeddings_tensor=generate_embeddings_tensor,
                    )
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
                sublora.lora_b = sublora.lora_b[:, (sublora_len *
                                                    i):(sublora_len * (i + 1))]
                sublora.optimize()
                subloras.append(sublora)

            lora = PackedLoRALayerWeights.pack(
                subloras) if repeats > 1 else subloras[0]

            layer.set_lora(
                slot_idx,
                lora_a=lora.lora_a,
                lora_b=lora.lora_b,
                embeddings_tensor=lora.embeddings_tensor,
            )

            lora_dict[lora_id] = lora
            sublora_dict[lora_id] = subloras

    return lora_dict, sublora_dict


def create_random_inputs(
185
    active_lora_ids: list[int],
186
    num_inputs: int,
187
188
    input_size: tuple[int, ...],
    input_range: tuple[float, float],
189
    input_type: torch.dtype = torch.int,
190
    device: torch.device = "cuda"
191
) -> tuple[list[torch.Tensor], list[int], list[int]]:
192
193
194
195
196
197
198
199
200
201
202
203
204
    """Creates random inputs.

    Args:
        active_lora_ids: lora IDs of active lora weights.
        num_inputs: the number of inputs to create.
        input_size: the size of each individual input.
        input_range: the range of values to include in the input.
            input_range[0] <= possible input values < input_range[1]
        input_type: the type of values in the input.
    """

    low, high = input_range

205
206
207
    inputs: list[torch.Tensor] = []
    index_mapping: list[int] = []
    prompt_mapping: list[int] = []
208

209
210
211
    for _ in range(num_inputs):
        if input_type == torch.int:
            inputs.append(
212
213
214
215
                torch.randint(low=int(low),
                              high=int(high),
                              size=input_size,
                              device=device))
216
217
        else:
            inputs.append(
218
219
                torch.rand(size=input_size, dtype=input_type, device=device) *
                high + low)
220
221
222
223
224
225
226
227

        lora_id = random.choice(active_lora_ids)
        index_mapping += [lora_id] * input_size[0]
        prompt_mapping += [lora_id]

    return inputs, index_mapping, prompt_mapping


228
229
230
231
232
def check_punica_wrapper(punica_wrapper) -> bool:
    if current_platform.is_cuda_alike():
        from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU

        return type(punica_wrapper) is PunicaWrapperGPU
233
234
235
236
    elif current_platform.is_cpu():
        from vllm.lora.punica_wrapper.punica_cpu import PunicaWrapperCPU

        return type(punica_wrapper) is PunicaWrapperCPU
237
238
239
240
    else:
        return False


241
242
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
243
@pytest.mark.parametrize("device", DEVICES)
244
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
245
246
@pytest.mark.parametrize("stage", STAGES)
def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
247
248
249
    # For multi-GPU testing of Triton kernel, we must explicitly set the CUDA
    # device, see: https://github.com/triton-lang/triton/issues/2925
    # Same below.
250
251
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)
252

253
    torch.set_default_device(device)
254
    max_loras = 8
255
    punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
256
    assert check_punica_wrapper(punica_wrapper)
257
258
259
260
261
    lora_config = LoRAConfig(max_loras=max_loras,
                             max_lora_rank=8,
                             lora_dtype=torch.float16)

    def create_random_embedding_layer():
262
        embedding = VocabParallelEmbedding(vocab_size, 256)
263
        embedding.weight.data = torch.rand_like(embedding.weight.data)
264
        embedding.weight.data[vocab_size:, :] = 0
265
266
267
268
269
        lora_embedding = VocabParallelEmbeddingWithLoRA(embedding)
        lora_embedding.create_lora_weights(max_loras, lora_config)

        return embedding, lora_embedding

270
    for i in range(NUM_RANDOM_SEEDS):
271
272
273
274
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
        embedding, lora_embedding = create_random_embedding_layer()
275
        lora_embedding.set_mapping(punica_wrapper)
276
277
278
279
280
281
282
283
284
285
        lora_dict, _ = populate_loras(
            id_to_index,
            layer=lora_embedding,
            layer_weights=embedding.weight.T,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=num_loras * 3,
            input_size=(200, ),
286
            input_range=(1, vocab_size),
287
            device=device)
288
289
290
291
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
        punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
292
293
                                       vocab_size,
                                       lora_config.lora_extra_vocab_size)
294
295
296

        lora_result = lora_embedding(torch.cat(inputs))

297
        expected_results: list[torch.Tensor] = []
298
299
300
301
302
303
304
305
306
307
308
309
        for input_, lora_id in zip(inputs, prompt_mapping):
            lora = lora_dict[lora_id]
            result = embedding(input_)
            after_a = F.embedding(
                input_,
                lora.lora_a,
            )
            result += (after_a @ lora.lora_b)
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
310
311
312
313
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
314
315
316
317
318
319
320
321
322
323

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
            lora_embedding.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=num_loras * 3,
            input_size=(200, ),
324
            input_range=(1, vocab_size),
325
            device=device)
326
327
328
329
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
        punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
330
331
                                       vocab_size,
                                       lora_config.lora_extra_vocab_size)
332
333
334
335
336

        lora_result = lora_embedding(torch.cat(inputs))
        expected_result = embedding(torch.cat(inputs))

        rtol, atol = TOLERANCES[lora_result.dtype]
337
338
339
340
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
341
342
343


@torch.inference_mode()
344
345
# @pytest.mark.skip(
#     reason="Fails when loras are in any slot other than the first.")
346
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
347
@pytest.mark.parametrize("device", DEVICES)
348
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
349
@pytest.mark.parametrize("stage", STAGES)
350
def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
351
                                        vocab_size, stage) -> None:
352

353
354
355
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)

356
    torch.set_default_device(device)
357
    max_loras = 8
358
    punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
359
    assert check_punica_wrapper(punica_wrapper)
360
361
362
363
364
    lora_config = LoRAConfig(max_loras=max_loras,
                             max_lora_rank=8,
                             lora_dtype=torch.float16)

    def create_random_embedding_layer():
365
        embedding = VocabParallelEmbedding(vocab_size, 256)
366
367
        embedding_data = torch.rand_like(embedding.weight.data)
        embedding.weight.data = embedding_data
368
        embedding.weight.data[vocab_size:, :] = 0
369
        expanded_embedding = VocabParallelEmbedding(
370
            vocab_size + lora_config.lora_extra_vocab_size * max_loras,
371
            256,
372
373
            org_num_embeddings=vocab_size)
        expanded_embedding.weight.data[:vocab_size, :] = embedding_data
374
        # We need to deepcopy the embedding as it will be modified
375
376
377
378
379
380
381
        # in place
        lora_embedding = VocabParallelEmbeddingWithLoRA(
            deepcopy(expanded_embedding))
        lora_embedding.create_lora_weights(max_loras, lora_config)

        return expanded_embedding, lora_embedding

382
    for i in range(NUM_RANDOM_SEEDS):
383
384
385
386
387
388
389
390
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
        expanded_embedding, lora_embedding = create_random_embedding_layer()
        lora_dict, _ = populate_loras(
            id_to_index,
            layer=lora_embedding,
            layer_weights=torch.zeros(
391
                (256, vocab_size + lora_config.lora_extra_vocab_size)),
392
393
394
            generate_embeddings_tensor=256,
        )

395
        lora_embedding.set_mapping(punica_wrapper)
396
397
398
399
400
401
402
403
        # All embeddings tensors have the same shape.
        embeddings_tensors = [
            lora_dict[id].embeddings_tensor for id in sorted(lora_dict.keys())
        ]
        embeddings_tensor_len = embeddings_tensors[0].shape[0]

        # Add empty embeddings_tensors for unoccupied lora slots.
        for _ in range(max_loras - len(embeddings_tensors)):
404
            embeddings_tensors.append(torch.zeros(embeddings_tensors[0].shape))
405
406
407
408
409

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=num_loras * 3,
            input_size=(200, ),
410
            input_range=(1, vocab_size),
411
            device=device)
412
413
414
415
416
417
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
        punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
                                       vocab_size,
                                       lora_config.lora_extra_vocab_size)
418
419
420
421
422
423
424
        original_inputs = deepcopy(inputs)

        # Force some of the inputs to be in the extended embeddings range
        # to guarantee that their behavior is tested.
        for input_, original_input_, lora_id in zip(inputs, original_inputs,
                                                    prompt_mapping):
            embedding_id = lora_id - 1
425
426
427
428
429
            input_[-1] = vocab_size + (embedding_id * embeddings_tensor_len)
            original_input_[-1] = vocab_size
            input_[-2] = vocab_size + (
                (embedding_id + 1) * embeddings_tensor_len - 1)
            original_input_[-2] = vocab_size + embeddings_tensor_len - 1
430

431
        expanded_embedding.weight[vocab_size:vocab_size +
432
433
434
435
436
                                  (embeddings_tensor_len *
                                   max_loras)] = torch.cat(embeddings_tensors)

        lora_result = lora_embedding(torch.cat(original_inputs))

437
        expected_results: list[torch.Tensor] = []
438
439
440
441
442
443
444
445
446
447
448
449
450
        for input_, original_input_, lora_id in zip(inputs, original_inputs,
                                                    prompt_mapping):
            lora = lora_dict[lora_id]
            result = expanded_embedding(input_)
            after_a = F.embedding(
                original_input_,
                lora.lora_a,
            )
            result += (after_a @ lora.lora_b)
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
451
452
453
454
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
455
456
457
458
459
460
461
462
463
464

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
            lora_embedding.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=num_loras * 3,
            input_size=(200, ),
465
            input_range=(1, vocab_size),
466
            device=device)
467
        original_inputs = deepcopy(inputs)
468
469
470
471
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
        punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
472
473
                                       vocab_size,
                                       lora_config.lora_extra_vocab_size)
474
475
476
477
        lora_result = lora_embedding(torch.cat(original_inputs))
        expected_result = expanded_embedding(torch.cat(inputs))

        rtol, atol = TOLERANCES[lora_result.dtype]
478
479
480
481
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
482
483
484
485


@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
486
@pytest.mark.parametrize("device", DEVICES)
487
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 256512])
488
489
490
@pytest.mark.parametrize("stage", STAGES)
def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
                                  stage) -> None:
491

492
493
494
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)

495
    torch.set_default_device(device)
496
    max_loras = 8
497
    punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
498
    assert check_punica_wrapper(punica_wrapper)
499
500
501
502
    lora_config = LoRAConfig(max_loras=max_loras,
                             max_lora_rank=8,
                             lora_dtype=torch.float16)

503
    def _pretest():
504
        linear = ParallelLMHead(vocab_size + lora_config.lora_extra_vocab_size,
505
506
507
                                1024,
                                vocab_size,
                                params_dtype=torch.float16)
508
        linear.weight.data = torch.rand_like(linear.weight.data)
509
        linear.weight.data[:, vocab_size:] = 0
510
        logits_processor = LogitsProcessor(
511
            vocab_size + lora_config.lora_extra_vocab_size, vocab_size)
512
        lora_logits_processor = LogitsProcessorWithLoRA(
513
514
            logits_processor, 1024, linear.weight.dtype, linear.weight.device,
            None)
515
        lora_logits_processor.create_lora_weights(max_loras, lora_config)
516

517
        return linear, logits_processor, lora_logits_processor
518

519
    for i in range(NUM_RANDOM_SEEDS):
520
521
522
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
523
        linear, logits_processor, lora_logits_processor = _pretest()
524
        lora_logits_processor.set_mapping(punica_wrapper)
525
526
527
        # NOTE: all the generated loras share the same embeddings tensor.
        lora_dict, _ = populate_loras(
            id_to_index,
528
            layer=lora_logits_processor,
529
530
531
532
533
534
535
536
537
538
539
            layer_weights=linear.weight,
            generate_embeddings_tensor=1024,
        )
        embeddings_tensor = list(lora_dict.values())[0].embeddings_tensor
        embeddings_tensor_len = embeddings_tensor.shape[0]

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=8 * num_loras,  # * 3,
            input_size=(1, 1024),
            input_range=(0, 1),
540
            input_type=torch.float16,
541
            device=device)
542
543
544
545
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
        punica_wrapper.update_metadata(
546
547
548
            lora_mapping,
            id_to_index,
            max_loras,
549
            vocab_size,
550
551
            lora_config.lora_extra_vocab_size,
        )
552
        input_ = torch.rand(20, 1024)
553

554
555
        lora_result = lora_logits_processor._get_logits(
            hidden_states=torch.cat(inputs),
556
            lm_head=linear,
557
            embedding_bias=None)
558

559
        original_lm_head = deepcopy(linear)
560

561
562
        linear.weight[logits_processor.
                      org_vocab_size:logits_processor.org_vocab_size +
563
564
                      embeddings_tensor_len] = embeddings_tensor

565
        logits_processor.org_vocab_size = (vocab_size +
566
                                           lora_config.lora_extra_vocab_size)
567
        expected_results: list[torch.Tensor] = []
568
569
        for input_, lora_id in zip(inputs, prompt_mapping):
            lora = lora_dict[lora_id]
570
            result = logits_processor._get_logits(hidden_states=input_,
571
                                                  lm_head=linear,
572
                                                  embedding_bias=None)
573
            result[:, vocab_size + embeddings_tensor_len:] = float("-inf")
574
575
576
            result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
            expected_results.append(result)
        expected_result = torch.cat(expected_results)
577
        logits_processor.org_vocab_size = vocab_size
578
579
580
581

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
582
            lora_logits_processor.reset_lora(slot_idx)
583
584
585
586
587
588

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=8 * num_loras * 3,
            input_size=(1, 1024),
            input_range=(0, 1),
589
            input_type=torch.float16,
590
            device=device)
591
592
593
594
595
596
597
598
599
600
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
        punica_wrapper.update_metadata(
            lora_mapping,
            id_to_index,
            max_loras,
            vocab_size,
            lora_config.lora_extra_vocab_size,
        )
601
602
603

        lora_result = lora_logits_processor._get_logits(
            hidden_states=torch.cat(inputs),
604
            lm_head=original_lm_head,
605
            embedding_bias=None)[:, :vocab_size]
606
607
        expected_result = logits_processor._get_logits(
            hidden_states=torch.cat(inputs),
608
            lm_head=original_lm_head,
609
            embedding_bias=None)
610
611

        rtol, atol = TOLERANCES[lora_result.dtype]
612
613
614
615
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
616
617


618
619
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
620
@pytest.mark.parametrize("device", DEVICES)
621
@pytest.mark.parametrize("stage", STAGES)
622
623
624
@pytest.mark.parametrize("bias_enabled", [True, False])
def test_linear_replicated(dist_init, num_loras, device, stage,
                           bias_enabled) -> None:
625

626
627
628
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)

629
    max_loras = 8
630
    torch.set_default_device(device)
631
    punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
632
    assert check_punica_wrapper(punica_wrapper)
633
634
    lora_config = LoRAConfig(max_loras=max_loras,
                             max_lora_rank=8,
635
636
                             lora_dtype=torch.float16,
                             bias_enabled=bias_enabled)
637
638
639
640
641
642
643
644
645
646
647

    def create_random_linear_replicated_layer():

        linear = ReplicatedLinear(4096,
                                  4096,
                                  bias=False,
                                  params_dtype=torch.float16)
        linear.weight.data = torch.rand_like(linear.weight.data)
        lora_linear = ReplicatedLinearWithLoRA(linear)

        lora_linear.create_lora_weights(max_loras, lora_config)
648
649
650
651
652
653
        assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len(
            lora_linear.lora_b_stacked) == 1)
        if bias_enabled:
            assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices
        else:
            assert lora_linear.lora_bias_stacked is None
654
655
        return linear, lora_linear

656
    for i in range(NUM_RANDOM_SEEDS):
657
658
659
660
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
        linear, lora_linear = create_random_linear_replicated_layer()
661
        assert torch.equal(linear.weight, lora_linear.weight)
662
663
664
665
666
667
668
669
670
671
672
673
674
        lora_linear.set_mapping(punica_wrapper)
        lora_dict, _ = populate_loras(
            id_to_index,
            layer=lora_linear,
            layer_weights=linear.weight,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
            input_type=torch.float16,
675
            device=device)
676
677
678
679
680
681
682
683
684
685
686
687
688
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
        punica_wrapper.update_metadata(
            lora_mapping,
            id_to_index,
            max_loras,
            512,
            lora_config.lora_extra_vocab_size,
        )

        lora_result = lora_linear(torch.cat(inputs))[0]

689
        expected_results: list[torch.Tensor] = []
690
691
692
693
694
695
696
697
        for input_, lora_id in zip(inputs, prompt_mapping):
            lora = lora_dict[lora_id]
            result = linear(input_)[0]
            result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
698
699
700
701
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
702
703
704
705
706
707
708
709
710
711
712
713

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
            lora_linear.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
            input_type=torch.float16,
714
            device=device)
715
716
717
718
719
720
721
722
723
724
725
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)

        punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
                                       512, lora_config.lora_extra_vocab_size)

        lora_result = lora_linear(torch.cat(inputs))[0]
        expected_result = linear(torch.cat(inputs))[0]

        rtol, atol = TOLERANCES[lora_result.dtype]
726
727
728
729
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
730
731


732
733
734
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("orientation", ["row", "column"])
735
@pytest.mark.parametrize("fully_shard", [True, False])
736
@pytest.mark.parametrize("device", DEVICES)
737
@pytest.mark.parametrize("stage", STAGES)
738
@pytest.mark.parametrize("bias_enabled", [True, False])
739
def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
740
                         device, stage, bias_enabled) -> None:
741

742
743
744
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)

745
    max_loras = 8
746
    torch.set_default_device(device)
747
    punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
748
    assert check_punica_wrapper(punica_wrapper)
749
750
    lora_config = LoRAConfig(max_loras=max_loras,
                             max_lora_rank=8,
751
                             fully_sharded_loras=fully_shard,
752
753
                             lora_dtype=torch.float16,
                             bias_enabled=bias_enabled)
754
755
756

    def create_random_linear_parallel_layer():
        if orientation == "row":
757
758
759
760
            linear = RowParallelLinear(4096,
                                       4096,
                                       bias=False,
                                       params_dtype=torch.float16)
761
            linear.weight.data = torch.rand_like(linear.weight.data)
762
763
            lora_linear = (RowParallelLinearWithLoRA(linear) if not fully_shard
                           else RowParallelLinearWithShardedLoRA(linear))
764
        else:
765
766
767
768
            linear = ColumnParallelLinear(4096,
                                          4096,
                                          bias=False,
                                          params_dtype=torch.float16)
769
            linear.weight.data = torch.rand_like(linear.weight.data)
770
771
772
            lora_linear = (ColumnParallelLinearWithLoRA(linear)
                           if not fully_shard else
                           ColumnParallelLinearWithShardedLoRA(linear))
773
        lora_linear.create_lora_weights(max_loras, lora_config)
774
775
776
777
778
779
        assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len(
            lora_linear.lora_b_stacked) == 1)
        if bias_enabled:
            assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices
        else:
            assert lora_linear.lora_bias_stacked is None
780
781
        return linear, lora_linear

782
    for i in range(NUM_RANDOM_SEEDS):
783
784
785
786
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
        linear, lora_linear = create_random_linear_parallel_layer()
787
        assert torch.equal(linear.weight, lora_linear.weight)
788
        lora_linear.set_mapping(punica_wrapper)
789
790
791
792
793
794
795
796
797
798
799
        lora_dict, _ = populate_loras(
            id_to_index,
            layer=lora_linear,
            layer_weights=linear.weight,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
800
            input_type=torch.float16,
801
            device=device)
802
803
804
805
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
        punica_wrapper.update_metadata(
806
807
808
809
810
811
812
813
814
            lora_mapping,
            id_to_index,
            max_loras,
            512,
            lora_config.lora_extra_vocab_size,
        )

        lora_result = lora_linear(torch.cat(inputs))[0]

815
        expected_results: list[torch.Tensor] = []
816
817
818
819
820
821
822
823
        for input_, lora_id in zip(inputs, prompt_mapping):
            lora = lora_dict[lora_id]
            result = linear(input_)[0]
            result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
824
825
826
827
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
828
829
830
831
832
833
834
835
836
837
838

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
            lora_linear.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
839
            input_type=torch.float16,
840
            device=device)
841
842
843
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
844

845
        punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
846
847
848
849
850
851
                                       512, lora_config.lora_extra_vocab_size)

        lora_result = lora_linear(torch.cat(inputs))[0]
        expected_result = linear(torch.cat(inputs))[0]

        rtol, atol = TOLERANCES[lora_result.dtype]
852
853
854
855
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
856
857
858
859


@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
860
@pytest.mark.parametrize("repeats", [1, 2, 3])
861
@pytest.mark.parametrize("fully_shard", [True, False])
862
@pytest.mark.parametrize("device", DEVICES)
863
@pytest.mark.parametrize("stage", STAGES)
864
@pytest.mark.parametrize("bias_enabled", [True, False])
865
def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
866
                                device, stage, bias_enabled) -> None:
867

868
869
870
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)

871
    max_loras = 8
872
    torch.set_default_device(device)
873
    punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
874
    assert check_punica_wrapper(punica_wrapper)
875
876
    lora_config = LoRAConfig(max_loras=max_loras,
                             max_lora_rank=8,
877
                             fully_sharded_loras=fully_shard,
878
879
                             lora_dtype=torch.float16,
                             bias_enabled=bias_enabled)
880
881
882
883

    def create_column_parallel_packed_layer():
        if repeats == 2:
            linear = MergedColumnParallelLinear(4096, [4096] * repeats,
884
885
                                                bias=False,
                                                params_dtype=torch.float16)
886
            linear.weight.data = torch.rand_like(linear.weight.data)
887
888
889
            lora_linear = (MergedColumnParallelLinearWithLoRA(linear)
                           if not fully_shard else
                           MergedColumnParallelLinearWithShardedLoRA(linear))
890
        elif repeats == 3:
891
892
893
894
895
            linear = QKVParallelLinear(4096,
                                       64,
                                       32,
                                       bias=False,
                                       params_dtype=torch.float16)
896
            linear.weight.data = torch.rand_like(linear.weight.data)
897
            lora_linear = (MergedQKVParallelLinearWithLoRA(linear)
898
                           if not fully_shard else
899
                           MergedQKVParallelLinearWithShardedLoRA(linear))
900
        else:
901
902
903
904
905
            linear = QKVParallelLinear(4096,
                                       64,
                                       32,
                                       bias=False,
                                       params_dtype=torch.float16)
906
            linear.weight.data = torch.rand_like(linear.weight.data)
907
            lora_linear = QKVParallelLinearWithLoRA(
908
                linear
909
            ) if not fully_shard else QKVParallelLinearWithShardedLoRA(linear)
910
911
912
913
914
915
916

        @dataclass
        class FakeConfig:
            hidden_size = 4096
            num_key_value_heads = 32
            num_attention_heads = 32

917
        n_slices = repeats
918
919
920
        lora_linear.create_lora_weights(max_loras,
                                        lora_config,
                                        model_config=FakeConfig())
921
922
923
924
925
926
        assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len(
            lora_linear.lora_b_stacked) == n_slices)
        if bias_enabled:
            assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices
        else:
            assert lora_linear.lora_bias_stacked is None
927
928
        return linear, lora_linear

929
    for i in range(NUM_RANDOM_SEEDS):
930
931
932
933
934
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)

        linear, lora_linear = create_column_parallel_packed_layer()
935
        assert torch.equal(linear.weight, lora_linear.weight)
936
        lora_linear.set_mapping(punica_wrapper)
937
938
939
940
941
942
943
944
945
946
947
948
        lora_dict, sublora_dict = populate_loras(
            id_to_index,
            layer=lora_linear,
            layer_weights=linear.weight,
            repeats=repeats,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
949
            input_type=torch.float16,
950
            device=device)
951
952
953
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
954

955
        punica_wrapper.update_metadata(
956
957
958
959
960
961
962
963
964
            lora_mapping,
            id_to_index,
            max_loras,
            512,
            lora_config.lora_extra_vocab_size,
        )

        lora_result = lora_linear(torch.cat(inputs))[0]

965
        expected_results: list[torch.Tensor] = []
966
967
968
969
        for input_, lora_id in zip(inputs, prompt_mapping):
            result = linear(input_)[0]
            subloras = sublora_dict[lora_id]
            for i, sublora in enumerate(subloras):
970
971
972
                result[:, sublora.lora_b.shape[1] * i:sublora.lora_b.shape[1] *
                       (i + 1)] += (input_ @ sublora.lora_a @ sublora.lora_b *
                                    sublora.scaling)
973
974
975
976
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
977
978
979
980
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
981
982
983
984
985
986
987
988
989

        for slot_idx in range(max_loras):
            lora_linear.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
990
            input_type=torch.float16,
991
            device=device)
992
993
994
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
995

996
        punica_wrapper.update_metadata(
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
            lora_mapping,
            id_to_index,
            max_loras,
            512,
            lora_config.lora_extra_vocab_size,
        )

        lora_result = lora_linear(torch.cat(inputs))[0]
        expected_result = linear(torch.cat(inputs))[0]

        rtol, atol = TOLERANCES[lora_result.dtype]
1008
1009
1010
1011
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023


@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 8])
@pytest.mark.parametrize("device", ["cuda"])
@pytest.mark.parametrize("scaling_factors", [(1.0, ), (4.0, ), (4.0, 8.0),
                                             (6.0, 1.0)])
@pytest.mark.parametrize("max_position", [11, 4096, 32768])
@pytest.mark.parametrize("is_neox_style", [True, False])
@pytest.mark.parametrize("rotary_dim", [None, 32])
@pytest.mark.parametrize("head_size", [32, 108])
@pytest.mark.parametrize("seq_len", [11, 1024])
1024
1025
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
                    reason="Only CUDA backends are supported")
1026
1027
1028
1029
1030
def test_rotary_embedding_long_context(dist_init, num_loras, device,
                                       scaling_factors, max_position,
                                       is_neox_style, rotary_dim, head_size,
                                       seq_len) -> None:
    dtype = torch.float16
1031
    max_loras = 8
1032
    seed = 0
1033
    current_platform.seed_everything(seed)
1034
    torch.set_default_device(device)
1035
    punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
1036
    assert check_punica_wrapper(punica_wrapper)
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
    lora_config = LoRAConfig(max_loras=max_loras,
                             max_lora_rank=8,
                             long_lora_scaling_factors=scaling_factors,
                             lora_dtype=dtype)

    if rotary_dim is None:
        rotary_dim = head_size
    base = 10000
    batch_size = 5 * num_loras
    num_heads = 7

    # Verify lora is equivalent to linear scaling rotary embedding.
    rope = get_rope(
        head_size,
        rotary_dim,
        max_position,
        base,
        is_neox_style,
    )
1056
    lora_rope = LinearScalingRotaryEmbeddingWithLoRA(rope)
1057
    lora_rope.set_mapping(punica_wrapper)
1058
1059
1060
    lora_rope.create_lora_weights(max_loras, lora_config)
    linear_rope = get_rope(head_size, rotary_dim, max_position, base,
                           is_neox_style, {
1061
                               "rope_type": "linear",
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
                               "factor": scaling_factors
                           })
    linear_rope = linear_rope.to(dtype=dtype)
    id_to_index = get_random_id_to_index(num_loras, max_loras)
    _, index_mapping, prompt_mapping = create_random_inputs(
        active_lora_ids=[0],
        num_inputs=batch_size,
        input_size=(1, max_position),
        input_range=(0, lora_config.lora_extra_vocab_size),
        input_type=torch.float16,
1072
        device=device)
1073

1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
    lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
    long_lora_context = LongContextLoRAContext(list(scaling_factors),
                                               rotary_dim)

    next_expected_offset = 0
    # Make sure the offset is correct.
    scaling_factor_to_offset = lora_rope.scaling_factor_to_offset
    for scaling_factor, offset in scaling_factor_to_offset.items():
        assert offset == next_expected_offset
        next_expected_offset += scaling_factor * max_position

    for i in range(len(scaling_factors)):
        long_lora_context.offsets_by_lora_id[i] = scaling_factor_to_offset.get(
            scaling_factors[i], 0)
1088
    punica_wrapper.update_metadata(
1089
1090
1091
1092
1093
1094
1095
        lora_mapping,
        id_to_index,
        max_loras,
        512,
        lora_config.lora_extra_vocab_size,
        long_lora_context=long_lora_context,
    )
1096
    # lora_rope.set_mapping(*mapping_info)
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108

    positions = torch.randint(0, max_position, (batch_size, seq_len))
    query = torch.randn(batch_size,
                        seq_len,
                        num_heads * head_size,
                        dtype=dtype)
    key = torch.randn_like(query)
    ref_q, ref_k = linear_rope(positions, query, key)
    actual_q, actual_k = lora_rope(positions, query, key)

    torch.allclose(ref_q, actual_q)
    torch.allclose(ref_k, actual_k)
1109
1110
1111


@pytest.mark.parametrize("tp_size", [1, 2, 4, 8])
1112
1113
@pytest.mark.parametrize(
    "seed", list(range(VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS)))
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
def test_vocab_parallel_embedding_indices(tp_size, seed):
    random.seed(seed)
    vocab_size = random.randint(4000, 64000)
    added_vocab_size = random.randint(0, 1024)
    org_vocab_size = vocab_size - added_vocab_size
    last_org_vocab_end_index = 0
    last_added_vocab_end_index = org_vocab_size
    computed_vocab_size = 0
    computed_org_vocab_size = 0
    computed_added_vocab_size = 0
    vocab_size_padded = -1

1126
1127
1128
    all_org_tokens: list[int] = []
    all_added_tokens: list[int] = []
    token_ids: list[int] = []
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322

    for tp_rank in range(tp_size):
        with patch(
                "vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank",
                return_value=tp_rank
        ), patch(
                "vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_world_size",
                return_value=tp_size):
            vocab_embedding = VocabParallelEmbedding(
                vocab_size, 1, org_num_embeddings=org_vocab_size)
        vocab_size_padded = vocab_embedding.num_embeddings_padded
        shard_indices = vocab_embedding.shard_indices
        # Assert that the ranges are contiguous
        assert shard_indices.org_vocab_start_index == last_org_vocab_end_index
        assert (shard_indices.added_vocab_start_index ==
                last_added_vocab_end_index)

        # Ensure that we are not exceeding the vocab size
        computed_vocab_size += shard_indices.num_elements_padded
        computed_org_vocab_size += shard_indices.num_org_elements
        computed_added_vocab_size += shard_indices.num_added_elements

        # Ensure that the ranges are not overlapping
        all_org_tokens.extend(
            range(shard_indices.org_vocab_start_index,
                  shard_indices.org_vocab_end_index))
        all_added_tokens.extend(
            range(shard_indices.added_vocab_start_index,
                  shard_indices.added_vocab_end_index))

        token_ids.extend(
            range(shard_indices.org_vocab_start_index,
                  shard_indices.org_vocab_end_index))
        token_ids.extend([-1] * (shard_indices.num_org_elements_padded -
                                 shard_indices.num_org_elements))
        token_ids.extend(
            range(shard_indices.added_vocab_start_index,
                  shard_indices.added_vocab_end_index))
        token_ids.extend([-1] * (shard_indices.num_added_elements_padded -
                                 shard_indices.num_added_elements))

        last_org_vocab_end_index = shard_indices.org_vocab_end_index
        last_added_vocab_end_index = shard_indices.added_vocab_end_index

    assert computed_vocab_size == vocab_size_padded
    assert computed_org_vocab_size == org_vocab_size
    assert computed_added_vocab_size == added_vocab_size

    # Ensure that the ranges are not overlapping
    assert len(all_org_tokens) == len(set(all_org_tokens))
    assert len(all_added_tokens) == len(set(all_added_tokens))
    assert not set(all_org_tokens).intersection(set(all_added_tokens))

    token_ids_tensor = torch.tensor(token_ids, dtype=torch.long)
    reindex_mapping = vocab_embedding.get_sharded_to_full_mapping()
    assert reindex_mapping is not None or tp_size == 1
    if reindex_mapping is not None:
        reindexed_token_ids = token_ids_tensor[reindex_mapping]
        expected = torch.tensor(list(range(0, vocab_size)))
        assert reindexed_token_ids[:vocab_size].equal(expected)
        assert torch.all(reindexed_token_ids[vocab_size:] == -1)


def test_get_masked_input_and_mask():
    x = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])

    # base tp 1 case, no padding
    modified_x, _ = get_masked_input_and_mask(x,
                                              org_vocab_start_index=0,
                                              org_vocab_end_index=8,
                                              added_vocab_start_index=8,
                                              added_vocab_end_index=12,
                                              num_org_vocab_padding=0)
    assert torch.equal(x, modified_x)

    # tp 2 case, no padding
    modified_x_rank_0, _ = get_masked_input_and_mask(x,
                                                     org_vocab_start_index=0,
                                                     org_vocab_end_index=4,
                                                     added_vocab_start_index=8,
                                                     added_vocab_end_index=10,
                                                     num_org_vocab_padding=0)
    modified_x_rank_1, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=4,
        org_vocab_end_index=8,
        added_vocab_start_index=10,
        added_vocab_end_index=12,
        num_org_vocab_padding=0)
    assert torch.equal(modified_x_rank_0,
                       torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 0, 0]))
    assert torch.equal(modified_x_rank_1,
                       torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 4, 5]))

    # tp 4 case, no padding
    modified_x_rank_0, _ = get_masked_input_and_mask(x,
                                                     org_vocab_start_index=0,
                                                     org_vocab_end_index=2,
                                                     added_vocab_start_index=8,
                                                     added_vocab_end_index=9,
                                                     num_org_vocab_padding=0)
    modified_x_rank_1, _ = get_masked_input_and_mask(x,
                                                     org_vocab_start_index=2,
                                                     org_vocab_end_index=4,
                                                     added_vocab_start_index=9,
                                                     added_vocab_end_index=10,
                                                     num_org_vocab_padding=0)
    modified_x_rank_2, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=4,
        org_vocab_end_index=6,
        added_vocab_start_index=10,
        added_vocab_end_index=11,
        num_org_vocab_padding=0)
    modified_x_rank_3, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=6,
        org_vocab_end_index=8,
        added_vocab_start_index=11,
        added_vocab_end_index=12,
        num_org_vocab_padding=0)
    assert torch.equal(modified_x_rank_0,
                       torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0]))
    assert torch.equal(modified_x_rank_1,
                       torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0]))
    assert torch.equal(modified_x_rank_2,
                       torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0]))
    assert torch.equal(modified_x_rank_3,
                       torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2]))

    # base tp 1 case, with padding
    modified_x, _ = get_masked_input_and_mask(x,
                                              org_vocab_start_index=0,
                                              org_vocab_end_index=8,
                                              added_vocab_start_index=8,
                                              added_vocab_end_index=12,
                                              num_org_vocab_padding=2)
    assert torch.equal(modified_x,
                       torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13]))

    # tp 2 case, with padding
    modified_x_rank_0, _ = get_masked_input_and_mask(x,
                                                     org_vocab_start_index=0,
                                                     org_vocab_end_index=4,
                                                     added_vocab_start_index=8,
                                                     added_vocab_end_index=10,
                                                     num_org_vocab_padding=2)
    modified_x_rank_1, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=4,
        org_vocab_end_index=8,
        added_vocab_start_index=10,
        added_vocab_end_index=12,
        num_org_vocab_padding=2)
    assert torch.equal(modified_x_rank_0,
                       torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 6, 7, 0, 0]))
    assert torch.equal(modified_x_rank_1,
                       torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 6, 7]))

    # tp 4 case, with padding
    modified_x_rank_0, _ = get_masked_input_and_mask(x,
                                                     org_vocab_start_index=0,
                                                     org_vocab_end_index=2,
                                                     added_vocab_start_index=8,
                                                     added_vocab_end_index=9,
                                                     num_org_vocab_padding=2)
    modified_x_rank_1, _ = get_masked_input_and_mask(x,
                                                     org_vocab_start_index=2,
                                                     org_vocab_end_index=4,
                                                     added_vocab_start_index=9,
                                                     added_vocab_end_index=10,
                                                     num_org_vocab_padding=2)
    modified_x_rank_2, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=4,
        org_vocab_end_index=6,
        added_vocab_start_index=10,
        added_vocab_end_index=11,
        num_org_vocab_padding=2)
    modified_x_rank_3, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=6,
        org_vocab_end_index=8,
        added_vocab_start_index=11,
        added_vocab_end_index=12,
        num_org_vocab_padding=2)
    assert torch.equal(modified_x_rank_0,
                       torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0]))
    assert torch.equal(modified_x_rank_1,
                       torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 4, 0, 0]))
    assert torch.equal(modified_x_rank_2,
                       torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 4, 0]))
    assert torch.equal(modified_x_rank_3,
                       torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 4]))