Unverified Commit 378ce1e1 authored by mcarilli's avatar mcarilli Committed by GitHub
Browse files

Merge pull request #9 from NVIDIA/imagenet_fix

DDP fix, imagenet speedup
parents fb075b86 06ee98c2
...@@ -79,17 +79,19 @@ class DistributedDataParallel(Module): ...@@ -79,17 +79,19 @@ class DistributedDataParallel(Module):
def create_hooks(self): def create_hooks(self):
#all reduce gradient hook #all reduce gradient hook
def allreduce_params(): def allreduce_params():
if(self.needs_reduction): if not self.needs_reduction:
self.needs_reduction = False
self.needs_refresh = False
else:
return return
grads = [param.grad.data for param in self.module.parameters() if param.grad is not None] self.needs_reduction = False
flat_dist_call(grads, dist.all_reduce)
#parameter ordering refresh
if self.needs_refresh and not self.shared_param:
t_record = torch.cuda.IntTensor(self.record) t_record = torch.cuda.IntTensor(self.record)
dist.broadcast(t_record, 0) dist.broadcast(t_record, 0)
self.record = [int(entry) for entry in t_record] self.record = [int(entry) for entry in t_record]
self.needs_refresh = False
grads = [param.grad.data for param in self.module.parameters() if param.grad is not None]
flat_dist_call(grads, dist.all_reduce)
def flush_buckets(): def flush_buckets():
if not self.needs_reduction: if not self.needs_reduction:
...@@ -184,10 +186,10 @@ class DistributedDataParallel(Module): ...@@ -184,10 +186,10 @@ class DistributedDataParallel(Module):
#Force needs_refresh to True if there are shared params #Force needs_refresh to True if there are shared params
#this will force it to always, only call flush_buckets which is safe #this will force it to always, only call flush_buckets which is safe
#for shared parameters in the model. #for shared parameters in the model.
if self.shared_param: if not self.param_refs or self.shared_param:
self.param_refs = [] self.needs_refresh = True
else:
self.needs_refresh = True if not self.param_refs else any( self.needs_refresh = any(
[param1 is not param2 for param1, param2 in zip(param_list, self.param_refs)] [param1 is not param2 for param1, param2 in zip(param_list, self.param_refs)]
) )
......
...@@ -131,7 +131,8 @@ def main(): ...@@ -131,7 +131,8 @@ def main():
if args.fp16: if args.fp16:
model = network_to_half(model) model = network_to_half(model)
if args.distributed: if args.distributed:
model = DDP(model) #shared param turns off bucketing in DDP, for lower latency runs this can improve perf
model = DDP(model, shared_param=True)
global model_params, master_params global model_params, master_params
if args.fp16: if args.fp16:
...@@ -189,19 +190,14 @@ def main(): ...@@ -189,19 +190,14 @@ def main():
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=fast_collate) num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=fast_collate)
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
val_loader = torch.utils.data.DataLoader( val_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(valdir, transforms.Compose([ datasets.ImageFolder(valdir, transforms.Compose([
transforms.Resize(val_size), transforms.Resize(val_size),
transforms.CenterCrop(crop_size), transforms.CenterCrop(crop_size),
transforms.ToTensor(),
normalize,
])), ])),
batch_size=args.batch_size, shuffle=False, batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True) num_workers=args.workers, pin_memory=True,
collate_fn=fast_collate)
if args.evaluate: if args.evaluate:
validate(val_loader, model, criterion) validate(val_loader, model, criterion)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment