Unverified Commit 8f1c960e authored by Abi See's avatar Abi See Committed by GitHub
Browse files

Fix two bugs with --logging_first_step (#8193)

* make sure that logging_first_step evaluates

* fix bug with incorrect loss on logging_first_step

* fix style

* logging_first_step only logs, not evals
parent 689ff74f
...@@ -729,6 +729,7 @@ class Trainer: ...@@ -729,6 +729,7 @@ class Trainer:
tr_loss = torch.tensor(0.0).to(self.args.device) tr_loss = torch.tensor(0.0).to(self.args.device)
self._logging_loss_scalar = 0 self._logging_loss_scalar = 0
self._globalstep_last_logged = 0
self._total_flos = self.state.total_flos self._total_flos = self.state.total_flos
model.zero_grad() model.zero_grad()
...@@ -849,7 +850,9 @@ class Trainer: ...@@ -849,7 +850,9 @@ class Trainer:
if self.control.should_log: if self.control.should_log:
logs: Dict[str, float] = {} logs: Dict[str, float] = {}
tr_loss_scalar = tr_loss.item() tr_loss_scalar = tr_loss.item()
logs["loss"] = (tr_loss_scalar - self._logging_loss_scalar) / self.args.logging_steps logs["loss"] = (tr_loss_scalar - self._logging_loss_scalar) / (
self.state.global_step - self._globalstep_last_logged
)
# backward compatibility for pytorch schedulers # backward compatibility for pytorch schedulers
logs["learning_rate"] = ( logs["learning_rate"] = (
self.lr_scheduler.get_last_lr()[0] self.lr_scheduler.get_last_lr()[0]
...@@ -857,6 +860,7 @@ class Trainer: ...@@ -857,6 +860,7 @@ class Trainer:
else self.lr_scheduler.get_lr()[0] else self.lr_scheduler.get_lr()[0]
) )
self._logging_loss_scalar = tr_loss_scalar self._logging_loss_scalar = tr_loss_scalar
self._globalstep_last_logged = self.state.global_step
self.log(logs) self.log(logs)
......
...@@ -250,7 +250,7 @@ class TrainingArguments: ...@@ -250,7 +250,7 @@ class TrainingArguments:
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."}) warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
logging_dir: Optional[str] = field(default_factory=default_logdir, metadata={"help": "Tensorboard log dir."}) logging_dir: Optional[str] = field(default_factory=default_logdir, metadata={"help": "Tensorboard log dir."})
logging_first_step: bool = field(default=False, metadata={"help": "Log and eval the first global_step"}) logging_first_step: bool = field(default=False, metadata={"help": "Log the first global_step"})
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."}) logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."}) save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
save_total_limit: Optional[int] = field( save_total_limit: Optional[int] = field(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment