Commit 35891b28 authored by Michael Carilli's avatar Michael Carilli
Browse files

Compatibility with new_group() API

parent 48343d94
import torch import torch
# Backward compatibility hack around
# https://github.com/pytorch/pytorch/pull/14767
if hasattr(torch.distributed, 'get_default_group'):
group_creator = torch.distributed.get_default_group
else:
group_creator = torch.distributed.new_group
from .distributed import DistributedDataParallel, Reducer from .distributed import DistributedDataParallel, Reducer
try: try:
import syncbn import syncbn
......
...@@ -18,7 +18,7 @@ class SyncBatchnormFunction(Function): ...@@ -18,7 +18,7 @@ class SyncBatchnormFunction(Function):
if process_group: if process_group:
world_size = torch.distributed.get_world_size(process_group) world_size = torch.distributed.get_world_size(process_group)
else: else:
process_group = torch.distributed.get_default_group() process_group = group_creator()
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
mean_all = torch.empty(world_size, mean.size(0), dtype=mean.dtype, device=mean.device) mean_all = torch.empty(world_size, mean.size(0), dtype=mean.dtype, device=mean.device)
var_all = torch.empty(world_size, var.size(0), dtype=var.dtype, device=var.device) var_all = torch.empty(world_size, var.size(0), dtype=var.dtype, device=var.device)
......
...@@ -65,7 +65,7 @@ class SyncBatchNorm(_BatchNorm): ...@@ -65,7 +65,7 @@ class SyncBatchNorm(_BatchNorm):
if self.process_group: if self.process_group:
world_size = torch.distributed.get_world_size(process_group) world_size = torch.distributed.get_world_size(process_group)
else: else:
process_group = torch.distributed.get_default_group() process_group = group_creator()
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
self.num_batches_tracked += 1 self.num_batches_tracked += 1
with torch.no_grad(): with torch.no_grad():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment