"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "87a09d66f357c4afb2e2ed7fdabfbf46b485b815"
Commit 7c07e87c authored by Myle Ott's avatar Myle Ott
Browse files

All-reduce in FP16

parent fc312d28
...@@ -11,7 +11,7 @@ Train a network on multiple GPUs. ...@@ -11,7 +11,7 @@ Train a network on multiple GPUs.
import torch import torch
from fairseq import optim from fairseq import optim, utils
from fairseq.meters import AverageMeter from fairseq.meters import AverageMeter
from fairseq.optim import lr_scheduler from fairseq.optim import lr_scheduler
from fairseq.trainer import Trainer from fairseq.trainer import Trainer
...@@ -105,8 +105,26 @@ class FP16Trainer(Trainer): ...@@ -105,8 +105,26 @@ class FP16Trainer(Trainer):
# undo effect of dynamic loss scaling on gradients # undo effect of dynamic loss scaling on gradients
grad_denom *= self.scaler.loss_scale grad_denom *= self.scaler.loss_scale
# all-reduce and rescale gradients if self.args.distributed_world_size > 1:
grad_norm = super()._all_reduce_and_rescale(grad_denom) # flatten grads into a single buffer
flat_grads = self._flat_grads = self._get_flat_grads(self._flat_grads)
# scale gradients to avoid overflow in all-reduce
flat_grads.div_(self.args.distributed_world_size)
grad_denom /= self.args.distributed_world_size
# all-reduce flat grads
torch.distributed.all_reduce(flat_grads)
# copy grads back to FP32
self.fp32_params.grad.data.copy_(flat_grads)
else:
# single worker: copy grads directly to FP32
self._get_flat_grads(out=self.fp32_params.grad.data)
# rescale and clip grads
self.fp32_params.grad.data.div_(grad_denom)
grad_norm = utils.clip_grad_norm_(self.fp32_params.grad.data, self.args.clip_norm)
# detect overflow and adjust loss scale # detect overflow and adjust loss scale
overflow = DynamicLossScaler.has_overflow(grad_norm) overflow = DynamicLossScaler.has_overflow(grad_norm)
...@@ -116,15 +134,6 @@ class FP16Trainer(Trainer): ...@@ -116,15 +134,6 @@ class FP16Trainer(Trainer):
return grad_norm return grad_norm
def _get_flat_grads(self, out=None):
if out is None:
out = self.fp32_params.grad
return super()._get_flat_grads(out)
def _set_flat_grads(self, new_grads):
# no-op
assert new_grads.data_ptr() == self.fp32_params.grad.data.data_ptr()
def _opt(self): def _opt(self):
# take an optimization step using the FP32 params and grads # take an optimization step using the FP32 params and grads
super()._opt() super()._opt()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment