Commit 42f8e96e authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Change self.global_step.numpy() != expected_step as logging.warning.

PiperOrigin-RevId: 342276315
parent db4acd91
......@@ -189,7 +189,6 @@ class Controller:
tf.summary.experimental.set_step(self.global_step)
# Restores the model if needed.
# TODO(momernick): We probably only want to do this on certain occasions?
if self.checkpoint_manager is not None:
restored_path = self.restore_checkpoint()
if restored_path:
......@@ -417,12 +416,14 @@ class Controller:
# Verify that global_step was updated properly, then update current_step.
expected_step = current_step + num_steps
if self.global_step.numpy() != expected_step:
raise RuntimeError(
message = (
f"`trainer.train({num_steps})` did not update `global_step` by "
f"{num_steps}. Old value was {current_step}, expected updated value "
f"to be {expected_step}, but it was {self.global_step.numpy()}.")
current_step = expected_step
logging.warning(message)
return
current_step = expected_step
steps_per_second = self.step_timer.steps_per_second()
_log(f"train | step: {current_step: 6d} | "
f"steps/sec: {steps_per_second: 6.1f} | "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment