optimizer_builder.py 7.32 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Functions to build DetectionModel training optimizers."""

18
import tensorflow.compat.v1 as tf
19

20
21
from object_detection.utils import learning_schedules

22
23
24
25
26
try:
  from tensorflow.contrib import opt as tf_opt  # pylint: disable=g-import-not-at-top
except:  # pylint: disable=bare-except
  pass

27

28
29
def build_optimizers_tf_v1(optimizer_config, global_step=None):
  """Create a TF v1 compatible optimizer based on config.
30
31
32

  Args:
    optimizer_config: A Optimizer proto message.
33
34
    global_step: A variable representing the current step.
      If None, defaults to tf.train.get_or_create_global_step()
35
36

  Returns:
37
    An optimizer and a list of variables for summary.
38
39
40
41
42
43
44

  Raises:
    ValueError: when using an unsupported input data type.
  """
  optimizer_type = optimizer_config.WhichOneof('optimizer')
  optimizer = None

45
  summary_vars = []
46
47
  if optimizer_type == 'rms_prop_optimizer':
    config = optimizer_config.rms_prop_optimizer
48
49
    learning_rate = _create_learning_rate(config.learning_rate,
                                          global_step=global_step)
50
    summary_vars.append(learning_rate)
51
    optimizer = tf.train.RMSPropOptimizer(
52
        learning_rate,
53
54
55
56
57
58
        decay=config.decay,
        momentum=config.momentum_optimizer_value,
        epsilon=config.epsilon)

  if optimizer_type == 'momentum_optimizer':
    config = optimizer_config.momentum_optimizer
59
60
    learning_rate = _create_learning_rate(config.learning_rate,
                                          global_step=global_step)
61
    summary_vars.append(learning_rate)
62
    optimizer = tf.train.MomentumOptimizer(
63
        learning_rate,
64
65
66
67
        momentum=config.momentum_optimizer_value)

  if optimizer_type == 'adam_optimizer':
    config = optimizer_config.adam_optimizer
68
69
    learning_rate = _create_learning_rate(config.learning_rate,
                                          global_step=global_step)
70
    summary_vars.append(learning_rate)
71
    optimizer = tf.train.AdamOptimizer(learning_rate, epsilon=config.epsilon)
72

73

74
75
76
77
  if optimizer is None:
    raise ValueError('Optimizer %s not supported.' % optimizer_type)

  if optimizer_config.use_moving_average:
78
    optimizer = tf_opt.MovingAverageOptimizer(
79
80
        optimizer, average_decay=optimizer_config.moving_average_decay)

81
  return optimizer, summary_vars
82
83


84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
def build_optimizers_tf_v2(optimizer_config, global_step=None):
  """Create a TF v2 compatible optimizer based on config.

  Args:
    optimizer_config: A Optimizer proto message.
    global_step: A variable representing the current step.
      If None, defaults to tf.train.get_or_create_global_step()

  Returns:
    An optimizer and a list of variables for summary.

  Raises:
    ValueError: when using an unsupported input data type.
  """
  optimizer_type = optimizer_config.WhichOneof('optimizer')
  optimizer = None

  summary_vars = []
  if optimizer_type == 'rms_prop_optimizer':
    config = optimizer_config.rms_prop_optimizer
    learning_rate = _create_learning_rate(config.learning_rate,
                                          global_step=global_step)
    summary_vars.append(learning_rate)
    optimizer = tf.keras.optimizers.RMSprop(
        learning_rate,
        decay=config.decay,
        momentum=config.momentum_optimizer_value,
        epsilon=config.epsilon)

  if optimizer_type == 'momentum_optimizer':
    config = optimizer_config.momentum_optimizer
    learning_rate = _create_learning_rate(config.learning_rate,
                                          global_step=global_step)
    summary_vars.append(learning_rate)
    optimizer = tf.keras.optimizers.SGD(
        learning_rate,
        momentum=config.momentum_optimizer_value)

  if optimizer_type == 'adam_optimizer':
    config = optimizer_config.adam_optimizer
    learning_rate = _create_learning_rate(config.learning_rate,
                                          global_step=global_step)
    summary_vars.append(learning_rate)
127
    optimizer = tf.keras.optimizers.Adam(learning_rate, epsilon=config.epsilon)
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145

  if optimizer is None:
    raise ValueError('Optimizer %s not supported.' % optimizer_type)

  if optimizer_config.use_moving_average:
    raise ValueError('Moving average not supported in eager mode.')

  return optimizer, summary_vars


def build(config, global_step=None):

  if tf.executing_eagerly():
    return build_optimizers_tf_v2(config, global_step)
  else:
    return build_optimizers_tf_v1(config, global_step)


146
def _create_learning_rate(learning_rate_config, global_step=None):
147
148
149
150
  """Create optimizer learning rate based on config.

  Args:
    learning_rate_config: A LearningRate proto message.
151
152
    global_step: A variable representing the current step.
      If None, defaults to tf.train.get_or_create_global_step()
153
154
155
156
157
158
159

  Returns:
    A learning rate.

  Raises:
    ValueError: when using an unsupported input data type.
  """
160
161
  if global_step is None:
    global_step = tf.train.get_or_create_global_step()
162
163
164
165
  learning_rate = None
  learning_rate_type = learning_rate_config.WhichOneof('learning_rate')
  if learning_rate_type == 'constant_learning_rate':
    config = learning_rate_config.constant_learning_rate
166
167
    learning_rate = tf.constant(config.learning_rate, dtype=tf.float32,
                                name='learning_rate')
168
169
170

  if learning_rate_type == 'exponential_decay_learning_rate':
    config = learning_rate_config.exponential_decay_learning_rate
171
    learning_rate = learning_schedules.exponential_decay_with_burnin(
172
        global_step,
173
        config.initial_learning_rate,
174
175
        config.decay_steps,
        config.decay_factor,
176
177
178
179
        burnin_learning_rate=config.burnin_learning_rate,
        burnin_steps=config.burnin_steps,
        min_learning_rate=config.min_learning_rate,
        staircase=config.staircase)
180
181
182
183
184
185
186
187
188

  if learning_rate_type == 'manual_step_learning_rate':
    config = learning_rate_config.manual_step_learning_rate
    if not config.schedule:
      raise ValueError('Empty learning rate schedule.')
    learning_rate_step_boundaries = [x.step for x in config.schedule]
    learning_rate_sequence = [config.initial_learning_rate]
    learning_rate_sequence += [x.learning_rate for x in config.schedule]
    learning_rate = learning_schedules.manual_stepping(
189
        global_step, learning_rate_step_boundaries,
190
        learning_rate_sequence, config.warmup)
191

Vivek Rathod's avatar
Vivek Rathod committed
192
193
194
  if learning_rate_type == 'cosine_decay_learning_rate':
    config = learning_rate_config.cosine_decay_learning_rate
    learning_rate = learning_schedules.cosine_decay_with_warmup(
195
        global_step,
Vivek Rathod's avatar
Vivek Rathod committed
196
197
198
        config.learning_rate_base,
        config.total_steps,
        config.warmup_learning_rate,
199
200
        config.warmup_steps,
        config.hold_base_rate_steps)
Vivek Rathod's avatar
Vivek Rathod committed
201

202
203
204
205
  if learning_rate is None:
    raise ValueError('Learning_rate %s not supported.' % learning_rate_type)

  return learning_rate