Unverified Commit ae1cdd64 authored by Thor Johnsen's avatar Thor Johnsen Committed by GitHub
Browse files

Merge pull request #1161 from NVIDIA/optional_caller_supplied_communicator

Optional NCCL communicator argument to init method
parents 9b880665 e777bddb
...@@ -393,7 +393,7 @@ class SpatialBottleneck(torch.nn.Module): ...@@ -393,7 +393,7 @@ class SpatialBottleneck(torch.nn.Module):
def __init__(self, in_channels, bottleneck_channels, out_channels, stride=1, groups=1, def __init__(self, in_channels, bottleneck_channels, out_channels, stride=1, groups=1,
dilation=1, norm_func=None, use_cudnn=False, explicit_nhwc=False, dilation=1, norm_func=None, use_cudnn=False, explicit_nhwc=False,
spatial_group_size=1): spatial_group_size=1, communicator=None):
super(SpatialBottleneck, self).__init__() super(SpatialBottleneck, self).__init__()
if groups != 1: if groups != 1:
raise RuntimeError('Only support groups == 1') raise RuntimeError('Only support groups == 1')
...@@ -454,11 +454,14 @@ class SpatialBottleneck(torch.nn.Module): ...@@ -454,11 +454,14 @@ class SpatialBottleneck(torch.nn.Module):
assert(num_groups*spatial_group_size == world_size), "torch.distributed.get_world_size() must be multiple of group_size" assert(num_groups*spatial_group_size == world_size), "torch.distributed.get_world_size() must be multiple of group_size"
rank = dist.get_rank() rank = dist.get_rank()
self.local_rank = rank % spatial_group_size self.local_rank = rank % spatial_group_size
for group in range(num_groups): if communicator is None:
ranks = list(range(group*spatial_group_size,(group+1)*spatial_group_size)) for group in range(num_groups):
comm = torch.distributed.new_group(ranks=ranks) ranks = list(range(group*spatial_group_size,(group+1)*spatial_group_size))
if rank in ranks: comm = torch.distributed.new_group(ranks=ranks)
self.communicator = comm if rank in ranks:
self.communicator = comm
else:
self.communicator = communicator
self.stream1 = torch.cuda.Stream() self.stream1 = torch.cuda.Stream()
self.spatial_args = self.spatial_group_size, self.local_rank, self.communicator, self.stream1 self.spatial_args = self.spatial_group_size, self.local_rank, self.communicator, self.stream1
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment