callbacks.py 4.94 KB
Newer Older
Allen Wang's avatar
Allen Wang committed
1
# Lint as: python3
Allen Wang's avatar
Allen Wang committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Common modules for callbacks."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function

import os
from absl import logging

import tensorflow as tf
Allen Wang's avatar
Allen Wang committed
26
27
28
from typing import Any, List, MutableMapping

from official.utils.misc import keras_utils
Allen Wang's avatar
Allen Wang committed
29
30
31
32


def get_callbacks(model_checkpoint: bool = True,
                  include_tensorboard: bool = True,
Allen Wang's avatar
Allen Wang committed
33
                  time_history: bool = True,
Allen Wang's avatar
Allen Wang committed
34
35
36
                  track_lr: bool = True,
                  write_model_weights: bool = True,
                  initial_step: int = 0,
Allen Wang's avatar
Allen Wang committed
37
38
39
                  batch_size: int = 0,
                  log_steps: int = 0,
                  model_dir: str = None) -> List[tf.keras.callbacks.Callback]:
Allen Wang's avatar
Allen Wang committed
40
41
42
43
44
  """Get all callbacks."""
  model_dir = model_dir or ''
  callbacks = []
  if model_checkpoint:
    ckpt_full_path = os.path.join(model_dir, 'model.ckpt-{epoch:04d}')
Hongkun Yu's avatar
Hongkun Yu committed
45
46
47
    callbacks.append(
        tf.keras.callbacks.ModelCheckpoint(
            ckpt_full_path, save_weights_only=True, verbose=1))
Allen Wang's avatar
Allen Wang committed
48
  if include_tensorboard:
Hongkun Yu's avatar
Hongkun Yu committed
49
50
51
52
53
54
    callbacks.append(
        CustomTensorBoard(
            log_dir=model_dir,
            track_lr=track_lr,
            initial_step=initial_step,
            write_images=write_model_weights))
Allen Wang's avatar
Allen Wang committed
55
  if time_history:
Hongkun Yu's avatar
Hongkun Yu committed
56
57
58
59
60
    callbacks.append(
        keras_utils.TimeHistory(
            batch_size,
            log_steps,
            logdir=model_dir if include_tensorboard else None))
Allen Wang's avatar
Allen Wang committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
  return callbacks


def get_scalar_from_tensor(t: tf.Tensor) -> int:
  """Utility function to convert a Tensor to a scalar."""
  t = tf.keras.backend.get_value(t)
  if callable(t):
    return t()
  else:
    return t


class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
  """A customized TensorBoard callback that tracks additional datapoints.

  Metrics tracked:
  - Global learning rate

  Attributes:
Hongkun Yu's avatar
Hongkun Yu committed
80
81
    log_dir: the path of the directory where to save the log files to be parsed
      by TensorBoard.
Allen Wang's avatar
Allen Wang committed
82
83
    track_lr: `bool`, whether or not to track the global learning rate.
    initial_step: the initial step, used for preemption recovery.
Hongkun Yu's avatar
Hongkun Yu committed
84
85
    **kwargs: Additional arguments for backwards compatibility. Possible key is
      `period`.
Allen Wang's avatar
Allen Wang committed
86
  """
Hongkun Yu's avatar
Hongkun Yu committed
87

Allen Wang's avatar
Allen Wang committed
88
89
90
91
  # TODO(b/146499062): track params, flops, log lr, l2 loss,
  # classification loss

  def __init__(self,
Allen Wang's avatar
Allen Wang committed
92
               log_dir: str,
Allen Wang's avatar
Allen Wang committed
93
94
95
96
97
98
99
100
101
               track_lr: bool = False,
               initial_step: int = 0,
               **kwargs):
    super(CustomTensorBoard, self).__init__(log_dir=log_dir, **kwargs)
    self.step = initial_step
    self._track_lr = track_lr

  def on_batch_begin(self,
                     epoch: int,
Allen Wang's avatar
Allen Wang committed
102
                     logs: MutableMapping[str, Any] = None) -> None:
Allen Wang's avatar
Allen Wang committed
103
104
105
106
107
108
109
110
    self.step += 1
    if logs is None:
      logs = {}
    logs.update(self._calculate_metrics())
    super(CustomTensorBoard, self).on_batch_begin(epoch, logs)

  def on_epoch_begin(self,
                     epoch: int,
Allen Wang's avatar
Allen Wang committed
111
                     logs: MutableMapping[str, Any] = None) -> None:
Allen Wang's avatar
Allen Wang committed
112
113
114
115
116
117
118
119
120
121
    if logs is None:
      logs = {}
    metrics = self._calculate_metrics()
    logs.update(metrics)
    for k, v in metrics.items():
      logging.info('Current %s: %f', k, v)
    super(CustomTensorBoard, self).on_epoch_begin(epoch, logs)

  def on_epoch_end(self,
                   epoch: int,
Allen Wang's avatar
Allen Wang committed
122
                   logs: MutableMapping[str, Any] = None) -> None:
Allen Wang's avatar
Allen Wang committed
123
124
125
126
127
128
    if logs is None:
      logs = {}
    metrics = self._calculate_metrics()
    logs.update(metrics)
    super(CustomTensorBoard, self).on_epoch_end(epoch, logs)

Allen Wang's avatar
Allen Wang committed
129
  def _calculate_metrics(self) -> MutableMapping[str, Any]:
Allen Wang's avatar
Allen Wang committed
130
    logs = {}
131
132
133
    # TODO(b/149030439): disable LR reporting.
    # if self._track_lr:
    #   logs['learning_rate'] = self._calculate_lr()
Allen Wang's avatar
Allen Wang committed
134
135
136
137
    return logs

  def _calculate_lr(self) -> int:
    """Calculates the learning rate given the current step."""
Hongkun Yu's avatar
Hongkun Yu committed
138
139
    return get_scalar_from_tensor(
        self._get_base_optimizer()._decayed_lr(var_dtype=tf.float32))
Allen Wang's avatar
Allen Wang committed
140
141
142
143
144
145
146
147
148
149
150

  def _get_base_optimizer(self) -> tf.keras.optimizers.Optimizer:
    """Get the base optimizer used by the current model."""

    optimizer = self.model.optimizer

    # The optimizer might be wrapped by another class, so unwrap it
    while hasattr(optimizer, '_optimizer'):
      optimizer = optimizer._optimizer  # pylint:disable=protected-access

    return optimizer