optimizer_builder.py 7.14 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Functions to build DetectionModel training optimizers."""

import tensorflow as tf
19
20


21
22
23
from object_detection.utils import learning_schedules


24
25
def build_optimizers_tf_v1(optimizer_config, global_step=None):
  """Create a TF v1 compatible optimizer based on config.
26
27
28

  Args:
    optimizer_config: A Optimizer proto message.
29
30
    global_step: A variable representing the current step.
      If None, defaults to tf.train.get_or_create_global_step()
31
32

  Returns:
33
    An optimizer and a list of variables for summary.
34
35
36
37
38
39
40

  Raises:
    ValueError: when using an unsupported input data type.
  """
  optimizer_type = optimizer_config.WhichOneof('optimizer')
  optimizer = None

41
  summary_vars = []
42
43
  if optimizer_type == 'rms_prop_optimizer':
    config = optimizer_config.rms_prop_optimizer
44
45
    learning_rate = _create_learning_rate(config.learning_rate,
                                          global_step=global_step)
46
    summary_vars.append(learning_rate)
47
    optimizer = tf.train.RMSPropOptimizer(
48
        learning_rate,
49
50
51
52
53
54
        decay=config.decay,
        momentum=config.momentum_optimizer_value,
        epsilon=config.epsilon)

  if optimizer_type == 'momentum_optimizer':
    config = optimizer_config.momentum_optimizer
55
56
    learning_rate = _create_learning_rate(config.learning_rate,
                                          global_step=global_step)
57
    summary_vars.append(learning_rate)
58
    optimizer = tf.train.MomentumOptimizer(
59
        learning_rate,
60
61
62
63
        momentum=config.momentum_optimizer_value)

  if optimizer_type == 'adam_optimizer':
    config = optimizer_config.adam_optimizer
64
65
    learning_rate = _create_learning_rate(config.learning_rate,
                                          global_step=global_step)
66
67
    summary_vars.append(learning_rate)
    optimizer = tf.train.AdamOptimizer(learning_rate)
68

69

70
71
72
73
74
75
76
  if optimizer is None:
    raise ValueError('Optimizer %s not supported.' % optimizer_type)

  if optimizer_config.use_moving_average:
    optimizer = tf.contrib.opt.MovingAverageOptimizer(
        optimizer, average_decay=optimizer_config.moving_average_decay)

77
  return optimizer, summary_vars
78
79


80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
def build_optimizers_tf_v2(optimizer_config, global_step=None):
  """Create a TF v2 compatible optimizer based on config.

  Args:
    optimizer_config: A Optimizer proto message.
    global_step: A variable representing the current step.
      If None, defaults to tf.train.get_or_create_global_step()

  Returns:
    An optimizer and a list of variables for summary.

  Raises:
    ValueError: when using an unsupported input data type.
  """
  optimizer_type = optimizer_config.WhichOneof('optimizer')
  optimizer = None

  summary_vars = []
  if optimizer_type == 'rms_prop_optimizer':
    config = optimizer_config.rms_prop_optimizer
    learning_rate = _create_learning_rate(config.learning_rate,
                                          global_step=global_step)
    summary_vars.append(learning_rate)
    optimizer = tf.keras.optimizers.RMSprop(
        learning_rate,
        decay=config.decay,
        momentum=config.momentum_optimizer_value,
        epsilon=config.epsilon)

  if optimizer_type == 'momentum_optimizer':
    config = optimizer_config.momentum_optimizer
    learning_rate = _create_learning_rate(config.learning_rate,
                                          global_step=global_step)
    summary_vars.append(learning_rate)
    optimizer = tf.keras.optimizers.SGD(
        learning_rate,
        momentum=config.momentum_optimizer_value)

  if optimizer_type == 'adam_optimizer':
    config = optimizer_config.adam_optimizer
    learning_rate = _create_learning_rate(config.learning_rate,
                                          global_step=global_step)
    summary_vars.append(learning_rate)
    optimizer = tf.keras.optimizers.Adam(learning_rate)

  if optimizer is None:
    raise ValueError('Optimizer %s not supported.' % optimizer_type)

  if optimizer_config.use_moving_average:
    raise ValueError('Moving average not supported in eager mode.')

  return optimizer, summary_vars


def build(config, global_step=None):

  if tf.executing_eagerly():
    return build_optimizers_tf_v2(config, global_step)
  else:
    return build_optimizers_tf_v1(config, global_step)


142
def _create_learning_rate(learning_rate_config, global_step=None):
143
144
145
146
  """Create optimizer learning rate based on config.

  Args:
    learning_rate_config: A LearningRate proto message.
147
148
    global_step: A variable representing the current step.
      If None, defaults to tf.train.get_or_create_global_step()
149
150
151
152
153
154
155

  Returns:
    A learning rate.

  Raises:
    ValueError: when using an unsupported input data type.
  """
156
157
  if global_step is None:
    global_step = tf.train.get_or_create_global_step()
158
159
160
161
  learning_rate = None
  learning_rate_type = learning_rate_config.WhichOneof('learning_rate')
  if learning_rate_type == 'constant_learning_rate':
    config = learning_rate_config.constant_learning_rate
162
163
    learning_rate = tf.constant(config.learning_rate, dtype=tf.float32,
                                name='learning_rate')
164
165
166

  if learning_rate_type == 'exponential_decay_learning_rate':
    config = learning_rate_config.exponential_decay_learning_rate
167
    learning_rate = learning_schedules.exponential_decay_with_burnin(
168
        global_step,
169
        config.initial_learning_rate,
170
171
        config.decay_steps,
        config.decay_factor,
172
173
174
175
        burnin_learning_rate=config.burnin_learning_rate,
        burnin_steps=config.burnin_steps,
        min_learning_rate=config.min_learning_rate,
        staircase=config.staircase)
176
177
178
179
180
181
182
183
184

  if learning_rate_type == 'manual_step_learning_rate':
    config = learning_rate_config.manual_step_learning_rate
    if not config.schedule:
      raise ValueError('Empty learning rate schedule.')
    learning_rate_step_boundaries = [x.step for x in config.schedule]
    learning_rate_sequence = [config.initial_learning_rate]
    learning_rate_sequence += [x.learning_rate for x in config.schedule]
    learning_rate = learning_schedules.manual_stepping(
185
        global_step, learning_rate_step_boundaries,
186
        learning_rate_sequence, config.warmup)
187

Vivek Rathod's avatar
Vivek Rathod committed
188
189
190
  if learning_rate_type == 'cosine_decay_learning_rate':
    config = learning_rate_config.cosine_decay_learning_rate
    learning_rate = learning_schedules.cosine_decay_with_warmup(
191
        global_step,
Vivek Rathod's avatar
Vivek Rathod committed
192
193
194
        config.learning_rate_base,
        config.total_steps,
        config.warmup_learning_rate,
195
196
        config.warmup_steps,
        config.hold_base_rate_steps)
Vivek Rathod's avatar
Vivek Rathod committed
197

198
199
200
201
  if learning_rate is None:
    raise ValueError('Learning_rate %s not supported.' % learning_rate_type)

  return learning_rate