Commit af431d28 authored by Michael Carilli's avatar Michael Carilli
Browse files

Merge changes from Christian's fork

parents d6db91a4 414dc119
......@@ -4,6 +4,7 @@ import torch.distributed as dist
from torch.nn.modules import Module
from torch.autograd import Variable
def flat_dist_call(tensors, call, extra_args=None):
flat_dist_call.warn_on_half = True
buckets = {}
......@@ -28,19 +29,17 @@ def flat_dist_call(tensors, call, extra_args=None):
call(coalesced)
if call is dist.all_reduce:
coalesced /= dist.get_world_size()
for buf, synced in zip(bucket, _unflatten_dense_tensors(coalesced, bucket)):
buf.copy_(synced)
class DistributedDataParallel(Module):
"""
:class:`DistributedDataParallel` is a simpler version of upstream :class:`
DistributedDataParallel`. Its usage is designed to be used in conjunction with
apex.parallel.multiproc.py. It assumes that your run is using multiprocess with
1 GPU/process, that the model is on the correct device, and that
torch.set_device has been used to set the device. Parameters are broadcasted
DistributedDataParallel` that is optimized for use with NCCL. Its usage is designed
to be used in conjunction with apex.parallel.multiproc.py. It assumes that your run
is using multiprocess with 1 GPU/process, that the model is on the correct device,
and that torch.set_device has been used to set the device. Parameters are broadcasted
to the other processes on initialization of DistributedDataParallel, and will be
allreduced in buckets durring the backward pass.
......@@ -48,59 +47,161 @@ class DistributedDataParallel(Module):
Args:
module: Network definition to be run in multi-gpu/distributed mode.
message_size (Default = 10000000): Minimum number of elements in a communication bucket.
message_size (Default = 10e6): Minimum number of elements in a communication bucket.
shared_param (Default = False): If your model uses shared parameters this must be true,
it will disable bucketing of parameters which is necessary to avoid race conditions.
"""
def __init__(self, module):
def __init__(self, module, message_size=10000000, shared_param=False):
super(DistributedDataParallel, self).__init__()
self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False
self.shared_param = shared_param
self.message_size = message_size
#reference to last iterations parameters to see if anything has changed
self.param_refs = []
self.reduction_stream = torch.cuda.Stream()
self.module = module
param_list = [param for param in self.module.state_dict().values() if torch.is_tensor(param)]
self.param_list = list(self.module.parameters())
if dist._backend == dist.dist_backend.NCCL:
for param in param_list:
for param in self.param_list:
assert param.is_cuda, "NCCL backend only supports model parameters to be on GPU."
#broadcast parameters
flat_dist_call(param_list, dist.broadcast, (0,) )
self.record = []
self.create_hooks()
flat_dist_call([param.data for param in self.module.parameters()], dist.broadcast, (0,) )
def create_hooks(self):
#all reduce gradient hook
def allreduce_params():
if(self.needs_reduction):
self.needs_reduction = False
self.needs_refresh = False
else:
return
grads = [param.grad.data for param in self.module.parameters() if param.grad is not None]
flat_dist_call(grads, dist.all_reduce)
t_record = torch.cuda.IntTensor(self.record)
dist.broadcast(t_record, 0)
self.record = [int(entry) for entry in t_record]
def flush_buckets():
if not self.needs_reduction:
return
self.needs_reduction = False
grads = []
for i in range(self.ready_end, len(self.param_state)):
param = self.param_refs[self.record[i]]
if param.grad is not None:
grads.append(param.grad.data)
grads = [param.grad.data for param in self.ready_params] + grads
if(len(grads)>0):
orig_stream = torch.cuda.current_stream()
with torch.cuda.stream(self.reduction_stream):
self.reduction_stream.wait_stream(orig_stream)
flat_dist_call(grads, dist.all_reduce)
torch.cuda.current_stream().wait_stream(self.reduction_stream)
for param_i, param in enumerate(list(self.module.parameters())):
def wrapper(param_i):
for param in list(self.module.parameters()):
def allreduce_hook(*unused):
if self.needs_refresh:
self.record.append(param_i)
Variable._execution_engine.queue_callback(allreduce_params)
else:
Variable._execution_engine.queue_callback(flush_buckets)
self.comm_ready_buckets(self.record.index(param_i))
if param.requires_grad:
param.register_hook(allreduce_hook)
wrapper(param_i)
def comm_ready_buckets(self, param_ind):
if self.param_state[param_ind] != 0:
raise RuntimeError("Error: Your model uses shared parameters, DDP flag shared_params must be set to True in initialization.")
if self.param_state[self.ready_end] == 0:
self.param_state[param_ind] = 1
return
while self.ready_end < len(self.param_state) and self.param_state[self.ready_end] == 1:
self.ready_params.append(self.param_refs[self.record[self.ready_end]])
self.ready_numel += self.ready_params[-1].numel()
self.ready_end += 1
if self.ready_numel < self.message_size:
self.param_state[param_ind] = 1
return
grads = [param.grad.data for param in self.ready_params]
bucket = []
bucket_inds = []
while grads:
bucket.append(grads.pop(0))
cumm_size = 0
for ten in bucket:
cumm_size += ten.numel()
if cumm_size < self.message_size:
continue
evt = torch.cuda.Event()
evt.record(torch.cuda.current_stream())
evt.wait(stream=self.reduction_stream)
with torch.cuda.stream(self.reduction_stream):
flat_dist_call(bucket, dist.all_reduce)
for i in range(self.ready_start, self.ready_start+len(bucket)):
self.param_state[i] = 2
self.ready_params.pop(0)
self.param_state[param_ind] = 1
def forward(self, *inputs, **kwargs):
param_list = [param for param in list(self.module.parameters()) if param.requires_grad]
#Force needs_refresh to True if there are shared params
#this will force it to always, only call flush_buckets which is safe
#for shared parameters in the model.
if self.shared_param:
self.param_refs = []
self.needs_refresh = True if not self.param_refs else any(
[param1 is not param2 for param1, param2 in zip(param_list, self.param_refs)]
)
if self.needs_refresh:
self.record = []
self.param_state = [0 for i in range(len(param_list))]
self.param_refs = param_list
self.needs_reduction = True
return self.module(*inputs, **kwargs)
'''
def _sync_buffers(self):
buffers = list(self.module._all_buffers())
if len(buffers) > 0:
# cross-node buffer sync
flat_buffers = _flatten_dense_tensors(buffers)
dist.broadcast(flat_buffers, 0)
for buf, synced in zip(buffers, _unflatten_dense_tensors(flat_buffers, buffers)):
buf.copy_(synced)
def train(self, mode=True):
# Clear NCCL communicator and CUDA event cache of the default group ID,
# These cache will be recreated at the later call. This is currently a
# work-around for a potential NCCL deadlock.
if dist._backend == dist.dist_backend.NCCL:
dist._clear_group_cache()
super(DistributedDataParallel, self).train(mode)
self.module.train(mode)
'''
self.ready_start = 0
self.ready_end = 0
self.ready_params = []
self.ready_numel = 0
return self.module(*inputs, **kwargs)
......@@ -16,14 +16,14 @@ import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
import numpy as np
try:
from apex.parallel import DistributedDataParallel as DDP
from apex.fp16_utils import *
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")
import numpy as np
model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name]))
......@@ -61,8 +61,8 @@ parser.add_argument('--pretrained', dest='pretrained', action='store_true',
parser.add_argument('--fp16', action='store_true',
help='Run model fp16 mode.')
parser.add_argument('--static-loss-scale', type=float, default=1,
help='Static loss scale, positive power of 2 values can improve fp16 convergence.')
parser.add_argument('--loss-scale', type=float, default=1,
help='Loss scaling, positive power of 2 values can improve fp16 convergence.')
parser.add_argument('--prof', dest='prof', action='store_true',
help='Only run 10 iterations for profiling.')
......@@ -80,6 +80,26 @@ parser.add_argument('--rank', default=0, type=int,
cudnn.benchmark = True
import numpy as np
def fast_collate(batch):
imgs = [img[0] for img in batch]
targets = torch.tensor([target[1] for target in batch], dtype=torch.int64)
w = imgs[0].size[0]
h = imgs[0].size[1]
tensor = torch.zeros( (len(imgs), 3, h, w), dtype=torch.uint8 )
for i, img in enumerate(imgs):
nump_array = np.asarray(img, dtype=np.uint8)
tens = torch.from_numpy(nump_array)
if(nump_array.ndim < 3):
nump_array = np.expand_dims(nump_array, axis=-1)
nump_array = np.rollaxis(nump_array, 2)
tensor[i] += torch.from_numpy(nump_array)
return tensor, targets
best_prec1 = 0
args = parser.parse_args()
def main():
......@@ -93,18 +113,12 @@ def main():
if args.distributed:
torch.cuda.set_device(args.gpu)
dist.init_process_group(backend=args.dist_backend,
init_method=args.dist_url,
world_size=args.world_size,
rank=args.rank)
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
if args.fp16:
assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."
if args.static_loss_scale != 1.0:
if not args.fp16:
print("Warning: if --fp16 is not used, static_loss_scale will be ignored.")
# create model
if args.pretrained:
print("=> using pre-trained model '{}'".format(args.arch))
......@@ -149,12 +163,10 @@ def main():
# Data loading code
traindir = os.path.join(args.data, 'train')
valdir = os.path.join(args.data, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
if(args.arch == "inception_v3"):
crop_size = 299
val_size = 320 # Arbitrarily chosen, adjustable.
val_size = 320 # I chose this value arbitrarily, we can adjust.
else:
crop_size = 224
val_size = 256
......@@ -164,8 +176,8 @@ def main():
transforms.Compose([
transforms.RandomResizedCrop(crop_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
#transforms.ToTensor(), Too slow
#normalize,
]))
if args.distributed:
......@@ -175,7 +187,11 @@ def main():
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_sampler)
num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=fast_collate)
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
val_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(valdir, transforms.Compose([
......@@ -215,10 +231,22 @@ def main():
'optimizer' : optimizer.state_dict(),
}, is_best)
# item() is a recent addition, so this helps with backward compatibility.
def to_python_float(t):
if hasattr(t, 'item'):
return t.item()
else:
return t[0]
class data_prefetcher():
def __init__(self, loader):
self.loader = iter(loader)
self.stream = torch.cuda.Stream()
self.mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(1,3,1,1)
self.std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(1,3,1,1)
if args.fp16:
self.mean = self.mean.half()
self.std = self.std.half()
self.preload()
def preload(self):
......@@ -231,6 +259,11 @@ class data_prefetcher():
with torch.cuda.stream(self.stream):
self.next_input = self.next_input.cuda(async=True)
self.next_target = self.next_target.cuda(async=True)
if args.fp16:
self.next_input = self.next_input.half()
else:
self.next_input = self.next_input.float()
self.next_input = self.next_input.sub_(self.mean).div_(self.std)
def next(self):
torch.cuda.current_stream().wait_stream(self.stream)
......@@ -284,15 +317,15 @@ def train(train_loader, model, criterion, optimizer, epoch):
top1.update(to_python_float(prec1), input.size(0))
top5.update(to_python_float(prec5), input.size(0))
loss = loss*args.loss_scale
# compute gradient and do SGD step
if args.fp16:
loss = loss*args.static_loss_scale
model.zero_grad()
loss.backward()
model_grads_to_master_grads(model_params, master_params)
if args.static_loss_scale != 1:
if args.loss_scale != 1:
for param in master_params:
param.grad.data = param.grad.data/args.static_loss_scale
param.grad.data = param.grad.data/args.loss_scale
optimizer.step()
master_params_to_model_params(model_params, master_params)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment