callbacks.py 4.85 KB
Newer Older
Allen Wang's avatar
Allen Wang committed
1
# Lint as: python3
Allen Wang's avatar
Allen Wang committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Common modules for callbacks."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function

import os
from absl import logging

import tensorflow as tf
Allen Wang's avatar
Allen Wang committed
26
27
28
from typing import Any, List, MutableMapping

from official.utils.misc import keras_utils
Allen Wang's avatar
Allen Wang committed
29
30
31
32


def get_callbacks(model_checkpoint: bool = True,
                  include_tensorboard: bool = True,
Allen Wang's avatar
Allen Wang committed
33
                  time_history: bool = True,
Allen Wang's avatar
Allen Wang committed
34
35
36
                  track_lr: bool = True,
                  write_model_weights: bool = True,
                  initial_step: int = 0,
Allen Wang's avatar
Allen Wang committed
37
38
39
                  batch_size: int = 0,
                  log_steps: int = 0,
                  model_dir: str = None) -> List[tf.keras.callbacks.Callback]:
Allen Wang's avatar
Allen Wang committed
40
41
42
43
44
45
46
47
48
49
50
51
52
  """Get all callbacks."""
  model_dir = model_dir or ''
  callbacks = []
  if model_checkpoint:
    ckpt_full_path = os.path.join(model_dir, 'model.ckpt-{epoch:04d}')
    callbacks.append(tf.keras.callbacks.ModelCheckpoint(
        ckpt_full_path, save_weights_only=True, verbose=1))
  if include_tensorboard:
    callbacks.append(CustomTensorBoard(
        log_dir=model_dir,
        track_lr=track_lr,
        initial_step=initial_step,
        write_images=write_model_weights))
Allen Wang's avatar
Allen Wang committed
53
54
55
56
57
  if time_history:
    callbacks.append(keras_utils.TimeHistory(
        batch_size,
        log_steps,
        logdir=model_dir if include_tensorboard else None))
Allen Wang's avatar
Allen Wang committed
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
  return callbacks


def get_scalar_from_tensor(t: tf.Tensor) -> int:
  """Utility function to convert a Tensor to a scalar."""
  t = tf.keras.backend.get_value(t)
  if callable(t):
    return t()
  else:
    return t


class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
  """A customized TensorBoard callback that tracks additional datapoints.

  Metrics tracked:
  - Global learning rate

  Attributes:
    log_dir: the path of the directory where to save the log files to be
      parsed by TensorBoard.
    track_lr: `bool`, whether or not to track the global learning rate.
    initial_step: the initial step, used for preemption recovery.
    **kwargs: Additional arguments for backwards compatibility. Possible key
      is `period`.
  """
  # TODO(b/146499062): track params, flops, log lr, l2 loss,
  # classification loss

  def __init__(self,
Allen Wang's avatar
Allen Wang committed
88
               log_dir: str,
Allen Wang's avatar
Allen Wang committed
89
90
91
92
93
94
95
96
97
               track_lr: bool = False,
               initial_step: int = 0,
               **kwargs):
    super(CustomTensorBoard, self).__init__(log_dir=log_dir, **kwargs)
    self.step = initial_step
    self._track_lr = track_lr

  def on_batch_begin(self,
                     epoch: int,
Allen Wang's avatar
Allen Wang committed
98
                     logs: MutableMapping[str, Any] = None) -> None:
Allen Wang's avatar
Allen Wang committed
99
100
101
102
103
104
105
106
    self.step += 1
    if logs is None:
      logs = {}
    logs.update(self._calculate_metrics())
    super(CustomTensorBoard, self).on_batch_begin(epoch, logs)

  def on_epoch_begin(self,
                     epoch: int,
Allen Wang's avatar
Allen Wang committed
107
                     logs: MutableMapping[str, Any] = None) -> None:
Allen Wang's avatar
Allen Wang committed
108
109
110
111
112
113
114
115
116
117
    if logs is None:
      logs = {}
    metrics = self._calculate_metrics()
    logs.update(metrics)
    for k, v in metrics.items():
      logging.info('Current %s: %f', k, v)
    super(CustomTensorBoard, self).on_epoch_begin(epoch, logs)

  def on_epoch_end(self,
                   epoch: int,
Allen Wang's avatar
Allen Wang committed
118
                   logs: MutableMapping[str, Any] = None) -> None:
Allen Wang's avatar
Allen Wang committed
119
120
121
122
123
124
    if logs is None:
      logs = {}
    metrics = self._calculate_metrics()
    logs.update(metrics)
    super(CustomTensorBoard, self).on_epoch_end(epoch, logs)

Allen Wang's avatar
Allen Wang committed
125
  def _calculate_metrics(self) -> MutableMapping[str, Any]:
Allen Wang's avatar
Allen Wang committed
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
    logs = {}
    if self._track_lr:
      logs['learning_rate'] = self._calculate_lr()
    return logs

  def _calculate_lr(self) -> int:
    """Calculates the learning rate given the current step."""
    lr = self._get_base_optimizer().lr
    if callable(lr):
      lr = lr(self.step)
    return get_scalar_from_tensor(lr)

  def _get_base_optimizer(self) -> tf.keras.optimizers.Optimizer:
    """Get the base optimizer used by the current model."""

    optimizer = self.model.optimizer

    # The optimizer might be wrapped by another class, so unwrap it
    while hasattr(optimizer, '_optimizer'):
      optimizer = optimizer._optimizer  # pylint:disable=protected-access

    return optimizer