Commit e49dca6e authored by root's avatar root
Browse files

fixing it to work properly in multi-node environment

parent 598fbc88
......@@ -120,7 +120,7 @@ bn_opt = optim.SGD(bn.parameters(), lr=1.0)
for group_num in (range(args.world_size//args.group_size)):
group_ids = range(group_num*args.group_size, (group_num+1)*args.group_size)
cur_group = torch.distributed.new_group(ranks=group_ids)
if (args.local_rank//args.group_size == group_num):
if (torch.distributed.get_rank()//args.group_size == group_num):
group = cur_group
sbn = apex.parallel.SyncBatchNorm(feature_size, process_group=group).cuda()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment