Unverified Commit 8421cfb4 authored by mcarilli's avatar mcarilli Committed by GitHub
Browse files

Merge pull request #113 from NVIDIA/sbn_issue

[syncBN]
parents 241dd6c4 fa719e8b
import torch
# Backward compatibility hack around
# https://github.com/pytorch/pytorch/pull/14767
if hasattr(torch.distributed, 'get_default_group'):
group_creator = torch.distributed.get_default_group
elif hasattr(torch.distributed, 'new_group'):
group_creator = torch.distributed.new_group
else:
group_creator = torch.distributed.deprecated.new_group
if hasattr(torch.distributed, 'ReduceOp'):
ReduceOp = torch.distributed.ReduceOp
elif hasattr(torch.distributed, 'reduce_op'):
......
......@@ -2,7 +2,7 @@ import torch
from torch.autograd.function import Function
import syncbn
from apex.parallel import group_creator, ReduceOp
from apex.parallel import ReduceOp
class SyncBatchnormFunction(Function):
......@@ -16,11 +16,9 @@ class SyncBatchnormFunction(Function):
mean, var, var_biased = syncbn.welford_mean_var(input)
if torch.distributed.is_initialized():
if process_group:
world_size = torch.distributed.get_world_size(process_group)
else:
process_group = group_creator()
world_size = torch.distributed.get_world_size()
if not process_group:
process_group = torch.distributed.group.WORLD
world_size = torch.distributed.get_world_size(process_group)
mean_all = torch.empty(world_size, mean.size(0), dtype=mean.dtype, device=mean.device)
var_all = torch.empty(world_size, var.size(0), dtype=var.dtype, device=var.device)
mean_l = [mean_all.narrow(0, i, 1) for i in range(world_size)]
......
......@@ -3,7 +3,7 @@ from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn import functional as F
from .sync_batchnorm_kernel import SyncBatchnormFunction
from apex.parallel import group_creator, ReduceOp
from apex.parallel import ReduceOp
class SyncBatchNorm(_BatchNorm):
......@@ -63,11 +63,9 @@ class SyncBatchNorm(_BatchNorm):
else:
process_group = self.process_group
world_size = 0
if self.process_group:
world_size = torch.distributed.get_world_size(process_group)
else:
process_group = group_creator()
world_size = torch.distributed.get_world_size()
if not self.process_group:
process_group = torch.distributed.group.WORLD
world_size = torch.distributed.get_world_size(process_group)
self.num_batches_tracked += 1
with torch.no_grad():
channel_first_input = input.transpose(0, 1).contiguous()
......
import torch
from torch.autograd.function import Function
from apex.parallel import group_creator, ReduceOp
from apex.parallel import ReduceOp
class SyncBatchnormFunction(Function):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment