Commit 2b566593 authored by Austin Myers's avatar Austin Myers Committed by TF Object Detection Team
Browse files

Adds FreezableSyncBatchNormalization to the Object Detection API.

PiperOrigin-RevId: 352601466
parent 7fcb79bb
...@@ -20,7 +20,11 @@ import tf_slim as slim ...@@ -20,7 +20,11 @@ import tf_slim as slim
from object_detection.core import freezable_batch_norm from object_detection.core import freezable_batch_norm
from object_detection.protos import hyperparams_pb2 from object_detection.protos import hyperparams_pb2
from object_detection.utils import context_manager from object_detection.utils import context_manager
from object_detection.utils import tf_version
# pylint: disable=g-import-not-at-top
if tf_version.is_tf2():
from object_detection.core import freezable_sync_batch_norm
# pylint: enable=g-import-not-at-top # pylint: enable=g-import-not-at-top
...@@ -60,9 +64,14 @@ class KerasLayerHyperparams(object): ...@@ -60,9 +64,14 @@ class KerasLayerHyperparams(object):
'hyperparams_pb.Hyperparams.') 'hyperparams_pb.Hyperparams.')
self._batch_norm_params = None self._batch_norm_params = None
self._use_sync_batch_norm = False
if hyperparams_config.HasField('batch_norm'): if hyperparams_config.HasField('batch_norm'):
self._batch_norm_params = _build_keras_batch_norm_params( self._batch_norm_params = _build_keras_batch_norm_params(
hyperparams_config.batch_norm) hyperparams_config.batch_norm)
elif hyperparams_config.HasField('sync_batch_norm'):
self._use_sync_batch_norm = True
self._batch_norm_params = _build_keras_batch_norm_params(
hyperparams_config.sync_batch_norm)
self._force_use_bias = hyperparams_config.force_use_bias self._force_use_bias = hyperparams_config.force_use_bias
self._activation_fn = _build_activation_fn(hyperparams_config.activation) self._activation_fn = _build_activation_fn(hyperparams_config.activation)
...@@ -133,10 +142,12 @@ class KerasLayerHyperparams(object): ...@@ -133,10 +142,12 @@ class KerasLayerHyperparams(object):
is False) is False)
""" """
if self.use_batch_norm(): if self.use_batch_norm():
if self._use_sync_batch_norm:
return freezable_sync_batch_norm.FreezableSyncBatchNorm(
training=training, **self.batch_norm_params(**overrides))
else:
return freezable_batch_norm.FreezableBatchNorm( return freezable_batch_norm.FreezableBatchNorm(
training=training, training=training, **self.batch_norm_params(**overrides))
**self.batch_norm_params(**overrides)
)
else: else:
return tf.keras.layers.Lambda(tf.identity) return tf.keras.layers.Lambda(tf.identity)
...@@ -219,6 +230,10 @@ def build(hyperparams_config, is_training): ...@@ -219,6 +230,10 @@ def build(hyperparams_config, is_training):
raise ValueError('Hyperparams force_use_bias only supported by ' raise ValueError('Hyperparams force_use_bias only supported by '
'KerasLayerHyperparams.') 'KerasLayerHyperparams.')
if hyperparams_config.HasField('sync_batch_norm'):
raise ValueError('Hyperparams sync_batch_norm only supported by '
'KerasLayerHyperparams.')
normalizer_fn = None normalizer_fn = None
batch_norm_params = None batch_norm_params = None
if hyperparams_config.HasField('batch_norm'): if hyperparams_config.HasField('batch_norm'):
......
...@@ -17,25 +17,40 @@ ...@@ -17,25 +17,40 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import unittest import unittest
from absl.testing import parameterized
import numpy as np import numpy as np
from six.moves import zip from six.moves import zip
import tensorflow.compat.v1 as tf import tensorflow as tf
from object_detection.core import freezable_batch_norm from object_detection.core import freezable_batch_norm
from object_detection.utils import tf_version from object_detection.utils import tf_version
# pylint: disable=g-import-not-at-top
if tf_version.is_tf2():
from object_detection.core import freezable_sync_batch_norm
# pylint: enable=g-import-not-at-top
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.') @unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class FreezableBatchNormTest(tf.test.TestCase): class FreezableBatchNormTest(tf.test.TestCase, parameterized.TestCase):
"""Tests for FreezableBatchNorm operations.""" """Tests for FreezableBatchNorm operations."""
def _build_model(self, training=None): def _build_model(self, use_sync_batch_norm, training=None):
model = tf.keras.models.Sequential() model = tf.keras.models.Sequential()
norm = None
if use_sync_batch_norm:
norm = freezable_sync_batch_norm.FreezableSyncBatchNorm(training=training,
input_shape=(10,),
momentum=0.8)
else:
norm = freezable_batch_norm.FreezableBatchNorm(training=training, norm = freezable_batch_norm.FreezableBatchNorm(training=training,
input_shape=(10,), input_shape=(10,),
momentum=0.8) momentum=0.8)
model.add(norm) model.add(norm)
return model, norm return model, norm
...@@ -43,8 +58,9 @@ class FreezableBatchNormTest(tf.test.TestCase): ...@@ -43,8 +58,9 @@ class FreezableBatchNormTest(tf.test.TestCase):
for source, target in zip(source_weights, target_weights): for source, target in zip(source_weights, target_weights):
target.assign(source) target.assign(source)
def _train_freezable_batch_norm(self, training_mean, training_var): def _train_freezable_batch_norm(self, training_mean, training_var,
model, _ = self._build_model() use_sync_batch_norm):
model, _ = self._build_model(use_sync_batch_norm=use_sync_batch_norm)
model.compile(loss='mse', optimizer='sgd') model.compile(loss='mse', optimizer='sgd')
# centered on training_mean, variance training_var # centered on training_mean, variance training_var
...@@ -72,7 +88,8 @@ class FreezableBatchNormTest(tf.test.TestCase): ...@@ -72,7 +88,8 @@ class FreezableBatchNormTest(tf.test.TestCase):
np.testing.assert_allclose(out.numpy().mean(), 0.0, atol=1.5e-1) np.testing.assert_allclose(out.numpy().mean(), 0.0, atol=1.5e-1)
np.testing.assert_allclose(out.numpy().std(), 1.0, atol=1.5e-1) np.testing.assert_allclose(out.numpy().std(), 1.0, atol=1.5e-1)
def test_batchnorm_freezing_training_none(self): @parameterized.parameters(True, False)
def test_batchnorm_freezing_training_none(self, use_sync_batch_norm):
training_mean = 5.0 training_mean = 5.0
training_var = 10.0 training_var = 10.0
...@@ -81,12 +98,13 @@ class FreezableBatchNormTest(tf.test.TestCase): ...@@ -81,12 +98,13 @@ class FreezableBatchNormTest(tf.test.TestCase):
# Initially train the batch norm, and save the weights # Initially train the batch norm, and save the weights
trained_weights = self._train_freezable_batch_norm(training_mean, trained_weights = self._train_freezable_batch_norm(training_mean,
training_var) training_var,
use_sync_batch_norm)
# Load the batch norm weights, freezing training to True. # Load the batch norm weights, freezing training to True.
# Apply the batch norm layer to testing data and ensure it is normalized # Apply the batch norm layer to testing data and ensure it is normalized
# according to the batch statistics. # according to the batch statistics.
model, norm = self._build_model(training=True) model, norm = self._build_model(use_sync_batch_norm, training=True)
self._copy_weights(trained_weights, model.weights) self._copy_weights(trained_weights, model.weights)
# centered on testing_mean, variance testing_var # centered on testing_mean, variance testing_var
...@@ -136,7 +154,8 @@ class FreezableBatchNormTest(tf.test.TestCase): ...@@ -136,7 +154,8 @@ class FreezableBatchNormTest(tf.test.TestCase):
testing_mean, testing_var, training_arg, testing_mean, testing_var, training_arg,
training_mean, training_var) training_mean, training_var)
def test_batchnorm_freezing_training_false(self): @parameterized.parameters(True, False)
def test_batchnorm_freezing_training_false(self, use_sync_batch_norm):
training_mean = 5.0 training_mean = 5.0
training_var = 10.0 training_var = 10.0
...@@ -145,12 +164,13 @@ class FreezableBatchNormTest(tf.test.TestCase): ...@@ -145,12 +164,13 @@ class FreezableBatchNormTest(tf.test.TestCase):
# Initially train the batch norm, and save the weights # Initially train the batch norm, and save the weights
trained_weights = self._train_freezable_batch_norm(training_mean, trained_weights = self._train_freezable_batch_norm(training_mean,
training_var) training_var,
use_sync_batch_norm)
# Load the batch norm back up, freezing training to False. # Load the batch norm back up, freezing training to False.
# Apply the batch norm layer to testing data and ensure it is normalized # Apply the batch norm layer to testing data and ensure it is normalized
# according to the training data's statistics. # according to the training data's statistics.
model, norm = self._build_model(training=False) model, norm = self._build_model(use_sync_batch_norm, training=False)
self._copy_weights(trained_weights, model.weights) self._copy_weights(trained_weights, model.weights)
# centered on testing_mean, variance testing_var # centered on testing_mean, variance testing_var
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A freezable batch norm layer that uses Keras sync batch normalization."""
import tensorflow as tf
class FreezableSyncBatchNorm(tf.keras.layers.experimental.SyncBatchNormalization
):
"""Sync Batch normalization layer (Ioffe and Szegedy, 2014).
This is a `freezable` batch norm layer that supports setting the `training`
parameter in the __init__ method rather than having to set it either via
the Keras learning phase or via the `call` method parameter. This layer will
forward all other parameters to the Keras `SyncBatchNormalization` layer
This is class is necessary because Object Detection model training sometimes
requires batch normalization layers to be `frozen` and used as if it was
evaluation time, despite still training (and potentially using dropout layers)
Like the default Keras SyncBatchNormalization layer, this will normalize the
activations of the previous layer at each batch,
i.e. applies a transformation that maintains the mean activation
close to 0 and the activation standard deviation close to 1.
Input shape:
Arbitrary. Use the keyword argument `input_shape`
(tuple of integers, does not include the samples axis)
when using this layer as the first layer in a model.
Output shape:
Same shape as input.
References:
- [Batch Normalization: Accelerating Deep Network Training by Reducing
Internal Covariate Shift](https://arxiv.org/abs/1502.03167)
"""
def __init__(self, training=None, **kwargs):
"""Constructor.
Args:
training: If False, the layer will normalize using the moving average and
std. dev, without updating the learned avg and std. dev.
If None or True, the layer will follow the keras SyncBatchNormalization
layer strategy of checking the Keras learning phase at `call` time to
decide what to do.
**kwargs: The keyword arguments to forward to the keras
SyncBatchNormalization layer constructor.
"""
super(FreezableSyncBatchNorm, self).__init__(**kwargs)
self._training = training
def call(self, inputs, training=None):
# Override the call arg only if the batchnorm is frozen. (Ignore None)
if self._training is False: # pylint: disable=g-bool-id-comparison
training = self._training
return super(FreezableSyncBatchNorm, self).call(inputs, training=training)
...@@ -42,6 +42,8 @@ message Hyperparams { ...@@ -42,6 +42,8 @@ message Hyperparams {
// Note that if nothing below is selected, then no normalization is applied // Note that if nothing below is selected, then no normalization is applied
// BatchNorm hyperparameters. // BatchNorm hyperparameters.
BatchNorm batch_norm = 5; BatchNorm batch_norm = 5;
// SyncBatchNorm hyperparameters (KerasLayerHyperparams only).
BatchNorm sync_batch_norm = 9;
// GroupNorm hyperparameters. This is only supported on a subset of models. // GroupNorm hyperparameters. This is only supported on a subset of models.
// Note that the current implementation of group norm instantiated in // Note that the current implementation of group norm instantiated in
// tf.contrib.group.layers.group_norm() only supports fixed_size_resizer // tf.contrib.group.layers.group_norm() only supports fixed_size_resizer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment