Commit 904d875a authored by Kai Chen's avatar Kai Chen
Browse files

modify distributed training api and use coalesced all_reduce

parent 15e9d026
import os
from collections import OrderedDict
import torch
import torch.multiprocessing as mp
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.nn.utils import clip_grad
from mmcv.torchpack import Hook, OptimizerHook
__all__ = [
'init_dist', 'average_gradients', 'broadcast_params', 'DistOptimizerHook',
'DistSamplerSeedHook'
'init_dist', 'reduce_grads', 'DistOptimizerHook', 'DistSamplerSeedHook'
]
def init_dist(world_size,
rank,
backend='gloo',
master_ip='127.0.0.1',
port=29500):
def init_dist(launcher, backend='nccl', **kwargs):
if mp.get_start_method(allow_none=True) is None:
mp.set_start_method('spawn')
if launcher == 'pytorch':
_init_dist_pytorch(backend, **kwargs)
elif launcher == 'mpi':
_init_dist_pytorch(backend, **kwargs)
elif launcher == 'slurm':
_init_dist_pytorch(backend, **kwargs)
else:
raise ValueError('Invalid launcher type: {}'.format(launcher))
def _init_dist_pytorch(backend, **kwargs):
# TODO: use local_rank instead of rank % num_gpus
rank = int(os.environ['RANK'])
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
os.environ['MASTER_ADDR'] = master_ip
os.environ['MASTER_PORT'] = str(port)
if backend == 'nccl':
dist.init_process_group(backend='nccl')
else:
dist.init_process_group(
backend='gloo', rank=rank, world_size=world_size)
dist.init_process_group(backend=backend, **kwargs)
def _init_dist_mpi(backend, **kwargs):
raise NotImplementedError
def average_gradients(model):
for param in model.parameters():
if param.requires_grad and not (param.grad is None):
dist.all_reduce(param.grad.data)
def _init_dist_slurm(backend, **kwargs):
raise NotImplementedError
def broadcast_params(model):
for p in model.state_dict().values():
dist.broadcast(p, 0)
# modified from https://github.com/NVIDIA/apex/blob/master/apex/parallel/distributed.py#L9
def coalesce_all_reduce(tensors):
buckets = OrderedDict()
for tensor in tensors:
tp = tensor.type()
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(tensor)
for tp in buckets:
bucket = buckets[tp]
coalesced = _flatten_dense_tensors(bucket)
dist.all_reduce(coalesced)
coalesced /= dist.get_world_size()
for buf, synced in zip(bucket,
_unflatten_dense_tensors(coalesced, bucket)):
buf.copy_(synced)
def reduce_grads(model, coalesce=True):
grads = [
param.grad.data for param in model.parameters()
if param.requires_grad and param.grad is not None
]
if coalesce:
coalesce_all_reduce(grads)
else:
for tensor in grads:
dist.all_reduce(tensor)
class DistOptimizerHook(OptimizerHook):
def __init__(self, grad_clip=None, coalesce=True):
self.grad_clip = grad_clip
self.coalesce = coalesce
def after_train_iter(self, runner):
runner.optimizer.zero_grad()
runner.outputs['loss'].backward()
average_gradients(runner.model)
reduce_grads(runner.model, self.coalesce)
if self.grad_clip is not None:
clip_grad.clip_grad_norm_(
filter(lambda p: p.requires_grad, runner.model.parameters()),
......
from functools import partial
from mmcv.torchpack import get_dist_info
from torch.utils.data import DataLoader
from .collate import collate
......@@ -11,10 +12,9 @@ def build_dataloader(dataset,
workers_per_gpu,
num_gpus,
dist=True,
world_size=1,
rank=0,
**kwargs):
if dist:
rank, world_size = get_dist_info()
sampler = DistributedGroupSampler(dataset, imgs_per_gpu, world_size,
rank)
batch_size = imgs_per_gpu
......
......@@ -121,8 +121,7 @@ log_config = dict(
# yapf:enable
# runtime settings
total_epochs = 12
device_ids = range(8)
dist_params = dict(backend='nccl', port='29500', master_ip='127.0.0.1')
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/fpn_faster_rcnn_r50_1x'
load_from = None
......
......@@ -134,8 +134,7 @@ log_config = dict(
# yapf:enable
# runtime settings
total_epochs = 12
device_ids = range(8)
dist_params = dict(backend='nccl', port='29500', master_ip='127.0.0.1')
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/fpn_mask_rcnn_r50_1x'
load_from = None
......
......@@ -100,8 +100,7 @@ log_config = dict(
# yapf:enable
# runtime settings
total_epochs = 12
device_ids = range(8)
dist_params = dict(backend='gloo', port='29500', master_ip='127.0.0.1')
dist_params = dict(backend='gloo')
log_level = 'INFO'
work_dir = './work_dirs/fpn_rpn_r50_1x'
load_from = None
......
......@@ -39,9 +39,7 @@ def batch_processor(model, data, train_mode):
loss, log_vars = parse_losses(losses)
outputs = dict(
loss=loss / args.world_size,
log_vars=log_vars,
num_samples=len(data['img'].data))
loss=loss, log_vars=log_vars, num_samples=len(data['img'].data))
return outputs
......@@ -54,61 +52,65 @@ def parse_args():
action='store_true',
help='whether to add a validate phase')
parser.add_argument(
'--dist', action='store_true', help='use distributed training or not')
parser.add_argument('--world-size', default=1, type=int)
parser.add_argument('--rank', default=0, type=int)
'--gpus', type=int, default=1, help='number of gpus to use')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
return args
args = parse_args()
def main():
# get config from file
args = parse_args()
cfg = Config.fromfile(args.config)
cfg.update(world_size=args.world_size, rank=args.rank)
cfg.update(gpus=args.gpus)
# init distributed environment if necessary
if args.dist:
print('Enable distributed training.')
init_dist(args.world_size, args.rank, **cfg.dist_params)
else:
if args.launcher == 'none':
dist = False
print('Disabled distributed training.')
else:
dist = True
print('Enabled distributed training.')
init_dist(args.launcher, **cfg.dist_args)
# prepare data loaders
train_dataset = obj_from_dict(cfg.data.train, datasets)
data_loaders = [
build_dataloader(
train_dataset, cfg.data.imgs_per_gpu, cfg.data.workers_per_gpu,
len(cfg.device_ids), args.dist, cfg.world_size, cfg.rank)
build_dataloader(train_dataset, cfg.data.imgs_per_gpu,
cfg.data.workers_per_gpu, cfg.gpus, dist)
]
if args.validate:
val_dataset = obj_from_dict(cfg.data.val, datasets)
data_loaders.append(
build_dataloader(
val_dataset, cfg.data.imgs_per_gpu, cfg.data.workers_per_gpu,
len(cfg.device_ids), args.dist, cfg.world_size, cfg.rank))
build_dataloader(val_dataset, cfg.data.imgs_per_gpu,
cfg.data.workers_per_gpu, cfg.gpus, dist))
# build model
model = build_detector(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
if args.dist:
if dist:
model = MMDistributedDataParallel(
model, device_ids=[cfg.rank], broadcast_buffers=False).cuda()
model,
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False).cuda()
else:
model = MMDataParallel(model, device_ids=cfg.device_ids).cuda()
model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda()
# build runner
runner = Runner(model, batch_processor, cfg.optimizer, cfg.work_dir,
cfg.log_level)
# register hooks
optimizer_config = DistOptimizerHook(
**cfg.optimizer_config) if args.dist else cfg.optimizer_config
**cfg.optimizer_config) if dist else cfg.optimizer_config
runner.register_training_hooks(cfg.lr_config, optimizer_config,
cfg.checkpoint_config, cfg.log_config)
if args.dist:
if dist:
runner.register_hook(DistSamplerSeedHook())
if cfg.resume_from:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment