Commit 3a08d827 authored by Christian Sarofeen's avatar Christian Sarofeen
Browse files

Distributed refactor.

parent dd05dbdb
...@@ -4,6 +4,7 @@ import torch.distributed as dist ...@@ -4,6 +4,7 @@ import torch.distributed as dist
from torch.nn.modules import Module from torch.nn.modules import Module
from torch.autograd import Variable from torch.autograd import Variable
def flat_dist_call(tensors, call, extra_args=None): def flat_dist_call(tensors, call, extra_args=None):
flat_dist_call.warn_on_half = True flat_dist_call.warn_on_half = True
buckets = {} buckets = {}
...@@ -41,17 +42,17 @@ class DistributedDataParallel(Module): ...@@ -41,17 +42,17 @@ class DistributedDataParallel(Module):
and that torch.set_device has been used to set the device. Parameters are broadcasted and that torch.set_device has been used to set the device. Parameters are broadcasted
to the other processes on initialization of DistributedDataParallel, and will be to the other processes on initialization of DistributedDataParallel, and will be
allreduced in buckets durring the backward pass. allreduced in buckets durring the backward pass.
See https://github.com/csarofeen/examples/tree/apex/distributed for detailed usage. See https://github.com/csarofeen/examples/tree/apex/distributed for detailed usage.
Args: Args:
module: Network definition to be run in multi-gpu/distributed mode. module: Network definition to be run in multi-gpu/distributed mode.
message_size (Default = 10000000): Minimum number of elements in a communication bucket. message_size (Default = 100e6): Minimum number of elements in a communication bucket.
""" """
def __init__(self, module, message_size=10000000): def __init__(self, module, message_size=100000000):
super(DistributedDataParallel, self).__init__() super(DistributedDataParallel, self).__init__()
self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False
...@@ -93,19 +94,19 @@ class DistributedDataParallel(Module): ...@@ -93,19 +94,19 @@ class DistributedDataParallel(Module):
if not self.needs_reduction: if not self.needs_reduction:
return return
self.needs_reduction = False self.needs_reduction = False
ready = [] grads = []
for i in range(len(self.param_state)): for i in range(self.ready_end, len(self.param_state)):
if self.param_state[i] == 1: param = self.param_refs[self.record[i]]
param = self.param_list[self.record[i]] if param.grad is not None:
if param.grad is not None: grads.append(param.grad.data)
ready.append(param.grad.data) grads = [param.grad.data for param in self.ready_params] + grads
if(len(ready)>0): if(len(grads)>0):
orig_stream = torch.cuda.current_stream() orig_stream = torch.cuda.current_stream()
with torch.cuda.stream(self.reduction_stream): with torch.cuda.stream(self.reduction_stream):
self.reduction_stream.wait_stream(orig_stream) self.reduction_stream.wait_stream(orig_stream)
flat_dist_call(ready, dist.all_reduce) flat_dist_call(grads, dist.all_reduce)
torch.cuda.current_stream().wait_stream(self.reduction_stream) torch.cuda.current_stream().wait_stream(self.reduction_stream)
...@@ -129,30 +130,25 @@ class DistributedDataParallel(Module): ...@@ -129,30 +130,25 @@ class DistributedDataParallel(Module):
def comm_ready_buckets(self): def comm_ready_buckets(self):
ready = [] if self.param_state[self.ready_end] == 0:
counter = 0 return
while counter < len(self.param_state) and self.param_state[counter] == 2:
counter += 1
while counter < len(self.param_state) and self.param_state[counter] == 1: while self.ready_end < len(self.param_state) and self.param_state[self.ready_end] == 1:
ready.append(counter) self.ready_params.append(self.param_refs[self.record[self.ready_end]])
counter += 1 self.ready_numel += self.ready_params[-1].numel()
self.ready_end += 1
if not ready:
return
grads = [] if self.ready_numel < self.message_size:
for ind in ready: return
param_ind = self.record[ind]
if self.param_list[param_ind].grad is not None: grads = [param.grad.data for param in self.ready_params]
grads.append(self.param_list[param_ind].grad.data)
bucket = [] bucket = []
bucket_inds = [] bucket_inds = []
while grads: while grads:
bucket.append(grads.pop(0)) bucket.append(grads.pop(0))
bucket_inds.append(ready.pop(0))
cumm_size = 0 cumm_size = 0
for ten in bucket: for ten in bucket:
...@@ -168,17 +164,11 @@ class DistributedDataParallel(Module): ...@@ -168,17 +164,11 @@ class DistributedDataParallel(Module):
with torch.cuda.stream(self.reduction_stream): with torch.cuda.stream(self.reduction_stream):
flat_dist_call(bucket, dist.all_reduce) flat_dist_call(bucket, dist.all_reduce)
for ind in bucket_inds: for i in range(self.ready_start, self.ready_start+len(bucket)):
self.param_state[ind] = 2 self.param_state[i] = 2
self.ready_params.pop(0)
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
"""
Forward function for DDP.
Args:
inputs: inputs that match the module's passed in for initialization.
kwargs: kwargs that match the module's passed in for initialization.
"""
param_list = [param for param in list(self.module.parameters()) if param.requires_grad] param_list = [param for param in list(self.module.parameters()) if param.requires_grad]
...@@ -195,5 +185,10 @@ class DistributedDataParallel(Module): ...@@ -195,5 +185,10 @@ class DistributedDataParallel(Module):
self.param_state = [0 for i in range(len(param_list))] self.param_state = [0 for i in range(len(param_list))]
self.param_refs = param_list self.param_refs = param_list
self.needs_reduction = True self.needs_reduction = True
self.ready_start = 0
self.ready_end = 0
self.ready_params = []
self.ready_numel = 0
return self.module(*inputs, **kwargs) return self.module(*inputs, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment