"...resnet50_tensorflow.git" did not exist on "1d5234950217bebbe374e28c62465c08f71874e3"
Unverified Commit f6fe41c9 authored by Benjamin Minixhofer's avatar Benjamin Minixhofer Committed by GitHub
Browse files

Reset loss to zero on logging in Trainer to avoid bfloat16 issues (#8561)

* make tr_loss regular float

* Revert "make tr_loss regular float"

This reverts commit c9d7ccfaf0c4387187b0841694f01ec0ffd5f4ba.

* reset loss at each logging step

* keep track of total loss with _total_loss_scalar

* add remaining tr_loss at the end
parent b592728e
......@@ -696,8 +696,10 @@ class Trainer:
self.state.is_local_process_zero = self.is_local_process_zero()
self.state.is_world_process_zero = self.is_world_process_zero()
# tr_loss is a tensor to avoid synchronization of TPUs through .item()
tr_loss = torch.tensor(0.0).to(self.args.device)
self._logging_loss_scalar = 0
# _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
self._total_loss_scalar = 0.0
self._globalstep_last_logged = 0
self._total_flos = self.state.total_flos
model.zero_grad()
......@@ -812,23 +814,26 @@ class Trainer:
self.log({"total_flos": self.state.total_flos})
self.control = self.callback_handler.on_train_end(self.args, self.state, self.control)
# add remaining tr_loss
self._total_loss_scalar += tr_loss.item()
return TrainOutput(self.state.global_step, tr_loss.item() / self.state.global_step)
return TrainOutput(self.state.global_step, self._total_loss_scalar / self.state.global_step)
def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch):
if self.control.should_log:
logs: Dict[str, float] = {}
tr_loss_scalar = tr_loss.item()
logs["loss"] = (tr_loss_scalar - self._logging_loss_scalar) / (
self.state.global_step - self._globalstep_last_logged
)
# reset tr_loss to zero
tr_loss -= tr_loss
logs["loss"] = tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged)
# backward compatibility for pytorch schedulers
logs["learning_rate"] = (
self.lr_scheduler.get_last_lr()[0]
if version.parse(torch.__version__) >= version.parse("1.4")
else self.lr_scheduler.get_lr()[0]
)
self._logging_loss_scalar = tr_loss_scalar
self._total_loss_scalar += tr_loss_scalar
self._globalstep_last_logged = self.state.global_step
self.log(logs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment