Commit 3a08d827 authored by Christian Sarofeen's avatar Christian Sarofeen
Browse files

Distributed refactor.

parent dd05dbdb
......@@ -4,6 +4,7 @@ import torch.distributed as dist
from torch.nn.modules import Module
from torch.autograd import Variable
def flat_dist_call(tensors, call, extra_args=None):
flat_dist_call.warn_on_half = True
buckets = {}
......@@ -41,17 +42,17 @@ class DistributedDataParallel(Module):
and that torch.set_device has been used to set the device. Parameters are broadcasted
to the other processes on initialization of DistributedDataParallel, and will be
allreduced in buckets durring the backward pass.
See https://github.com/csarofeen/examples/tree/apex/distributed for detailed usage.
Args:
module: Network definition to be run in multi-gpu/distributed mode.
message_size (Default = 10000000): Minimum number of elements in a communication bucket.
message_size (Default = 100e6): Minimum number of elements in a communication bucket.
"""
def __init__(self, module, message_size=10000000):
def __init__(self, module, message_size=100000000):
super(DistributedDataParallel, self).__init__()
self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False
......@@ -93,19 +94,19 @@ class DistributedDataParallel(Module):
if not self.needs_reduction:
return
self.needs_reduction = False
ready = []
for i in range(len(self.param_state)):
if self.param_state[i] == 1:
param = self.param_list[self.record[i]]
if param.grad is not None:
ready.append(param.grad.data)
if(len(ready)>0):
grads = []
for i in range(self.ready_end, len(self.param_state)):
param = self.param_refs[self.record[i]]
if param.grad is not None:
grads.append(param.grad.data)
grads = [param.grad.data for param in self.ready_params] + grads
if(len(grads)>0):
orig_stream = torch.cuda.current_stream()
with torch.cuda.stream(self.reduction_stream):
self.reduction_stream.wait_stream(orig_stream)
flat_dist_call(ready, dist.all_reduce)
flat_dist_call(grads, dist.all_reduce)
torch.cuda.current_stream().wait_stream(self.reduction_stream)
......@@ -129,30 +130,25 @@ class DistributedDataParallel(Module):
def comm_ready_buckets(self):
ready = []
counter = 0
if self.param_state[self.ready_end] == 0:
return
while counter < len(self.param_state) and self.param_state[counter] == 2:
counter += 1
while counter < len(self.param_state) and self.param_state[counter] == 1:
ready.append(counter)
counter += 1
while self.ready_end < len(self.param_state) and self.param_state[self.ready_end] == 1:
self.ready_params.append(self.param_refs[self.record[self.ready_end]])
self.ready_numel += self.ready_params[-1].numel()
self.ready_end += 1
if not ready:
return
grads = []
for ind in ready:
param_ind = self.record[ind]
if self.param_list[param_ind].grad is not None:
grads.append(self.param_list[param_ind].grad.data)
if self.ready_numel < self.message_size:
return
grads = [param.grad.data for param in self.ready_params]
bucket = []
bucket_inds = []
while grads:
bucket.append(grads.pop(0))
bucket_inds.append(ready.pop(0))
cumm_size = 0
for ten in bucket:
......@@ -168,17 +164,11 @@ class DistributedDataParallel(Module):
with torch.cuda.stream(self.reduction_stream):
flat_dist_call(bucket, dist.all_reduce)
for ind in bucket_inds:
self.param_state[ind] = 2
for i in range(self.ready_start, self.ready_start+len(bucket)):
self.param_state[i] = 2
self.ready_params.pop(0)
def forward(self, *inputs, **kwargs):
"""
Forward function for DDP.
Args:
inputs: inputs that match the module's passed in for initialization.
kwargs: kwargs that match the module's passed in for initialization.
"""
param_list = [param for param in list(self.module.parameters()) if param.requires_grad]
......@@ -195,5 +185,10 @@ class DistributedDataParallel(Module):
self.param_state = [0 for i in range(len(param_list))]
self.param_refs = param_list
self.needs_reduction = True
self.ready_start = 0
self.ready_end = 0
self.ready_params = []
self.ready_numel = 0
return self.module(*inputs, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment