utils.py 451 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch.distributed as dist


def get_torch_default_comm():
    try:
        comm = dist.distributed_c10d._get_default_group()
        return comm
    except Exception as e:
        print('Error {}'.format(e))
        pass
    try:
        comm = dist.distributed_c10d._default_pg
        if comm is not None:
            return comm
    except Exception as _:
        pass
    raise RuntimeError('Unsupported PyTorch version')
    return None