Commit 7c07e87c authored by Myle Ott's avatar Myle Ott
Browse files

All-reduce in FP16

parent fc312d28
......@@ -11,7 +11,7 @@ Train a network on multiple GPUs.
import torch
from fairseq import optim
from fairseq import optim, utils
from fairseq.meters import AverageMeter
from fairseq.optim import lr_scheduler
from fairseq.trainer import Trainer
......@@ -105,8 +105,26 @@ class FP16Trainer(Trainer):
# undo effect of dynamic loss scaling on gradients
grad_denom *= self.scaler.loss_scale
# all-reduce and rescale gradients
grad_norm = super()._all_reduce_and_rescale(grad_denom)
if self.args.distributed_world_size > 1:
# flatten grads into a single buffer
flat_grads = self._flat_grads = self._get_flat_grads(self._flat_grads)
# scale gradients to avoid overflow in all-reduce
flat_grads.div_(self.args.distributed_world_size)
grad_denom /= self.args.distributed_world_size
# all-reduce flat grads
torch.distributed.all_reduce(flat_grads)
# copy grads back to FP32
self.fp32_params.grad.data.copy_(flat_grads)
else:
# single worker: copy grads directly to FP32
self._get_flat_grads(out=self.fp32_params.grad.data)
# rescale and clip grads
self.fp32_params.grad.data.div_(grad_denom)
grad_norm = utils.clip_grad_norm_(self.fp32_params.grad.data, self.args.clip_norm)
# detect overflow and adjust loss scale
overflow = DynamicLossScaler.has_overflow(grad_norm)
......@@ -116,15 +134,6 @@ class FP16Trainer(Trainer):
return grad_norm
def _get_flat_grads(self, out=None):
if out is None:
out = self.fp32_params.grad
return super()._get_flat_grads(out)
def _set_flat_grads(self, new_grads):
# no-op
assert new_grads.data_ptr() == self.fp32_params.grad.data.data_ptr()
def _opt(self):
# take an optimization step using the FP32 params and grads
super()._opt()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment