optimizer_builder.py 7.22 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Functions to build DetectionModel training optimizers."""

import tensorflow as tf
19
20


21
from tensorflow.contrib import opt as tf_opt
22
23
24
from object_detection.utils import learning_schedules


25
26
def build_optimizers_tf_v1(optimizer_config, global_step=None):
  """Create a TF v1 compatible optimizer based on config.
27
28
29

  Args:
    optimizer_config: A Optimizer proto message.
30
31
    global_step: A variable representing the current step.
      If None, defaults to tf.train.get_or_create_global_step()
32
33

  Returns:
34
    An optimizer and a list of variables for summary.
35
36
37
38
39
40
41

  Raises:
    ValueError: when using an unsupported input data type.
  """
  optimizer_type = optimizer_config.WhichOneof('optimizer')
  optimizer = None

42
  summary_vars = []
43
44
  if optimizer_type == 'rms_prop_optimizer':
    config = optimizer_config.rms_prop_optimizer
45
46
    learning_rate = _create_learning_rate(config.learning_rate,
                                          global_step=global_step)
47
    summary_vars.append(learning_rate)
48
    optimizer = tf.train.RMSPropOptimizer(
49
        learning_rate,
50
51
52
53
54
55
        decay=config.decay,
        momentum=config.momentum_optimizer_value,
        epsilon=config.epsilon)

  if optimizer_type == 'momentum_optimizer':
    config = optimizer_config.momentum_optimizer
56
57
    learning_rate = _create_learning_rate(config.learning_rate,
                                          global_step=global_step)
58
    summary_vars.append(learning_rate)
59
    optimizer = tf.train.MomentumOptimizer(
60
        learning_rate,
61
62
63
64
        momentum=config.momentum_optimizer_value)

  if optimizer_type == 'adam_optimizer':
    config = optimizer_config.adam_optimizer
65
66
    learning_rate = _create_learning_rate(config.learning_rate,
                                          global_step=global_step)
67
    summary_vars.append(learning_rate)
68
    optimizer = tf.train.AdamOptimizer(learning_rate, epsilon=config.epsilon)
69

70

71
72
73
74
  if optimizer is None:
    raise ValueError('Optimizer %s not supported.' % optimizer_type)

  if optimizer_config.use_moving_average:
75
    optimizer = tf_opt.MovingAverageOptimizer(
76
77
        optimizer, average_decay=optimizer_config.moving_average_decay)

78
  return optimizer, summary_vars
79
80


81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
def build_optimizers_tf_v2(optimizer_config, global_step=None):
  """Create a TF v2 compatible optimizer based on config.

  Args:
    optimizer_config: A Optimizer proto message.
    global_step: A variable representing the current step.
      If None, defaults to tf.train.get_or_create_global_step()

  Returns:
    An optimizer and a list of variables for summary.

  Raises:
    ValueError: when using an unsupported input data type.
  """
  optimizer_type = optimizer_config.WhichOneof('optimizer')
  optimizer = None

  summary_vars = []
  if optimizer_type == 'rms_prop_optimizer':
    config = optimizer_config.rms_prop_optimizer
    learning_rate = _create_learning_rate(config.learning_rate,
                                          global_step=global_step)
    summary_vars.append(learning_rate)
    optimizer = tf.keras.optimizers.RMSprop(
        learning_rate,
        decay=config.decay,
        momentum=config.momentum_optimizer_value,
        epsilon=config.epsilon)

  if optimizer_type == 'momentum_optimizer':
    config = optimizer_config.momentum_optimizer
    learning_rate = _create_learning_rate(config.learning_rate,
                                          global_step=global_step)
    summary_vars.append(learning_rate)
    optimizer = tf.keras.optimizers.SGD(
        learning_rate,
        momentum=config.momentum_optimizer_value)

  if optimizer_type == 'adam_optimizer':
    config = optimizer_config.adam_optimizer
    learning_rate = _create_learning_rate(config.learning_rate,
                                          global_step=global_step)
    summary_vars.append(learning_rate)
124
    optimizer = tf.keras.optimizers.Adam(learning_rate, epsilon=config.epsilon)
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142

  if optimizer is None:
    raise ValueError('Optimizer %s not supported.' % optimizer_type)

  if optimizer_config.use_moving_average:
    raise ValueError('Moving average not supported in eager mode.')

  return optimizer, summary_vars


def build(config, global_step=None):

  if tf.executing_eagerly():
    return build_optimizers_tf_v2(config, global_step)
  else:
    return build_optimizers_tf_v1(config, global_step)


143
def _create_learning_rate(learning_rate_config, global_step=None):
144
145
146
147
  """Create optimizer learning rate based on config.

  Args:
    learning_rate_config: A LearningRate proto message.
148
149
    global_step: A variable representing the current step.
      If None, defaults to tf.train.get_or_create_global_step()
150
151
152
153
154
155
156

  Returns:
    A learning rate.

  Raises:
    ValueError: when using an unsupported input data type.
  """
157
158
  if global_step is None:
    global_step = tf.train.get_or_create_global_step()
159
160
161
162
  learning_rate = None
  learning_rate_type = learning_rate_config.WhichOneof('learning_rate')
  if learning_rate_type == 'constant_learning_rate':
    config = learning_rate_config.constant_learning_rate
163
164
    learning_rate = tf.constant(config.learning_rate, dtype=tf.float32,
                                name='learning_rate')
165
166
167

  if learning_rate_type == 'exponential_decay_learning_rate':
    config = learning_rate_config.exponential_decay_learning_rate
168
    learning_rate = learning_schedules.exponential_decay_with_burnin(
169
        global_step,
170
        config.initial_learning_rate,
171
172
        config.decay_steps,
        config.decay_factor,
173
174
175
176
        burnin_learning_rate=config.burnin_learning_rate,
        burnin_steps=config.burnin_steps,
        min_learning_rate=config.min_learning_rate,
        staircase=config.staircase)
177
178
179
180
181
182
183
184
185

  if learning_rate_type == 'manual_step_learning_rate':
    config = learning_rate_config.manual_step_learning_rate
    if not config.schedule:
      raise ValueError('Empty learning rate schedule.')
    learning_rate_step_boundaries = [x.step for x in config.schedule]
    learning_rate_sequence = [config.initial_learning_rate]
    learning_rate_sequence += [x.learning_rate for x in config.schedule]
    learning_rate = learning_schedules.manual_stepping(
186
        global_step, learning_rate_step_boundaries,
187
        learning_rate_sequence, config.warmup)
188

Vivek Rathod's avatar
Vivek Rathod committed
189
190
191
  if learning_rate_type == 'cosine_decay_learning_rate':
    config = learning_rate_config.cosine_decay_learning_rate
    learning_rate = learning_schedules.cosine_decay_with_warmup(
192
        global_step,
Vivek Rathod's avatar
Vivek Rathod committed
193
194
195
        config.learning_rate_base,
        config.total_steps,
        config.warmup_learning_rate,
196
197
        config.warmup_steps,
        config.hold_base_rate_steps)
Vivek Rathod's avatar
Vivek Rathod committed
198

199
200
201
202
  if learning_rate is None:
    raise ValueError('Learning_rate %s not supported.' % learning_rate_type)

  return learning_rate