hyperparams_builder.py 17.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Builder function to construct tf-slim arg_scope for convolution, fc ops."""
17
18
import tensorflow.compat.v1 as tf
import tf_slim as slim
19

20
from object_detection.core import freezable_batch_norm
21
from object_detection.protos import hyperparams_pb2
22
from object_detection.utils import context_manager
23
from object_detection.utils import tf_version
24

25
26
27
# pylint: disable=g-import-not-at-top
if tf_version.is_tf2():
  from object_detection.core import freezable_sync_batch_norm
28
# pylint: enable=g-import-not-at-top
29
30


31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
class KerasLayerHyperparams(object):
  """
  A hyperparameter configuration object for Keras layers used in
  Object Detection models.
  """

  def __init__(self, hyperparams_config):
    """Builds keras hyperparameter config for layers based on the proto config.

    It automatically converts from Slim layer hyperparameter configs to
    Keras layer hyperparameters. Namely, it:
    - Builds Keras initializers/regularizers instead of Slim ones
    - sets weights_regularizer/initializer to kernel_regularizer/initializer
    - converts batchnorm decay to momentum
    - converts Slim l2 regularizer weights to the equivalent Keras l2 weights

    Contains a hyperparameter configuration for ops that specifies kernel
    initializer, kernel regularizer, activation. Also contains parameters for
    batch norm operators based on the configuration.

    Note that if the batch_norm parameters are not specified in the config
    (i.e. left to default) then batch norm is excluded from the config.

    Args:
      hyperparams_config: hyperparams.proto object containing
        hyperparameters.

    Raises:
      ValueError: if hyperparams_config is not of type hyperparams.Hyperparams.
    """
    if not isinstance(hyperparams_config,
                      hyperparams_pb2.Hyperparams):
      raise ValueError('hyperparams_config not of type '
                       'hyperparams_pb.Hyperparams.')

    self._batch_norm_params = None
67
    self._use_sync_batch_norm = False
68
69
70
    if hyperparams_config.HasField('batch_norm'):
      self._batch_norm_params = _build_keras_batch_norm_params(
          hyperparams_config.batch_norm)
71
72
73
74
    elif hyperparams_config.HasField('sync_batch_norm'):
      self._use_sync_batch_norm = True
      self._batch_norm_params = _build_keras_batch_norm_params(
          hyperparams_config.sync_batch_norm)
75

76
    self._force_use_bias = hyperparams_config.force_use_bias
77
    self._activation_fn = _build_activation_fn(hyperparams_config.activation)
78
79
80
81
    # TODO(kaftan): Unclear if these kwargs apply to separable & depthwise conv
    # (Those might use depthwise_* instead of kernel_*)
    # We should probably switch to using build_conv2d_layer and
    # build_depthwise_conv2d_layer methods instead.
82
83
84
85
86
87
88
89
90
91
92
    self._op_params = {
        'kernel_regularizer': _build_keras_regularizer(
            hyperparams_config.regularizer),
        'kernel_initializer': _build_initializer(
            hyperparams_config.initializer, build_for_keras=True),
        'activation': _build_activation_fn(hyperparams_config.activation)
    }

  def use_batch_norm(self):
    return self._batch_norm_params is not None

93
94
95
96
97
98
99
  def force_use_bias(self):
    return self._force_use_bias

  def use_bias(self):
    return (self._force_use_bias or not
            (self.use_batch_norm() and self.batch_norm_params()['center']))

100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
  def batch_norm_params(self, **overrides):
    """Returns a dict containing batchnorm layer construction hyperparameters.

    Optionally overrides values in the batchnorm hyperparam dict. Overrides
    only apply to individual calls of this method, and do not affect
    future calls.

    Args:
      **overrides: keyword arguments to override in the hyperparams dictionary

    Returns: dict containing the layer construction keyword arguments, with
      values overridden by the `overrides` keyword arguments.
    """
    if self._batch_norm_params is None:
      new_batch_norm_params = dict()
    else:
      new_batch_norm_params = self._batch_norm_params.copy()
    new_batch_norm_params.update(overrides)
    return new_batch_norm_params

120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
  def build_batch_norm(self, training=None, **overrides):
    """Returns a Batch Normalization layer with the appropriate hyperparams.

    If the hyperparams are configured to not use batch normalization,
    this will return a Keras Lambda layer that only applies tf.Identity,
    without doing any normalization.

    Optionally overrides values in the batch_norm hyperparam dict. Overrides
    only apply to individual calls of this method, and do not affect
    future calls.

    Args:
      training: if True, the normalization layer will normalize using the batch
       statistics. If False, the normalization layer will be frozen and will
       act as if it is being used for inference. If None, the layer
       will look up the Keras learning phase at `call` time to decide what to
       do.
      **overrides: batch normalization construction args to override from the
        batch_norm hyperparams dictionary.

    Returns: Either a FreezableBatchNorm layer (if use_batch_norm() is True),
      or a Keras Lambda layer that applies the identity (if use_batch_norm()
      is False)
    """
    if self.use_batch_norm():
145
146
147
148
149
150
      if self._use_sync_batch_norm:
        return freezable_sync_batch_norm.FreezableSyncBatchNorm(
            training=training, **self.batch_norm_params(**overrides))
      else:
        return freezable_batch_norm.FreezableBatchNorm(
            training=training, **self.batch_norm_params(**overrides))
151
152
153
    else:
      return tf.keras.layers.Lambda(tf.identity)

154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
  def build_activation_layer(self, name='activation'):
    """Returns a Keras layer that applies the desired activation function.

    Args:
      name: The name to assign the Keras layer.
    Returns: A Keras lambda layer that applies the activation function
      specified in the hyperparam config, or applies the identity if the
      activation function is None.
    """
    if self._activation_fn:
      return tf.keras.layers.Lambda(self._activation_fn, name=name)
    else:
      return tf.keras.layers.Lambda(tf.identity, name=name)

  def params(self, include_activation=False, **overrides):
169
170
171
172
173
174
175
    """Returns a dict containing the layer construction hyperparameters to use.

    Optionally overrides values in the returned dict. Overrides
    only apply to individual calls of this method, and do not affect
    future calls.

    Args:
176
177
178
179
180
      include_activation: If False, activation in the returned dictionary will
        be set to `None`, and the activation must be applied via a separate
        layer created by `build_activation_layer`. If True, `activation` in the
        output param dictionary will be set to the activation function
        specified in the hyperparams config.
181
182
183
184
185
186
      **overrides: keyword arguments to override in the hyperparams dictionary.

    Returns: dict containing the layer construction keyword arguments, with
      values overridden by the `overrides` keyword arguments.
    """
    new_params = self._op_params.copy()
187
188
189
    new_params['activation'] = None
    if include_activation:
      new_params['activation'] = self._activation_fn
190
    new_params['use_bias'] = self.use_bias()
191
192
193
194
    new_params.update(**overrides)
    return new_params


195
196
197
198
199
200
201
def build(hyperparams_config, is_training):
  """Builds tf-slim arg_scope for convolution ops based on the config.

  Returns an arg_scope to use for convolution ops containing weights
  initializer, weights regularizer, activation function, batch norm function
  and batch norm parameters based on the configuration.

202
203
204
  Note that if no normalization parameters are specified in the config,
  (i.e. left to default) then both batch norm and group norm are excluded
  from the arg_scope.
205
206
207
208
209
210
211
212
213
214
215
216
217

  The batch norm parameters are set for updates based on `is_training` argument
  and conv_hyperparams_config.batch_norm.train parameter. During training, they
  are updated only if batch_norm.train parameter is true. However, during eval,
  no updates are made to the batch norm variables. In both cases, their current
  values are used during forward pass.

  Args:
    hyperparams_config: hyperparams.proto object containing
      hyperparameters.
    is_training: Whether the network is in training mode.

  Returns:
218
219
    arg_scope_fn: A function to construct tf-slim arg_scope containing
      hyperparameters for ops.
220
221
222
223
224
225
226
227
228

  Raises:
    ValueError: if hyperparams_config is not of type hyperparams.Hyperparams.
  """
  if not isinstance(hyperparams_config,
                    hyperparams_pb2.Hyperparams):
    raise ValueError('hyperparams_config not of type '
                     'hyperparams_pb.Hyperparams.')

229
230
231
232
  if hyperparams_config.force_use_bias:
    raise ValueError('Hyperparams force_use_bias only supported by '
                     'KerasLayerHyperparams.')

233
234
235
236
  if hyperparams_config.HasField('sync_batch_norm'):
    raise ValueError('Hyperparams sync_batch_norm only supported by '
                     'KerasLayerHyperparams.')

237
  normalizer_fn = None
238
239
  batch_norm_params = None
  if hyperparams_config.HasField('batch_norm'):
240
    normalizer_fn = slim.batch_norm
241
242
    batch_norm_params = _build_batch_norm_params(
        hyperparams_config.batch_norm, is_training)
243
  if hyperparams_config.HasField('group_norm'):
244
    normalizer_fn = slim.group_norm
245
246
247
248
  affected_ops = [slim.conv2d, slim.separable_conv2d, slim.conv2d_transpose]
  if hyperparams_config.HasField('op') and (
      hyperparams_config.op == hyperparams_pb2.Hyperparams.FC):
    affected_ops = [slim.fully_connected]
249
  def scope_fn():
250
251
252
253
254
    with (slim.arg_scope([slim.batch_norm], **batch_norm_params)
          if batch_norm_params is not None else
          context_manager.IdentityContextManager()):
      with slim.arg_scope(
          affected_ops,
255
          weights_regularizer=_build_slim_regularizer(
256
257
258
259
              hyperparams_config.regularizer),
          weights_initializer=_build_initializer(
              hyperparams_config.initializer),
          activation_fn=_build_activation_fn(hyperparams_config.activation),
260
          normalizer_fn=normalizer_fn) as sc:
261
262
        return sc

263
  return scope_fn
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283


def _build_activation_fn(activation_fn):
  """Builds a callable activation from config.

  Args:
    activation_fn: hyperparams_pb2.Hyperparams.activation

  Returns:
    Callable activation function.

  Raises:
    ValueError: On unknown activation function.
  """
  if activation_fn == hyperparams_pb2.Hyperparams.NONE:
    return None
  if activation_fn == hyperparams_pb2.Hyperparams.RELU:
    return tf.nn.relu
  if activation_fn == hyperparams_pb2.Hyperparams.RELU_6:
    return tf.nn.relu6
284
285
  if activation_fn == hyperparams_pb2.Hyperparams.SWISH:
    return tf.nn.swish
286
287
288
  raise ValueError('Unknown activation function: {}'.format(activation_fn))


289
def _build_slim_regularizer(regularizer):
290
291
292
293
294
295
296
297
298
299
300
301
302
  """Builds a tf-slim regularizer from config.

  Args:
    regularizer: hyperparams_pb2.Hyperparams.regularizer proto.

  Returns:
    tf-slim regularizer.

  Raises:
    ValueError: On unknown regularizer.
  """
  regularizer_oneof = regularizer.WhichOneof('regularizer_oneof')
  if  regularizer_oneof == 'l1_regularizer':
303
    return slim.l1_regularizer(scale=float(regularizer.l1_regularizer.weight))
304
  if regularizer_oneof == 'l2_regularizer':
305
    return slim.l2_regularizer(scale=float(regularizer.l2_regularizer.weight))
306
307
  if regularizer_oneof is None:
    return None
308
309
310
  raise ValueError('Unknown regularizer function: {}'.format(regularizer_oneof))


311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
def _build_keras_regularizer(regularizer):
  """Builds a keras regularizer from config.

  Args:
    regularizer: hyperparams_pb2.Hyperparams.regularizer proto.

  Returns:
    Keras regularizer.

  Raises:
    ValueError: On unknown regularizer.
  """
  regularizer_oneof = regularizer.WhichOneof('regularizer_oneof')
  if  regularizer_oneof == 'l1_regularizer':
    return tf.keras.regularizers.l1(float(regularizer.l1_regularizer.weight))
  if regularizer_oneof == 'l2_regularizer':
    # The Keras L2 regularizer weight differs from the Slim L2 regularizer
    # weight by a factor of 2
    return tf.keras.regularizers.l2(
        float(regularizer.l2_regularizer.weight * 0.5))
331
332
  if regularizer_oneof is None:
    return None
333
334
335
336
  raise ValueError('Unknown regularizer function: {}'.format(regularizer_oneof))


def _build_initializer(initializer, build_for_keras=False):
337
338
339
340
  """Build a tf initializer from config.

  Args:
    initializer: hyperparams_pb2.Hyperparams.regularizer proto.
341
342
    build_for_keras: Whether the initializers should be built for Keras
      operators. If false builds for Slim.
343
344
345
346
347
348
349
350
351
352
353
354

  Returns:
    tf initializer.

  Raises:
    ValueError: On unknown initializer.
  """
  initializer_oneof = initializer.WhichOneof('initializer_oneof')
  if initializer_oneof == 'truncated_normal_initializer':
    return tf.truncated_normal_initializer(
        mean=initializer.truncated_normal_initializer.mean,
        stddev=initializer.truncated_normal_initializer.stddev)
355
356
357
358
  if initializer_oneof == 'random_normal_initializer':
    return tf.random_normal_initializer(
        mean=initializer.random_normal_initializer.mean,
        stddev=initializer.random_normal_initializer.stddev)
359
360
361
362
363
364
  if initializer_oneof == 'variance_scaling_initializer':
    enum_descriptor = (hyperparams_pb2.VarianceScalingInitializer.
                       DESCRIPTOR.enum_types_by_name['Mode'])
    mode = enum_descriptor.values_by_number[initializer.
                                            variance_scaling_initializer.
                                            mode].name
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
    if build_for_keras:
      if initializer.variance_scaling_initializer.uniform:
        return tf.variance_scaling_initializer(
            scale=initializer.variance_scaling_initializer.factor,
            mode=mode.lower(),
            distribution='uniform')
      else:
        # In TF 1.9 release and earlier, the truncated_normal distribution was
        # not supported correctly. So, in these earlier versions of tensorflow,
        # the ValueError will be raised, and we manually truncate the
        # distribution scale.
        #
        # It is insufficient to just set distribution to `normal` from the
        # start, because the `normal` distribution in newer Tensorflow versions
        # creates a truncated distribution, whereas it created untruncated
        # distributions in older versions.
        try:
          return tf.variance_scaling_initializer(
              scale=initializer.variance_scaling_initializer.factor,
              mode=mode.lower(),
              distribution='truncated_normal')
        except ValueError:
          truncate_constant = 0.87962566103423978
          truncated_scale = initializer.variance_scaling_initializer.factor / (
              truncate_constant * truncate_constant
          )
          return tf.variance_scaling_initializer(
              scale=truncated_scale,
              mode=mode.lower(),
              distribution='normal')

    else:
      return slim.variance_scaling_initializer(
          factor=initializer.variance_scaling_initializer.factor,
          mode=mode,
          uniform=initializer.variance_scaling_initializer.uniform)
401
402
  if initializer_oneof is None:
    return None
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
  raise ValueError('Unknown initializer function: {}'.format(
      initializer_oneof))


def _build_batch_norm_params(batch_norm, is_training):
  """Build a dictionary of batch_norm params from config.

  Args:
    batch_norm: hyperparams_pb2.ConvHyperparams.batch_norm proto.
    is_training: Whether the models is in training mode.

  Returns:
    A dictionary containing batch_norm parameters.
  """
  batch_norm_params = {
      'decay': batch_norm.decay,
      'center': batch_norm.center,
      'scale': batch_norm.scale,
      'epsilon': batch_norm.epsilon,
422
423
424
      # Remove is_training parameter from here and deprecate it in the proto
      # once we refactor Faster RCNN models to set is_training through an outer
      # arg_scope in the meta architecture.
425
426
427
      'is_training': is_training and batch_norm.train,
  }
  return batch_norm_params
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449


def _build_keras_batch_norm_params(batch_norm):
  """Build a dictionary of Keras BatchNormalization params from config.

  Args:
    batch_norm: hyperparams_pb2.ConvHyperparams.batch_norm proto.

  Returns:
    A dictionary containing Keras BatchNormalization parameters.
  """
  # Note: Although decay is defined to be 1 - momentum in batch_norm,
  # decay in the slim batch_norm layers was erroneously defined and is
  # actually the same as momentum in the Keras batch_norm layers.
  # For context, see: github.com/keras-team/keras/issues/6839
  batch_norm_params = {
      'momentum': batch_norm.decay,
      'center': batch_norm.center,
      'scale': batch_norm.scale,
      'epsilon': batch_norm.epsilon,
  }
  return batch_norm_params