Unverified Commit 4f09d0fd authored by Shijie Wu's avatar Shijie Wu Committed by GitHub
Browse files

storing & logging gradient norm in trainer (#27326)

* report grad_norm during training

* support getting grad_norm from deepspeed
parent a4851d94
...@@ -198,6 +198,7 @@ if is_accelerate_available(): ...@@ -198,6 +198,7 @@ if is_accelerate_available():
from accelerate import __version__ as accelerate_version from accelerate import __version__ as accelerate_version
from accelerate.utils import ( from accelerate.utils import (
DistributedDataParallelKwargs, DistributedDataParallelKwargs,
DistributedType,
GradientAccumulationPlugin, GradientAccumulationPlugin,
load_fsdp_model, load_fsdp_model,
load_fsdp_optimizer, load_fsdp_optimizer,
...@@ -1856,6 +1857,7 @@ class Trainer: ...@@ -1856,6 +1857,7 @@ class Trainer:
self._total_loss_scalar = 0.0 self._total_loss_scalar = 0.0
self._globalstep_last_logged = self.state.global_step self._globalstep_last_logged = self.state.global_step
model.zero_grad() model.zero_grad()
grad_norm: Optional[float] = None
self.control = self.callback_handler.on_train_begin(args, self.state, self.control) self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
...@@ -1973,19 +1975,27 @@ class Trainer: ...@@ -1973,19 +1975,27 @@ class Trainer:
# deepspeed does its own clipping # deepspeed does its own clipping
if is_sagemaker_mp_enabled() and args.fp16: if is_sagemaker_mp_enabled() and args.fp16:
self.optimizer.clip_master_grads(args.max_grad_norm) _grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm)
elif self.use_apex: elif self.use_apex:
# Revert to normal clipping otherwise, handling Apex or full precision # Revert to normal clipping otherwise, handling Apex or full precision
nn.utils.clip_grad_norm_( _grad_norm = nn.utils.clip_grad_norm_(
amp.master_params(self.optimizer), amp.master_params(self.optimizer),
args.max_grad_norm, args.max_grad_norm,
) )
else: else:
self.accelerator.clip_grad_norm_( _grad_norm = self.accelerator.clip_grad_norm_(
model.parameters(), model.parameters(),
args.max_grad_norm, args.max_grad_norm,
) )
if (
is_accelerate_available()
and self.accelerator.distributed_type == DistributedType.DEEPSPEED
):
grad_norm = model.get_global_grad_norm()
else:
grad_norm = _grad_norm.item() if _grad_norm is not None else None
# Optimizer step # Optimizer step
self.optimizer.step() self.optimizer.step()
optimizer_was_run = not self.accelerator.optimizer_step_was_skipped optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
...@@ -1999,7 +2009,7 @@ class Trainer: ...@@ -1999,7 +2009,7 @@ class Trainer:
self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
self.control = self.callback_handler.on_step_end(args, self.state, self.control) self.control = self.callback_handler.on_step_end(args, self.state, self.control)
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval) self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
else: else:
self.control = self.callback_handler.on_substep_end(args, self.state, self.control) self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
...@@ -2019,7 +2029,7 @@ class Trainer: ...@@ -2019,7 +2029,7 @@ class Trainer:
self.control.should_training_stop = True self.control.should_training_stop = True
self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval) self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
if DebugOption.TPU_METRICS_DEBUG in self.args.debug: if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
if is_torch_tpu_available(): if is_torch_tpu_available():
...@@ -2356,7 +2366,7 @@ class Trainer: ...@@ -2356,7 +2366,7 @@ class Trainer:
f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}." f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}."
) )
def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval): def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval):
if self.control.should_log and self.state.global_step > self._globalstep_last_logged: if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
if is_torch_tpu_available(): if is_torch_tpu_available():
xm.mark_step() xm.mark_step()
...@@ -2370,6 +2380,8 @@ class Trainer: ...@@ -2370,6 +2380,8 @@ class Trainer:
tr_loss -= tr_loss tr_loss -= tr_loss
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
if grad_norm is not None:
logs["grad_norm"] = grad_norm
logs["learning_rate"] = self._get_learning_rate() logs["learning_rate"] = self._get_learning_rate()
self._total_loss_scalar += tr_loss_scalar self._total_loss_scalar += tr_loss_scalar
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment