optimizer_builder.py 7.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Functions to build DetectionModel training optimizers."""

18
import tensorflow.compat.v1 as tf
19

20
from object_detection.utils import learning_schedules
21
22
23
24
25
26
from object_detection.utils import tf_version

# pylint: disable=g-import-not-at-top
if tf_version.is_tf2():
  from official.modeling.optimization import ema_optimizer
# pylint: enable=g-import-not-at-top
27

28
29
30
31
32
try:
  from tensorflow.contrib import opt as tf_opt  # pylint: disable=g-import-not-at-top
except:  # pylint: disable=bare-except
  pass

33

34
35
def build_optimizers_tf_v1(optimizer_config, global_step=None):
  """Create a TF v1 compatible optimizer based on config.
36
37
38

  Args:
    optimizer_config: A Optimizer proto message.
39
40
    global_step: A variable representing the current step.
      If None, defaults to tf.train.get_or_create_global_step()
41
42

  Returns:
43
    An optimizer and a list of variables for summary.
44
45
46
47
48
49
50

  Raises:
    ValueError: when using an unsupported input data type.
  """
  optimizer_type = optimizer_config.WhichOneof('optimizer')
  optimizer = None

51
  summary_vars = []
52
53
  if optimizer_type == 'rms_prop_optimizer':
    config = optimizer_config.rms_prop_optimizer
54
55
    learning_rate = _create_learning_rate(config.learning_rate,
                                          global_step=global_step)
56
    summary_vars.append(learning_rate)
57
    optimizer = tf.train.RMSPropOptimizer(
58
        learning_rate,
59
60
61
62
63
64
        decay=config.decay,
        momentum=config.momentum_optimizer_value,
        epsilon=config.epsilon)

  if optimizer_type == 'momentum_optimizer':
    config = optimizer_config.momentum_optimizer
65
66
    learning_rate = _create_learning_rate(config.learning_rate,
                                          global_step=global_step)
67
    summary_vars.append(learning_rate)
68
    optimizer = tf.train.MomentumOptimizer(
69
        learning_rate,
70
71
72
73
        momentum=config.momentum_optimizer_value)

  if optimizer_type == 'adam_optimizer':
    config = optimizer_config.adam_optimizer
74
75
    learning_rate = _create_learning_rate(config.learning_rate,
                                          global_step=global_step)
76
    summary_vars.append(learning_rate)
77
    optimizer = tf.train.AdamOptimizer(learning_rate, epsilon=config.epsilon)
78

79

80
81
82
83
  if optimizer is None:
    raise ValueError('Optimizer %s not supported.' % optimizer_type)

  if optimizer_config.use_moving_average:
84
    optimizer = tf_opt.MovingAverageOptimizer(
85
86
        optimizer, average_decay=optimizer_config.moving_average_decay)

87
  return optimizer, summary_vars
88
89


90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
def build_optimizers_tf_v2(optimizer_config, global_step=None):
  """Create a TF v2 compatible optimizer based on config.

  Args:
    optimizer_config: A Optimizer proto message.
    global_step: A variable representing the current step.
      If None, defaults to tf.train.get_or_create_global_step()

  Returns:
    An optimizer and a list of variables for summary.

  Raises:
    ValueError: when using an unsupported input data type.
  """
  optimizer_type = optimizer_config.WhichOneof('optimizer')
  optimizer = None

  summary_vars = []
  if optimizer_type == 'rms_prop_optimizer':
    config = optimizer_config.rms_prop_optimizer
    learning_rate = _create_learning_rate(config.learning_rate,
                                          global_step=global_step)
    summary_vars.append(learning_rate)
    optimizer = tf.keras.optimizers.RMSprop(
        learning_rate,
        decay=config.decay,
        momentum=config.momentum_optimizer_value,
        epsilon=config.epsilon)

  if optimizer_type == 'momentum_optimizer':
    config = optimizer_config.momentum_optimizer
    learning_rate = _create_learning_rate(config.learning_rate,
                                          global_step=global_step)
    summary_vars.append(learning_rate)
    optimizer = tf.keras.optimizers.SGD(
        learning_rate,
        momentum=config.momentum_optimizer_value)

  if optimizer_type == 'adam_optimizer':
    config = optimizer_config.adam_optimizer
    learning_rate = _create_learning_rate(config.learning_rate,
                                          global_step=global_step)
    summary_vars.append(learning_rate)
133
    optimizer = tf.keras.optimizers.Adam(learning_rate, epsilon=config.epsilon)
134
135
136
137
138

  if optimizer is None:
    raise ValueError('Optimizer %s not supported.' % optimizer_type)

  if optimizer_config.use_moving_average:
139
140
141
    optimizer = ema_optimizer.ExponentialMovingAverage(
        optimizer=optimizer,
        average_decay=optimizer_config.moving_average_decay)
142
143
144
145
146
147
148
149
150
151
152
153

  return optimizer, summary_vars


def build(config, global_step=None):

  if tf.executing_eagerly():
    return build_optimizers_tf_v2(config, global_step)
  else:
    return build_optimizers_tf_v1(config, global_step)


154
def _create_learning_rate(learning_rate_config, global_step=None):
155
156
157
158
  """Create optimizer learning rate based on config.

  Args:
    learning_rate_config: A LearningRate proto message.
159
160
    global_step: A variable representing the current step.
      If None, defaults to tf.train.get_or_create_global_step()
161
162
163
164
165
166
167

  Returns:
    A learning rate.

  Raises:
    ValueError: when using an unsupported input data type.
  """
168
169
  if global_step is None:
    global_step = tf.train.get_or_create_global_step()
170
171
172
173
  learning_rate = None
  learning_rate_type = learning_rate_config.WhichOneof('learning_rate')
  if learning_rate_type == 'constant_learning_rate':
    config = learning_rate_config.constant_learning_rate
174
175
    learning_rate = tf.constant(config.learning_rate, dtype=tf.float32,
                                name='learning_rate')
176
177
178

  if learning_rate_type == 'exponential_decay_learning_rate':
    config = learning_rate_config.exponential_decay_learning_rate
179
    learning_rate = learning_schedules.exponential_decay_with_burnin(
180
        global_step,
181
        config.initial_learning_rate,
182
183
        config.decay_steps,
        config.decay_factor,
184
185
186
187
        burnin_learning_rate=config.burnin_learning_rate,
        burnin_steps=config.burnin_steps,
        min_learning_rate=config.min_learning_rate,
        staircase=config.staircase)
188
189
190
191
192
193
194
195
196

  if learning_rate_type == 'manual_step_learning_rate':
    config = learning_rate_config.manual_step_learning_rate
    if not config.schedule:
      raise ValueError('Empty learning rate schedule.')
    learning_rate_step_boundaries = [x.step for x in config.schedule]
    learning_rate_sequence = [config.initial_learning_rate]
    learning_rate_sequence += [x.learning_rate for x in config.schedule]
    learning_rate = learning_schedules.manual_stepping(
197
        global_step, learning_rate_step_boundaries,
198
        learning_rate_sequence, config.warmup)
199

Vivek Rathod's avatar
Vivek Rathod committed
200
201
202
  if learning_rate_type == 'cosine_decay_learning_rate':
    config = learning_rate_config.cosine_decay_learning_rate
    learning_rate = learning_schedules.cosine_decay_with_warmup(
203
        global_step,
Vivek Rathod's avatar
Vivek Rathod committed
204
205
206
        config.learning_rate_base,
        config.total_steps,
        config.warmup_learning_rate,
207
208
        config.warmup_steps,
        config.hold_base_rate_steps)
Vivek Rathod's avatar
Vivek Rathod committed
209

210
211
212
213
  if learning_rate is None:
    raise ValueError('Learning_rate %s not supported.' % learning_rate_type)

  return learning_rate