Commit 2361a646 authored by Michael Carilli's avatar Michael Carilli
Browse files

Updating imagenet FP16_Optimizer example for new syntax

parent e4c97f32
...@@ -128,8 +128,11 @@ def main(): ...@@ -128,8 +128,11 @@ def main():
if args.fp16: if args.fp16:
model = network_to_half(model) model = network_to_half(model)
if args.distributed: if args.distributed:
# shared param turns off bucketing in DDP, for lower latency runs this can improve perf # By default, apex.parallel.DistributedDataParallel overlaps communication with
model = DDP(model, shared_param=True) # computation in the backward pass.
# model = DDP(model)
# delay_allreduce delays all communication to the end of the backward pass.
model = DDP(model, delay_allreduce=True)
# define loss function (criterion) and optimizer # define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda() criterion = nn.CrossEntropyLoss().cuda()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment