Unverified Commit 77d6941e authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix gradient clipping for Sharded DDP (#9168)

* Fix gradient clipping for Sharded DDP

* Fix typos in comments
parent 1aca3d6a
......@@ -804,14 +804,23 @@ class Trainer:
steps_in_epoch <= self.args.gradient_accumulation_steps
and (step + 1) == steps_in_epoch
):
# Gradient clipping
if self.args.max_grad_norm is not None and self.args.max_grad_norm > 0:
if self.use_amp:
# AMP: gradients need unscaling
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
elif self.use_apex:
torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.args.max_grad_norm)
if hasattr(self.optimizer, "clip_grad_norm"):
# Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
self.optimizer.clip_grad_norm(self.args.max_grad_norm)
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
# Revert to normal clipping otherwise, handling Apex or full precision
torch.nn.utils.clip_grad_norm_(
amp.master_params(self.optimizer) if self.use_apex else model.parameters(),
self.args.max_grad_norm,
)
# Optimizer step
if is_torch_tpu_available():
xm.optimizer_step(self.optimizer)
elif self.use_amp:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment