Commit a99e1875 authored by Michael Carilli's avatar Michael Carilli
Browse files

Fix deprecation warnings for ReduceOp

parent 35891b28
......@@ -7,6 +7,11 @@ if hasattr(torch.distributed, 'get_default_group'):
else:
group_creator = torch.distributed.new_group
if hasattr(torch.distributed, 'ReduceOp'):
ReduceOp = torch.distributed.ReduceOp
else:
ReduceOp = torch.distributed.reduce_op
from .distributed import DistributedDataParallel, Reducer
try:
import syncbn
......
......@@ -2,6 +2,7 @@ import torch
from torch.autograd.function import Function
import syncbn
from apex.parallel import group_creator, ReduceOp
class SyncBatchnormFunction(Function):
......@@ -68,10 +69,10 @@ class SyncBatchnormFunction(Function):
if torch.distributed.is_initialized():
torch.distributed.all_reduce(
mean_dy, torch.distributed.reduce_op.SUM, process_group)
mean_dy, ReduceOp.SUM, process_group)
mean_dy = mean_dy / world_size
torch.distributed.all_reduce(
mean_dy_xmu, torch.distributed.reduce_op.SUM, process_group)
mean_dy_xmu, ReduceOp.SUM, process_group)
mean_dy_xmu = mean_dy_xmu / world_size
grad_input = syncbn.batchnorm_backward(grad_output, saved_input, running_mean, running_variance, weight, mean_dy, mean_dy_xmu, eps)
......
......@@ -3,6 +3,7 @@ from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn import functional as F
from .sync_batchnorm_kernel import SyncBatchnormFunction
from apex.parallel import group_creator, ReduceOp
class SyncBatchNorm(_BatchNorm):
......@@ -80,10 +81,10 @@ class SyncBatchNorm(_BatchNorm):
squashed_input_tensor_view, 2).mean(1)
if torch.distributed.is_initialized():
torch.distributed.all_reduce(
local_mean, torch.distributed.reduce_op.SUM, process_group)
local_mean, ReduceOp.SUM, process_group)
mean = local_mean / world_size
torch.distributed.all_reduce(
local_sqr_mean, torch.distributed.reduce_op.SUM, process_group)
local_sqr_mean, ReduceOp.SUM, process_group)
sqr_mean = local_sqr_mean / world_size
m = local_m * world_size
else:
......
import torch
from torch.autograd.function import Function
from apex.parallel import group_creator, ReduceOp
class SyncBatchnormFunction(Function):
......@@ -57,10 +59,10 @@ class SyncBatchnormFunction(Function):
running_mean)).view(-1, num_features).mean(0)
if torch.distributed.is_initialized():
torch.distributed.all_reduce(
mean_dy, torch.distributed.reduce_op.SUM, process_group)
mean_dy, ReduceOp.SUM, process_group)
mean_dy = mean_dy / world_size
torch.distributed.all_reduce(
mean_dy_xmu, torch.distributed.reduce_op.SUM, process_group)
mean_dy_xmu, ReduceOp.SUM, process_group)
mean_dy_xmu = mean_dy_xmu / world_size
c_last_grad_input = (c_last_grad - mean_dy - (c_last_input - running_mean) / (
running_variance + eps) * mean_dy_xmu) / torch.sqrt(running_variance + eps)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment