Commit e54fcee2 authored by Ruoxin Sang's avatar Ruoxin Sang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 323098007
parent a78b05b9
......@@ -30,14 +30,6 @@ def _log_info(message: Text):
print(message)
def _validate_interval(interval: Optional[int], steps_per_loop: Optional[int],
interval_name: str):
if interval and steps_per_loop and (interval % steps_per_loop != 0):
raise ValueError("The {} interval ({}) must be a multiple "
"of the steps_per_loop ({})".format(
interval_name, interval, steps_per_loop))
class Controller:
"""Class that facilitates training and evaluation of models."""
......@@ -103,8 +95,10 @@ class Controller:
if summary_interval is not None:
if summary_interval <= 0:
raise ValueError("`summary_interval` should be larger than 0")
_validate_interval(
summary_interval, steps_per_loop, interval_name="summary")
if summary_interval % steps_per_loop != 0:
raise ValueError("The summary interval ({}) must be a multiple "
"of the steps_per_loop ({})".format(
summary_interval, steps_per_loop))
self.trainer = trainer
self.evaluator = evaluator
......@@ -142,9 +136,6 @@ class Controller:
# TODO(momernick): We probably only want to do this on certain occasions?
if self.checkpoint_manager is not None:
checkpoint_interval = self.checkpoint_manager.checkpoint_interval
_validate_interval(
checkpoint_interval, steps_per_loop, interval_name="checkpoint")
model_restored = self.restore_checkpoint()
if not model_restored and (checkpoint_interval and
self.trainer is not None):
......@@ -271,15 +262,15 @@ class Controller:
train_steps: The global step count to train up to.
eval_steps: The number of steps to run during an evaluation. If None,
this method will evaluate over the entire evaluation dataset.
eval_interval: The number of training steps to run between evalutions.
Must be a multiple of the controller's `steps_per_loop` init arg. If
None, evaluation will only be performed after training is complete.
eval_interval: The number of training steps to run between evaluations.
If set, training will always stop every `eval_interval` steps, even if
this results in a shorter inner loop than specified by `steps_per_loop`
setting. If None, evaluation will only be performed after training is
complete.
Raises:
ValueError: If eval_interval is not a multiple of self.steps_per_loop.
"""
_validate_interval(eval_interval, self.steps_per_loop, interval_name="eval")
current_step = self.global_step.numpy() # This is an expensive access.
eval_interval = eval_interval or (train_steps - current_step)
while current_step < train_steps:
......
......@@ -33,19 +33,15 @@ def create_model():
def summaries_with_matching_keyword(keyword, summary_dir):
"""Yields summary protos matching given keyword from event file."""
"""Returns summary protos matching given keyword from event file."""
matches = []
event_paths = tf.io.gfile.glob(os.path.join(summary_dir, "events*"))
for event in tf.compat.v1.train.summary_iterator(event_paths[-1]):
if event.summary is not None:
for value in event.summary.value:
if keyword in value.tag:
logging.info(event)
yield event.summary
def check_eventfile_for_keyword(keyword, summary_dir):
"""Checks event files for the keyword."""
return any(summaries_with_matching_keyword(keyword, summary_dir))
matches.append(event.summary)
return matches
def dataset_fn(ctx):
......@@ -219,13 +215,13 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
# Loss and accuracy values should be written into summaries.
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train")))
self.assertTrue(
check_eventfile_for_keyword(
self.assertNotEmpty(
summaries_with_matching_keyword(
"loss", os.path.join(self.model_dir, "summaries/train")))
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval")))
self.assertTrue(
check_eventfile_for_keyword(
self.assertNotEmpty(
summaries_with_matching_keyword(
"eval_loss", os.path.join(self.model_dir, "summaries/eval")))
# No checkpoint, so global step starts from 0.
test_runner.global_step.assign(0)
......@@ -275,13 +271,13 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
# Loss and accuracy values should be written into summaries.
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train")))
self.assertTrue(
check_eventfile_for_keyword(
self.assertNotEmpty(
summaries_with_matching_keyword(
"loss", os.path.join(self.model_dir, "summaries/train")))
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval")))
self.assertTrue(
check_eventfile_for_keyword(
self.assertNotEmpty(
summaries_with_matching_keyword(
"eval_loss", os.path.join(self.model_dir, "summaries/eval")))
def test_train_only(self):
......@@ -311,8 +307,8 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
# Only train summaries are written.
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train")))
self.assertTrue(
check_eventfile_for_keyword(
self.assertNotEmpty(
summaries_with_matching_keyword(
"loss", os.path.join(self.model_dir, "summaries/train")))
self.assertFalse(
tf.io.gfile.exists(os.path.join(self.model_dir, "summaries/eval")))
......@@ -340,8 +336,8 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
tf.io.gfile.exists(os.path.join(self.model_dir, "summaries/train")))
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval")))
self.assertTrue(
check_eventfile_for_keyword(
self.assertNotEmpty(
summaries_with_matching_keyword(
"eval_loss", os.path.join(self.model_dir, "summaries/eval")))
# Tests continuous eval with timeout and timeout_fn.
......@@ -423,8 +419,8 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
# Only train summaries are written.
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train")))
self.assertTrue(
check_eventfile_for_keyword(
self.assertNotEmpty(
summaries_with_matching_keyword(
"loss", os.path.join(self.model_dir, "summaries/train")))
self.assertFalse(
tf.io.gfile.exists(os.path.join(self.model_dir, "summaries/eval")))
......@@ -453,12 +449,12 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
# Loss and accuracy values should be written into summaries.
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries")))
self.assertTrue(
check_eventfile_for_keyword("loss",
os.path.join(self.model_dir, "summaries")))
self.assertTrue(
check_eventfile_for_keyword("eval_loss",
os.path.join(self.model_dir, "summaries")))
self.assertNotEmpty(
summaries_with_matching_keyword(
"loss", os.path.join(self.model_dir, "summaries")))
self.assertNotEmpty(
summaries_with_matching_keyword(
"eval_loss", os.path.join(self.model_dir, "summaries")))
def test_early_stop_on_eval_loss(self):
test_runner = TestRunner()
......@@ -518,8 +514,8 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
# Only eval summaries are written
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval")))
self.assertTrue(
check_eventfile_for_keyword(
self.assertNotEmpty(
summaries_with_matching_keyword(
"eval_loss", os.path.join(self.model_dir, "summaries/eval")))
def test_train_and_evaluate_reset_datasets(self):
......@@ -546,5 +542,33 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=6)
def test_eval_and_checkpoint_interval(self):
test_runner = TestRunner()
checkpoint = tf.train.Checkpoint(
model=test_runner.model, optimizer=test_runner.optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step,
checkpoint_interval=5)
test_controller = controller.Controller(
trainer=test_runner,
evaluator=test_runner,
global_step=test_runner.global_step,
steps_per_loop=10,
checkpoint_manager=checkpoint_manager)
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=5)
# Expect 3 checkpoints to be saved at step: 0, 5, 10.
self.assertLen(
tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt-*.data*")), 3)
# Expect evaluation is performed 2 times at step: 5, 10.
self.assertLen(
summaries_with_matching_keyword("eval_loss", self.model_dir), 2)
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment