Commit 37cd5dfd authored by root's avatar root
Browse files

moved process group creation into apex so it can be called by users

parent e49dca6e
......@@ -51,3 +51,42 @@ def convert_syncbn_model(module, process_group=None, channel_last=False):
# TODO(jie) should I delete model explicitly?
del module
return mod
def create_syncbn_process_group(group_size):
'''
Creates process groups to be used for syncbn of a give ``group_size`` and returns
process group that current GPU participates in.
``group_size`` must divide the total number of GPUs (world_size).
``group_size`` of 0 would be considered as =world_size. In this case ``None`` will be returned.
``group_size`` of 1 would be equivalent to using non-sync bn, but will still carry the overhead.
Args:
group_size (int): number of GPU's to collaborate for sync bn
Example::
>>> # model is an instance of torch.nn.Module
>>> import apex
>>> group = apex.parallel.create_syncbn_process_group(group_size)
'''
if group_size==0:
return None
world_size = torch.distributed.get_world_size()
assert(world_size >= group_size)
assert(world_size % group_size == 0)
group=None
for group_num in (range(world_size//group_size)):
group_ids = range(group_num*group_size, (group_num+1)*group_size)
cur_group = torch.distributed.new_group(ranks=group_ids)
if (torch.distributed.get_rank()//group_size == group_num):
group = cur_group
#can not drop out and return here, every process must go through creation of all subgroups
assert(group is not None)
return group
......@@ -38,12 +38,6 @@ except:
print("This is a multi-gpu test. To run it please use 'python -m torch.distributed.launch --nproc_per_node=<num gpus> test_groups.py <more options>'")
exit(1)
if args.group_size==0:
args.group_size = args.world_size
assert(args.world_size >= args.group_size)
assert(args.world_size % args.group_size == 0)
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
......@@ -116,14 +110,7 @@ for param in bn.parameters():
param.grad = param.grad / args.group_size
bn_opt = optim.SGD(bn.parameters(), lr=1.0)
# create process groups and pick the group this gpu participates
for group_num in (range(args.world_size//args.group_size)):
group_ids = range(group_num*args.group_size, (group_num+1)*args.group_size)
cur_group = torch.distributed.new_group(ranks=group_ids)
if (torch.distributed.get_rank()//args.group_size == group_num):
group = cur_group
sbn = apex.parallel.SyncBatchNorm(feature_size, process_group=group).cuda()
sbn = apex.parallel.SyncBatchNorm(feature_size, process_group=apex.parallel.create_syncbn_process_group(args.group_size)).cuda()
sbn.momentum = 1.0
sbn.weight.data = weight_t.clone()
sbn.bias.data = bias_t.clone()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment