Commit 09aeecd6 authored by Jiayu Ye's avatar Jiayu Ye Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 480378058
parent 7475840b
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import pprint import pprint
import time import time
from typing import Any, Callable, Iterable, Optional, Union from typing import Callable, Iterable, Optional, Union
from absl import logging from absl import logging
...@@ -101,8 +101,8 @@ class Controller: ...@@ -101,8 +101,8 @@ class Controller:
summary_dir: Optional[str] = None, summary_dir: Optional[str] = None,
# Evaluation related # Evaluation related
eval_summary_dir: Optional[str] = None, eval_summary_dir: Optional[str] = None,
summary_manager: Optional[Any] = None, summary_manager: Optional[utils.SummaryManagerInterface] = None,
eval_summary_manager: Optional[Any] = None): eval_summary_manager: Optional[utils.SummaryManagerInterface] = None):
"""Initializes a `Controller` instance. """Initializes a `Controller` instance.
Note that if `checkpoint_manager` is provided and there are checkpoints in Note that if `checkpoint_manager` is provided and there are checkpoints in
......
...@@ -24,6 +24,7 @@ import numpy as np ...@@ -24,6 +24,7 @@ import numpy as np
from orbit import controller from orbit import controller
from orbit import runner from orbit import runner
from orbit import standard_runner from orbit import standard_runner
import orbit.utils
import tensorflow as tf import tensorflow as tf
...@@ -698,12 +699,22 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -698,12 +699,22 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
self.assertLen( self.assertLen(
summaries_with_matching_keyword("eval_loss", self.model_dir), 2) summaries_with_matching_keyword("eval_loss", self.model_dir), 2)
def test_evaluate_with_nested_summaries(self): @parameterized.named_parameters(("DefaultSummary", False),
("InjectSummary", True))
def test_evaluate_with_nested_summaries(self, inject_summary_manager):
test_evaluator = TestEvaluatorWithNestedSummary() test_evaluator = TestEvaluatorWithNestedSummary()
if inject_summary_manager:
summary_manager = orbit.utils.SummaryManager(
self.model_dir,
tf.summary.scalar,
global_step=tf.Variable(0, dtype=tf.int64))
else:
summary_manager = None
test_controller = controller.Controller( test_controller = controller.Controller(
evaluator=test_evaluator, evaluator=test_evaluator,
global_step=tf.Variable(0, dtype=tf.int64), global_step=tf.Variable(0, dtype=tf.int64),
eval_summary_dir=self.model_dir) eval_summary_dir=self.model_dir,
summary_manager=summary_manager)
test_controller.evaluate(steps=5) test_controller.evaluate(steps=5)
self.assertNotEmpty( self.assertNotEmpty(
......
...@@ -25,5 +25,6 @@ from orbit.utils.loop_fns import create_tf_while_loop_fn ...@@ -25,5 +25,6 @@ from orbit.utils.loop_fns import create_tf_while_loop_fn
from orbit.utils.loop_fns import LoopFnWithSummaries from orbit.utils.loop_fns import LoopFnWithSummaries
from orbit.utils.summary_manager import SummaryManager from orbit.utils.summary_manager import SummaryManager
from orbit.utils.summary_manager_interface import SummaryManagerInterface
from orbit.utils.tpu_summaries import OptionalSummariesFunction from orbit.utils.tpu_summaries import OptionalSummariesFunction
...@@ -16,10 +16,12 @@ ...@@ -16,10 +16,12 @@
import os import os
from orbit.utils.summary_manager_interface import SummaryManagerInterface
import tensorflow as tf import tensorflow as tf
class SummaryManager: class SummaryManager(SummaryManagerInterface):
"""A utility class for managing summary writing.""" """A utility class for managing summary writing."""
def __init__(self, summary_dir, summary_fn, global_step=None): def __init__(self, summary_dir, summary_fn, global_step=None):
......
# Copyright 2022 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Provides a utility class for managing summary writing."""
import abc
class SummaryManagerInterface(abc.ABC):
"""A utility interface for managing summary writing."""
@abc.abstractmethod
def flush(self):
"""Flushes the the recorded summaries."""
raise NotImplementedError
@abc.abstractmethod
def summary_writer(self, relative_path=""):
"""Returns the underlying summary writer for scoped writers."""
raise NotImplementedError
@abc.abstractmethod
def write_summaries(self, summary_dict):
"""Writes summaries for the given dictionary of values.
The summary_dict can be any nested dict. The SummaryManager should
recursively creates summaries, yielding a hierarchy of summaries which will
then be reflected in the corresponding UIs.
For example, users may evaluate on multiple datasets and return
`summary_dict` as a nested dictionary:
{
"dataset1": {
"loss": loss1,
"accuracy": accuracy1
},
"dataset2": {
"loss": loss2,
"accuracy": accuracy2
},
}
This will create two set of summaries, "dataset1" and "dataset2". Each
summary dict will contain summaries including both "loss" and "accuracy".
Args:
summary_dict: A dictionary of values. If any value in `summary_dict` is
itself a dictionary, then the function will create a new summary_dict
with name given by the corresponding key. This is performed recursively.
Leaf values are then summarized using the parent relative path.
"""
raise NotImplementedError
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment