Commit 7572c1f4 authored by Jiayu Ye's avatar Jiayu Ye Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 480164858
parent bae94e6d
......@@ -15,6 +15,7 @@
"""Defines exported symbols for the `orbit` package."""
from orbit import actions
# Internal import orbit.
from orbit import utils
from orbit.controller import Action
......
......@@ -17,7 +17,7 @@
import pprint
import time
from typing import Callable, Iterable, Optional, Union
from typing import Any, Callable, Iterable, Optional, Union
from absl import logging
......@@ -101,7 +101,8 @@ class Controller:
summary_dir: Optional[str] = None,
# Evaluation related
eval_summary_dir: Optional[str] = None,
):
summary_manager: Optional[Any] = None,
eval_summary_manager: Optional[Any] = None):
"""Initializes a `Controller` instance.
Note that if `checkpoint_manager` is provided and there are checkpoints in
......@@ -152,6 +153,14 @@ class Controller:
eval_summary_dir: The directory to write eval summaries to. If `None`, it
will be set to `summary_dir`. If both `summary_dir` and
`eval_summary_dir` are `None`, no eval summaries will be written.
summary_manager: Instance of the summary manager. If set, the
`summary_dir` will be ignored. Otherwise the summary manager will be
created internally for TensorBoard summaries by default from the
`summary_dir`.
eval_summary_manager: Instance of the eval summary manager. If set, the
`eval_summary_dir` will be ignored. Otherwise the eval summary manager
will be created internally for TensorBoard summaries by default from the
`eval_summary_dir`.
Raises:
ValueError: If both `trainer` and `evaluator` are `None`.
......@@ -199,8 +208,11 @@ class Controller:
if self.trainer is not None:
self.step_timer = None
self.summary_interval = summary_interval
self.summary_manager = utils.SummaryManager(
summary_dir, tf.summary.scalar, global_step=self.global_step)
if summary_manager:
self.summary_manager = summary_manager
else:
self.summary_manager = utils.SummaryManager(
summary_dir, tf.summary.scalar, global_step=self.global_step)
self._steps_per_loop = steps_per_loop
if self.evaluator is not None:
......@@ -210,8 +222,11 @@ class Controller:
# are the same.
self.eval_summary_manager = self.summary_manager
else:
self.eval_summary_manager = utils.SummaryManager(
eval_summary_dir, tf.summary.scalar, global_step=self.global_step)
if eval_summary_manager:
self.eval_summary_manager = eval_summary_manager
else:
self.eval_summary_manager = utils.SummaryManager(
eval_summary_dir, tf.summary.scalar, global_step=self.global_step)
tf.summary.experimental.set_step(self.global_step)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment