Commit 73b62dde authored by Christian Sarofeen's avatar Christian Sarofeen Committed by mcarilli
Browse files

Add reducer class in parallel/distributed. (#37)

parent dc41c5ce
from .distributed import DistributedDataParallel
from .distributed import DistributedDataParallel, Reducer
......@@ -34,6 +34,55 @@ def flat_dist_call(tensors, call, extra_args=None):
for buf, synced in zip(bucket, _unflatten_dense_tensors(coalesced, bucket)):
buf.copy_(synced)
def extract_tensors(maybe_tensor, tensor_list):
if torch.is_tensor(maybe_tensor):
tensor_list.append(maybe_tensor)
else:
try:
for item in maybe_tensor:
extract_tensors(item, tensor_list)
except TypeError:
return
class Reducer(object):
"""
:class:`apex.parallel.Reducer` is a simple class that helps reduce a module parameters.
This class will not automatically reduce parameters in a module for the user, but it will
allow the user to call Reducer(module).reduce() which will immediately reduce all parameters.
:class:`apex.parallel.Reducer` is designed to work with
the launch utility script ``apex.parallel.multiproc.py`` or the launch utility script
``torch.distributed.launch`` with --nproc_per_node <= the number of gpus per node.
When used with these luanchers, :class:`apex.parallel.multiproc.py`
assumes 1:1 mapping of processes to GPUs.
Args:
module_or_grads_list: Either a network definition being run in multi-gpu/distributed mode.
Or an iterable of gradients to be reduced. If a list of gradients are passed in, user must
manually sync parameters with broadcast or another means. If module is passed in, this parameters
will be broadcasted from rank 0.
"""
def __init__(self, module_or_grads_list):
if isinstance(module_or_grads_list, Module):
self.module = module_or_grads_list
flat_dist_call([param.data for param in self.module.parameters()], dist.broadcast, (0,) )
else:
self.module = None
self.grads = []
extract_tensors(module_or_grads_list, self.grads)
def reduce(self):
if self.module:
grads = [param.grad.data for param in self.module.parameters() if param.grad is not None]
flat_dist_call(grads, dist.all_reduce)
else:
flat_dist_call(self.grads, dist.all_reduce)
class DistributedDataParallel(Module):
"""
:class:`apex.parallel.DistributedDataParallel` is a module wrapper that enables
......
......@@ -19,7 +19,7 @@ import torchvision.models as models
import numpy as np
try:
from apex.parallel import DistributedDataParallel as DDP
from apex.parallel import Reducer as DDP
from apex.fp16_utils import *
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")
......@@ -122,7 +122,8 @@ def main():
model = network_to_half(model)
if args.distributed:
# shared param turns off bucketing in DDP, for lower latency runs this can improve perf
model = DDP(model, shared_param=True)
global reducer
reducer = DDP(model)
global model_params, master_params
if args.fp16:
......@@ -307,6 +308,7 @@ def train(train_loader, model, criterion, optimizer, epoch):
if args.fp16:
model.zero_grad()
loss.backward()
reducer.reduce()
model_grads_to_master_grads(model_params, master_params)
if args.static_loss_scale != 1:
for param in master_params:
......@@ -316,6 +318,7 @@ def train(train_loader, model, criterion, optimizer, epoch):
else:
optimizer.zero_grad()
loss.backward()
reducer.reduce()
optimizer.step()
torch.cuda.synchronize()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment