Commit 09aeecd6 authored by Jiayu Ye's avatar Jiayu Ye Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 480378058
parent 7475840b
......@@ -17,7 +17,7 @@
import pprint
import time
from typing import Any, Callable, Iterable, Optional, Union
from typing import Callable, Iterable, Optional, Union
from absl import logging
......@@ -101,8 +101,8 @@ class Controller:
summary_dir: Optional[str] = None,
# Evaluation related
eval_summary_dir: Optional[str] = None,
summary_manager: Optional[Any] = None,
eval_summary_manager: Optional[Any] = None):
summary_manager: Optional[utils.SummaryManagerInterface] = None,
eval_summary_manager: Optional[utils.SummaryManagerInterface] = None):
"""Initializes a `Controller` instance.
Note that if `checkpoint_manager` is provided and there are checkpoints in
......
......@@ -24,6 +24,7 @@ import numpy as np
from orbit import controller
from orbit import runner
from orbit import standard_runner
import orbit.utils
import tensorflow as tf
......@@ -698,12 +699,22 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
self.assertLen(
summaries_with_matching_keyword("eval_loss", self.model_dir), 2)
def test_evaluate_with_nested_summaries(self):
@parameterized.named_parameters(("DefaultSummary", False),
("InjectSummary", True))
def test_evaluate_with_nested_summaries(self, inject_summary_manager):
test_evaluator = TestEvaluatorWithNestedSummary()
if inject_summary_manager:
summary_manager = orbit.utils.SummaryManager(
self.model_dir,
tf.summary.scalar,
global_step=tf.Variable(0, dtype=tf.int64))
else:
summary_manager = None
test_controller = controller.Controller(
evaluator=test_evaluator,
global_step=tf.Variable(0, dtype=tf.int64),
eval_summary_dir=self.model_dir)
eval_summary_dir=self.model_dir,
summary_manager=summary_manager)
test_controller.evaluate(steps=5)
self.assertNotEmpty(
......
......@@ -25,5 +25,6 @@ from orbit.utils.loop_fns import create_tf_while_loop_fn
from orbit.utils.loop_fns import LoopFnWithSummaries
from orbit.utils.summary_manager import SummaryManager
from orbit.utils.summary_manager_interface import SummaryManagerInterface
from orbit.utils.tpu_summaries import OptionalSummariesFunction
......@@ -16,10 +16,12 @@
import os
from orbit.utils.summary_manager_interface import SummaryManagerInterface
import tensorflow as tf
class SummaryManager:
class SummaryManager(SummaryManagerInterface):
"""A utility class for managing summary writing."""
def __init__(self, summary_dir, summary_fn, global_step=None):
......
# Copyright 2022 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Provides a utility class for managing summary writing."""
import abc
class SummaryManagerInterface(abc.ABC):
"""A utility interface for managing summary writing."""
@abc.abstractmethod
def flush(self):
"""Flushes the the recorded summaries."""
raise NotImplementedError
@abc.abstractmethod
def summary_writer(self, relative_path=""):
"""Returns the underlying summary writer for scoped writers."""
raise NotImplementedError
@abc.abstractmethod
def write_summaries(self, summary_dict):
"""Writes summaries for the given dictionary of values.
The summary_dict can be any nested dict. The SummaryManager should
recursively creates summaries, yielding a hierarchy of summaries which will
then be reflected in the corresponding UIs.
For example, users may evaluate on multiple datasets and return
`summary_dict` as a nested dictionary:
{
"dataset1": {
"loss": loss1,
"accuracy": accuracy1
},
"dataset2": {
"loss": loss2,
"accuracy": accuracy2
},
}
This will create two set of summaries, "dataset1" and "dataset2". Each
summary dict will contain summaries including both "loss" and "accuracy".
Args:
summary_dict: A dictionary of values. If any value in `summary_dict` is
itself a dictionary, then the function will create a new summary_dict
with name given by the corresponding key. This is performed recursively.
Leaf values are then summarized using the parent relative path.
"""
raise NotImplementedError
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment