Commit f498fa10 authored by Ruoxin Sang's avatar Ruoxin Sang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 323098007
parent f28b2a15
...@@ -30,14 +30,6 @@ def _log_info(message: Text): ...@@ -30,14 +30,6 @@ def _log_info(message: Text):
print(message) print(message)
def _validate_interval(interval: Optional[int], steps_per_loop: Optional[int],
interval_name: str):
if interval and steps_per_loop and (interval % steps_per_loop != 0):
raise ValueError("The {} interval ({}) must be a multiple "
"of the steps_per_loop ({})".format(
interval_name, interval, steps_per_loop))
class Controller: class Controller:
"""Class that facilitates training and evaluation of models.""" """Class that facilitates training and evaluation of models."""
...@@ -103,8 +95,10 @@ class Controller: ...@@ -103,8 +95,10 @@ class Controller:
if summary_interval is not None: if summary_interval is not None:
if summary_interval <= 0: if summary_interval <= 0:
raise ValueError("`summary_interval` should be larger than 0") raise ValueError("`summary_interval` should be larger than 0")
_validate_interval( if summary_interval % steps_per_loop != 0:
summary_interval, steps_per_loop, interval_name="summary") raise ValueError("The summary interval ({}) must be a multiple "
"of the steps_per_loop ({})".format(
summary_interval, steps_per_loop))
self.trainer = trainer self.trainer = trainer
self.evaluator = evaluator self.evaluator = evaluator
...@@ -142,9 +136,6 @@ class Controller: ...@@ -142,9 +136,6 @@ class Controller:
# TODO(momernick): We probably only want to do this on certain occasions? # TODO(momernick): We probably only want to do this on certain occasions?
if self.checkpoint_manager is not None: if self.checkpoint_manager is not None:
checkpoint_interval = self.checkpoint_manager.checkpoint_interval checkpoint_interval = self.checkpoint_manager.checkpoint_interval
_validate_interval(
checkpoint_interval, steps_per_loop, interval_name="checkpoint")
model_restored = self.restore_checkpoint() model_restored = self.restore_checkpoint()
if not model_restored and (checkpoint_interval and if not model_restored and (checkpoint_interval and
self.trainer is not None): self.trainer is not None):
...@@ -271,15 +262,15 @@ class Controller: ...@@ -271,15 +262,15 @@ class Controller:
train_steps: The global step count to train up to. train_steps: The global step count to train up to.
eval_steps: The number of steps to run during an evaluation. If None, eval_steps: The number of steps to run during an evaluation. If None,
this method will evaluate over the entire evaluation dataset. this method will evaluate over the entire evaluation dataset.
eval_interval: The number of training steps to run between evalutions. eval_interval: The number of training steps to run between evaluations.
Must be a multiple of the controller's `steps_per_loop` init arg. If If set, training will always stop every `eval_interval` steps, even if
None, evaluation will only be performed after training is complete. this results in a shorter inner loop than specified by `steps_per_loop`
setting. If None, evaluation will only be performed after training is
complete.
Raises: Raises:
ValueError: If eval_interval is not a multiple of self.steps_per_loop. ValueError: If eval_interval is not a multiple of self.steps_per_loop.
""" """
_validate_interval(eval_interval, self.steps_per_loop, interval_name="eval")
current_step = self.global_step.numpy() # This is an expensive access. current_step = self.global_step.numpy() # This is an expensive access.
eval_interval = eval_interval or (train_steps - current_step) eval_interval = eval_interval or (train_steps - current_step)
while current_step < train_steps: while current_step < train_steps:
......
...@@ -33,19 +33,15 @@ def create_model(): ...@@ -33,19 +33,15 @@ def create_model():
def summaries_with_matching_keyword(keyword, summary_dir): def summaries_with_matching_keyword(keyword, summary_dir):
"""Yields summary protos matching given keyword from event file.""" """Returns summary protos matching given keyword from event file."""
matches = []
event_paths = tf.io.gfile.glob(os.path.join(summary_dir, "events*")) event_paths = tf.io.gfile.glob(os.path.join(summary_dir, "events*"))
for event in tf.compat.v1.train.summary_iterator(event_paths[-1]): for event in tf.compat.v1.train.summary_iterator(event_paths[-1]):
if event.summary is not None: if event.summary is not None:
for value in event.summary.value: for value in event.summary.value:
if keyword in value.tag: if keyword in value.tag:
logging.info(event) matches.append(event.summary)
yield event.summary return matches
def check_eventfile_for_keyword(keyword, summary_dir):
"""Checks event files for the keyword."""
return any(summaries_with_matching_keyword(keyword, summary_dir))
def dataset_fn(ctx): def dataset_fn(ctx):
...@@ -219,13 +215,13 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -219,13 +215,13 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
# Loss and accuracy values should be written into summaries. # Loss and accuracy values should be written into summaries.
self.assertNotEmpty( self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train"))) tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train")))
self.assertTrue( self.assertNotEmpty(
check_eventfile_for_keyword( summaries_with_matching_keyword(
"loss", os.path.join(self.model_dir, "summaries/train"))) "loss", os.path.join(self.model_dir, "summaries/train")))
self.assertNotEmpty( self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval"))) tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval")))
self.assertTrue( self.assertNotEmpty(
check_eventfile_for_keyword( summaries_with_matching_keyword(
"eval_loss", os.path.join(self.model_dir, "summaries/eval"))) "eval_loss", os.path.join(self.model_dir, "summaries/eval")))
# No checkpoint, so global step starts from 0. # No checkpoint, so global step starts from 0.
test_runner.global_step.assign(0) test_runner.global_step.assign(0)
...@@ -275,13 +271,13 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -275,13 +271,13 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
# Loss and accuracy values should be written into summaries. # Loss and accuracy values should be written into summaries.
self.assertNotEmpty( self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train"))) tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train")))
self.assertTrue( self.assertNotEmpty(
check_eventfile_for_keyword( summaries_with_matching_keyword(
"loss", os.path.join(self.model_dir, "summaries/train"))) "loss", os.path.join(self.model_dir, "summaries/train")))
self.assertNotEmpty( self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval"))) tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval")))
self.assertTrue( self.assertNotEmpty(
check_eventfile_for_keyword( summaries_with_matching_keyword(
"eval_loss", os.path.join(self.model_dir, "summaries/eval"))) "eval_loss", os.path.join(self.model_dir, "summaries/eval")))
def test_train_only(self): def test_train_only(self):
...@@ -311,8 +307,8 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -311,8 +307,8 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
# Only train summaries are written. # Only train summaries are written.
self.assertNotEmpty( self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train"))) tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train")))
self.assertTrue( self.assertNotEmpty(
check_eventfile_for_keyword( summaries_with_matching_keyword(
"loss", os.path.join(self.model_dir, "summaries/train"))) "loss", os.path.join(self.model_dir, "summaries/train")))
self.assertFalse( self.assertFalse(
tf.io.gfile.exists(os.path.join(self.model_dir, "summaries/eval"))) tf.io.gfile.exists(os.path.join(self.model_dir, "summaries/eval")))
...@@ -340,8 +336,8 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -340,8 +336,8 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
tf.io.gfile.exists(os.path.join(self.model_dir, "summaries/train"))) tf.io.gfile.exists(os.path.join(self.model_dir, "summaries/train")))
self.assertNotEmpty( self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval"))) tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval")))
self.assertTrue( self.assertNotEmpty(
check_eventfile_for_keyword( summaries_with_matching_keyword(
"eval_loss", os.path.join(self.model_dir, "summaries/eval"))) "eval_loss", os.path.join(self.model_dir, "summaries/eval")))
# Tests continuous eval with timeout and timeout_fn. # Tests continuous eval with timeout and timeout_fn.
...@@ -423,8 +419,8 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -423,8 +419,8 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
# Only train summaries are written. # Only train summaries are written.
self.assertNotEmpty( self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train"))) tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train")))
self.assertTrue( self.assertNotEmpty(
check_eventfile_for_keyword( summaries_with_matching_keyword(
"loss", os.path.join(self.model_dir, "summaries/train"))) "loss", os.path.join(self.model_dir, "summaries/train")))
self.assertFalse( self.assertFalse(
tf.io.gfile.exists(os.path.join(self.model_dir, "summaries/eval"))) tf.io.gfile.exists(os.path.join(self.model_dir, "summaries/eval")))
...@@ -453,12 +449,12 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -453,12 +449,12 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
# Loss and accuracy values should be written into summaries. # Loss and accuracy values should be written into summaries.
self.assertNotEmpty( self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries"))) tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries")))
self.assertTrue( self.assertNotEmpty(
check_eventfile_for_keyword("loss", summaries_with_matching_keyword(
os.path.join(self.model_dir, "summaries"))) "loss", os.path.join(self.model_dir, "summaries")))
self.assertTrue( self.assertNotEmpty(
check_eventfile_for_keyword("eval_loss", summaries_with_matching_keyword(
os.path.join(self.model_dir, "summaries"))) "eval_loss", os.path.join(self.model_dir, "summaries")))
def test_early_stop_on_eval_loss(self): def test_early_stop_on_eval_loss(self):
test_runner = TestRunner() test_runner = TestRunner()
...@@ -518,8 +514,8 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -518,8 +514,8 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
# Only eval summaries are written # Only eval summaries are written
self.assertNotEmpty( self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval"))) tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval")))
self.assertTrue( self.assertNotEmpty(
check_eventfile_for_keyword( summaries_with_matching_keyword(
"eval_loss", os.path.join(self.model_dir, "summaries/eval"))) "eval_loss", os.path.join(self.model_dir, "summaries/eval")))
def test_train_and_evaluate_reset_datasets(self): def test_train_and_evaluate_reset_datasets(self):
...@@ -546,5 +542,33 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -546,5 +542,33 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
test_controller.train_and_evaluate( test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=6) train_steps=10, eval_steps=2, eval_interval=6)
def test_eval_and_checkpoint_interval(self):
test_runner = TestRunner()
checkpoint = tf.train.Checkpoint(
model=test_runner.model, optimizer=test_runner.optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step,
checkpoint_interval=5)
test_controller = controller.Controller(
trainer=test_runner,
evaluator=test_runner,
global_step=test_runner.global_step,
steps_per_loop=10,
checkpoint_manager=checkpoint_manager)
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=5)
# Expect 3 checkpoints to be saved at step: 0, 5, 10.
self.assertLen(
tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt-*.data*")), 3)
# Expect evaluation is performed 2 times at step: 5, 10.
self.assertLen(
summaries_with_matching_keyword("eval_loss", self.model_dir), 2)
if __name__ == "__main__": if __name__ == "__main__":
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment