Commit 02b874a1 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Update orbit.Controller: do not write training summary when input summary_dir is None.

PiperOrigin-RevId: 325115012
parent 9ed4356f
...@@ -167,6 +167,7 @@ def run(flags_obj): ...@@ -167,6 +167,7 @@ def run(flags_obj):
steps_per_loop=steps_per_loop, steps_per_loop=steps_per_loop,
checkpoint_manager=checkpoint_manager, checkpoint_manager=checkpoint_manager,
summary_interval=summary_interval, summary_interval=summary_interval,
summary_dir=flags_obj.model_dir,
eval_summary_dir=os.path.join(flags_obj.model_dir, 'eval')) eval_summary_dir=os.path.join(flags_obj.model_dir, 'eval'))
time_callback.on_train_begin() time_callback.on_train_begin()
......
...@@ -71,9 +71,11 @@ class Controller: ...@@ -71,9 +71,11 @@ class Controller:
`trainer.train` function will always be enabled. If set, the value `trainer.train` function will always be enabled. If set, the value
should be divisible by steps_per_loop. should be divisible by steps_per_loop.
summary_dir: The directory to restore and write checkpoints and summaries. summary_dir: The directory to restore and write checkpoints and summaries.
If None, it will be set to `checkpoint_manager.directory`. For example, You can set it to `checkpoint_manager.directory`.
If None, it will not write training summarizes.
eval_summary_dir: The directory to write eval summaries. If None, it will eval_summary_dir: The directory to write eval summaries. If None, it will
be set to `summary_dir`. be set to `summary_dir`. If both `summary_dir` and `eval_summary_dir`
are None, it will not write evaluation summarizes.
Raises: Raises:
ValueError: If both `trainer` and `evaluator` are None. ValueError: If both `trainer` and `evaluator` are None.
...@@ -108,9 +110,6 @@ class Controller: ...@@ -108,9 +110,6 @@ class Controller:
self.global_step = global_step self.global_step = global_step
self.checkpoint_manager = checkpoint_manager self.checkpoint_manager = checkpoint_manager
if summary_dir is None and checkpoint_manager:
summary_dir = checkpoint_manager.directory
if self.trainer is not None: if self.trainer is not None:
self.step_timer = None self.step_timer = None
self.steps_per_loop = steps_per_loop self.steps_per_loop = steps_per_loop
...@@ -118,7 +117,6 @@ class Controller: ...@@ -118,7 +117,6 @@ class Controller:
self.summary_manager = utils.SummaryManager( self.summary_manager = utils.SummaryManager(
summary_dir, tf.summary.scalar, global_step=self.global_step) summary_dir, tf.summary.scalar, global_step=self.global_step)
eval_summary_writer = None
if self.evaluator is not None: if self.evaluator is not None:
eval_summary_dir = eval_summary_dir or summary_dir eval_summary_dir = eval_summary_dir or summary_dir
if eval_summary_dir == summary_dir and self.trainer is not None: if eval_summary_dir == summary_dir and self.trainer is not None:
......
...@@ -294,6 +294,56 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -294,6 +294,56 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
train_steps=10, eval_steps=2, eval_interval=6) train_steps=10, eval_steps=2, eval_interval=6)
self.assertEqual(test_runner.global_step, 10) self.assertEqual(test_runner.global_step, 10)
def test_has_checkpoint_no_summaries(self):
test_runner = TestRunner()
# Has checkpoint, but no summary directories.
checkpoint = tf.train.Checkpoint(model=test_runner.model)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step)
test_controller = controller.Controller(
trainer=test_runner,
evaluator=test_runner,
global_step=test_runner.global_step,
checkpoint_manager=checkpoint_manager,
steps_per_loop=2)
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=6)
self.assertEqual(test_runner.global_step, 10)
# No summaries are saved.
self.assertEmpty(tf.io.gfile.glob(
os.path.join(checkpoint_manager.directory, "events.*")))
def test_has_checkpoint_eval_summary_only(self):
test_runner = TestRunner()
# Has checkpoint, but no summary directories.
checkpoint = tf.train.Checkpoint(model=test_runner.model)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step)
test_controller = controller.Controller(
trainer=test_runner,
evaluator=test_runner,
global_step=test_runner.global_step,
checkpoint_manager=checkpoint_manager,
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),
steps_per_loop=2)
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=6)
self.assertEqual(test_runner.global_step, 10)
# Training summaries are not saved.
self.assertEmpty(tf.io.gfile.glob(
os.path.join(checkpoint_manager.directory, "events.*")))
# Evaluation summaries are saved.
self.assertNotEmpty(tf.io.gfile.glob(
os.path.join(self.model_dir, "summaries/eval/events.*")))
@parameterized.named_parameters(("return_numpy", True), @parameterized.named_parameters(("return_numpy", True),
("return_tensor", False)) ("return_tensor", False))
def test_train_and_evaluate(self, return_numpy): def test_train_and_evaluate(self, return_numpy):
...@@ -612,7 +662,8 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -612,7 +662,8 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
evaluator=test_runner, evaluator=test_runner,
global_step=test_runner.global_step, global_step=test_runner.global_step,
steps_per_loop=10, steps_per_loop=10,
checkpoint_manager=checkpoint_manager) checkpoint_manager=checkpoint_manager,
summary_dir=self.model_dir)
test_controller.train_and_evaluate( test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=5) train_steps=10, eval_steps=2, eval_interval=5)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment