"vscode:/vscode.git/clone" did not exist on "65d7b21b77fe0d13e6a96fa37f5a6a6c4d0a550f"
Unverified Commit 4ba203d9 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[Trainer] add train loss and flops metrics reports (#11980)

* add train loss and flops metrics reports

* consistency

* add train_loss to skip keys

* restore on_train_end call timing
parent 7ec596ec
......@@ -1362,20 +1362,24 @@ class Trainer:
self.state.best_model_checkpoint, load_optimizer_states=False, load_lr_scheduler_states=False
)
# add remaining tr_loss
self._total_loss_scalar += tr_loss.item()
train_loss = self._total_loss_scalar / self.state.global_step
metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps)
self.store_flos()
metrics["total_flos"] = self.state.total_flos
self.log(metrics)
self.control = self.callback_handler.on_train_end(args, self.state, self.control)
# add remaining tr_loss
self._total_loss_scalar += tr_loss.item()
metrics["train_loss"] = train_loss
self.is_in_train = False
self._memory_tracker.stop_and_update_metrics(metrics)
return TrainOutput(self.state.global_step, self._total_loss_scalar / self.state.global_step, metrics)
self.log(metrics)
self.control = self.callback_handler.on_train_end(args, self.state, self.control)
return TrainOutput(self.state.global_step, train_loss, metrics)
def _load_state_dict_in_model(self, state_dict):
load_result = self.model.load_state_dict(state_dict, strict=False)
......
......@@ -311,13 +311,11 @@ class TrainerIntegrationCommon:
log_history = state.pop("log_history", None)
log_history1 = state1.pop("log_history", None)
self.assertEqual(state, state1)
skip_log_keys = ["train_runtime", "train_samples_per_second", "train_steps_per_second", "train_loss"]
for log, log1 in zip(log_history, log_history1):
_ = log.pop("train_runtime", None)
_ = log1.pop("train_runtime", None)
_ = log.pop("train_samples_per_second", None)
_ = log1.pop("train_samples_per_second", None)
_ = log.pop("train_steps_per_second", None)
_ = log1.pop("train_steps_per_second", None)
for key in skip_log_keys:
_ = log.pop(key, None)
_ = log1.pop(key, None)
self.assertEqual(log, log1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment