".github/vscode:/vscode.git/clone" did not exist on "4e74ec09a8a8ba55091fcc8c10ebcdbc37497d31"
Commit 7c2ae41e authored by Christian Sarofeen's avatar Christian Sarofeen
Browse files

Fix race condition in DDP.

parent 531cb5c3
......@@ -48,14 +48,15 @@ class DistributedDataParallel(Module):
Args:
module: Network definition to be run in multi-gpu/distributed mode.
message_size (Default = 100e6): Minimum number of elements in a communication bucket.
shared_param (Default = False): If your model uses shared parameters this must be true,
it will disable bucketing of parameters which is necessary to avoid race conditions.
"""
def __init__(self, module, message_size=100000000):
def __init__(self, module, message_size=100000000, shared_param=False):
super(DistributedDataParallel, self).__init__()
self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False
self.shared_param = shared_param
self.message_size = message_size
#reference to last iterations parameters to see if anything has changed
......@@ -119,8 +120,7 @@ class DistributedDataParallel(Module):
Variable._execution_engine.queue_callback(allreduce_params)
else:
Variable._execution_engine.queue_callback(flush_buckets)
self.param_state[self.record.index(param_i)] = 1
self.comm_ready_buckets()
self.comm_ready_buckets(self.record.index(param_i))
if param.requires_grad:
......@@ -128,9 +128,14 @@ class DistributedDataParallel(Module):
wrapper(param_i)
def comm_ready_buckets(self):
def comm_ready_buckets(self, param_ind):
if self.param_state[param_ind] != 0:
raise RuntimeError("Error: Your model uses shared parameters, DDP flag shared_params must be set to True in initialization.")
if self.param_state[self.ready_end] == 0:
self.param_state[param_ind] = 1
return
......@@ -141,6 +146,7 @@ class DistributedDataParallel(Module):
if self.ready_numel < self.message_size:
self.param_state[param_ind] = 1
return
grads = [param.grad.data for param in self.ready_params]
......@@ -168,11 +174,18 @@ class DistributedDataParallel(Module):
self.param_state[i] = 2
self.ready_params.pop(0)
self.param_state[param_ind] = 1
def forward(self, *inputs, **kwargs):
param_list = [param for param in list(self.module.parameters()) if param.requires_grad]
#Force needs_refresh to True if there are shared params
#this will force it to always, only call flush_buckets which is safe
#for shared parameters in the model.
if self.shared_param:
self.param_refs = []
self.needs_refresh = True if not self.param_refs else any(
[param1 is not param2 for param1, param2 in zip(param_list, self.param_refs)]
......
import torch
import torch.distributed as dist
from torch.nn import Parameter
from torch.nn import Module
from apex.parallel import DistributedDataParallel as DDP
import argparse
parser = argparse.ArgumentParser(description='allreduce hook example')
parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
help='url used to set up distributed training')
parser.add_argument('--dist-backend', default='nccl', type=str,
help='distributed backend')
parser.add_argument('--world-size', default=1, type=int,
help='Number of GPUs to use. Can either be manually set ' +
'or automatically set by using \'python -m multiproc\'.')
parser.add_argument('--rank', default=0, type=int,
help='Used for multi-process training. Can either be manually set ' +
'or automatically set by using \'python -m multiproc\'.')
args = parser.parse_args()
args.distributed = args.world_size > 1
if args.distributed:
torch.cuda.set_device(args.rank % torch.cuda.device_count())
dist.init_process_group(args.dist_backend, init_method=args.dist_url,
world_size=args.world_size)
rank = torch.distributed.get_rank()
torch.set_printoptions(precision=10)
class Model(Module):
def __init__(self):
super(Model, self).__init__()
self.x = Parameter(torch.cuda.FloatTensor(1,4096*4096).fill_(1.0))
def forward(self, input):
return self.x*input
model = DDP(Model(), message_size=1)
z = torch.cuda.FloatTensor(4096*4096)
for i in range(10):
z.fill_(i + rank) # fill z with new values every iteration for sanity
model.zero_grad()
out = model(z)
loss = out.sum()
torch.cuda.nvtx.range_push("backward")
loss.backward()
torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_push("synchronize() + sum")
torch.cuda.synchronize()
for param in model.parameters():
print("i = {},\n"
"param.grad.data_ptr() = {}\n"
"expected {},\n"
" got {}\n"
.format(i,
param.grad.data_ptr(),
4096*4096*(2.*i+1)/2.,
param.grad.data.sum().item()))
torch.cuda.nvtx.range_pop()
#!/bin/bash
CUDA_VISIBLE_DEVICES=0,1 python -m apex.parallel.multiproc ddp_race_condition.py
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment