Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
2b566593
Commit
2b566593
authored
Jan 19, 2021
by
Austin Myers
Committed by
TF Object Detection Team
Jan 19, 2021
Browse files
Adds FreezableSyncBatchNormalization to the Object Detection API.
PiperOrigin-RevId: 352601466
parent
7fcb79bb
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
125 additions
and
18 deletions
+125
-18
research/object_detection/builders/hyperparams_builder.py
research/object_detection/builders/hyperparams_builder.py
+19
-4
research/object_detection/core/freezable_batch_norm_tf2_test.py
...ch/object_detection/core/freezable_batch_norm_tf2_test.py
+34
-14
research/object_detection/core/freezable_sync_batch_norm.py
research/object_detection/core/freezable_sync_batch_norm.py
+70
-0
research/object_detection/protos/hyperparams.proto
research/object_detection/protos/hyperparams.proto
+2
-0
No files found.
research/object_detection/builders/hyperparams_builder.py
View file @
2b566593
...
@@ -20,7 +20,11 @@ import tf_slim as slim
...
@@ -20,7 +20,11 @@ import tf_slim as slim
from
object_detection.core
import
freezable_batch_norm
from
object_detection.core
import
freezable_batch_norm
from
object_detection.protos
import
hyperparams_pb2
from
object_detection.protos
import
hyperparams_pb2
from
object_detection.utils
import
context_manager
from
object_detection.utils
import
context_manager
from
object_detection.utils
import
tf_version
# pylint: disable=g-import-not-at-top
if
tf_version
.
is_tf2
():
from
object_detection.core
import
freezable_sync_batch_norm
# pylint: enable=g-import-not-at-top
# pylint: enable=g-import-not-at-top
...
@@ -60,9 +64,14 @@ class KerasLayerHyperparams(object):
...
@@ -60,9 +64,14 @@ class KerasLayerHyperparams(object):
'hyperparams_pb.Hyperparams.'
)
'hyperparams_pb.Hyperparams.'
)
self
.
_batch_norm_params
=
None
self
.
_batch_norm_params
=
None
self
.
_use_sync_batch_norm
=
False
if
hyperparams_config
.
HasField
(
'batch_norm'
):
if
hyperparams_config
.
HasField
(
'batch_norm'
):
self
.
_batch_norm_params
=
_build_keras_batch_norm_params
(
self
.
_batch_norm_params
=
_build_keras_batch_norm_params
(
hyperparams_config
.
batch_norm
)
hyperparams_config
.
batch_norm
)
elif
hyperparams_config
.
HasField
(
'sync_batch_norm'
):
self
.
_use_sync_batch_norm
=
True
self
.
_batch_norm_params
=
_build_keras_batch_norm_params
(
hyperparams_config
.
sync_batch_norm
)
self
.
_force_use_bias
=
hyperparams_config
.
force_use_bias
self
.
_force_use_bias
=
hyperparams_config
.
force_use_bias
self
.
_activation_fn
=
_build_activation_fn
(
hyperparams_config
.
activation
)
self
.
_activation_fn
=
_build_activation_fn
(
hyperparams_config
.
activation
)
...
@@ -133,10 +142,12 @@ class KerasLayerHyperparams(object):
...
@@ -133,10 +142,12 @@ class KerasLayerHyperparams(object):
is False)
is False)
"""
"""
if
self
.
use_batch_norm
():
if
self
.
use_batch_norm
():
if
self
.
_use_sync_batch_norm
:
return
freezable_sync_batch_norm
.
FreezableSyncBatchNorm
(
training
=
training
,
**
self
.
batch_norm_params
(
**
overrides
))
else
:
return
freezable_batch_norm
.
FreezableBatchNorm
(
return
freezable_batch_norm
.
FreezableBatchNorm
(
training
=
training
,
training
=
training
,
**
self
.
batch_norm_params
(
**
overrides
))
**
self
.
batch_norm_params
(
**
overrides
)
)
else
:
else
:
return
tf
.
keras
.
layers
.
Lambda
(
tf
.
identity
)
return
tf
.
keras
.
layers
.
Lambda
(
tf
.
identity
)
...
@@ -219,6 +230,10 @@ def build(hyperparams_config, is_training):
...
@@ -219,6 +230,10 @@ def build(hyperparams_config, is_training):
raise
ValueError
(
'Hyperparams force_use_bias only supported by '
raise
ValueError
(
'Hyperparams force_use_bias only supported by '
'KerasLayerHyperparams.'
)
'KerasLayerHyperparams.'
)
if
hyperparams_config
.
HasField
(
'sync_batch_norm'
):
raise
ValueError
(
'Hyperparams sync_batch_norm only supported by '
'KerasLayerHyperparams.'
)
normalizer_fn
=
None
normalizer_fn
=
None
batch_norm_params
=
None
batch_norm_params
=
None
if
hyperparams_config
.
HasField
(
'batch_norm'
):
if
hyperparams_config
.
HasField
(
'batch_norm'
):
...
...
research/object_detection/core/freezable_batch_norm_tf2_test.py
View file @
2b566593
...
@@ -17,25 +17,40 @@
...
@@ -17,25 +17,40 @@
from
__future__
import
absolute_import
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
import
unittest
import
unittest
from
absl.testing
import
parameterized
import
numpy
as
np
import
numpy
as
np
from
six.moves
import
zip
from
six.moves
import
zip
import
tensorflow
.compat.v1
as
tf
import
tensorflow
as
tf
from
object_detection.core
import
freezable_batch_norm
from
object_detection.core
import
freezable_batch_norm
from
object_detection.utils
import
tf_version
from
object_detection.utils
import
tf_version
# pylint: disable=g-import-not-at-top
if
tf_version
.
is_tf2
():
from
object_detection.core
import
freezable_sync_batch_norm
# pylint: enable=g-import-not-at-top
@
unittest
.
skipIf
(
tf_version
.
is_tf1
(),
'Skipping TF2.X only test.'
)
@
unittest
.
skipIf
(
tf_version
.
is_tf1
(),
'Skipping TF2.X only test.'
)
class
FreezableBatchNormTest
(
tf
.
test
.
TestCase
):
class
FreezableBatchNormTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
"""Tests for FreezableBatchNorm operations."""
"""Tests for FreezableBatchNorm operations."""
def
_build_model
(
self
,
training
=
None
):
def
_build_model
(
self
,
use_sync_batch_norm
,
training
=
None
):
model
=
tf
.
keras
.
models
.
Sequential
()
model
=
tf
.
keras
.
models
.
Sequential
()
norm
=
None
if
use_sync_batch_norm
:
norm
=
freezable_sync_batch_norm
.
FreezableSyncBatchNorm
(
training
=
training
,
input_shape
=
(
10
,),
momentum
=
0.8
)
else
:
norm
=
freezable_batch_norm
.
FreezableBatchNorm
(
training
=
training
,
norm
=
freezable_batch_norm
.
FreezableBatchNorm
(
training
=
training
,
input_shape
=
(
10
,),
input_shape
=
(
10
,),
momentum
=
0.8
)
momentum
=
0.8
)
model
.
add
(
norm
)
model
.
add
(
norm
)
return
model
,
norm
return
model
,
norm
...
@@ -43,8 +58,9 @@ class FreezableBatchNormTest(tf.test.TestCase):
...
@@ -43,8 +58,9 @@ class FreezableBatchNormTest(tf.test.TestCase):
for
source
,
target
in
zip
(
source_weights
,
target_weights
):
for
source
,
target
in
zip
(
source_weights
,
target_weights
):
target
.
assign
(
source
)
target
.
assign
(
source
)
def
_train_freezable_batch_norm
(
self
,
training_mean
,
training_var
):
def
_train_freezable_batch_norm
(
self
,
training_mean
,
training_var
,
model
,
_
=
self
.
_build_model
()
use_sync_batch_norm
):
model
,
_
=
self
.
_build_model
(
use_sync_batch_norm
=
use_sync_batch_norm
)
model
.
compile
(
loss
=
'mse'
,
optimizer
=
'sgd'
)
model
.
compile
(
loss
=
'mse'
,
optimizer
=
'sgd'
)
# centered on training_mean, variance training_var
# centered on training_mean, variance training_var
...
@@ -72,7 +88,8 @@ class FreezableBatchNormTest(tf.test.TestCase):
...
@@ -72,7 +88,8 @@ class FreezableBatchNormTest(tf.test.TestCase):
np
.
testing
.
assert_allclose
(
out
.
numpy
().
mean
(),
0.0
,
atol
=
1.5e-1
)
np
.
testing
.
assert_allclose
(
out
.
numpy
().
mean
(),
0.0
,
atol
=
1.5e-1
)
np
.
testing
.
assert_allclose
(
out
.
numpy
().
std
(),
1.0
,
atol
=
1.5e-1
)
np
.
testing
.
assert_allclose
(
out
.
numpy
().
std
(),
1.0
,
atol
=
1.5e-1
)
def
test_batchnorm_freezing_training_none
(
self
):
@
parameterized
.
parameters
(
True
,
False
)
def
test_batchnorm_freezing_training_none
(
self
,
use_sync_batch_norm
):
training_mean
=
5.0
training_mean
=
5.0
training_var
=
10.0
training_var
=
10.0
...
@@ -81,12 +98,13 @@ class FreezableBatchNormTest(tf.test.TestCase):
...
@@ -81,12 +98,13 @@ class FreezableBatchNormTest(tf.test.TestCase):
# Initially train the batch norm, and save the weights
# Initially train the batch norm, and save the weights
trained_weights
=
self
.
_train_freezable_batch_norm
(
training_mean
,
trained_weights
=
self
.
_train_freezable_batch_norm
(
training_mean
,
training_var
)
training_var
,
use_sync_batch_norm
)
# Load the batch norm weights, freezing training to True.
# Load the batch norm weights, freezing training to True.
# Apply the batch norm layer to testing data and ensure it is normalized
# Apply the batch norm layer to testing data and ensure it is normalized
# according to the batch statistics.
# according to the batch statistics.
model
,
norm
=
self
.
_build_model
(
training
=
True
)
model
,
norm
=
self
.
_build_model
(
use_sync_batch_norm
,
training
=
True
)
self
.
_copy_weights
(
trained_weights
,
model
.
weights
)
self
.
_copy_weights
(
trained_weights
,
model
.
weights
)
# centered on testing_mean, variance testing_var
# centered on testing_mean, variance testing_var
...
@@ -136,7 +154,8 @@ class FreezableBatchNormTest(tf.test.TestCase):
...
@@ -136,7 +154,8 @@ class FreezableBatchNormTest(tf.test.TestCase):
testing_mean
,
testing_var
,
training_arg
,
testing_mean
,
testing_var
,
training_arg
,
training_mean
,
training_var
)
training_mean
,
training_var
)
def
test_batchnorm_freezing_training_false
(
self
):
@
parameterized
.
parameters
(
True
,
False
)
def
test_batchnorm_freezing_training_false
(
self
,
use_sync_batch_norm
):
training_mean
=
5.0
training_mean
=
5.0
training_var
=
10.0
training_var
=
10.0
...
@@ -145,12 +164,13 @@ class FreezableBatchNormTest(tf.test.TestCase):
...
@@ -145,12 +164,13 @@ class FreezableBatchNormTest(tf.test.TestCase):
# Initially train the batch norm, and save the weights
# Initially train the batch norm, and save the weights
trained_weights
=
self
.
_train_freezable_batch_norm
(
training_mean
,
trained_weights
=
self
.
_train_freezable_batch_norm
(
training_mean
,
training_var
)
training_var
,
use_sync_batch_norm
)
# Load the batch norm back up, freezing training to False.
# Load the batch norm back up, freezing training to False.
# Apply the batch norm layer to testing data and ensure it is normalized
# Apply the batch norm layer to testing data and ensure it is normalized
# according to the training data's statistics.
# according to the training data's statistics.
model
,
norm
=
self
.
_build_model
(
training
=
False
)
model
,
norm
=
self
.
_build_model
(
use_sync_batch_norm
,
training
=
False
)
self
.
_copy_weights
(
trained_weights
,
model
.
weights
)
self
.
_copy_weights
(
trained_weights
,
model
.
weights
)
# centered on testing_mean, variance testing_var
# centered on testing_mean, variance testing_var
...
...
research/object_detection/core/freezable_sync_batch_norm.py
0 → 100644
View file @
2b566593
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A freezable batch norm layer that uses Keras sync batch normalization."""
import
tensorflow
as
tf
class
FreezableSyncBatchNorm
(
tf
.
keras
.
layers
.
experimental
.
SyncBatchNormalization
):
"""Sync Batch normalization layer (Ioffe and Szegedy, 2014).
This is a `freezable` batch norm layer that supports setting the `training`
parameter in the __init__ method rather than having to set it either via
the Keras learning phase or via the `call` method parameter. This layer will
forward all other parameters to the Keras `SyncBatchNormalization` layer
This is class is necessary because Object Detection model training sometimes
requires batch normalization layers to be `frozen` and used as if it was
evaluation time, despite still training (and potentially using dropout layers)
Like the default Keras SyncBatchNormalization layer, this will normalize the
activations of the previous layer at each batch,
i.e. applies a transformation that maintains the mean activation
close to 0 and the activation standard deviation close to 1.
Input shape:
Arbitrary. Use the keyword argument `input_shape`
(tuple of integers, does not include the samples axis)
when using this layer as the first layer in a model.
Output shape:
Same shape as input.
References:
- [Batch Normalization: Accelerating Deep Network Training by Reducing
Internal Covariate Shift](https://arxiv.org/abs/1502.03167)
"""
def
__init__
(
self
,
training
=
None
,
**
kwargs
):
"""Constructor.
Args:
training: If False, the layer will normalize using the moving average and
std. dev, without updating the learned avg and std. dev.
If None or True, the layer will follow the keras SyncBatchNormalization
layer strategy of checking the Keras learning phase at `call` time to
decide what to do.
**kwargs: The keyword arguments to forward to the keras
SyncBatchNormalization layer constructor.
"""
super
(
FreezableSyncBatchNorm
,
self
).
__init__
(
**
kwargs
)
self
.
_training
=
training
def
call
(
self
,
inputs
,
training
=
None
):
# Override the call arg only if the batchnorm is frozen. (Ignore None)
if
self
.
_training
is
False
:
# pylint: disable=g-bool-id-comparison
training
=
self
.
_training
return
super
(
FreezableSyncBatchNorm
,
self
).
call
(
inputs
,
training
=
training
)
research/object_detection/protos/hyperparams.proto
View file @
2b566593
...
@@ -42,6 +42,8 @@ message Hyperparams {
...
@@ -42,6 +42,8 @@ message Hyperparams {
// Note that if nothing below is selected, then no normalization is applied
// Note that if nothing below is selected, then no normalization is applied
// BatchNorm hyperparameters.
// BatchNorm hyperparameters.
BatchNorm
batch_norm
=
5
;
BatchNorm
batch_norm
=
5
;
// SyncBatchNorm hyperparameters (KerasLayerHyperparams only).
BatchNorm
sync_batch_norm
=
9
;
// GroupNorm hyperparameters. This is only supported on a subset of models.
// GroupNorm hyperparameters. This is only supported on a subset of models.
// Note that the current implementation of group norm instantiated in
// Note that the current implementation of group norm instantiated in
// tf.contrib.group.layers.group_norm() only supports fixed_size_resizer
// tf.contrib.group.layers.group_norm() only supports fixed_size_resizer
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment