utils.py 2.41 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
# Copyright 2023 The vLLM team.
2
3
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
Zhuohan Li's avatar
Zhuohan Li committed
4
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
5
from typing import List, Sequence
Zhuohan Li's avatar
Zhuohan Li committed
6
7

import torch
8

Zhuohan Li's avatar
Zhuohan Li committed
9

10
11
12
def ensure_divisibility(numerator, denominator):
    """Ensure that numerator is divisible by the denominator."""
    assert numerator % denominator == 0, "{} is not divisible by {}".format(
13
        numerator, denominator)
14
15
16
17
18
19
20
21


def divide(numerator, denominator):
    """Ensure that numerator is divisible by the denominator and return
    the division value."""
    ensure_divisibility(numerator, denominator)
    return numerator // denominator

Zhuohan Li's avatar
Zhuohan Li committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43

def split_tensor_along_last_dim(
    tensor: torch.Tensor,
    num_partitions: int,
    contiguous_split_chunks: bool = False,
) -> List[torch.Tensor]:
    """ Split a tensor along its last dimension.

        Arguments:
            tensor: input tensor.
            num_partitions: number of partitions to split the tensor
            contiguous_split_chunks: If True, make each chunk contiguous
                                     in memory.

        Returns:
            A list of Tensors
    """
    # Get the size and dimension.
    last_dim = tensor.dim() - 1
    last_dim_size = divide(tensor.size()[last_dim], num_partitions)
    # Split.
    tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
44
    # NOTE: torch.split does not create contiguous tensors by default.
Zhuohan Li's avatar
Zhuohan Li committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
    if contiguous_split_chunks:
        return tuple(chunk.contiguous() for chunk in tensor_list)

    return tensor_list


class VocabUtility:
    """ Split the vocabulary into `world_size` chunks and return the first
        and last index of the vocabulary belonging to the `rank`
        partition: Note that indices in [fist, last)

    """

    @staticmethod
    def vocab_range_from_per_partition_vocab_size(
60
            per_partition_vocab_size: int, rank: int) -> Sequence[int]:
Zhuohan Li's avatar
Zhuohan Li committed
61
62
63
64
65
        index_f = rank * per_partition_vocab_size
        index_l = index_f + per_partition_vocab_size
        return index_f, index_l

    @staticmethod
66
67
    def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int,
                                           world_size: int) -> Sequence[int]:
Zhuohan Li's avatar
Zhuohan Li committed
68
69
        per_partition_vocab_size = divide(global_vocab_size, world_size)
        return VocabUtility.vocab_range_from_per_partition_vocab_size(
70
            per_partition_vocab_size, rank)