Commit d7d82715 authored by Myle Ott's avatar Myle Ott
Browse files

Fix all-reduce for new versions of PyTorch

We previously assumed that once a model parameter's gradient buffer was allocated, it stayed fixed during training.
However, this assumption is violated in recent versions of PyTorch (i.e., the gradient buffer may be reallocated during
training), and it's no longer a safe assumption to make.

This is primarily relevant when we do the all-reduce, since we all-reduce a flattened (i.e., contiguous) copy of the
gradients. We can make this more robust by copying the result of the all-reduce back into the model parameter's gradient
buffers after each update. Intra-device copies are cheap, so this doesn't affect performance.
parent 83053f97
......@@ -69,7 +69,6 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
# initialize optimizer
self.optimizer = self._build_optimizer()
self.flat_grads = None
self.loss = None
# initialize LR scheduler
......@@ -200,19 +199,21 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
# backward pass
self.loss.backward()
# flatten grads into a contiguous block of memory
if self.flat_grads is None:
self.flat_grads = self._flatten_grads_(self.model)
# get model parameters as a flattened (contiguous) tensor
flat_grads = self._flat_model_grads()
# all-reduce grads
nccl.all_reduce(self.flat_grads)
nccl.all_reduce(flat_grads)
# normalize grads
if grad_denom != 0:
self.flat_grads.div_(grad_denom)
flat_grads.div_(grad_denom)
# clip grads
grad_norm = self._clip_grads_(self.flat_grads, self.args.clip_norm)
grad_norm = self._clip_grads_(flat_grads, self.args.clip_norm)
# copy reduced grads back
self._set_model_grads_(flat_grads)
# take an optimization step
self.optimizer.step()
......@@ -222,20 +223,34 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
return grad_norm
def _flatten_grads_(self, model):
num_params = sum(p.data.numel() for p in model.parameters())
flat_grads = next(model.parameters()).data.new(num_params)
def _model_grads(self):
return [p.grad for p in self.model.parameters() if p.requires_grad]
def _flat_model_grads(self):
grads = self._model_grads()
if not hasattr(self, '_flat_grads'):
num_params = sum(g.data.numel() for g in grads)
self._flat_grads = grads[0].data.new(num_params)
offset = 0
for grad in grads:
grad = grad.data.view(-1)
numel = grad.numel()
self._flat_grads[offset:offset+numel].copy_(grad)
offset += numel
return self._flat_grads
def _set_model_grads_(self, flat_grads):
grads = self._model_grads()
offset = 0
for p in model.parameters():
grad = p.grad.data
numel, sz = grad.numel(), grad.size()
flat_grads[offset:offset+numel] = grad.view(-1)
grad.set_(flat_grads[offset:offset+numel])
grad.resize_(sz) # preserve original shape
for grad in grads:
grad = grad.data.view(-1)
numel = grad.numel()
grad.copy_(flat_grads[offset:offset+numel])
offset += numel
return flat_grads
assert offset == flat_grads.numel()
def _clip_grads_(self, flat_grads, clipv):
"""nn.utils.clip_grad_norm for flattened (contiguous) tensors."""
norm = flat_grads.norm()
if clipv > 0 and norm > clipv:
coef = max(norm, 1e-6) / clipv
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment