utils.py 1.3 KB
Newer Older
Sengxian's avatar
Sengxian committed
1
r"""
Rick Ho's avatar
Rick Ho committed
2
Utils to play with PyTorch.
Sengxian's avatar
Sengxian committed
3
"""
4
import torch
5
6
7
import torch.distributed as dist


Rick Ho's avatar
Rick Ho committed
8
9
# pylint: disable=broad-except
# pylint: disable=protected-access
10
def get_torch_default_comm():
Sengxian's avatar
Sengxian committed
11
    r"""
Rick Ho's avatar
Rick Ho committed
12
13
14
15
16
17
18
    The NCCL communicator is needed so that Fast MoE can perform customized
    communication operators in the C code. However, it is not a publicly
    available variable. Therefore, a hacking class of the `ProcessGroupNCCL`
    in Fast MoE's C code takes the `_default_pg` and tries to dig the
    communicator out from the object. As PyTorch's private interface varies from
    time to time, different hacking techniques are tried one-by-one to be
    compatible with various versions of PyTorch.
Sengxian's avatar
Sengxian committed
19
    """
20
21
22
    try:
        comm = dist.distributed_c10d._get_default_group()
        return comm
Rick Ho's avatar
Rick Ho committed
23
    except Exception as _:
24
25
26
27
28
29
30
        pass
    try:
        comm = dist.distributed_c10d._default_pg
        if comm is not None:
            return comm
    except Exception as _:
        pass
Sengxian's avatar
Sengxian committed
31
    raise RuntimeError("Unsupported PyTorch version")
32
33
34
35
36
37
38
39
40
41


def get_rank_0_in_comm(comm):
    world_size = dist.get_world_size(comm)
    x = torch.tensor([dist.get_rank()], dtype=torch.int64, device='cuda')
    ys = [torch.empty_like(x) for _ in range(world_size)]
    dist.all_gather(ys, x, group=comm)
    root_rank = ys[0].item()
    return root_rank