dist_utils.py 2.83 KB
Newer Older
Kai Chen's avatar
Kai Chen committed
1
# Copyright (c) Open-MMLab. All rights reserved.
Kai Chen's avatar
Kai Chen committed
2
3
4
5
6
7
8
9
import functools
import os
import subprocess

import torch
import torch.distributed as dist
import torch.multiprocessing as mp

Cao Yuhang's avatar
Cao Yuhang committed
10
11
from mmcv.utils import TORCH_VERSION

Kai Chen's avatar
Kai Chen committed
12
13
14
15
16
17
18
19
20
21
22

def init_dist(launcher, backend='nccl', **kwargs):
    if mp.get_start_method(allow_none=True) is None:
        mp.set_start_method('spawn')
    if launcher == 'pytorch':
        _init_dist_pytorch(backend, **kwargs)
    elif launcher == 'mpi':
        _init_dist_mpi(backend, **kwargs)
    elif launcher == 'slurm':
        _init_dist_slurm(backend, **kwargs)
    else:
Cao Yuhang's avatar
Cao Yuhang committed
23
        raise ValueError(f'Invalid launcher type: {launcher}')
Kai Chen's avatar
Kai Chen committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37


def _init_dist_pytorch(backend, **kwargs):
    # TODO: use local_rank instead of rank % num_gpus
    rank = int(os.environ['RANK'])
    num_gpus = torch.cuda.device_count()
    torch.cuda.set_device(rank % num_gpus)
    dist.init_process_group(backend=backend, **kwargs)


def _init_dist_mpi(backend, **kwargs):
    raise NotImplementedError


38
39
40
41
42
43
44
45
46
47
48
def _init_dist_slurm(backend, port=None):
    """Initialize slurm distributed training environment.

    If argument ``port`` is not specified, then the master port will be system
    environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
    environment variable, then a default port ``29500`` will be used.

    Args:
        backend (str): Backend of torch.distributed.
        port (int, optional): Master port. Defaults to None.
    """
Kai Chen's avatar
Kai Chen committed
49
50
51
52
53
54
    proc_id = int(os.environ['SLURM_PROCID'])
    ntasks = int(os.environ['SLURM_NTASKS'])
    node_list = os.environ['SLURM_NODELIST']
    num_gpus = torch.cuda.device_count()
    torch.cuda.set_device(proc_id % num_gpus)
    addr = subprocess.getoutput(
Cao Yuhang's avatar
Cao Yuhang committed
55
        f'scontrol show hostname {node_list} | head -n1')
56
57
58
59
60
61
62
63
    # specify master port
    if port is not None:
        os.environ['MASTER_PORT'] = str(port)
    elif 'MASTER_PORT' in os.environ:
        pass  # use MASTER_PORT in the environment variable
    else:
        # 29500 is torch.distributed default port
        os.environ['MASTER_PORT'] = '29500'
Kai Chen's avatar
Kai Chen committed
64
65
    os.environ['MASTER_ADDR'] = addr
    os.environ['WORLD_SIZE'] = str(ntasks)
66
    os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
Kai Chen's avatar
Kai Chen committed
67
68
69
70
71
    os.environ['RANK'] = str(proc_id)
    dist.init_process_group(backend=backend)


def get_dist_info():
Cao Yuhang's avatar
Cao Yuhang committed
72
    if TORCH_VERSION < '1.0':
Kai Chen's avatar
Kai Chen committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
        initialized = dist._initialized
    else:
        if dist.is_available():
            initialized = dist.is_initialized()
        else:
            initialized = False
    if initialized:
        rank = dist.get_rank()
        world_size = dist.get_world_size()
    else:
        rank = 0
        world_size = 1
    return rank, world_size


def master_only(func):

    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        rank, _ = get_dist_info()
        if rank == 0:
            return func(*args, **kwargs)

    return wrapper