utils.py 2.13 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
# Copyright 2023 The vLLM team.
2
3
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
Zhuohan Li's avatar
Zhuohan Li committed
4
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
5
from typing import Sequence, Tuple
Zhuohan Li's avatar
Zhuohan Li committed
6
7

import torch
8

Zhuohan Li's avatar
Zhuohan Li committed
9

10
11
12
def ensure_divisibility(numerator, denominator):
    """Ensure that numerator is divisible by the denominator."""
    assert numerator % denominator == 0, "{} is not divisible by {}".format(
13
        numerator, denominator)
14
15
16
17
18
19
20
21


def divide(numerator, denominator):
    """Ensure that numerator is divisible by the denominator and return
    the division value."""
    ensure_divisibility(numerator, denominator)
    return numerator // denominator

Zhuohan Li's avatar
Zhuohan Li committed
22
23
24
25
26

def split_tensor_along_last_dim(
    tensor: torch.Tensor,
    num_partitions: int,
    contiguous_split_chunks: bool = False,
27
) -> Sequence[torch.Tensor]:
Zhuohan Li's avatar
Zhuohan Li committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
    """ Split a tensor along its last dimension.

        Arguments:
            tensor: input tensor.
            num_partitions: number of partitions to split the tensor
            contiguous_split_chunks: If True, make each chunk contiguous
                                     in memory.

        Returns:
            A list of Tensors
    """
    # Get the size and dimension.
    last_dim = tensor.dim() - 1
    last_dim_size = divide(tensor.size()[last_dim], num_partitions)
    # Split.
    tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
44
    # NOTE: torch.split does not create contiguous tensors by default.
Zhuohan Li's avatar
Zhuohan Li committed
45
46
47
48
    if contiguous_split_chunks:
        return tuple(chunk.contiguous() for chunk in tensor_list)

    return tensor_list
49
50
51
52


def get_pp_indices(num_hidden_layers: int, pp_rank: int,
                   pp_size: int) -> Tuple[int, int]:
53
54
55
56
57
    """Try to evenly distribute layers across partitions.
    If the number of layers is not divisible by the number of partitions,
    the last partition will have the remaining layers.
    """
    layers_per_partition = num_hidden_layers // pp_size
58
59
60
    start_layer = pp_rank * layers_per_partition
    end_layer = start_layer + layers_per_partition

61
62
63
    if pp_rank == pp_size - 1:
        end_layer = num_hidden_layers

64
    return (start_layer, end_layer)