Commit a99e1875 authored by Michael Carilli's avatar Michael Carilli
Browse files

Fix deprecation warnings for ReduceOp

parent 35891b28
...@@ -7,6 +7,11 @@ if hasattr(torch.distributed, 'get_default_group'): ...@@ -7,6 +7,11 @@ if hasattr(torch.distributed, 'get_default_group'):
else: else:
group_creator = torch.distributed.new_group group_creator = torch.distributed.new_group
if hasattr(torch.distributed, 'ReduceOp'):
ReduceOp = torch.distributed.ReduceOp
else:
ReduceOp = torch.distributed.reduce_op
from .distributed import DistributedDataParallel, Reducer from .distributed import DistributedDataParallel, Reducer
try: try:
import syncbn import syncbn
......
...@@ -2,6 +2,7 @@ import torch ...@@ -2,6 +2,7 @@ import torch
from torch.autograd.function import Function from torch.autograd.function import Function
import syncbn import syncbn
from apex.parallel import group_creator, ReduceOp
class SyncBatchnormFunction(Function): class SyncBatchnormFunction(Function):
...@@ -68,10 +69,10 @@ class SyncBatchnormFunction(Function): ...@@ -68,10 +69,10 @@ class SyncBatchnormFunction(Function):
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():
torch.distributed.all_reduce( torch.distributed.all_reduce(
mean_dy, torch.distributed.reduce_op.SUM, process_group) mean_dy, ReduceOp.SUM, process_group)
mean_dy = mean_dy / world_size mean_dy = mean_dy / world_size
torch.distributed.all_reduce( torch.distributed.all_reduce(
mean_dy_xmu, torch.distributed.reduce_op.SUM, process_group) mean_dy_xmu, ReduceOp.SUM, process_group)
mean_dy_xmu = mean_dy_xmu / world_size mean_dy_xmu = mean_dy_xmu / world_size
grad_input = syncbn.batchnorm_backward(grad_output, saved_input, running_mean, running_variance, weight, mean_dy, mean_dy_xmu, eps) grad_input = syncbn.batchnorm_backward(grad_output, saved_input, running_mean, running_variance, weight, mean_dy, mean_dy_xmu, eps)
......
...@@ -3,6 +3,7 @@ from torch.nn.modules.batchnorm import _BatchNorm ...@@ -3,6 +3,7 @@ from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn import functional as F from torch.nn import functional as F
from .sync_batchnorm_kernel import SyncBatchnormFunction from .sync_batchnorm_kernel import SyncBatchnormFunction
from apex.parallel import group_creator, ReduceOp
class SyncBatchNorm(_BatchNorm): class SyncBatchNorm(_BatchNorm):
...@@ -80,10 +81,10 @@ class SyncBatchNorm(_BatchNorm): ...@@ -80,10 +81,10 @@ class SyncBatchNorm(_BatchNorm):
squashed_input_tensor_view, 2).mean(1) squashed_input_tensor_view, 2).mean(1)
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():
torch.distributed.all_reduce( torch.distributed.all_reduce(
local_mean, torch.distributed.reduce_op.SUM, process_group) local_mean, ReduceOp.SUM, process_group)
mean = local_mean / world_size mean = local_mean / world_size
torch.distributed.all_reduce( torch.distributed.all_reduce(
local_sqr_mean, torch.distributed.reduce_op.SUM, process_group) local_sqr_mean, ReduceOp.SUM, process_group)
sqr_mean = local_sqr_mean / world_size sqr_mean = local_sqr_mean / world_size
m = local_m * world_size m = local_m * world_size
else: else:
......
import torch import torch
from torch.autograd.function import Function from torch.autograd.function import Function
from apex.parallel import group_creator, ReduceOp
class SyncBatchnormFunction(Function): class SyncBatchnormFunction(Function):
...@@ -57,10 +59,10 @@ class SyncBatchnormFunction(Function): ...@@ -57,10 +59,10 @@ class SyncBatchnormFunction(Function):
running_mean)).view(-1, num_features).mean(0) running_mean)).view(-1, num_features).mean(0)
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():
torch.distributed.all_reduce( torch.distributed.all_reduce(
mean_dy, torch.distributed.reduce_op.SUM, process_group) mean_dy, ReduceOp.SUM, process_group)
mean_dy = mean_dy / world_size mean_dy = mean_dy / world_size
torch.distributed.all_reduce( torch.distributed.all_reduce(
mean_dy_xmu, torch.distributed.reduce_op.SUM, process_group) mean_dy_xmu, ReduceOp.SUM, process_group)
mean_dy_xmu = mean_dy_xmu / world_size mean_dy_xmu = mean_dy_xmu / world_size
c_last_grad_input = (c_last_grad - mean_dy - (c_last_input - running_mean) / ( c_last_grad_input = (c_last_grad - mean_dy - (c_last_input - running_mean) / (
running_variance + eps) * mean_dy_xmu) / torch.sqrt(running_variance + eps) running_variance + eps) * mean_dy_xmu) / torch.sqrt(running_variance + eps)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment