import logging import os import random import numpy as np import torch import torch.distributed as dist import torch.multiprocessing as mp from mmcv.runner import get_dist_info def init_dist(launcher, backend='nccl', **kwargs): if mp.get_start_method(allow_none=True) is None: mp.set_start_method('spawn') if launcher == 'pytorch': _init_dist_pytorch(backend, **kwargs) elif launcher == 'mpi': _init_dist_mpi(backend, **kwargs) elif launcher == 'slurm': _init_dist_slurm(backend, **kwargs) else: raise ValueError('Invalid launcher type: {}'.format(launcher)) def _init_dist_pytorch(backend, **kwargs): # TODO: use local_rank instead of rank % num_gpus rank = int(os.environ['RANK']) num_gpus = torch.cuda.device_count() torch.cuda.set_device(rank % num_gpus) dist.init_process_group(backend=backend, **kwargs) def _init_dist_mpi(backend, **kwargs): raise NotImplementedError def _init_dist_slurm(backend, **kwargs): proc_id = int(os.environ['SLURM_PROCID']) ntasks = int(os.environ['SLURM_NTASKS']) node_list = os.environ['SLURM_NODELIST'] num_gpus = torch.cuda.device_count() torch.cuda.set_device(proc_id % num_gpus) if '[' in node_list: beg = node_list.find('[') pos1 = node_list.find('-', beg) if pos1 < 0: pos1 = 1000 pos2 = node_list.find(',', beg) if pos2 < 0: pos2 = 1000 node_list = node_list[:min(pos1, pos2)].replace('[', '') addr = node_list[8:].replace('-', '.') os.environ['MASTER_PORT'] = str(kwargs['port']) os.environ['MASTER_ADDR'] = addr os.environ['WORLD_SIZE'] = str(ntasks) os.environ['RANK'] = str(proc_id) if backend == 'nccl': dist.init_process_group(backend='nccl') else: dist.init_process_group( backend='gloo', rank=proc_id, world_size=ntasks) rank = dist.get_rank() world_size = dist.get_world_size() return rank, world_size def set_random_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def get_root_logger(log_level=logging.INFO): logger = logging.getLogger() if not logger.hasHandlers(): logging.basicConfig( format='%(asctime)s - %(levelname)s - %(message)s', level=log_level) rank, _ = get_dist_info() if rank != 0: logger.setLevel('ERROR') return logger