Unverified Commit fa183ee8 authored by mcarilli's avatar mcarilli Committed by GitHub
Browse files

Efficient bucketing (#49)

* beautiful

* IT'S WORKING

* Hopefully fix race condition for fallback hook

* Updating test

* shared_param -> delayed_allreduce

* Adding a safety check

* One more check

* syntax...
parent 53e1b61a
...@@ -4,8 +4,26 @@ import torch.distributed as dist ...@@ -4,8 +4,26 @@ import torch.distributed as dist
from torch.nn.modules import Module from torch.nn.modules import Module
from torch.autograd import Variable from torch.autograd import Variable
from collections import OrderedDict from collections import OrderedDict
from itertools import chain
import copy import copy
# apply_dist_call requires that tensors in 'bucket' are all the same type.
def apply_flat_dist_call(bucket, call, extra_args=None):
coalesced = _flatten_dense_tensors(bucket)
if extra_args is not None:
call(coalesced, *extra_args)
else:
call(coalesced)
if call is dist.all_reduce:
coalesced /= dist.get_world_size()
for buf, synced in zip(bucket, _unflatten_dense_tensors(coalesced, bucket)):
buf.copy_(synced)
# flat_dist_call organizes 'tensors' by type.
def flat_dist_call(tensors, call, extra_args=None): def flat_dist_call(tensors, call, extra_args=None):
flat_dist_call.warn_on_half = True flat_dist_call.warn_on_half = True
buckets = OrderedDict() buckets = OrderedDict()
...@@ -15,27 +33,11 @@ def flat_dist_call(tensors, call, extra_args=None): ...@@ -15,27 +33,11 @@ def flat_dist_call(tensors, call, extra_args=None):
buckets[tp] = [] buckets[tp] = []
buckets[tp].append(tensor) buckets[tp].append(tensor)
if flat_dist_call.warn_on_half:
if torch.cuda.HalfTensor in buckets:
print("WARNING: gloo dist backend for half parameters may be extremely slow." +
" It is recommended to use the NCCL backend in this case.")
flat_dist_call.warn_on_half = False
for tp in buckets: for tp in buckets:
bucket = buckets[tp] bucket = buckets[tp]
coalesced = _flatten_dense_tensors(bucket) apply_flat_dist_call(bucket, call, extra_args)
if extra_args is not None:
call(coalesced, *extra_args)
else:
call(coalesced)
if call is dist.all_reduce:
coalesced /= dist.get_world_size()
for buf, synced in zip(bucket, _unflatten_dense_tensors(coalesced, bucket)):
buf.copy_(synced)
def extract_tensors(maybe_tensor, tensor_list): def extract_tensors(maybe_tensor, tensor_list):
if torch.is_tensor(maybe_tensor): if torch.is_tensor(maybe_tensor):
tensor_list.append(maybe_tensor) tensor_list.append(maybe_tensor)
...@@ -123,11 +125,11 @@ class DistributedDataParallel(Module): ...@@ -123,11 +125,11 @@ class DistributedDataParallel(Module):
Args: Args:
module: Network definition to be run in multi-gpu/distributed mode. module: Network definition to be run in multi-gpu/distributed mode.
message_size (Default = 1e7): Minimum number of elements in a communication bucket. message_size (Default = 1e7): Minimum number of elements in a communication bucket.
shared_param (Default = False): If your model uses shared parameters this must be True. It will disable bucketing of parameters to avoid race conditions. delay_allreduce (Default = False): Delay all communication to the end of the backward pass. This disables overlapping communication with computation.
""" """
def __init__(self, module, message_size=10000000, shared_param=False): def __init__(self, module, message_size=10000000, delay_allreduce=False, shared_param=None):
super(DistributedDataParallel, self).__init__() super(DistributedDataParallel, self).__init__()
# Backward/forward compatibility around # Backward/forward compatibility around
...@@ -140,14 +142,13 @@ class DistributedDataParallel(Module): ...@@ -140,14 +142,13 @@ class DistributedDataParallel(Module):
self.backend_enum_holder = dist.dist_backend self.backend_enum_holder = dist.dist_backend
self.warn_on_half = True if self._backend == self.backend_enum_holder.GLOO else False self.warn_on_half = True if self._backend == self.backend_enum_holder.GLOO else False
self.shared_param = shared_param
if shared_param is not None:
raise ValueError("shared_param is no longer supported as an option. It was misleadingly named from the start. It turns out overlapping communication with computation should work fine with shared parameters. If you still wish to delay communication to the end of the backward pass, use delay_allreduce=True|False instead.")
self.delay_allreduce = delay_allreduce
self.message_size = message_size self.message_size = message_size
# Will hold [param for param in self.module.parameters() if param.requires_grad]
# aka, the active paramters this iteration. The ordering of this list will be
# the same across all processes.
self.active_params = []
self.reduction_stream = torch.cuda.Stream() self.reduction_stream = torch.cuda.Stream()
self.module = module self.module = module
...@@ -155,8 +156,13 @@ class DistributedDataParallel(Module): ...@@ -155,8 +156,13 @@ class DistributedDataParallel(Module):
if self._backend == self.backend_enum_holder.NCCL: if self._backend == self.backend_enum_holder.NCCL:
for param in self.module.parameters(): for param in self.module.parameters():
assert param.is_cuda, "NCCL backend only supports model parameters to be on GPU." assert param.is_cuda, "NCCL backend only supports model parameters to be on GPU."
self.record = [] self.active_params = []
self.param_type_to_tmp_i = {"torch.cuda.HalfTensor" : 0,
"torch.cuda.FloatTensor" : 1,
"torch.cuda.DoubleTensor" : 2}
self.create_hooks() self.create_hooks()
flat_dist_call([param.data for param in self.module.parameters()], dist.broadcast, (0,) ) flat_dist_call([param.data for param in self.module.parameters()], dist.broadcast, (0,) )
...@@ -172,141 +178,180 @@ class DistributedDataParallel(Module): ...@@ -172,141 +178,180 @@ class DistributedDataParallel(Module):
if self._backend != self.backend_enum_holder.NCCL: if self._backend != self.backend_enum_holder.NCCL:
del attrs['self.reduction_stream'] del attrs['self.reduction_stream']
return attrs return attrs
# Broadcast rank 0's bucket structure across all processes, and have all processes
# regenerate their bucket structures to match.
def sync_bucket_structure(self):
# Append leftover buckets
for tmp_bucket in self.tmp_buckets:
if len(tmp_bucket) > 0:
self.buckets.append(tmp_bucket)
self.num_buckets = len(self.buckets)
self.bucket_sizes = [len(bucket) for bucket in self.buckets]
info_tensor = torch.cuda.IntTensor([self.num_buckets] +
self.bucket_sizes +
list(chain(*self.buckets)))
dist.broadcast(info_tensor, 0)
info = [int(entry) for entry in info_tensor]
self.num_buckets = info[0]
self.bucket_sizes = info[1:self.num_buckets + 1]
self.buckets = [[None for _ in range(self.bucket_sizes[i])] for i in range(self.num_buckets)]
flattened_buckets = info[self.num_buckets + 1:]
flat_i = 0
for bucket_idx in range(self.num_buckets):
for bucket_loc in range(self.bucket_sizes[bucket_idx]):
param_i = flattened_buckets[flat_i]
self.param_id_to_bucket[id(self.active_params[param_i])] = (bucket_idx, bucket_loc)
flat_i += 1
def create_hooks(self): def create_hooks(self):
# all reduce gradient hook # Fallback hook that's only called at the end of backward.
# Used if you deliberately want to delay allreduces to the end, or to refresh the
# bucket structure that will be used to overlap communication with computation in later
# iterations.
def allreduce_params(): def allreduce_params():
if not self.needs_reduction: # Bucket record refresh
return if not self.delay_allreduce:
self.needs_reduction = False if self.needs_refresh:
self.sync_bucket_structure()
# parameter ordering refresh
if self.needs_refresh and not self.shared_param: self.needs_refresh = False
t_record = torch.cuda.IntTensor(self.record)
dist.broadcast(t_record, 0)
self.record = [int(entry) for entry in t_record]
# As before, self.record stores a list of indexes into self.active_params.
# param_id_to_record_i is a map from each active param's id to its slot in
# self.record.
self.param_id_to_record_i = {id(self.active_params[a]) : i
for i, a in enumerate(self.record)}
self.needs_refresh = False
grads = [param.grad.data for param in self.module.parameters() if param.grad is not None] grads = [param.grad.data for param in self.module.parameters() if param.grad is not None]
flat_dist_call(grads, dist.all_reduce) flat_dist_call(grads, dist.all_reduce)
def flush_buckets(): def overlapping_backward_epilogue():
if not self.needs_reduction:
return
self.needs_reduction = False
grads = []
for i in range(self.ready_end, len(self.param_state)):
param = self.active_params[self.record[i]]
if param.grad is not None:
grads.append(param.grad.data)
grads = [param.grad.data for param in self.ready_params] + grads
if(len(grads)>0):
orig_stream = torch.cuda.current_stream()
with torch.cuda.stream(self.reduction_stream):
self.reduction_stream.wait_stream(orig_stream)
flat_dist_call(grads, dist.all_reduce)
torch.cuda.current_stream().wait_stream(self.reduction_stream) torch.cuda.current_stream().wait_stream(self.reduction_stream)
# Sanity checks that all the buckets were kicked off
if self.next_bucket != self.num_buckets:
raise RuntimeError("In epilogue, next_bucket != num_buckets. "
"This probably indicates some buckets were not allreduced.")
for actual, expected in zip(self.buckets_ready_size, self.bucket_sizes):
if actual != expected:
raise RuntimeError("Some param buckets were not allreduced.")
self.grad_accs = []
for param in self.module.parameters(): for param in self.module.parameters():
if param.requires_grad: if param.requires_grad:
def wrapper(param): def wrapper(param):
param_tmp = param.expand_as(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0]
def allreduce_hook(*unused): def allreduce_hook(*unused):
if self.needs_refresh: if self.delay_allreduce or self.needs_refresh:
self.record.append(self.param_id_to_active_i[id(param)]) # TODO: How do we want to handle multiple backward passes between
Variable._execution_engine.queue_callback(allreduce_params) # each forward, e.g., backward passes with retain_graph=True?
# needs_refresh and callback_queued are both vulnerable states.
if not self.delay_allreduce and self.needs_refresh:
# Use the backward pass to build the bucket structure on the fly.
active_i = self.param_id_to_active_i[id(param)]
# Float, half, and double tensors are grouped into buckets separately.
current_type = self.param_type_to_tmp_i[param.type()]
self.tmp_buckets[current_type].append(active_i)
self.tmp_numels[current_type] += param.numel()
if self.tmp_numels[current_type] >= self.message_size:
self.buckets.append(self.tmp_buckets[current_type])
self.tmp_buckets[current_type] = []
self.tmp_numels[current_type] = 0
if not self.callback_queued:
Variable._execution_engine.queue_callback(allreduce_params)
self.callback_queued = True
else: else:
Variable._execution_engine.queue_callback(flush_buckets) if not self.callback_queued:
# param_id_to_record_i handily enables us to replace the Variable._execution_engine.queue_callback(overlapping_backward_epilogue)
# O(N) self.record.index(param_i) call with an O(1) dict lookup. self.callback_queued = True
self.comm_ready_buckets(self.param_id_to_record_i[id(param)])
self.comm_ready_buckets(param)
param.register_hook(allreduce_hook) grad_acc.register_hook(allreduce_hook)
self.grad_accs.append(grad_acc)
wrapper(param) wrapper(param)
def comm_ready_buckets(self, record_i): def comm_ready_buckets(self, param):
# Need to do this in every hook for compatibility with Ruberry's streaming backward PR.
if self.param_state[record_i] != 0: # self.reduction_stream.wait_stream(torch.cuda.current_stream())
raise RuntimeError("Error: Your model uses shared parameters, DDP flag shared_params must be set to True in initialization.")
if self.param_state[self.ready_end] == 0:
self.param_state[record_i] = 1
return
while self.ready_end < len(self.param_state) and self.param_state[self.ready_end] == 1:
self.ready_params.append(self.active_params[self.record[self.ready_end]])
self.ready_numel += self.ready_params[-1].numel()
self.ready_end += 1
if self.ready_numel < self.message_size: bucket_idx, bucket_loc = self.param_id_to_bucket[id(param)]
self.param_state[record_i] = 1
return
grads = [param.grad.data for param in self.ready_params]
bucket = [] if self.buckets[bucket_idx][bucket_loc] is not None:
bucket_inds = [] raise RuntimeError("The backward pass is attempting to replace an already-filled "
while grads: "bucket slot. This is almost certainly an error.")
bucket.append(grads.pop(0))
cumm_size = 0
for ten in bucket:
cumm_size += ten.numel()
if cumm_size < self.message_size: self.buckets[bucket_idx][bucket_loc] = param.grad.data
continue self.buckets_ready_size[bucket_idx] += 1
evt = torch.cuda.Event() if self.buckets_ready_size[bucket_idx] == self.bucket_sizes[bucket_idx]:
evt.record(torch.cuda.current_stream()) if bucket_idx == self.next_bucket:
evt.wait(stream=self.reduction_stream) self.reduction_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.reduction_stream):
with torch.cuda.stream(self.reduction_stream): apply_flat_dist_call(self.buckets[bucket_idx], dist.all_reduce)
flat_dist_call(bucket, dist.all_reduce)
self.next_bucket += 1
for i in range(self.ready_start, self.ready_start+len(bucket)):
self.param_state[i] = 2 # Reversing upstream's logic here, because we constructed our buckets based on
self.ready_params.pop(0) # the order things were received during backward.
if len(self.ready_buckets_not_reduced) > 0:
self.param_state[record_i] = 1 sorted_todo = sorted(self.ready_buckets_not_reduced)
for i in sorted_todo:
# Nothing can be reduced now
if i > self.next_bucket:
break
elif i == self.next_bucket:
apply_flat_dist_call(self.buckets[i], dist.all_reduce)
self.ready_buckets_not_reduced.remove(i)
self.next_bucket += 1
else:
raise ValueError("i should always be <= next_bucket")
else:
self.ready_buckets_not_reduced.add(bucket_idx)
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
if not self.delay_allreduce:
param_list = [param for param in self.module.parameters() if param.requires_grad] param_list = [param for param in self.module.parameters() if param.requires_grad]
# Conditions under which to refresh self.record # Conditions under which to refresh self.record
# Forward has the authority to set needs_refresh to True, but only allreduce_params # Forward has the authority to set needs_refresh to True, but only allreduce_params
# in backward has the authority to set needs_refresh to False. # in backward has the authority to set needs_refresh to False.
# Parentheses are not necessary for correct order of operations, but make the intent clearer. # Parentheses are not necessary for correct order of operations, but make the intent clearer.
if ( (not self.active_params) or if ((not self.active_params) or
self.shared_param or (len(param_list) != len(self.active_params)) or
(len(param_list) != len(self.active_params)) or any([param1 is not param2 for param1, param2 in zip(param_list, self.active_params)])):
any([param1 is not param2 for param1, param2 in zip(param_list, self.active_params)]) ): self.needs_refresh = True
self.needs_refresh = True
if self.needs_refresh:
if self.needs_refresh: self.buckets = []
self.record = [] self.tmp_buckets = [[], [], []] # [running half, float, double buckets]
# Map from each param's id to its index in the list of active parameters. self.tmp_numels = [0, 0, 0]
self.param_id_to_active_i = {id(param) : i for i, param in enumerate(param_list)} self.bucket_sizes = []
self.param_id_to_active_i = {id(param) : i for i, param in enumerate(param_list)}
self.param_id_to_bucket = {}
else:
self.buckets = [[None for _ in range(self.bucket_sizes[i])]
for i in range(self.num_buckets)]
self.buckets_ready_size = [0 for i in range(self.num_buckets)]
self.next_bucket = 0
self.ready_buckets_not_reduced = set()
self.param_state = [0 for i in range(len(param_list))] self.active_params = param_list
self.active_params = param_list
self.needs_reduction = True self.callback_queued = False
self.ready_start = 0
self.ready_end = 0
self.ready_params = []
self.ready_numel = 0
return self.module(*inputs, **kwargs) return self.module(*inputs, **kwargs)
...@@ -89,6 +89,7 @@ def fast_collate(batch): ...@@ -89,6 +89,7 @@ def fast_collate(batch):
best_prec1 = 0 best_prec1 = 0
args = parser.parse_args() args = parser.parse_args()
def main(): def main():
global best_prec1, args global best_prec1, args
...@@ -121,8 +122,11 @@ def main(): ...@@ -121,8 +122,11 @@ def main():
if args.fp16: if args.fp16:
model = network_to_half(model) model = network_to_half(model)
if args.distributed: if args.distributed:
# shared param turns off bucketing in DDP, for lower latency runs this can improve perf # By default, apex.parallel.DistributedDataParallel overlaps communication with
model = DDP(model, shared_param=True) # computation in the backward pass.
# model = DDP(model)
# delay_allreduce delays all communication to the end of the backward pass.
model = DDP(model, delay_allreduce=True)
global model_params, master_params global model_params, master_params
if args.fp16: if args.fp16:
......
...@@ -4,32 +4,26 @@ from torch.nn import Parameter ...@@ -4,32 +4,26 @@ from torch.nn import Parameter
from torch.nn import Module from torch.nn import Module
from apex.parallel import DistributedDataParallel as DDP from apex.parallel import DistributedDataParallel as DDP
import argparse import argparse
import os
parser = argparse.ArgumentParser(description='allreduce hook example') parser = argparse.ArgumentParser(description='allreduce hook example')
parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, parser.add_argument("--local_rank", default=0, type=int)
help='url used to set up distributed training')
parser.add_argument('--dist-backend', default='nccl', type=str,
help='distributed backend')
parser.add_argument('--world-size', default=1, type=int,
help='Number of GPUs to use. Can either be manually set ' +
'or automatically set by using \'python -m multiproc\'.')
parser.add_argument('--rank', default=0, type=int,
help='Used for multi-process training. Can either be manually set ' +
'or automatically set by using \'python -m multiproc\'.')
args = parser.parse_args() args = parser.parse_args()
args.distributed = args.world_size > 1 args.distributed = False
if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1
if args.distributed: if args.distributed:
torch.cuda.set_device(args.rank % torch.cuda.device_count()) args.gpu = args.local_rank % torch.cuda.device_count()
dist.init_process_group(args.dist_backend, torch.cuda.set_device(args.gpu)
init_method=args.dist_url, torch.distributed.init_process_group(backend='nccl',
world_size=args.world_size, init_method='env://')
rank=args.rank) args.world_size = torch.distributed.get_world_size()
torch.set_printoptions(precision=10) torch.set_printoptions(precision=10)
torch.manual_seed(args.local_rank)
class Model(Module): class Model(Module):
def __init__(self): def __init__(self):
...@@ -44,7 +38,7 @@ model = DDP(Model(), message_size=1) ...@@ -44,7 +38,7 @@ model = DDP(Model(), message_size=1)
x = torch.cuda.FloatTensor(4096*4096) x = torch.cuda.FloatTensor(4096*4096)
for i in range(10): for i in range(10):
x.fill_(i + args.rank) # fill x with new values every iteration for sanity x.fill_(i + args.local_rank) # fill x with new values every iteration for sanity
model.zero_grad() model.zero_grad()
out = model(x) out = model(x)
loss = out.sum() loss = out.sum()
...@@ -53,7 +47,7 @@ for i in range(10): ...@@ -53,7 +47,7 @@ for i in range(10):
torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_push("synchronize() + info") torch.cuda.nvtx.range_push("synchronize() + info")
torch.cuda.synchronize() # torch.cuda.synchronize()
print("i = {}".format(i)) print("i = {}".format(i))
def info(name, param, val): def info(name, param, val):
print(name+": grad.data_ptr() = {}, expected sum {}, got {}".format( print(name+": grad.data_ptr() = {}, expected sum {}, got {}".format(
......
#!/bin/bash #!/bin/bash
CUDA_VISIBLE_DEVICES=0,1 python -m apex.parallel.multiproc ddp_race_condition.py CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 ddp_race_condition_test.py
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment