Commit fa719e8b authored by Jie's avatar Jie
Browse files

[syncBN]

replacing new_group with torch.distributed.group.WORLD, avoids creating new
group in every iteration.

This should resolve the issue in Training gets stuck when using SyncBN #105
parent 241dd6c4
import torch import torch
# Backward compatibility hack around
# https://github.com/pytorch/pytorch/pull/14767
if hasattr(torch.distributed, 'get_default_group'):
group_creator = torch.distributed.get_default_group
elif hasattr(torch.distributed, 'new_group'):
group_creator = torch.distributed.new_group
else:
group_creator = torch.distributed.deprecated.new_group
if hasattr(torch.distributed, 'ReduceOp'): if hasattr(torch.distributed, 'ReduceOp'):
ReduceOp = torch.distributed.ReduceOp ReduceOp = torch.distributed.ReduceOp
elif hasattr(torch.distributed, 'reduce_op'): elif hasattr(torch.distributed, 'reduce_op'):
......
...@@ -2,7 +2,7 @@ import torch ...@@ -2,7 +2,7 @@ import torch
from torch.autograd.function import Function from torch.autograd.function import Function
import syncbn import syncbn
from apex.parallel import group_creator, ReduceOp from apex.parallel import ReduceOp
class SyncBatchnormFunction(Function): class SyncBatchnormFunction(Function):
...@@ -16,11 +16,9 @@ class SyncBatchnormFunction(Function): ...@@ -16,11 +16,9 @@ class SyncBatchnormFunction(Function):
mean, var, var_biased = syncbn.welford_mean_var(input) mean, var, var_biased = syncbn.welford_mean_var(input)
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():
if process_group: if not process_group:
world_size = torch.distributed.get_world_size(process_group) process_group = torch.distributed.group.WORLD
else: world_size = torch.distributed.get_world_size(process_group)
process_group = group_creator()
world_size = torch.distributed.get_world_size()
mean_all = torch.empty(world_size, mean.size(0), dtype=mean.dtype, device=mean.device) mean_all = torch.empty(world_size, mean.size(0), dtype=mean.dtype, device=mean.device)
var_all = torch.empty(world_size, var.size(0), dtype=var.dtype, device=var.device) var_all = torch.empty(world_size, var.size(0), dtype=var.dtype, device=var.device)
mean_l = [mean_all.narrow(0, i, 1) for i in range(world_size)] mean_l = [mean_all.narrow(0, i, 1) for i in range(world_size)]
......
...@@ -3,7 +3,7 @@ from torch.nn.modules.batchnorm import _BatchNorm ...@@ -3,7 +3,7 @@ from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn import functional as F from torch.nn import functional as F
from .sync_batchnorm_kernel import SyncBatchnormFunction from .sync_batchnorm_kernel import SyncBatchnormFunction
from apex.parallel import group_creator, ReduceOp from apex.parallel import ReduceOp
class SyncBatchNorm(_BatchNorm): class SyncBatchNorm(_BatchNorm):
...@@ -63,11 +63,9 @@ class SyncBatchNorm(_BatchNorm): ...@@ -63,11 +63,9 @@ class SyncBatchNorm(_BatchNorm):
else: else:
process_group = self.process_group process_group = self.process_group
world_size = 0 world_size = 0
if self.process_group: if not self.process_group:
world_size = torch.distributed.get_world_size(process_group) process_group = torch.distributed.group.WORLD
else: world_size = torch.distributed.get_world_size(process_group)
process_group = group_creator()
world_size = torch.distributed.get_world_size()
self.num_batches_tracked += 1 self.num_batches_tracked += 1
with torch.no_grad(): with torch.no_grad():
channel_first_input = input.transpose(0, 1).contiguous() channel_first_input = input.transpose(0, 1).contiguous()
......
import torch import torch
from torch.autograd.function import Function from torch.autograd.function import Function
from apex.parallel import group_creator, ReduceOp from apex.parallel import ReduceOp
class SyncBatchnormFunction(Function): class SyncBatchnormFunction(Function):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment