Unverified Commit f6fe41c9 authored by Benjamin Minixhofer's avatar Benjamin Minixhofer Committed by GitHub
Browse files

Reset loss to zero on logging in Trainer to avoid bfloat16 issues (#8561)

* make tr_loss regular float

* Revert "make tr_loss regular float"

This reverts commit c9d7ccfaf0c4387187b0841694f01ec0ffd5f4ba.

* reset loss at each logging step

* keep track of total loss with _total_loss_scalar

* add remaining tr_loss at the end
parent b592728e
...@@ -696,8 +696,10 @@ class Trainer: ...@@ -696,8 +696,10 @@ class Trainer:
self.state.is_local_process_zero = self.is_local_process_zero() self.state.is_local_process_zero = self.is_local_process_zero()
self.state.is_world_process_zero = self.is_world_process_zero() self.state.is_world_process_zero = self.is_world_process_zero()
# tr_loss is a tensor to avoid synchronization of TPUs through .item()
tr_loss = torch.tensor(0.0).to(self.args.device) tr_loss = torch.tensor(0.0).to(self.args.device)
self._logging_loss_scalar = 0 # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
self._total_loss_scalar = 0.0
self._globalstep_last_logged = 0 self._globalstep_last_logged = 0
self._total_flos = self.state.total_flos self._total_flos = self.state.total_flos
model.zero_grad() model.zero_grad()
...@@ -812,23 +814,26 @@ class Trainer: ...@@ -812,23 +814,26 @@ class Trainer:
self.log({"total_flos": self.state.total_flos}) self.log({"total_flos": self.state.total_flos})
self.control = self.callback_handler.on_train_end(self.args, self.state, self.control) self.control = self.callback_handler.on_train_end(self.args, self.state, self.control)
# add remaining tr_loss
self._total_loss_scalar += tr_loss.item()
return TrainOutput(self.state.global_step, tr_loss.item() / self.state.global_step) return TrainOutput(self.state.global_step, self._total_loss_scalar / self.state.global_step)
def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch): def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch):
if self.control.should_log: if self.control.should_log:
logs: Dict[str, float] = {} logs: Dict[str, float] = {}
tr_loss_scalar = tr_loss.item() tr_loss_scalar = tr_loss.item()
logs["loss"] = (tr_loss_scalar - self._logging_loss_scalar) / ( # reset tr_loss to zero
self.state.global_step - self._globalstep_last_logged tr_loss -= tr_loss
)
logs["loss"] = tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged)
# backward compatibility for pytorch schedulers # backward compatibility for pytorch schedulers
logs["learning_rate"] = ( logs["learning_rate"] = (
self.lr_scheduler.get_last_lr()[0] self.lr_scheduler.get_last_lr()[0]
if version.parse(torch.__version__) >= version.parse("1.4") if version.parse(torch.__version__) >= version.parse("1.4")
else self.lr_scheduler.get_lr()[0] else self.lr_scheduler.get_lr()[0]
) )
self._logging_loss_scalar = tr_loss_scalar self._total_loss_scalar += tr_loss_scalar
self._globalstep_last_logged = self.state.global_step self._globalstep_last_logged = self.state.global_step
self.log(logs) self.log(logs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment