Commit 7c2ae41e authored by Christian Sarofeen's avatar Christian Sarofeen
Browse files

Fix race condition in DDP.

parent 531cb5c3
...@@ -48,14 +48,15 @@ class DistributedDataParallel(Module): ...@@ -48,14 +48,15 @@ class DistributedDataParallel(Module):
Args: Args:
module: Network definition to be run in multi-gpu/distributed mode. module: Network definition to be run in multi-gpu/distributed mode.
message_size (Default = 100e6): Minimum number of elements in a communication bucket. message_size (Default = 100e6): Minimum number of elements in a communication bucket.
shared_param (Default = False): If your model uses shared parameters this must be true,
it will disable bucketing of parameters which is necessary to avoid race conditions.
""" """
def __init__(self, module, message_size=100000000): def __init__(self, module, message_size=100000000, shared_param=False):
super(DistributedDataParallel, self).__init__() super(DistributedDataParallel, self).__init__()
self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False
self.shared_param = shared_param
self.message_size = message_size self.message_size = message_size
#reference to last iterations parameters to see if anything has changed #reference to last iterations parameters to see if anything has changed
...@@ -119,8 +120,7 @@ class DistributedDataParallel(Module): ...@@ -119,8 +120,7 @@ class DistributedDataParallel(Module):
Variable._execution_engine.queue_callback(allreduce_params) Variable._execution_engine.queue_callback(allreduce_params)
else: else:
Variable._execution_engine.queue_callback(flush_buckets) Variable._execution_engine.queue_callback(flush_buckets)
self.param_state[self.record.index(param_i)] = 1 self.comm_ready_buckets(self.record.index(param_i))
self.comm_ready_buckets()
if param.requires_grad: if param.requires_grad:
...@@ -128,9 +128,14 @@ class DistributedDataParallel(Module): ...@@ -128,9 +128,14 @@ class DistributedDataParallel(Module):
wrapper(param_i) wrapper(param_i)
def comm_ready_buckets(self): def comm_ready_buckets(self, param_ind):
if self.param_state[param_ind] != 0:
raise RuntimeError("Error: Your model uses shared parameters, DDP flag shared_params must be set to True in initialization.")
if self.param_state[self.ready_end] == 0: if self.param_state[self.ready_end] == 0:
self.param_state[param_ind] = 1
return return
...@@ -141,6 +146,7 @@ class DistributedDataParallel(Module): ...@@ -141,6 +146,7 @@ class DistributedDataParallel(Module):
if self.ready_numel < self.message_size: if self.ready_numel < self.message_size:
self.param_state[param_ind] = 1
return return
grads = [param.grad.data for param in self.ready_params] grads = [param.grad.data for param in self.ready_params]
...@@ -167,13 +173,20 @@ class DistributedDataParallel(Module): ...@@ -167,13 +173,20 @@ class DistributedDataParallel(Module):
for i in range(self.ready_start, self.ready_start+len(bucket)): for i in range(self.ready_start, self.ready_start+len(bucket)):
self.param_state[i] = 2 self.param_state[i] = 2
self.ready_params.pop(0) self.ready_params.pop(0)
self.param_state[param_ind] = 1
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
param_list = [param for param in list(self.module.parameters()) if param.requires_grad] param_list = [param for param in list(self.module.parameters()) if param.requires_grad]
#Force needs_refresh to True if there are shared params
#this will force it to always, only call flush_buckets which is safe
#for shared parameters in the model.
if self.shared_param:
self.param_refs = []
self.needs_refresh = True if not self.param_refs else any( self.needs_refresh = True if not self.param_refs else any(
[param1 is not param2 for param1, param2 in zip(param_list, self.param_refs)] [param1 is not param2 for param1, param2 in zip(param_list, self.param_refs)]
) )
......
import torch
import torch.distributed as dist
from torch.nn import Parameter
from torch.nn import Module
from apex.parallel import DistributedDataParallel as DDP
import argparse
parser = argparse.ArgumentParser(description='allreduce hook example')
parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
help='url used to set up distributed training')
parser.add_argument('--dist-backend', default='nccl', type=str,
help='distributed backend')
parser.add_argument('--world-size', default=1, type=int,
help='Number of GPUs to use. Can either be manually set ' +
'or automatically set by using \'python -m multiproc\'.')
parser.add_argument('--rank', default=0, type=int,
help='Used for multi-process training. Can either be manually set ' +
'or automatically set by using \'python -m multiproc\'.')
args = parser.parse_args()
args.distributed = args.world_size > 1
if args.distributed:
torch.cuda.set_device(args.rank % torch.cuda.device_count())
dist.init_process_group(args.dist_backend, init_method=args.dist_url,
world_size=args.world_size)
rank = torch.distributed.get_rank()
torch.set_printoptions(precision=10)
class Model(Module):
def __init__(self):
super(Model, self).__init__()
self.x = Parameter(torch.cuda.FloatTensor(1,4096*4096).fill_(1.0))
def forward(self, input):
return self.x*input
model = DDP(Model(), message_size=1)
z = torch.cuda.FloatTensor(4096*4096)
for i in range(10):
z.fill_(i + rank) # fill z with new values every iteration for sanity
model.zero_grad()
out = model(z)
loss = out.sum()
torch.cuda.nvtx.range_push("backward")
loss.backward()
torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_push("synchronize() + sum")
torch.cuda.synchronize()
for param in model.parameters():
print("i = {},\n"
"param.grad.data_ptr() = {}\n"
"expected {},\n"
" got {}\n"
.format(i,
param.grad.data_ptr(),
4096*4096*(2.*i+1)/2.,
param.grad.data.sum().item()))
torch.cuda.nvtx.range_pop()
#!/bin/bash
CUDA_VISIBLE_DEVICES=0,1 python -m apex.parallel.multiproc ddp_race_condition.py
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment