Commit 904d875a authored by Kai Chen's avatar Kai Chen
Browse files

modify distributed training api and use coalesced all_reduce

parent 15e9d026
import os import os
from collections import OrderedDict
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.distributed as dist import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.nn.utils import clip_grad from torch.nn.utils import clip_grad
from mmcv.torchpack import Hook, OptimizerHook from mmcv.torchpack import Hook, OptimizerHook
__all__ = [ __all__ = [
'init_dist', 'average_gradients', 'broadcast_params', 'DistOptimizerHook', 'init_dist', 'reduce_grads', 'DistOptimizerHook', 'DistSamplerSeedHook'
'DistSamplerSeedHook'
] ]
def init_dist(world_size, def init_dist(launcher, backend='nccl', **kwargs):
rank,
backend='gloo',
master_ip='127.0.0.1',
port=29500):
if mp.get_start_method(allow_none=True) is None: if mp.get_start_method(allow_none=True) is None:
mp.set_start_method('spawn') mp.set_start_method('spawn')
if launcher == 'pytorch':
_init_dist_pytorch(backend, **kwargs)
elif launcher == 'mpi':
_init_dist_pytorch(backend, **kwargs)
elif launcher == 'slurm':
_init_dist_pytorch(backend, **kwargs)
else:
raise ValueError('Invalid launcher type: {}'.format(launcher))
def _init_dist_pytorch(backend, **kwargs):
# TODO: use local_rank instead of rank % num_gpus
rank = int(os.environ['RANK'])
num_gpus = torch.cuda.device_count() num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus) torch.cuda.set_device(rank % num_gpus)
os.environ['MASTER_ADDR'] = master_ip dist.init_process_group(backend=backend, **kwargs)
os.environ['MASTER_PORT'] = str(port)
if backend == 'nccl':
dist.init_process_group(backend='nccl') def _init_dist_mpi(backend, **kwargs):
else: raise NotImplementedError
dist.init_process_group(
backend='gloo', rank=rank, world_size=world_size)
def average_gradients(model): def _init_dist_slurm(backend, **kwargs):
for param in model.parameters(): raise NotImplementedError
if param.requires_grad and not (param.grad is None):
dist.all_reduce(param.grad.data)
def broadcast_params(model): # modified from https://github.com/NVIDIA/apex/blob/master/apex/parallel/distributed.py#L9
for p in model.state_dict().values(): def coalesce_all_reduce(tensors):
dist.broadcast(p, 0) buckets = OrderedDict()
for tensor in tensors:
tp = tensor.type()
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(tensor)
for tp in buckets:
bucket = buckets[tp]
coalesced = _flatten_dense_tensors(bucket)
dist.all_reduce(coalesced)
coalesced /= dist.get_world_size()
for buf, synced in zip(bucket,
_unflatten_dense_tensors(coalesced, bucket)):
buf.copy_(synced)
def reduce_grads(model, coalesce=True):
grads = [
param.grad.data for param in model.parameters()
if param.requires_grad and param.grad is not None
]
if coalesce:
coalesce_all_reduce(grads)
else:
for tensor in grads:
dist.all_reduce(tensor)
class DistOptimizerHook(OptimizerHook): class DistOptimizerHook(OptimizerHook):
def __init__(self, grad_clip=None, coalesce=True):
self.grad_clip = grad_clip
self.coalesce = coalesce
def after_train_iter(self, runner): def after_train_iter(self, runner):
runner.optimizer.zero_grad() runner.optimizer.zero_grad()
runner.outputs['loss'].backward() runner.outputs['loss'].backward()
average_gradients(runner.model) reduce_grads(runner.model, self.coalesce)
if self.grad_clip is not None: if self.grad_clip is not None:
clip_grad.clip_grad_norm_( clip_grad.clip_grad_norm_(
filter(lambda p: p.requires_grad, runner.model.parameters()), filter(lambda p: p.requires_grad, runner.model.parameters()),
......
from functools import partial from functools import partial
from mmcv.torchpack import get_dist_info
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from .collate import collate from .collate import collate
...@@ -11,10 +12,9 @@ def build_dataloader(dataset, ...@@ -11,10 +12,9 @@ def build_dataloader(dataset,
workers_per_gpu, workers_per_gpu,
num_gpus, num_gpus,
dist=True, dist=True,
world_size=1,
rank=0,
**kwargs): **kwargs):
if dist: if dist:
rank, world_size = get_dist_info()
sampler = DistributedGroupSampler(dataset, imgs_per_gpu, world_size, sampler = DistributedGroupSampler(dataset, imgs_per_gpu, world_size,
rank) rank)
batch_size = imgs_per_gpu batch_size = imgs_per_gpu
......
...@@ -121,8 +121,7 @@ log_config = dict( ...@@ -121,8 +121,7 @@ log_config = dict(
# yapf:enable # yapf:enable
# runtime settings # runtime settings
total_epochs = 12 total_epochs = 12
device_ids = range(8) dist_params = dict(backend='nccl')
dist_params = dict(backend='nccl', port='29500', master_ip='127.0.0.1')
log_level = 'INFO' log_level = 'INFO'
work_dir = './work_dirs/fpn_faster_rcnn_r50_1x' work_dir = './work_dirs/fpn_faster_rcnn_r50_1x'
load_from = None load_from = None
......
...@@ -134,8 +134,7 @@ log_config = dict( ...@@ -134,8 +134,7 @@ log_config = dict(
# yapf:enable # yapf:enable
# runtime settings # runtime settings
total_epochs = 12 total_epochs = 12
device_ids = range(8) dist_params = dict(backend='nccl')
dist_params = dict(backend='nccl', port='29500', master_ip='127.0.0.1')
log_level = 'INFO' log_level = 'INFO'
work_dir = './work_dirs/fpn_mask_rcnn_r50_1x' work_dir = './work_dirs/fpn_mask_rcnn_r50_1x'
load_from = None load_from = None
......
...@@ -100,8 +100,7 @@ log_config = dict( ...@@ -100,8 +100,7 @@ log_config = dict(
# yapf:enable # yapf:enable
# runtime settings # runtime settings
total_epochs = 12 total_epochs = 12
device_ids = range(8) dist_params = dict(backend='gloo')
dist_params = dict(backend='gloo', port='29500', master_ip='127.0.0.1')
log_level = 'INFO' log_level = 'INFO'
work_dir = './work_dirs/fpn_rpn_r50_1x' work_dir = './work_dirs/fpn_rpn_r50_1x'
load_from = None load_from = None
......
...@@ -39,9 +39,7 @@ def batch_processor(model, data, train_mode): ...@@ -39,9 +39,7 @@ def batch_processor(model, data, train_mode):
loss, log_vars = parse_losses(losses) loss, log_vars = parse_losses(losses)
outputs = dict( outputs = dict(
loss=loss / args.world_size, loss=loss, log_vars=log_vars, num_samples=len(data['img'].data))
log_vars=log_vars,
num_samples=len(data['img'].data))
return outputs return outputs
...@@ -54,61 +52,65 @@ def parse_args(): ...@@ -54,61 +52,65 @@ def parse_args():
action='store_true', action='store_true',
help='whether to add a validate phase') help='whether to add a validate phase')
parser.add_argument( parser.add_argument(
'--dist', action='store_true', help='use distributed training or not') '--gpus', type=int, default=1, help='number of gpus to use')
parser.add_argument('--world-size', default=1, type=int) parser.add_argument(
parser.add_argument('--rank', default=0, type=int) '--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()
return args return args
args = parse_args()
def main(): def main():
# get config from file args = parse_args()
cfg = Config.fromfile(args.config) cfg = Config.fromfile(args.config)
cfg.update(world_size=args.world_size, rank=args.rank) cfg.update(gpus=args.gpus)
# init distributed environment if necessary # init distributed environment if necessary
if args.dist: if args.launcher == 'none':
print('Enable distributed training.') dist = False
init_dist(args.world_size, args.rank, **cfg.dist_params)
else:
print('Disabled distributed training.') print('Disabled distributed training.')
else:
dist = True
print('Enabled distributed training.')
init_dist(args.launcher, **cfg.dist_args)
# prepare data loaders # prepare data loaders
train_dataset = obj_from_dict(cfg.data.train, datasets) train_dataset = obj_from_dict(cfg.data.train, datasets)
data_loaders = [ data_loaders = [
build_dataloader( build_dataloader(train_dataset, cfg.data.imgs_per_gpu,
train_dataset, cfg.data.imgs_per_gpu, cfg.data.workers_per_gpu, cfg.data.workers_per_gpu, cfg.gpus, dist)
len(cfg.device_ids), args.dist, cfg.world_size, cfg.rank)
] ]
if args.validate: if args.validate:
val_dataset = obj_from_dict(cfg.data.val, datasets) val_dataset = obj_from_dict(cfg.data.val, datasets)
data_loaders.append( data_loaders.append(
build_dataloader( build_dataloader(val_dataset, cfg.data.imgs_per_gpu,
val_dataset, cfg.data.imgs_per_gpu, cfg.data.workers_per_gpu, cfg.data.workers_per_gpu, cfg.gpus, dist))
len(cfg.device_ids), args.dist, cfg.world_size, cfg.rank))
# build model # build model
model = build_detector( model = build_detector(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg) cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
if args.dist: if dist:
model = MMDistributedDataParallel( model = MMDistributedDataParallel(
model, device_ids=[cfg.rank], broadcast_buffers=False).cuda() model,
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False).cuda()
else: else:
model = MMDataParallel(model, device_ids=cfg.device_ids).cuda() model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda()
# build runner # build runner
runner = Runner(model, batch_processor, cfg.optimizer, cfg.work_dir, runner = Runner(model, batch_processor, cfg.optimizer, cfg.work_dir,
cfg.log_level) cfg.log_level)
# register hooks # register hooks
optimizer_config = DistOptimizerHook( optimizer_config = DistOptimizerHook(
**cfg.optimizer_config) if args.dist else cfg.optimizer_config **cfg.optimizer_config) if dist else cfg.optimizer_config
runner.register_training_hooks(cfg.lr_config, optimizer_config, runner.register_training_hooks(cfg.lr_config, optimizer_config,
cfg.checkpoint_config, cfg.log_config) cfg.checkpoint_config, cfg.log_config)
if args.dist: if dist:
runner.register_hook(DistSamplerSeedHook()) runner.register_hook(DistSamplerSeedHook())
if cfg.resume_from: if cfg.resume_from:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment