Commit 73b62dde authored by Christian Sarofeen's avatar Christian Sarofeen Committed by mcarilli
Browse files

Add reducer class in parallel/distributed. (#37)

parent dc41c5ce
from .distributed import DistributedDataParallel from .distributed import DistributedDataParallel, Reducer
...@@ -34,6 +34,55 @@ def flat_dist_call(tensors, call, extra_args=None): ...@@ -34,6 +34,55 @@ def flat_dist_call(tensors, call, extra_args=None):
for buf, synced in zip(bucket, _unflatten_dense_tensors(coalesced, bucket)): for buf, synced in zip(bucket, _unflatten_dense_tensors(coalesced, bucket)):
buf.copy_(synced) buf.copy_(synced)
def extract_tensors(maybe_tensor, tensor_list):
if torch.is_tensor(maybe_tensor):
tensor_list.append(maybe_tensor)
else:
try:
for item in maybe_tensor:
extract_tensors(item, tensor_list)
except TypeError:
return
class Reducer(object):
"""
:class:`apex.parallel.Reducer` is a simple class that helps reduce a module parameters.
This class will not automatically reduce parameters in a module for the user, but it will
allow the user to call Reducer(module).reduce() which will immediately reduce all parameters.
:class:`apex.parallel.Reducer` is designed to work with
the launch utility script ``apex.parallel.multiproc.py`` or the launch utility script
``torch.distributed.launch`` with --nproc_per_node <= the number of gpus per node.
When used with these luanchers, :class:`apex.parallel.multiproc.py`
assumes 1:1 mapping of processes to GPUs.
Args:
module_or_grads_list: Either a network definition being run in multi-gpu/distributed mode.
Or an iterable of gradients to be reduced. If a list of gradients are passed in, user must
manually sync parameters with broadcast or another means. If module is passed in, this parameters
will be broadcasted from rank 0.
"""
def __init__(self, module_or_grads_list):
if isinstance(module_or_grads_list, Module):
self.module = module_or_grads_list
flat_dist_call([param.data for param in self.module.parameters()], dist.broadcast, (0,) )
else:
self.module = None
self.grads = []
extract_tensors(module_or_grads_list, self.grads)
def reduce(self):
if self.module:
grads = [param.grad.data for param in self.module.parameters() if param.grad is not None]
flat_dist_call(grads, dist.all_reduce)
else:
flat_dist_call(self.grads, dist.all_reduce)
class DistributedDataParallel(Module): class DistributedDataParallel(Module):
""" """
:class:`apex.parallel.DistributedDataParallel` is a module wrapper that enables :class:`apex.parallel.DistributedDataParallel` is a module wrapper that enables
......
...@@ -19,7 +19,7 @@ import torchvision.models as models ...@@ -19,7 +19,7 @@ import torchvision.models as models
import numpy as np import numpy as np
try: try:
from apex.parallel import DistributedDataParallel as DDP from apex.parallel import Reducer as DDP
from apex.fp16_utils import * from apex.fp16_utils import *
except ImportError: except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.") raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")
...@@ -122,7 +122,8 @@ def main(): ...@@ -122,7 +122,8 @@ def main():
model = network_to_half(model) model = network_to_half(model)
if args.distributed: if args.distributed:
# shared param turns off bucketing in DDP, for lower latency runs this can improve perf # shared param turns off bucketing in DDP, for lower latency runs this can improve perf
model = DDP(model, shared_param=True) global reducer
reducer = DDP(model)
global model_params, master_params global model_params, master_params
if args.fp16: if args.fp16:
...@@ -307,6 +308,7 @@ def train(train_loader, model, criterion, optimizer, epoch): ...@@ -307,6 +308,7 @@ def train(train_loader, model, criterion, optimizer, epoch):
if args.fp16: if args.fp16:
model.zero_grad() model.zero_grad()
loss.backward() loss.backward()
reducer.reduce()
model_grads_to_master_grads(model_params, master_params) model_grads_to_master_grads(model_params, master_params)
if args.static_loss_scale != 1: if args.static_loss_scale != 1:
for param in master_params: for param in master_params:
...@@ -316,6 +318,7 @@ def train(train_loader, model, criterion, optimizer, epoch): ...@@ -316,6 +318,7 @@ def train(train_loader, model, criterion, optimizer, epoch):
else: else:
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
reducer.reduce()
optimizer.step() optimizer.step()
torch.cuda.synchronize() torch.cuda.synchronize()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment