utils.py 1017 Bytes
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
r'''
Utils to play with PyTorch.
'''
4
5
6
import torch.distributed as dist


Rick Ho's avatar
Rick Ho committed
7
8
# pylint: disable=broad-except
# pylint: disable=protected-access
9
def get_torch_default_comm():
Rick Ho's avatar
Rick Ho committed
10
11
12
13
14
15
16
17
18
    r'''
    The NCCL communicator is needed so that Fast MoE can perform customized
    communication operators in the C code. However, it is not a publicly
    available variable. Therefore, a hacking class of the `ProcessGroupNCCL`
    in Fast MoE's C code takes the `_default_pg` and tries to dig the
    communicator out from the object. As PyTorch's private interface varies from
    time to time, different hacking techniques are tried one-by-one to be
    compatible with various versions of PyTorch.
    '''
19
20
21
    try:
        comm = dist.distributed_c10d._get_default_group()
        return comm
Rick Ho's avatar
Rick Ho committed
22
    except Exception as _:
23
24
25
26
27
28
29
30
        pass
    try:
        comm = dist.distributed_c10d._default_pg
        if comm is not None:
            return comm
    except Exception as _:
        pass
    raise RuntimeError('Unsupported PyTorch version')