Unverified Commit 4ba203d9 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[Trainer] add train loss and flops metrics reports (#11980)

* add train loss and flops metrics reports

* consistency

* add train_loss to skip keys

* restore on_train_end call timing
parent 7ec596ec
...@@ -1362,20 +1362,24 @@ class Trainer: ...@@ -1362,20 +1362,24 @@ class Trainer:
self.state.best_model_checkpoint, load_optimizer_states=False, load_lr_scheduler_states=False self.state.best_model_checkpoint, load_optimizer_states=False, load_lr_scheduler_states=False
) )
# add remaining tr_loss
self._total_loss_scalar += tr_loss.item()
train_loss = self._total_loss_scalar / self.state.global_step
metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps) metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps)
self.store_flos() self.store_flos()
metrics["total_flos"] = self.state.total_flos metrics["total_flos"] = self.state.total_flos
self.log(metrics) metrics["train_loss"] = train_loss
self.control = self.callback_handler.on_train_end(args, self.state, self.control)
# add remaining tr_loss
self._total_loss_scalar += tr_loss.item()
self.is_in_train = False self.is_in_train = False
self._memory_tracker.stop_and_update_metrics(metrics) self._memory_tracker.stop_and_update_metrics(metrics)
return TrainOutput(self.state.global_step, self._total_loss_scalar / self.state.global_step, metrics) self.log(metrics)
self.control = self.callback_handler.on_train_end(args, self.state, self.control)
return TrainOutput(self.state.global_step, train_loss, metrics)
def _load_state_dict_in_model(self, state_dict): def _load_state_dict_in_model(self, state_dict):
load_result = self.model.load_state_dict(state_dict, strict=False) load_result = self.model.load_state_dict(state_dict, strict=False)
......
...@@ -311,13 +311,11 @@ class TrainerIntegrationCommon: ...@@ -311,13 +311,11 @@ class TrainerIntegrationCommon:
log_history = state.pop("log_history", None) log_history = state.pop("log_history", None)
log_history1 = state1.pop("log_history", None) log_history1 = state1.pop("log_history", None)
self.assertEqual(state, state1) self.assertEqual(state, state1)
skip_log_keys = ["train_runtime", "train_samples_per_second", "train_steps_per_second", "train_loss"]
for log, log1 in zip(log_history, log_history1): for log, log1 in zip(log_history, log_history1):
_ = log.pop("train_runtime", None) for key in skip_log_keys:
_ = log1.pop("train_runtime", None) _ = log.pop(key, None)
_ = log.pop("train_samples_per_second", None) _ = log1.pop(key, None)
_ = log1.pop("train_samples_per_second", None)
_ = log.pop("train_steps_per_second", None)
_ = log1.pop("train_steps_per_second", None)
self.assertEqual(log, log1) self.assertEqual(log, log1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment