Commit d7d82715 authored by Myle Ott's avatar Myle Ott
Browse files

Fix all-reduce for new versions of PyTorch

We previously assumed that once a model parameter's gradient buffer was allocated, it stayed fixed during training.
However, this assumption is violated in recent versions of PyTorch (i.e., the gradient buffer may be reallocated during
training), and it's no longer a safe assumption to make.

This is primarily relevant when we do the all-reduce, since we all-reduce a flattened (i.e., contiguous) copy of the
gradients. We can make this more robust by copying the result of the all-reduce back into the model parameter's gradient
buffers after each update. Intra-device copies are cheap, so this doesn't affect performance.
parent 83053f97
...@@ -69,7 +69,6 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -69,7 +69,6 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
# initialize optimizer # initialize optimizer
self.optimizer = self._build_optimizer() self.optimizer = self._build_optimizer()
self.flat_grads = None
self.loss = None self.loss = None
# initialize LR scheduler # initialize LR scheduler
...@@ -200,19 +199,21 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -200,19 +199,21 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
# backward pass # backward pass
self.loss.backward() self.loss.backward()
# flatten grads into a contiguous block of memory # get model parameters as a flattened (contiguous) tensor
if self.flat_grads is None: flat_grads = self._flat_model_grads()
self.flat_grads = self._flatten_grads_(self.model)
# all-reduce grads # all-reduce grads
nccl.all_reduce(self.flat_grads) nccl.all_reduce(flat_grads)
# normalize grads # normalize grads
if grad_denom != 0: if grad_denom != 0:
self.flat_grads.div_(grad_denom) flat_grads.div_(grad_denom)
# clip grads # clip grads
grad_norm = self._clip_grads_(self.flat_grads, self.args.clip_norm) grad_norm = self._clip_grads_(flat_grads, self.args.clip_norm)
# copy reduced grads back
self._set_model_grads_(flat_grads)
# take an optimization step # take an optimization step
self.optimizer.step() self.optimizer.step()
...@@ -222,20 +223,34 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -222,20 +223,34 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
return grad_norm return grad_norm
def _flatten_grads_(self, model): def _model_grads(self):
num_params = sum(p.data.numel() for p in model.parameters()) return [p.grad for p in self.model.parameters() if p.requires_grad]
flat_grads = next(model.parameters()).data.new(num_params)
def _flat_model_grads(self):
grads = self._model_grads()
if not hasattr(self, '_flat_grads'):
num_params = sum(g.data.numel() for g in grads)
self._flat_grads = grads[0].data.new(num_params)
offset = 0
for grad in grads:
grad = grad.data.view(-1)
numel = grad.numel()
self._flat_grads[offset:offset+numel].copy_(grad)
offset += numel
return self._flat_grads
def _set_model_grads_(self, flat_grads):
grads = self._model_grads()
offset = 0 offset = 0
for p in model.parameters(): for grad in grads:
grad = p.grad.data grad = grad.data.view(-1)
numel, sz = grad.numel(), grad.size() numel = grad.numel()
flat_grads[offset:offset+numel] = grad.view(-1) grad.copy_(flat_grads[offset:offset+numel])
grad.set_(flat_grads[offset:offset+numel])
grad.resize_(sz) # preserve original shape
offset += numel offset += numel
return flat_grads assert offset == flat_grads.numel()
def _clip_grads_(self, flat_grads, clipv): def _clip_grads_(self, flat_grads, clipv):
"""nn.utils.clip_grad_norm for flattened (contiguous) tensors."""
norm = flat_grads.norm() norm = flat_grads.norm()
if clipv > 0 and norm > clipv: if clipv > 0 and norm > clipv:
coef = max(norm, 1e-6) / clipv coef = max(norm, 1e-6) / clipv
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment