summary_manager.py 4.18 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
# Copyright 2022 The Orbit Authors. All Rights Reserved.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Hongkun Yu's avatar
Hongkun Yu committed
14

Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
15
16
17
18
"""Provides a utility class for managing summary writing."""

import os

Jiayu Ye's avatar
Jiayu Ye committed
19
20
from orbit.utils.summary_manager_interface import SummaryManagerInterface

Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
21
22
23
import tensorflow as tf


Jiayu Ye's avatar
Jiayu Ye committed
24
class SummaryManager(SummaryManagerInterface):
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
25
  """A utility class for managing summary writing."""
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
26
27

  def __init__(self, summary_dir, summary_fn, global_step=None):
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
28
    """Initializes the `SummaryManager` instance.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
29
30

    Args:
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
31
32
33
34
35
      summary_dir: The directory in which to write summaries. If `None`, all
        summary writing operations provided by this class are no-ops.
      summary_fn: A callable defined accepting `name`, `value`, and `step`
        parameters, making calls to `tf.summary` functions to write summaries.
      global_step: A `tf.Variable` containing the global step value.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
36
    """
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
37
    self._enabled = summary_dir is not None
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
38
39
40
41
42
43
44
45
46
47
    self._summary_dir = summary_dir
    self._summary_fn = summary_fn
    self._summary_writers = {}

    if global_step is None:
      self._global_step = tf.summary.experimental.get_step()
    else:
      self._global_step = global_step

  def summary_writer(self, relative_path=""):
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
48
    """Returns the underlying summary writer for a specific subdirectory.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
49
50
51

    Args:
      relative_path: The current path in which to write summaries, relative to
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
52
53
        the summary directory. By default it is empty, which corresponds to the
        root directory.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
54
55
56
57
58
59
60
61
62
63
64
    """
    if self._summary_writers and relative_path in self._summary_writers:
      return self._summary_writers[relative_path]
    if self._enabled:
      self._summary_writers[relative_path] = tf.summary.create_file_writer(
          os.path.join(self._summary_dir, relative_path))
    else:
      self._summary_writers[relative_path] = tf.summary.create_noop_writer()
    return self._summary_writers[relative_path]

  def flush(self):
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
65
    """Flushes the underlying summary writers."""
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
66
67
68
69
    if self._enabled:
      tf.nest.map_structure(tf.summary.flush, self._summary_writers)

  def write_summaries(self, summary_dict):
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
70
    """Writes summaries for the given dictionary of values.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
71
72
73
74
75

    This recursively creates subdirectories for any nested dictionaries
    provided in `summary_dict`, yielding a hierarchy of directories which will
    then be reflected in the TensorBoard UI as different colored curves.

Dana Movshovitz-Attias's avatar
Dana Movshovitz-Attias committed
76
    For example, users may evaluate on multiple datasets and return
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
77
78
79
80
81
82
83
84
85
86
87
88
89
90
    `summary_dict` as a nested dictionary:

        {
            "dataset1": {
                "loss": loss1,
                "accuracy": accuracy1
            },
            "dataset2": {
                "loss": loss2,
                "accuracy": accuracy2
            },
        }

    This will create two subdirectories, "dataset1" and "dataset2", inside the
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
91
92
93
94
95
    summary root directory. Each directory will contain event files including
    both "loss" and "accuracy" summaries.

    Args:
      summary_dict: A dictionary of values. If any value in `summary_dict` is
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
96
97
98
99
        itself a dictionary, then the function will create a subdirectory with
        name given by the corresponding key. This is performed recursively. Leaf
        values are then summarized using the summary writer instance specific to
        the parent relative path.
Dan Holtmann-Rice's avatar
Dan Holtmann-Rice committed
100
101
102
103
104
105
106
107
108
109
110
111
112
    """
    if not self._enabled:
      return
    self._write_summaries(summary_dict)

  def _write_summaries(self, summary_dict, relative_path=""):
    for name, value in summary_dict.items():
      if isinstance(value, dict):
        self._write_summaries(
            value, relative_path=os.path.join(relative_path, name))
      else:
        with self.summary_writer(relative_path).as_default():
          self._summary_fn(name, value, step=self._global_step)