Unverified Commit 2ce3ddab authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Small fixes to NotebookProgressCallback (#7813)

parent 6f45dd2f
...@@ -153,7 +153,7 @@ try: ...@@ -153,7 +153,7 @@ try:
import IPython # noqa: F401 import IPython # noqa: F401
_in_notebook = True _in_notebook = True
except: # noqa: E722 except (AttributeError, ImportError, KeyError):
_in_notebook = False _in_notebook = False
......
...@@ -19,6 +19,7 @@ from typing import Optional ...@@ -19,6 +19,7 @@ from typing import Optional
import IPython.display as disp import IPython.display as disp
from ..trainer_callback import TrainerCallback from ..trainer_callback import TrainerCallback
from ..trainer_utils import EvaluationStrategy
def format_time(t): def format_time(t):
...@@ -146,7 +147,7 @@ class NotebookProgressBar: ...@@ -146,7 +147,7 @@ class NotebookProgressBar:
self.first_calls = self.warmup self.first_calls = self.warmup
self.wait_for = 1 self.wait_for = 1
self.update_bar(value) self.update_bar(value)
elif value <= self.last_value: elif value <= self.last_value and not force_update:
return return
elif force_update or self.first_calls > 0 or value >= min(self.last_value + self.wait_for, self.total): elif force_update or self.first_calls > 0 or value >= min(self.last_value + self.wait_for, self.total):
if self.first_calls > 0: if self.first_calls > 0:
...@@ -272,17 +273,25 @@ class NotebookProgressCallback(TrainerCallback): ...@@ -272,17 +273,25 @@ class NotebookProgressCallback(TrainerCallback):
def __init__(self): def __init__(self):
self.training_tracker = None self.training_tracker = None
self.prediction_bar = None self.prediction_bar = None
self._force_next_update = False
def on_train_begin(self, args, state, control, **kwargs): def on_train_begin(self, args, state, control, **kwargs):
self.first_column = "Epoch" if args.max_steps <= 0 else "Step" self.first_column = "Epoch" if args.evaluation_strategy == EvaluationStrategy.EPOCH else "Step"
self.training_loss = 0 self.training_loss = 0
self.last_log = 0 self.last_log = 0
column_names = [self.first_column] + ["Training Loss", "Validation Loss"] column_names = [self.first_column] + ["Training Loss"]
if args.evaluation_strategy != EvaluationStrategy.NO:
column_names.append("Validation Loss")
self.training_tracker = NotebookTrainingTracker(state.max_steps, column_names) self.training_tracker = NotebookTrainingTracker(state.max_steps, column_names)
def on_step_end(self, args, state, control, **kwargs): def on_step_end(self, args, state, control, **kwargs):
epoch = int(state.epoch) if int(state.epoch) == state.epoch else f"{state.epoch:.2f}" epoch = int(state.epoch) if int(state.epoch) == state.epoch else f"{state.epoch:.2f}"
self.training_tracker.update(state.global_step + 1, comment=f"Epoch {epoch}/{state.num_train_epochs}") self.training_tracker.update(
state.global_step + 1,
comment=f"Epoch {epoch}/{state.num_train_epochs}",
force_update=self._force_next_update,
)
self._force_next_update = False
def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs): def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
if self.prediction_bar is None: if self.prediction_bar is None:
...@@ -294,6 +303,14 @@ class NotebookProgressCallback(TrainerCallback): ...@@ -294,6 +303,14 @@ class NotebookProgressCallback(TrainerCallback):
else: else:
self.prediction_bar.update(self.prediction_bar.value + 1) self.prediction_bar.update(self.prediction_bar.value + 1)
def on_log(self, args, state, control, logs=None, **kwargs):
# Only for when there is no evaluation
if args.evaluation_strategy == EvaluationStrategy.NO and "loss" in logs:
values = {"Training Loss": logs["loss"]}
# First column is necessarily Step sine we're not in epoch eval strategy
values["Step"] = state.global_step
self.training_tracker.write_line(values)
def on_evaluate(self, args, state, control, metrics=None, **kwargs): def on_evaluate(self, args, state, control, metrics=None, **kwargs):
if self.training_tracker is not None: if self.training_tracker is not None:
values = {"Training Loss": "No log"} values = {"Training Loss": "No log"}
...@@ -319,6 +336,8 @@ class NotebookProgressCallback(TrainerCallback): ...@@ -319,6 +336,8 @@ class NotebookProgressCallback(TrainerCallback):
self.training_tracker.write_line(values) self.training_tracker.write_line(values)
self.training_tracker.remove_child() self.training_tracker.remove_child()
self.prediction_bar = None self.prediction_bar = None
# Evaluation takes a long time so we should force the next update.
self._force_next_update = True
def on_train_end(self, args, state, control, **kwargs): def on_train_end(self, args, state, control, **kwargs):
self.training_tracker.update( self.training_tracker.update(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment