Commit 37a4b221 authored by Michael Carilli's avatar Michael Carilli
Browse files

Cleaning up git weirdness + updating docs for Reducer

parent f09fb4f4
...@@ -49,20 +49,33 @@ def extract_tensors(maybe_tensor, tensor_list): ...@@ -49,20 +49,33 @@ def extract_tensors(maybe_tensor, tensor_list):
class Reducer(object): class Reducer(object):
""" """
:class:`apex.parallel.Reducer` is a simple class that helps reduce a module parameters. :class:`apex.parallel.Reducer` is a simple class that helps allreduce a module's parameters
This class will not automatically reduce parameters in a module for the user, but it will across processes. :class:`Reducer` is intended to give the user additional control:
allow the user to call Reducer(module).reduce() which will immediately reduce all parameters. Unlike :class:`DistributedDataParallel`, :class:`Reducer` will not automatically allreduce
:class:`apex.parallel.Reducer` is designed to work with parameters during ``backward()``.
the launch utility script ``apex.parallel.multiproc.py`` or the launch utility script Instead, :class:`Reducer` waits for the user to call `<reducer_instance>.reduce()` manually.
This enables, for example, delaying the allreduce to be carried out every
several iterations instead of every single iteration.
Like :class:`DistributedDataParallel`, :class:`Reducer` averages any tensors it allreduces
over the number of participating processes.
:class:`Reducer` is designed to work with the launch utility script
``apex.parallel.multiproc.py`` or the upstream launch utility script
``torch.distributed.launch`` with --nproc_per_node <= the number of gpus per node. ``torch.distributed.launch`` with --nproc_per_node <= the number of gpus per node.
When used with these luanchers, :class:`apex.parallel.multiproc.py` When used with these launchers, :class:`apex.parallel.multiproc.py`
assumes 1:1 mapping of processes to GPUs. assumes 1:1 mapping of processes to GPUs.
Args:
module_or_grads_list: Either a network definition being run in multi-gpu/distributed mode.
Or an iterable of gradients to be reduced. If a list of gradients are passed in, user must
manually sync parameters with broadcast or another means. If module is passed in, this parameters
will be broadcasted from rank 0.
main_reducer.py in https://github.com/NVIDIA/apex/tree/master/examples/imagenet shows example usage.
Args:
module_or_grads_list: Either a network definition (module) being run in
multi-gpu/distributed mode, or an iterable of gradients to be reduced.
If a module is passed in, the Reducer constructor will sync the parameters across processes
(broadcasting from rank 0) to make sure they're all initialized with the same values.
If a list of gradients (that came from some module)
is passed in, the user is responsible for manually syncing that module's parameters
at the beginning of training.
""" """
def __init__(self, module_or_grads_list): def __init__(self, module_or_grads_list):
......
...@@ -12,3 +12,6 @@ apex.parallel ...@@ -12,3 +12,6 @@ apex.parallel
.. autoclass:: DistributedDataParallel .. autoclass:: DistributedDataParallel
:members: :members:
.. autoclass:: Reducer
:members:
...@@ -7,6 +7,9 @@ It implements training of popular model architectures, such as ResNet, AlexNet, ...@@ -7,6 +7,9 @@ It implements training of popular model architectures, such as ResNet, AlexNet,
`main_fp16_optimizer.py` with `--fp16` demonstrates use of `apex.fp16_utils.FP16_Optimizer` to automatically manage master parameters and loss scaling. `main_fp16_optimizer.py` with `--fp16` demonstrates use of `apex.fp16_utils.FP16_Optimizer` to automatically manage master parameters and loss scaling.
`apex.parallel.DistributedDataParallel` automatically allreduces and averages gradients during `backward()`. If you wish to control the allreduce manually instead (for example, to carry out the allreduce every few iterations instead of every iteration), [apex.parallel.reduce](https://nvidia.github.io/apex/parallel.html#apex.parallel.Reducer) provides a convenient wrapper.
`main_reducer.py` is identical to `main.py`, except that it shows the use of `Reducer` instead of `DistributedDataParallel`.
## Requirements ## Requirements
- `pip install -r requirements.txt` - `pip install -r requirements.txt`
......
...@@ -19,7 +19,7 @@ import torchvision.models as models ...@@ -19,7 +19,7 @@ import torchvision.models as models
import numpy as np import numpy as np
try: try:
from apex.parallel import Reducer as DDP from apex.parallel import DistributedDataParallel as DDP
from apex.fp16_utils import * from apex.fp16_utils import *
except ImportError: except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.") raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")
...@@ -122,8 +122,7 @@ def main(): ...@@ -122,8 +122,7 @@ def main():
model = network_to_half(model) model = network_to_half(model)
if args.distributed: if args.distributed:
# shared param turns off bucketing in DDP, for lower latency runs this can improve perf # shared param turns off bucketing in DDP, for lower latency runs this can improve perf
global reducer model = DDP(model)
reducer = DDP(model)
global model_params, master_params global model_params, master_params
if args.fp16: if args.fp16:
...@@ -308,7 +307,6 @@ def train(train_loader, model, criterion, optimizer, epoch): ...@@ -308,7 +307,6 @@ def train(train_loader, model, criterion, optimizer, epoch):
if args.fp16: if args.fp16:
model.zero_grad() model.zero_grad()
loss.backward() loss.backward()
reducer.reduce()
model_grads_to_master_grads(model_params, master_params) model_grads_to_master_grads(model_params, master_params)
if args.static_loss_scale != 1: if args.static_loss_scale != 1:
for param in master_params: for param in master_params:
...@@ -318,7 +316,6 @@ def train(train_loader, model, criterion, optimizer, epoch): ...@@ -318,7 +316,6 @@ def train(train_loader, model, criterion, optimizer, epoch):
else: else:
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
reducer.reduce()
optimizer.step() optimizer.step()
torch.cuda.synchronize() torch.cuda.synchronize()
......
...@@ -98,7 +98,7 @@ def main(): ...@@ -98,7 +98,7 @@ def main():
args.gpu = 0 args.gpu = 0
args.world_size = 1 args.world_size = 1
if args.distributed: if args.distributed:
args.gpu = args.local_rank % torch.cuda.device_count() args.gpu = args.local_rank % torch.cuda.device_count()
torch.cuda.set_device(args.gpu) torch.cuda.set_device(args.gpu)
...@@ -133,6 +133,8 @@ def main(): ...@@ -133,6 +133,8 @@ def main():
# define loss function (criterion) and optimizer # define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda() criterion = nn.CrossEntropyLoss().cuda()
# Scale learning rate based on per-process batch size
args.lr = args.lr*float(args.batch_size)/256.
optimizer = torch.optim.SGD(master_params, args.lr, optimizer = torch.optim.SGD(master_params, args.lr,
momentum=args.momentum, momentum=args.momentum,
weight_decay=args.weight_decay) weight_decay=args.weight_decay)
...@@ -170,23 +172,26 @@ def main(): ...@@ -170,23 +172,26 @@ def main():
# transforms.ToTensor(), Too slow # transforms.ToTensor(), Too slow
# normalize, # normalize,
])) ]))
val_dataset = datasets.ImageFolder(valdir, transforms.Compose([
transforms.Resize(val_size),
transforms.CenterCrop(crop_size),
]))
train_sampler = None
val_sampler = None
if args.distributed: if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
else: val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
train_sampler = None
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=fast_collate) num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=fast_collate)
val_loader = torch.utils.data.DataLoader( val_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(valdir, transforms.Compose([ val_dataset,
transforms.Resize(val_size),
transforms.CenterCrop(crop_size),
])),
batch_size=args.batch_size, shuffle=False, batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True, num_workers=args.workers, pin_memory=True,
sampler=val_sampler,
collate_fn=fast_collate) collate_fn=fast_collate)
if args.evaluate: if args.evaluate:
...@@ -196,7 +201,6 @@ def main(): ...@@ -196,7 +201,6 @@ def main():
for epoch in range(args.start_epoch, args.epochs): for epoch in range(args.start_epoch, args.epochs):
if args.distributed: if args.distributed:
train_sampler.set_epoch(epoch) train_sampler.set_epoch(epoch)
adjust_learning_rate(optimizer, epoch)
# train for one epoch # train for one epoch
train(train_loader, model, criterion, optimizer, epoch) train(train_loader, model, criterion, optimizer, epoch)
...@@ -269,6 +273,8 @@ def train(train_loader, model, criterion, optimizer, epoch): ...@@ -269,6 +273,8 @@ def train(train_loader, model, criterion, optimizer, epoch):
while input is not None: while input is not None:
i += 1 i += 1
adjust_learning_rate(optimizer, epoch, i, len(train_loader))
if args.prof: if args.prof:
if i > 10: if i > 10:
break break
...@@ -425,9 +431,22 @@ class AverageMeter(object): ...@@ -425,9 +431,22 @@ class AverageMeter(object):
self.avg = self.sum / self.count self.avg = self.sum / self.count
def adjust_learning_rate(optimizer, epoch): def adjust_learning_rate(optimizer, epoch, step, len_epoch):
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" """LR schedule that should yield 76% converged accuracy with batch size 256"""
lr = args.lr * (0.1 ** (epoch // 30)) factor = epoch // 30
if epoch >= 80:
factor = factor + 1
lr = args.lr*(0.1**factor)
"""Warmup"""
if epoch < 5:
lr = lr*float(1 + step + epoch*len_epoch)/(5.*len_epoch)
# if(args.local_rank == 0):
# print("epoch = {}, step = {}, lr = {}".format(epoch, step, lr))
for param_group in optimizer.param_groups: for param_group in optimizer.param_groups:
param_group['lr'] = lr param_group['lr'] = lr
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment