Commit 3a08d827 authored by Christian Sarofeen's avatar Christian Sarofeen
Browse files

Distributed refactor.

parent dd05dbdb
...@@ -4,6 +4,7 @@ import torch.distributed as dist ...@@ -4,6 +4,7 @@ import torch.distributed as dist
from torch.nn.modules import Module from torch.nn.modules import Module
from torch.autograd import Variable from torch.autograd import Variable
def flat_dist_call(tensors, call, extra_args=None): def flat_dist_call(tensors, call, extra_args=None):
flat_dist_call.warn_on_half = True flat_dist_call.warn_on_half = True
buckets = {} buckets = {}
...@@ -46,12 +47,12 @@ class DistributedDataParallel(Module): ...@@ -46,12 +47,12 @@ class DistributedDataParallel(Module):
Args: Args:
module: Network definition to be run in multi-gpu/distributed mode. module: Network definition to be run in multi-gpu/distributed mode.
message_size (Default = 10000000): Minimum number of elements in a communication bucket. message_size (Default = 100e6): Minimum number of elements in a communication bucket.
""" """
def __init__(self, module, message_size=10000000): def __init__(self, module, message_size=100000000):
super(DistributedDataParallel, self).__init__() super(DistributedDataParallel, self).__init__()
self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False
...@@ -94,18 +95,18 @@ class DistributedDataParallel(Module): ...@@ -94,18 +95,18 @@ class DistributedDataParallel(Module):
return return
self.needs_reduction = False self.needs_reduction = False
ready = [] grads = []
for i in range(len(self.param_state)): for i in range(self.ready_end, len(self.param_state)):
if self.param_state[i] == 1: param = self.param_refs[self.record[i]]
param = self.param_list[self.record[i]]
if param.grad is not None: if param.grad is not None:
ready.append(param.grad.data) grads.append(param.grad.data)
grads = [param.grad.data for param in self.ready_params] + grads
if(len(ready)>0): if(len(grads)>0):
orig_stream = torch.cuda.current_stream() orig_stream = torch.cuda.current_stream()
with torch.cuda.stream(self.reduction_stream): with torch.cuda.stream(self.reduction_stream):
self.reduction_stream.wait_stream(orig_stream) self.reduction_stream.wait_stream(orig_stream)
flat_dist_call(ready, dist.all_reduce) flat_dist_call(grads, dist.all_reduce)
torch.cuda.current_stream().wait_stream(self.reduction_stream) torch.cuda.current_stream().wait_stream(self.reduction_stream)
...@@ -129,30 +130,25 @@ class DistributedDataParallel(Module): ...@@ -129,30 +130,25 @@ class DistributedDataParallel(Module):
def comm_ready_buckets(self): def comm_ready_buckets(self):
ready = [] if self.param_state[self.ready_end] == 0:
counter = 0 return
while counter < len(self.param_state) and self.param_state[counter] == 2: while self.ready_end < len(self.param_state) and self.param_state[self.ready_end] == 1:
counter += 1 self.ready_params.append(self.param_refs[self.record[self.ready_end]])
self.ready_numel += self.ready_params[-1].numel()
self.ready_end += 1
while counter < len(self.param_state) and self.param_state[counter] == 1:
ready.append(counter)
counter += 1
if not ready: if self.ready_numel < self.message_size:
return return
grads = [] grads = [param.grad.data for param in self.ready_params]
for ind in ready:
param_ind = self.record[ind]
if self.param_list[param_ind].grad is not None:
grads.append(self.param_list[param_ind].grad.data)
bucket = [] bucket = []
bucket_inds = [] bucket_inds = []
while grads: while grads:
bucket.append(grads.pop(0)) bucket.append(grads.pop(0))
bucket_inds.append(ready.pop(0))
cumm_size = 0 cumm_size = 0
for ten in bucket: for ten in bucket:
...@@ -168,17 +164,11 @@ class DistributedDataParallel(Module): ...@@ -168,17 +164,11 @@ class DistributedDataParallel(Module):
with torch.cuda.stream(self.reduction_stream): with torch.cuda.stream(self.reduction_stream):
flat_dist_call(bucket, dist.all_reduce) flat_dist_call(bucket, dist.all_reduce)
for ind in bucket_inds: for i in range(self.ready_start, self.ready_start+len(bucket)):
self.param_state[ind] = 2 self.param_state[i] = 2
self.ready_params.pop(0)
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
"""
Forward function for DDP.
Args:
inputs: inputs that match the module's passed in for initialization.
kwargs: kwargs that match the module's passed in for initialization.
"""
param_list = [param for param in list(self.module.parameters()) if param.requires_grad] param_list = [param for param in list(self.module.parameters()) if param.requires_grad]
...@@ -196,4 +186,9 @@ class DistributedDataParallel(Module): ...@@ -196,4 +186,9 @@ class DistributedDataParallel(Module):
self.param_refs = param_list self.param_refs = param_list
self.needs_reduction = True self.needs_reduction = True
self.ready_start = 0
self.ready_end = 0
self.ready_params = []
self.ready_numel = 0
return self.module(*inputs, **kwargs) return self.module(*inputs, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment