utils.py 2.96 KB
Newer Older
aiss's avatar
aiss committed
1
2
3
4
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
aiss's avatar
aiss committed
5

aiss's avatar
aiss committed
6
7
8
import os
from typing import List

Jeff Rasley's avatar
Jeff Rasley committed
9
import torch
aiss's avatar
aiss committed
10
from deepspeed import comm as dist
11
from deepspeed.utils import logger
Jeff Rasley's avatar
Jeff Rasley committed
12
from deepspeed.ops.adam import DeepSpeedCPUAdam
aiss's avatar
aiss committed
13
from deepspeed.ops.adagrad import DeepSpeedCPUAdagrad
14
from deepspeed.ops.adam import FusedAdam
aiss's avatar
aiss committed
15
from deepspeed.utils.nvtx import instrument_w_nvtx
aiss's avatar
aiss committed
16
from deepspeed.accelerator import get_accelerator
Chunyang Wen's avatar
Chunyang Wen committed
17

Jeff Rasley's avatar
Jeff Rasley committed
18
19
20

def _initialize_parameter_parallel_groups(parameter_parallel_size=None):
    data_parallel_size = int(dist.get_world_size())
Chunyang Wen's avatar
Chunyang Wen committed
21
    parameter_parallel_size = parameter_parallel_size or data_parallel_size
aiss's avatar
aiss committed
22
    logger.info("data_parallel_size: %s, parameter_parallel_size: %s", data_parallel_size, parameter_parallel_size)
Jeff Rasley's avatar
Jeff Rasley committed
23
24
25
26
    assert data_parallel_size % parameter_parallel_size == 0, \
        'world size should be divisible by parameter parallel size'
    rank = dist.get_rank()
    my_group = None
Chunyang Wen's avatar
Chunyang Wen committed
27
    for i in range(data_parallel_size // parameter_parallel_size):
Jeff Rasley's avatar
Jeff Rasley committed
28
        ranks = range(i * parameter_parallel_size, (i + 1) * parameter_parallel_size)
aiss's avatar
aiss committed
29
        group = dist.new_group(ranks)
Jeff Rasley's avatar
Jeff Rasley committed
30
31
32
        if rank in ranks:
            my_group = group
    return my_group
Jeff Rasley's avatar
Jeff Rasley committed
33
34


aiss's avatar
aiss committed
35
36
37
38
class ZeRORuntimeException(Exception):
    pass


39
ZERO_SUPPORTED_OPTIMIZERS = [
aiss's avatar
aiss committed
40
    torch.optim.Adam, torch.optim.AdamW, FusedAdam, DeepSpeedCPUAdam, torch.optim.Adagrad, DeepSpeedCPUAdagrad
41
]
42
43
44
45

# Add apex FusedAdam to supported list if apex is installed
try:
    import apex
aiss's avatar
aiss committed
46
47
    if hasattr(apex, 'optimizers') and hasattr(apex.optimizers, 'FusedAdam'):
        ZERO_SUPPORTED_OPTIMIZERS.append(apex.optimizers.FusedAdam)
48
49
except ImportError:
    pass
Jeff Rasley's avatar
Jeff Rasley committed
50
51
52


def is_zero_supported_optimizer(optimizer):
Samyam Rajbhandari's avatar
Samyam Rajbhandari committed
53
    if dist.get_rank() == 0:
aiss's avatar
aiss committed
54
        logger.info(f'Checking ZeRO support for optimizer={optimizer.__class__.__name__} type={type(optimizer)}')
Jeff Rasley's avatar
Jeff Rasley committed
55
    return type(optimizer) in ZERO_SUPPORTED_OPTIMIZERS
aiss's avatar
aiss committed
56
57
58
59
60
61
62
63
64
65


def get_lst_from_rank0(lst: List[int]) -> None:
    """
    NOTE: creates both communication and synchronization overhead so should be used
    sparingly
    """
    lst_tensor = torch.tensor(
        lst if dist.get_rank() == 0 else [-1] * len(lst),
        dtype=int,
aiss's avatar
aiss committed
66
67
        # device=get_accelerator().current_device_name(),
        device=torch.device(get_accelerator().device_name(os.environ["LOCAL_RANK"])),
aiss's avatar
aiss committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
        requires_grad=False,
    )
    dist.broadcast(lst_tensor, src=0, async_op=False)

    return list(lst_tensor.cpu().numpy())


@instrument_w_nvtx
def assert_ints_same_as_other_ranks(ints: List[int]) -> None:
    """
    NOTE: creates both communication and synchronization overhead so should be
    used sparingly

    takes a list of ints from each rank and ensures that they are the same
    across ranks, throwing an exception if they are not.
    """
    rank0_ints = get_lst_from_rank0(ints)
    if ints != rank0_ints:
        raise RuntimeError(f"disagreement between rank0 and rank{dist.get_rank()}: "
                           f"rank0: {rank0_ints}, rank{dist.get_rank()}: {ints}")