test_layers.py 53.8 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import importlib
4
5
6
import random
from copy import deepcopy
from dataclasses import dataclass
7
from typing import Optional
8
from unittest.mock import patch
9

10
import pytest
11
12
13
14
import torch
import torch.nn.functional as F

from vllm.config import LoRAConfig
15
16
17
from vllm.lora.fully_sharded_layers import (
    ColumnParallelLinearWithShardedLoRA,
    MergedColumnParallelLinearWithShardedLoRA,
18
    MergedQKVParallelLinearWithShardedLoRA, QKVParallelLinearWithShardedLoRA,
19
    RowParallelLinearWithShardedLoRA)
20
21
# yapf conflicts with isort for this block
# yapf: disable
22
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
23
                              LinearScalingRotaryEmbeddingWithLoRA,
24
25
                              LogitsProcessorWithLoRA, LoRAMapping,
                              MergedColumnParallelLinearWithLoRA,
26
27
                              MergedQKVParallelLinearWithLoRA,
                              QKVParallelLinearWithLoRA,
28
                              ReplicatedLinearWithLoRA,
29
30
                              RowParallelLinearWithLoRA,
                              VocabParallelEmbeddingWithLoRA)
31
# yapf: enable
32
from vllm.lora.models import (LongContextLoRAContext, LoRALayerWeights,
33
                              PackedLoRALayerWeights)
34
from vllm.lora.punica_wrapper import get_punica_wrapper
35
36
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               MergedColumnParallelLinear,
37
                                               QKVParallelLinear,
38
                                               ReplicatedLinear,
39
40
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
41
from vllm.model_executor.layers.rotary_embedding import get_rope
42
from vllm.model_executor.layers.vocab_parallel_embedding import (
43
    ParallelLMHead, VocabParallelEmbedding, get_masked_input_and_mask)
44
from vllm.model_executor.utils import set_random_seed
45
from vllm.platforms import current_platform
46
47
48
49
50
51
52
53

from .utils import DummyLoRAManager

TOLERANCES = {
    torch.float16: (5e-3, 5e-3),
    torch.float32: (5e-3, 5e-3),
    torch.bfloat16: (3e-2, 2e-2),
}
54
55
56
57
58
59

pytestmark = pytest.mark.skipif(
    not (current_platform.is_cuda_alike() or current_platform.is_cpu()),
    reason="Backend not supported")

DEVICES = ([
60
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
61
] if current_platform.is_cuda_alike() else ["cpu"])
62

63
#For GPU, we will launch different triton kernels between the prefill and decode
64
65
# stages, so we need to verify this. prefill stage(True) or decode stage(False)
STAGES = [True, False]
66

67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# With the inclusion of V1 tests (look at the run_with_both_engines_lora),
# the tests in this file run twice, once with the V0 engine and then with
# the V1 engine.
# The NUM_RANDOM_SEEDS value was set to 10 before. It is cut to half
# with the inclusion of V1 tests to maintain the CI test times.
NUM_RANDOM_SEEDS = 5
# The VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS value was set to
# 256 before. It is cut to half with the inclusion of V1 tests to maintain
# the CI test times.
VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS = 128


@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
    # Simple autouse wrapper to run both engines for each test
    # This can be promoted up to conftest.py to run for every
    # test in a package

    # Reload punica_gpu as the kernels used are tied to engine type.
    from vllm.lora.punica_wrapper import punica_gpu
    importlib.reload(punica_gpu)

    # Release any memory we might be holding on to. CI runs OOMs otherwise.
    from vllm.lora.ops.triton_ops.utils import (_LORA_A_PTR_DICT,
                                                _LORA_B_PTR_DICT)
    _LORA_B_PTR_DICT.clear()
    _LORA_A_PTR_DICT.clear()

    yield

97
98
99

def get_random_id_to_index(num_loras: int,
                           num_slots: int,
100
                           log: bool = True) -> list[Optional[int]]:
101
102
103
104
105
106
107
108
109
110
111
112
113
114
    """Creates a random lora_id_to_index mapping.

    Args:
        num_loras: The number of active loras in the mapping.
        num_slots: The number of slots in the mapping. Must be larger
            than num_loras.
        log: Whether to log the output.
    """

    if num_loras > num_slots:
        raise ValueError(
            f"num_loras is higher than num_slots: {num_loras} > {num_slots}. "
            "num_loras must be less than or equal to num_slots.")

115
    slots: list[Optional[int]] = [None] * num_slots
116
117
118
119
120
121
122
123
124
125
126
    random_slot_selections = (torch.randperm(num_slots)[:num_loras]).tolist()
    for lora_id, slot_idx in enumerate(random_slot_selections, start=1):
        slots[slot_idx] = lora_id

    if log:
        print(f"Created lora_id_to_index mapping: {slots}.")

    return slots


def populate_loras(
127
    id_to_index: list[Optional[int]],
128
129
130
131
    layer: BaseLayerWithLoRA,
    layer_weights: torch.Tensor,
    generate_embeddings_tensor: int = 0,
    repeats: int = 1,
132
) -> tuple[dict[int, LoRALayerWeights], dict[int, list[LoRALayerWeights]]]:
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
    """This method populates the lora layers with lora weights.

    Args:
        id_to_index: a list of lora ids. The index of the lora id
            represents which memory slot the lora matrices are
            stored in. A None value indicates a free slot.
        layer: the LoRAlayer to populate.
        layer_weights: the PyTorch tensor containing the layer's
            weights.
        generate_embeddings_tensor: whether to generate an
            embeddings tensor for each LoRA.
        repeats: must only be set for column parallel packed
            layers. Indicates the number of loras to compose
            together to create a single lora layer.
    """

    # Dictionary that maps the lora ID to the
    # corresponding lora weights.
151
    lora_dict: dict[int, LoRALayerWeights] = dict()
152
153

    # Dictionary that maps the lora ID to the
154
    # corresponding subloras.
155
    sublora_dict: dict[int, list[LoRALayerWeights]] = dict()
156
157
158

    for slot_idx, lora_id in enumerate(id_to_index):
        if lora_id is not None:
159
            subloras: list[LoRALayerWeights] = []
160
161
            sublora_len = layer_weights.shape[0] // repeats
            for i in range(repeats):
162
163
164
165
166
167
                sublora = DummyLoRAManager(
                    layer_weights.device).init_random_lora(
                        module_name=f"fake_{i}",
                        weight=layer_weights,
                        generate_embeddings_tensor=generate_embeddings_tensor,
                    )
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
                sublora.lora_b = sublora.lora_b[:, (sublora_len *
                                                    i):(sublora_len * (i + 1))]
                sublora.optimize()
                subloras.append(sublora)

            lora = PackedLoRALayerWeights.pack(
                subloras) if repeats > 1 else subloras[0]

            layer.set_lora(
                slot_idx,
                lora_a=lora.lora_a,
                lora_b=lora.lora_b,
                embeddings_tensor=lora.embeddings_tensor,
            )

            lora_dict[lora_id] = lora
            sublora_dict[lora_id] = subloras

    return lora_dict, sublora_dict


def create_random_inputs(
190
    active_lora_ids: list[int],
191
    num_inputs: int,
192
193
    input_size: tuple[int, ...],
    input_range: tuple[float, float],
194
    input_type: torch.dtype = torch.int,
195
    device: torch.device = "cuda"
196
) -> tuple[list[torch.Tensor], list[int], list[int]]:
197
198
199
200
201
202
203
204
205
206
207
208
209
    """Creates random inputs.

    Args:
        active_lora_ids: lora IDs of active lora weights.
        num_inputs: the number of inputs to create.
        input_size: the size of each individual input.
        input_range: the range of values to include in the input.
            input_range[0] <= possible input values < input_range[1]
        input_type: the type of values in the input.
    """

    low, high = input_range

210
211
212
    inputs: list[torch.Tensor] = []
    index_mapping: list[int] = []
    prompt_mapping: list[int] = []
213

214
215
216
    for _ in range(num_inputs):
        if input_type == torch.int:
            inputs.append(
217
218
219
220
                torch.randint(low=int(low),
                              high=int(high),
                              size=input_size,
                              device=device))
221
222
        else:
            inputs.append(
223
224
                torch.rand(size=input_size, dtype=input_type, device=device) *
                high + low)
225
226
227
228
229
230
231
232

        lora_id = random.choice(active_lora_ids)
        index_mapping += [lora_id] * input_size[0]
        prompt_mapping += [lora_id]

    return inputs, index_mapping, prompt_mapping


233
234
235
236
237
def check_punica_wrapper(punica_wrapper) -> bool:
    if current_platform.is_cuda_alike():
        from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU

        return type(punica_wrapper) is PunicaWrapperGPU
238
239
240
241
    elif current_platform.is_cpu():
        from vllm.lora.punica_wrapper.punica_cpu import PunicaWrapperCPU

        return type(punica_wrapper) is PunicaWrapperCPU
242
243
244
245
    else:
        return False


246
247
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
248
@pytest.mark.parametrize("device", DEVICES)
249
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
250
251
@pytest.mark.parametrize("stage", STAGES)
def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
252
253
254
    # For multi-GPU testing of Triton kernel, we must explicitly set the CUDA
    # device, see: https://github.com/triton-lang/triton/issues/2925
    # Same below.
255
256
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)
257

258
    torch.set_default_device(device)
259
    max_loras = 8
260
    punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
261
    assert check_punica_wrapper(punica_wrapper)
262
263
264
265
266
    lora_config = LoRAConfig(max_loras=max_loras,
                             max_lora_rank=8,
                             lora_dtype=torch.float16)

    def create_random_embedding_layer():
267
        embedding = VocabParallelEmbedding(vocab_size, 256)
268
        embedding.weight.data = torch.rand_like(embedding.weight.data)
269
        embedding.weight.data[vocab_size:, :] = 0
270
271
272
273
274
        lora_embedding = VocabParallelEmbeddingWithLoRA(embedding)
        lora_embedding.create_lora_weights(max_loras, lora_config)

        return embedding, lora_embedding

275
    for i in range(NUM_RANDOM_SEEDS):
276
277
278
279
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
        embedding, lora_embedding = create_random_embedding_layer()
280
        lora_embedding.set_mapping(punica_wrapper)
281
282
283
284
285
286
287
288
289
290
        lora_dict, _ = populate_loras(
            id_to_index,
            layer=lora_embedding,
            layer_weights=embedding.weight.T,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=num_loras * 3,
            input_size=(200, ),
291
            input_range=(1, vocab_size),
292
            device=device)
293
294
295
296
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
        punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
297
298
                                       vocab_size,
                                       lora_config.lora_extra_vocab_size)
299
300
301

        lora_result = lora_embedding(torch.cat(inputs))

302
        expected_results: list[torch.Tensor] = []
303
304
305
306
307
308
309
310
311
312
313
314
        for input_, lora_id in zip(inputs, prompt_mapping):
            lora = lora_dict[lora_id]
            result = embedding(input_)
            after_a = F.embedding(
                input_,
                lora.lora_a,
            )
            result += (after_a @ lora.lora_b)
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
315
316
317
318
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
319
320
321
322
323
324
325
326
327
328

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
            lora_embedding.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=num_loras * 3,
            input_size=(200, ),
329
            input_range=(1, vocab_size),
330
            device=device)
331
332
333
334
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
        punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
335
336
                                       vocab_size,
                                       lora_config.lora_extra_vocab_size)
337
338
339
340
341

        lora_result = lora_embedding(torch.cat(inputs))
        expected_result = embedding(torch.cat(inputs))

        rtol, atol = TOLERANCES[lora_result.dtype]
342
343
344
345
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
346
347
348


@torch.inference_mode()
349
350
# @pytest.mark.skip(
#     reason="Fails when loras are in any slot other than the first.")
351
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
352
@pytest.mark.parametrize("device", DEVICES)
353
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
354
@pytest.mark.parametrize("stage", STAGES)
355
def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
356
                                        vocab_size, stage) -> None:
357

358
359
360
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)

361
    torch.set_default_device(device)
362
    max_loras = 8
363
    punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
364
    assert check_punica_wrapper(punica_wrapper)
365
366
367
368
369
    lora_config = LoRAConfig(max_loras=max_loras,
                             max_lora_rank=8,
                             lora_dtype=torch.float16)

    def create_random_embedding_layer():
370
        embedding = VocabParallelEmbedding(vocab_size, 256)
371
372
        embedding_data = torch.rand_like(embedding.weight.data)
        embedding.weight.data = embedding_data
373
        embedding.weight.data[vocab_size:, :] = 0
374
        expanded_embedding = VocabParallelEmbedding(
375
            vocab_size + lora_config.lora_extra_vocab_size * max_loras,
376
            256,
377
378
            org_num_embeddings=vocab_size)
        expanded_embedding.weight.data[:vocab_size, :] = embedding_data
379
        # We need to deepcopy the embedding as it will be modified
380
381
382
383
384
385
386
        # in place
        lora_embedding = VocabParallelEmbeddingWithLoRA(
            deepcopy(expanded_embedding))
        lora_embedding.create_lora_weights(max_loras, lora_config)

        return expanded_embedding, lora_embedding

387
    for i in range(NUM_RANDOM_SEEDS):
388
389
390
391
392
393
394
395
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
        expanded_embedding, lora_embedding = create_random_embedding_layer()
        lora_dict, _ = populate_loras(
            id_to_index,
            layer=lora_embedding,
            layer_weights=torch.zeros(
396
                (256, vocab_size + lora_config.lora_extra_vocab_size)),
397
398
399
            generate_embeddings_tensor=256,
        )

400
        lora_embedding.set_mapping(punica_wrapper)
401
402
403
404
405
406
407
408
        # All embeddings tensors have the same shape.
        embeddings_tensors = [
            lora_dict[id].embeddings_tensor for id in sorted(lora_dict.keys())
        ]
        embeddings_tensor_len = embeddings_tensors[0].shape[0]

        # Add empty embeddings_tensors for unoccupied lora slots.
        for _ in range(max_loras - len(embeddings_tensors)):
409
            embeddings_tensors.append(torch.zeros(embeddings_tensors[0].shape))
410
411
412
413
414

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=num_loras * 3,
            input_size=(200, ),
415
            input_range=(1, vocab_size),
416
            device=device)
417
418
419
420
421
422
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
        punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
                                       vocab_size,
                                       lora_config.lora_extra_vocab_size)
423
424
425
426
427
428
429
        original_inputs = deepcopy(inputs)

        # Force some of the inputs to be in the extended embeddings range
        # to guarantee that their behavior is tested.
        for input_, original_input_, lora_id in zip(inputs, original_inputs,
                                                    prompt_mapping):
            embedding_id = lora_id - 1
430
431
432
433
434
            input_[-1] = vocab_size + (embedding_id * embeddings_tensor_len)
            original_input_[-1] = vocab_size
            input_[-2] = vocab_size + (
                (embedding_id + 1) * embeddings_tensor_len - 1)
            original_input_[-2] = vocab_size + embeddings_tensor_len - 1
435

436
        expanded_embedding.weight[vocab_size:vocab_size +
437
438
439
440
441
                                  (embeddings_tensor_len *
                                   max_loras)] = torch.cat(embeddings_tensors)

        lora_result = lora_embedding(torch.cat(original_inputs))

442
        expected_results: list[torch.Tensor] = []
443
444
445
446
447
448
449
450
451
452
453
454
455
        for input_, original_input_, lora_id in zip(inputs, original_inputs,
                                                    prompt_mapping):
            lora = lora_dict[lora_id]
            result = expanded_embedding(input_)
            after_a = F.embedding(
                original_input_,
                lora.lora_a,
            )
            result += (after_a @ lora.lora_b)
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
456
457
458
459
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
460
461
462
463
464
465
466
467
468
469

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
            lora_embedding.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=num_loras * 3,
            input_size=(200, ),
470
            input_range=(1, vocab_size),
471
            device=device)
472
        original_inputs = deepcopy(inputs)
473
474
475
476
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
        punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
477
478
                                       vocab_size,
                                       lora_config.lora_extra_vocab_size)
479
480
481
482
        lora_result = lora_embedding(torch.cat(original_inputs))
        expected_result = expanded_embedding(torch.cat(inputs))

        rtol, atol = TOLERANCES[lora_result.dtype]
483
484
485
486
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
487
488
489
490


@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
491
@pytest.mark.parametrize("device", DEVICES)
492
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 256512])
493
494
495
@pytest.mark.parametrize("stage", STAGES)
def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
                                  stage) -> None:
496

497
498
499
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)

500
    torch.set_default_device(device)
501
    max_loras = 8
502
    punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
503
    assert check_punica_wrapper(punica_wrapper)
504
505
506
507
    lora_config = LoRAConfig(max_loras=max_loras,
                             max_lora_rank=8,
                             lora_dtype=torch.float16)

508
    def _pretest():
509
        linear = ParallelLMHead(vocab_size + lora_config.lora_extra_vocab_size,
510
511
512
                                1024,
                                vocab_size,
                                params_dtype=torch.float16)
513
        linear.weight.data = torch.rand_like(linear.weight.data)
514
        linear.weight.data[:, vocab_size:] = 0
515
        logits_processor = LogitsProcessor(
516
            vocab_size + lora_config.lora_extra_vocab_size, vocab_size)
517
        lora_logits_processor = LogitsProcessorWithLoRA(
518
519
            logits_processor, 1024, linear.weight.dtype, linear.weight.device,
            None)
520
        lora_logits_processor.create_lora_weights(max_loras, lora_config)
521

522
        return linear, logits_processor, lora_logits_processor
523

524
    for i in range(NUM_RANDOM_SEEDS):
525
526
527
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
528
        linear, logits_processor, lora_logits_processor = _pretest()
529
        lora_logits_processor.set_mapping(punica_wrapper)
530
531
532
        # NOTE: all the generated loras share the same embeddings tensor.
        lora_dict, _ = populate_loras(
            id_to_index,
533
            layer=lora_logits_processor,
534
535
536
537
538
539
540
541
542
543
544
            layer_weights=linear.weight,
            generate_embeddings_tensor=1024,
        )
        embeddings_tensor = list(lora_dict.values())[0].embeddings_tensor
        embeddings_tensor_len = embeddings_tensor.shape[0]

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=8 * num_loras,  # * 3,
            input_size=(1, 1024),
            input_range=(0, 1),
545
            input_type=torch.float16,
546
            device=device)
547
548
549
550
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
        punica_wrapper.update_metadata(
551
552
553
            lora_mapping,
            id_to_index,
            max_loras,
554
            vocab_size,
555
556
            lora_config.lora_extra_vocab_size,
        )
557
        input_ = torch.rand(20, 1024)
558

559
560
        lora_result = lora_logits_processor._get_logits(
            hidden_states=torch.cat(inputs),
561
            lm_head=linear,
562
            embedding_bias=None)
563

564
        original_lm_head = deepcopy(linear)
565

566
567
        linear.weight[logits_processor.
                      org_vocab_size:logits_processor.org_vocab_size +
568
569
                      embeddings_tensor_len] = embeddings_tensor

570
        logits_processor.org_vocab_size = (vocab_size +
571
                                           lora_config.lora_extra_vocab_size)
572
        expected_results: list[torch.Tensor] = []
573
574
        for input_, lora_id in zip(inputs, prompt_mapping):
            lora = lora_dict[lora_id]
575
            result = logits_processor._get_logits(hidden_states=input_,
576
                                                  lm_head=linear,
577
                                                  embedding_bias=None)
578
            result[:, vocab_size + embeddings_tensor_len:] = float("-inf")
579
580
581
            result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
            expected_results.append(result)
        expected_result = torch.cat(expected_results)
582
        logits_processor.org_vocab_size = vocab_size
583
584
585
586

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
587
            lora_logits_processor.reset_lora(slot_idx)
588
589
590
591
592
593

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=8 * num_loras * 3,
            input_size=(1, 1024),
            input_range=(0, 1),
594
            input_type=torch.float16,
595
            device=device)
596
597
598
599
600
601
602
603
604
605
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
        punica_wrapper.update_metadata(
            lora_mapping,
            id_to_index,
            max_loras,
            vocab_size,
            lora_config.lora_extra_vocab_size,
        )
606
607
608

        lora_result = lora_logits_processor._get_logits(
            hidden_states=torch.cat(inputs),
609
            lm_head=original_lm_head,
610
            embedding_bias=None)[:, :vocab_size]
611
612
        expected_result = logits_processor._get_logits(
            hidden_states=torch.cat(inputs),
613
            lm_head=original_lm_head,
614
            embedding_bias=None)
615
616

        rtol, atol = TOLERANCES[lora_result.dtype]
617
618
619
620
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
621
622


623
624
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
625
@pytest.mark.parametrize("device", DEVICES)
626
@pytest.mark.parametrize("stage", STAGES)
627
628
629
@pytest.mark.parametrize("bias_enabled", [True, False])
def test_linear_replicated(dist_init, num_loras, device, stage,
                           bias_enabled) -> None:
630

631
632
633
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)

634
    max_loras = 8
635
    torch.set_default_device(device)
636
    punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
637
    assert check_punica_wrapper(punica_wrapper)
638
639
    lora_config = LoRAConfig(max_loras=max_loras,
                             max_lora_rank=8,
640
641
                             lora_dtype=torch.float16,
                             bias_enabled=bias_enabled)
642
643
644
645
646
647
648
649
650
651
652

    def create_random_linear_replicated_layer():

        linear = ReplicatedLinear(4096,
                                  4096,
                                  bias=False,
                                  params_dtype=torch.float16)
        linear.weight.data = torch.rand_like(linear.weight.data)
        lora_linear = ReplicatedLinearWithLoRA(linear)

        lora_linear.create_lora_weights(max_loras, lora_config)
653
654
655
656
657
658
        assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len(
            lora_linear.lora_b_stacked) == 1)
        if bias_enabled:
            assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices
        else:
            assert lora_linear.lora_bias_stacked is None
659
660
        return linear, lora_linear

661
    for i in range(NUM_RANDOM_SEEDS):
662
663
664
665
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
        linear, lora_linear = create_random_linear_replicated_layer()
666
        assert torch.equal(linear.weight, lora_linear.weight)
667
668
669
670
671
672
673
674
675
676
677
678
679
        lora_linear.set_mapping(punica_wrapper)
        lora_dict, _ = populate_loras(
            id_to_index,
            layer=lora_linear,
            layer_weights=linear.weight,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
            input_type=torch.float16,
680
            device=device)
681
682
683
684
685
686
687
688
689
690
691
692
693
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
        punica_wrapper.update_metadata(
            lora_mapping,
            id_to_index,
            max_loras,
            512,
            lora_config.lora_extra_vocab_size,
        )

        lora_result = lora_linear(torch.cat(inputs))[0]

694
        expected_results: list[torch.Tensor] = []
695
696
697
698
699
700
701
702
        for input_, lora_id in zip(inputs, prompt_mapping):
            lora = lora_dict[lora_id]
            result = linear(input_)[0]
            result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
703
704
705
706
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
707
708
709
710
711
712
713
714
715
716
717
718

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
            lora_linear.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
            input_type=torch.float16,
719
            device=device)
720
721
722
723
724
725
726
727
728
729
730
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)

        punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
                                       512, lora_config.lora_extra_vocab_size)

        lora_result = lora_linear(torch.cat(inputs))[0]
        expected_result = linear(torch.cat(inputs))[0]

        rtol, atol = TOLERANCES[lora_result.dtype]
731
732
733
734
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
735
736


737
738
739
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("orientation", ["row", "column"])
740
@pytest.mark.parametrize("fully_shard", [True, False])
741
@pytest.mark.parametrize("device", DEVICES)
742
@pytest.mark.parametrize("stage", STAGES)
743
@pytest.mark.parametrize("bias_enabled", [True, False])
744
def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
745
                         device, stage, bias_enabled) -> None:
746

747
748
749
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)

750
    max_loras = 8
751
    torch.set_default_device(device)
752
    punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
753
    assert check_punica_wrapper(punica_wrapper)
754
755
    lora_config = LoRAConfig(max_loras=max_loras,
                             max_lora_rank=8,
756
                             fully_sharded_loras=fully_shard,
757
758
                             lora_dtype=torch.float16,
                             bias_enabled=bias_enabled)
759
760
761

    def create_random_linear_parallel_layer():
        if orientation == "row":
762
763
764
765
            linear = RowParallelLinear(4096,
                                       4096,
                                       bias=False,
                                       params_dtype=torch.float16)
766
            linear.weight.data = torch.rand_like(linear.weight.data)
767
768
            lora_linear = (RowParallelLinearWithLoRA(linear) if not fully_shard
                           else RowParallelLinearWithShardedLoRA(linear))
769
        else:
770
771
772
773
            linear = ColumnParallelLinear(4096,
                                          4096,
                                          bias=False,
                                          params_dtype=torch.float16)
774
            linear.weight.data = torch.rand_like(linear.weight.data)
775
776
777
            lora_linear = (ColumnParallelLinearWithLoRA(linear)
                           if not fully_shard else
                           ColumnParallelLinearWithShardedLoRA(linear))
778
        lora_linear.create_lora_weights(max_loras, lora_config)
779
780
781
782
783
784
        assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len(
            lora_linear.lora_b_stacked) == 1)
        if bias_enabled:
            assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices
        else:
            assert lora_linear.lora_bias_stacked is None
785
786
        return linear, lora_linear

787
    for i in range(NUM_RANDOM_SEEDS):
788
789
790
791
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)
        linear, lora_linear = create_random_linear_parallel_layer()
792
        assert torch.equal(linear.weight, lora_linear.weight)
793
        lora_linear.set_mapping(punica_wrapper)
794
795
796
797
798
799
800
801
802
803
804
        lora_dict, _ = populate_loras(
            id_to_index,
            layer=lora_linear,
            layer_weights=linear.weight,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
805
            input_type=torch.float16,
806
            device=device)
807
808
809
810
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
        punica_wrapper.update_metadata(
811
812
813
814
815
816
817
818
819
            lora_mapping,
            id_to_index,
            max_loras,
            512,
            lora_config.lora_extra_vocab_size,
        )

        lora_result = lora_linear(torch.cat(inputs))[0]

820
        expected_results: list[torch.Tensor] = []
821
822
823
824
825
826
827
828
        for input_, lora_id in zip(inputs, prompt_mapping):
            lora = lora_dict[lora_id]
            result = linear(input_)[0]
            result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
829
830
831
832
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
833
834
835
836
837
838
839
840
841
842
843

        # Check that resetting the lora weights succeeds

        for slot_idx in range(max_loras):
            lora_linear.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
844
            input_type=torch.float16,
845
            device=device)
846
847
848
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
849

850
        punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
851
852
853
854
855
856
                                       512, lora_config.lora_extra_vocab_size)

        lora_result = lora_linear(torch.cat(inputs))[0]
        expected_result = linear(torch.cat(inputs))[0]

        rtol, atol = TOLERANCES[lora_result.dtype]
857
858
859
860
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
861
862
863
864


@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
865
@pytest.mark.parametrize("repeats", [1, 2, 3])
866
@pytest.mark.parametrize("fully_shard", [True, False])
867
@pytest.mark.parametrize("device", DEVICES)
868
@pytest.mark.parametrize("stage", STAGES)
869
@pytest.mark.parametrize("bias_enabled", [True, False])
870
def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
871
                                device, stage, bias_enabled) -> None:
872

873
874
875
    if current_platform.is_cuda_alike():
        torch.cuda.set_device(device)

876
    max_loras = 8
877
    torch.set_default_device(device)
878
    punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
879
    assert check_punica_wrapper(punica_wrapper)
880
881
    lora_config = LoRAConfig(max_loras=max_loras,
                             max_lora_rank=8,
882
                             fully_sharded_loras=fully_shard,
883
884
                             lora_dtype=torch.float16,
                             bias_enabled=bias_enabled)
885
886
887
888

    def create_column_parallel_packed_layer():
        if repeats == 2:
            linear = MergedColumnParallelLinear(4096, [4096] * repeats,
889
890
                                                bias=False,
                                                params_dtype=torch.float16)
891
            linear.weight.data = torch.rand_like(linear.weight.data)
892
893
894
            lora_linear = (MergedColumnParallelLinearWithLoRA(linear)
                           if not fully_shard else
                           MergedColumnParallelLinearWithShardedLoRA(linear))
895
        elif repeats == 3:
896
897
898
899
900
            linear = QKVParallelLinear(4096,
                                       64,
                                       32,
                                       bias=False,
                                       params_dtype=torch.float16)
901
            linear.weight.data = torch.rand_like(linear.weight.data)
902
            lora_linear = (MergedQKVParallelLinearWithLoRA(linear)
903
                           if not fully_shard else
904
                           MergedQKVParallelLinearWithShardedLoRA(linear))
905
        else:
906
907
908
909
910
            linear = QKVParallelLinear(4096,
                                       64,
                                       32,
                                       bias=False,
                                       params_dtype=torch.float16)
911
            linear.weight.data = torch.rand_like(linear.weight.data)
912
            lora_linear = QKVParallelLinearWithLoRA(
913
                linear
914
            ) if not fully_shard else QKVParallelLinearWithShardedLoRA(linear)
915
916
917
918
919
920
921

        @dataclass
        class FakeConfig:
            hidden_size = 4096
            num_key_value_heads = 32
            num_attention_heads = 32

922
        n_slices = repeats
923
924
925
        lora_linear.create_lora_weights(max_loras,
                                        lora_config,
                                        model_config=FakeConfig())
926
927
928
929
930
931
        assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len(
            lora_linear.lora_b_stacked) == n_slices)
        if bias_enabled:
            assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices
        else:
            assert lora_linear.lora_bias_stacked is None
932
933
        return linear, lora_linear

934
    for i in range(NUM_RANDOM_SEEDS):
935
936
937
938
939
        set_random_seed(i)

        id_to_index = get_random_id_to_index(num_loras, max_loras)

        linear, lora_linear = create_column_parallel_packed_layer()
940
        assert torch.equal(linear.weight, lora_linear.weight)
941
        lora_linear.set_mapping(punica_wrapper)
942
943
944
945
946
947
948
949
950
951
952
953
        lora_dict, sublora_dict = populate_loras(
            id_to_index,
            layer=lora_linear,
            layer_weights=linear.weight,
            repeats=repeats,
        )

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=list(lora_dict.keys()),
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
954
            input_type=torch.float16,
955
            device=device)
956
957
958
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
959

960
        punica_wrapper.update_metadata(
961
962
963
964
965
966
967
968
969
            lora_mapping,
            id_to_index,
            max_loras,
            512,
            lora_config.lora_extra_vocab_size,
        )

        lora_result = lora_linear(torch.cat(inputs))[0]

970
        expected_results: list[torch.Tensor] = []
971
972
973
974
        for input_, lora_id in zip(inputs, prompt_mapping):
            result = linear(input_)[0]
            subloras = sublora_dict[lora_id]
            for i, sublora in enumerate(subloras):
975
976
977
                result[:, sublora.lora_b.shape[1] * i:sublora.lora_b.shape[1] *
                       (i + 1)] += (input_ @ sublora.lora_a @ sublora.lora_b *
                                    sublora.scaling)
978
979
980
981
            expected_results.append(result)
        expected_result = torch.cat(expected_results)

        rtol, atol = TOLERANCES[lora_result.dtype]
982
983
984
985
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
986
987
988
989
990
991
992
993
994

        for slot_idx in range(max_loras):
            lora_linear.reset_lora(slot_idx)

        inputs, index_mapping, prompt_mapping = create_random_inputs(
            active_lora_ids=[0],
            num_inputs=32 * num_loras,
            input_size=(1, 4096),
            input_range=(0, 1),
995
            input_type=torch.float16,
996
            device=device)
997
998
999
        lora_mapping = LoRAMapping(index_mapping,
                                   prompt_mapping,
                                   is_prefill=stage)
1000

1001
        punica_wrapper.update_metadata(
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
            lora_mapping,
            id_to_index,
            max_loras,
            512,
            lora_config.lora_extra_vocab_size,
        )

        lora_result = lora_linear(torch.cat(inputs))[0]
        expected_result = linear(torch.cat(inputs))[0]

        rtol, atol = TOLERANCES[lora_result.dtype]
1013
1014
1015
1016
        torch.testing.assert_close(lora_result,
                                   expected_result,
                                   rtol=rtol,
                                   atol=atol)
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028


@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 8])
@pytest.mark.parametrize("device", ["cuda"])
@pytest.mark.parametrize("scaling_factors", [(1.0, ), (4.0, ), (4.0, 8.0),
                                             (6.0, 1.0)])
@pytest.mark.parametrize("max_position", [11, 4096, 32768])
@pytest.mark.parametrize("is_neox_style", [True, False])
@pytest.mark.parametrize("rotary_dim", [None, 32])
@pytest.mark.parametrize("head_size", [32, 108])
@pytest.mark.parametrize("seq_len", [11, 1024])
1029
1030
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
                    reason="Only CUDA backends are supported")
1031
1032
1033
1034
1035
def test_rotary_embedding_long_context(dist_init, num_loras, device,
                                       scaling_factors, max_position,
                                       is_neox_style, rotary_dim, head_size,
                                       seq_len) -> None:
    dtype = torch.float16
1036
    max_loras = 8
1037
    seed = 0
1038
    current_platform.seed_everything(seed)
1039
    torch.set_default_device(device)
1040
    punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
1041
    assert check_punica_wrapper(punica_wrapper)
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
    lora_config = LoRAConfig(max_loras=max_loras,
                             max_lora_rank=8,
                             long_lora_scaling_factors=scaling_factors,
                             lora_dtype=dtype)

    if rotary_dim is None:
        rotary_dim = head_size
    base = 10000
    batch_size = 5 * num_loras
    num_heads = 7

    # Verify lora is equivalent to linear scaling rotary embedding.
    rope = get_rope(
        head_size,
        rotary_dim,
        max_position,
        base,
        is_neox_style,
    )
1061
    lora_rope = LinearScalingRotaryEmbeddingWithLoRA(rope)
1062
    lora_rope.set_mapping(punica_wrapper)
1063
1064
1065
    lora_rope.create_lora_weights(max_loras, lora_config)
    linear_rope = get_rope(head_size, rotary_dim, max_position, base,
                           is_neox_style, {
1066
                               "rope_type": "linear",
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
                               "factor": scaling_factors
                           })
    linear_rope = linear_rope.to(dtype=dtype)
    id_to_index = get_random_id_to_index(num_loras, max_loras)
    _, index_mapping, prompt_mapping = create_random_inputs(
        active_lora_ids=[0],
        num_inputs=batch_size,
        input_size=(1, max_position),
        input_range=(0, lora_config.lora_extra_vocab_size),
        input_type=torch.float16,
1077
        device=device)
1078

1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
    lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
    long_lora_context = LongContextLoRAContext(list(scaling_factors),
                                               rotary_dim)

    next_expected_offset = 0
    # Make sure the offset is correct.
    scaling_factor_to_offset = lora_rope.scaling_factor_to_offset
    for scaling_factor, offset in scaling_factor_to_offset.items():
        assert offset == next_expected_offset
        next_expected_offset += scaling_factor * max_position

    for i in range(len(scaling_factors)):
        long_lora_context.offsets_by_lora_id[i] = scaling_factor_to_offset.get(
            scaling_factors[i], 0)
1093
    punica_wrapper.update_metadata(
1094
1095
1096
1097
1098
1099
1100
        lora_mapping,
        id_to_index,
        max_loras,
        512,
        lora_config.lora_extra_vocab_size,
        long_lora_context=long_lora_context,
    )
1101
    # lora_rope.set_mapping(*mapping_info)
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113

    positions = torch.randint(0, max_position, (batch_size, seq_len))
    query = torch.randn(batch_size,
                        seq_len,
                        num_heads * head_size,
                        dtype=dtype)
    key = torch.randn_like(query)
    ref_q, ref_k = linear_rope(positions, query, key)
    actual_q, actual_k = lora_rope(positions, query, key)

    torch.allclose(ref_q, actual_q)
    torch.allclose(ref_k, actual_k)
1114
1115
1116


@pytest.mark.parametrize("tp_size", [1, 2, 4, 8])
1117
1118
@pytest.mark.parametrize(
    "seed", list(range(VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS)))
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
def test_vocab_parallel_embedding_indices(tp_size, seed):
    random.seed(seed)
    vocab_size = random.randint(4000, 64000)
    added_vocab_size = random.randint(0, 1024)
    org_vocab_size = vocab_size - added_vocab_size
    last_org_vocab_end_index = 0
    last_added_vocab_end_index = org_vocab_size
    computed_vocab_size = 0
    computed_org_vocab_size = 0
    computed_added_vocab_size = 0
    vocab_size_padded = -1

1131
1132
1133
    all_org_tokens: list[int] = []
    all_added_tokens: list[int] = []
    token_ids: list[int] = []
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327

    for tp_rank in range(tp_size):
        with patch(
                "vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank",
                return_value=tp_rank
        ), patch(
                "vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_world_size",
                return_value=tp_size):
            vocab_embedding = VocabParallelEmbedding(
                vocab_size, 1, org_num_embeddings=org_vocab_size)
        vocab_size_padded = vocab_embedding.num_embeddings_padded
        shard_indices = vocab_embedding.shard_indices
        # Assert that the ranges are contiguous
        assert shard_indices.org_vocab_start_index == last_org_vocab_end_index
        assert (shard_indices.added_vocab_start_index ==
                last_added_vocab_end_index)

        # Ensure that we are not exceeding the vocab size
        computed_vocab_size += shard_indices.num_elements_padded
        computed_org_vocab_size += shard_indices.num_org_elements
        computed_added_vocab_size += shard_indices.num_added_elements

        # Ensure that the ranges are not overlapping
        all_org_tokens.extend(
            range(shard_indices.org_vocab_start_index,
                  shard_indices.org_vocab_end_index))
        all_added_tokens.extend(
            range(shard_indices.added_vocab_start_index,
                  shard_indices.added_vocab_end_index))

        token_ids.extend(
            range(shard_indices.org_vocab_start_index,
                  shard_indices.org_vocab_end_index))
        token_ids.extend([-1] * (shard_indices.num_org_elements_padded -
                                 shard_indices.num_org_elements))
        token_ids.extend(
            range(shard_indices.added_vocab_start_index,
                  shard_indices.added_vocab_end_index))
        token_ids.extend([-1] * (shard_indices.num_added_elements_padded -
                                 shard_indices.num_added_elements))

        last_org_vocab_end_index = shard_indices.org_vocab_end_index
        last_added_vocab_end_index = shard_indices.added_vocab_end_index

    assert computed_vocab_size == vocab_size_padded
    assert computed_org_vocab_size == org_vocab_size
    assert computed_added_vocab_size == added_vocab_size

    # Ensure that the ranges are not overlapping
    assert len(all_org_tokens) == len(set(all_org_tokens))
    assert len(all_added_tokens) == len(set(all_added_tokens))
    assert not set(all_org_tokens).intersection(set(all_added_tokens))

    token_ids_tensor = torch.tensor(token_ids, dtype=torch.long)
    reindex_mapping = vocab_embedding.get_sharded_to_full_mapping()
    assert reindex_mapping is not None or tp_size == 1
    if reindex_mapping is not None:
        reindexed_token_ids = token_ids_tensor[reindex_mapping]
        expected = torch.tensor(list(range(0, vocab_size)))
        assert reindexed_token_ids[:vocab_size].equal(expected)
        assert torch.all(reindexed_token_ids[vocab_size:] == -1)


def test_get_masked_input_and_mask():
    x = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])

    # base tp 1 case, no padding
    modified_x, _ = get_masked_input_and_mask(x,
                                              org_vocab_start_index=0,
                                              org_vocab_end_index=8,
                                              added_vocab_start_index=8,
                                              added_vocab_end_index=12,
                                              num_org_vocab_padding=0)
    assert torch.equal(x, modified_x)

    # tp 2 case, no padding
    modified_x_rank_0, _ = get_masked_input_and_mask(x,
                                                     org_vocab_start_index=0,
                                                     org_vocab_end_index=4,
                                                     added_vocab_start_index=8,
                                                     added_vocab_end_index=10,
                                                     num_org_vocab_padding=0)
    modified_x_rank_1, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=4,
        org_vocab_end_index=8,
        added_vocab_start_index=10,
        added_vocab_end_index=12,
        num_org_vocab_padding=0)
    assert torch.equal(modified_x_rank_0,
                       torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 0, 0]))
    assert torch.equal(modified_x_rank_1,
                       torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 4, 5]))

    # tp 4 case, no padding
    modified_x_rank_0, _ = get_masked_input_and_mask(x,
                                                     org_vocab_start_index=0,
                                                     org_vocab_end_index=2,
                                                     added_vocab_start_index=8,
                                                     added_vocab_end_index=9,
                                                     num_org_vocab_padding=0)
    modified_x_rank_1, _ = get_masked_input_and_mask(x,
                                                     org_vocab_start_index=2,
                                                     org_vocab_end_index=4,
                                                     added_vocab_start_index=9,
                                                     added_vocab_end_index=10,
                                                     num_org_vocab_padding=0)
    modified_x_rank_2, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=4,
        org_vocab_end_index=6,
        added_vocab_start_index=10,
        added_vocab_end_index=11,
        num_org_vocab_padding=0)
    modified_x_rank_3, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=6,
        org_vocab_end_index=8,
        added_vocab_start_index=11,
        added_vocab_end_index=12,
        num_org_vocab_padding=0)
    assert torch.equal(modified_x_rank_0,
                       torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0]))
    assert torch.equal(modified_x_rank_1,
                       torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0]))
    assert torch.equal(modified_x_rank_2,
                       torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0]))
    assert torch.equal(modified_x_rank_3,
                       torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2]))

    # base tp 1 case, with padding
    modified_x, _ = get_masked_input_and_mask(x,
                                              org_vocab_start_index=0,
                                              org_vocab_end_index=8,
                                              added_vocab_start_index=8,
                                              added_vocab_end_index=12,
                                              num_org_vocab_padding=2)
    assert torch.equal(modified_x,
                       torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13]))

    # tp 2 case, with padding
    modified_x_rank_0, _ = get_masked_input_and_mask(x,
                                                     org_vocab_start_index=0,
                                                     org_vocab_end_index=4,
                                                     added_vocab_start_index=8,
                                                     added_vocab_end_index=10,
                                                     num_org_vocab_padding=2)
    modified_x_rank_1, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=4,
        org_vocab_end_index=8,
        added_vocab_start_index=10,
        added_vocab_end_index=12,
        num_org_vocab_padding=2)
    assert torch.equal(modified_x_rank_0,
                       torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 6, 7, 0, 0]))
    assert torch.equal(modified_x_rank_1,
                       torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 6, 7]))

    # tp 4 case, with padding
    modified_x_rank_0, _ = get_masked_input_and_mask(x,
                                                     org_vocab_start_index=0,
                                                     org_vocab_end_index=2,
                                                     added_vocab_start_index=8,
                                                     added_vocab_end_index=9,
                                                     num_org_vocab_padding=2)
    modified_x_rank_1, _ = get_masked_input_and_mask(x,
                                                     org_vocab_start_index=2,
                                                     org_vocab_end_index=4,
                                                     added_vocab_start_index=9,
                                                     added_vocab_end_index=10,
                                                     num_org_vocab_padding=2)
    modified_x_rank_2, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=4,
        org_vocab_end_index=6,
        added_vocab_start_index=10,
        added_vocab_end_index=11,
        num_org_vocab_padding=2)
    modified_x_rank_3, _ = get_masked_input_and_mask(
        x,
        org_vocab_start_index=6,
        org_vocab_end_index=8,
        added_vocab_start_index=11,
        added_vocab_end_index=12,
        num_org_vocab_padding=2)
    assert torch.equal(modified_x_rank_0,
                       torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0]))
    assert torch.equal(modified_x_rank_1,
                       torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 4, 0, 0]))
    assert torch.equal(modified_x_rank_2,
                       torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 4, 0]))
    assert torch.equal(modified_x_rank_3,
                       torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 4]))