Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
4c83bd7f
Commit
4c83bd7f
authored
Jul 16, 2020
by
Yu-hui Chen
Committed by
TF Object Detection Team
Jul 17, 2020
Browse files
Internal change
PiperOrigin-RevId: 321666729
parent
2e6ccf02
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
194 additions
and
8 deletions
+194
-8
research/object_detection/models/keras_models/resnet_v1.py
research/object_detection/models/keras_models/resnet_v1.py
+151
-7
research/object_detection/models/keras_models/resnet_v1_tf2_test.py
...bject_detection/models/keras_models/resnet_v1_tf2_test.py
+43
-1
No files found.
research/object_detection/models/keras_models/resnet_v1.py
View file @
4c83bd7f
...
@@ -21,6 +21,7 @@ from __future__ import print_function
...
@@ -21,6 +21,7 @@ from __future__ import print_function
import
tensorflow.compat.v1
as
tf
import
tensorflow.compat.v1
as
tf
from
tensorflow.python.keras.applications
import
resnet
from
object_detection.core
import
freezable_batch_norm
from
object_detection.core
import
freezable_batch_norm
from
object_detection.models.keras_models
import
model_utils
from
object_detection.models.keras_models
import
model_utils
...
@@ -95,11 +96,11 @@ class _LayersOverride(object):
...
@@ -95,11 +96,11 @@ class _LayersOverride(object):
self
.
regularizer
=
tf
.
keras
.
regularizers
.
l2
(
weight_decay
)
self
.
regularizer
=
tf
.
keras
.
regularizers
.
l2
(
weight_decay
)
self
.
initializer
=
tf
.
variance_scaling_initializer
()
self
.
initializer
=
tf
.
variance_scaling_initializer
()
def
_FixedPaddingLayer
(
self
,
kernel_size
,
rate
=
1
):
def
_FixedPaddingLayer
(
self
,
kernel_size
,
rate
=
1
):
# pylint: disable=invalid-name
return
tf
.
keras
.
layers
.
Lambda
(
return
tf
.
keras
.
layers
.
Lambda
(
lambda
x
:
_fixed_padding
(
x
,
kernel_size
,
rate
))
lambda
x
:
_fixed_padding
(
x
,
kernel_size
,
rate
))
def
Conv2D
(
self
,
filters
,
kernel_size
,
**
kwargs
):
def
Conv2D
(
self
,
filters
,
kernel_size
,
**
kwargs
):
# pylint: disable=invalid-name
"""Builds a Conv2D layer according to the current Object Detection config.
"""Builds a Conv2D layer according to the current Object Detection config.
Overrides the Keras Resnet application's convolutions with ones that
Overrides the Keras Resnet application's convolutions with ones that
...
@@ -141,7 +142,7 @@ class _LayersOverride(object):
...
@@ -141,7 +142,7 @@ class _LayersOverride(object):
else
:
else
:
return
tf
.
keras
.
layers
.
Conv2D
(
filters
,
kernel_size
,
**
kwargs
)
return
tf
.
keras
.
layers
.
Conv2D
(
filters
,
kernel_size
,
**
kwargs
)
def
Activation
(
self
,
*
args
,
**
kwargs
):
# pylint: disable=unused-argument
def
Activation
(
self
,
*
args
,
**
kwargs
):
# pylint: disable=unused-argument
,invalid-name
"""Builds an activation layer.
"""Builds an activation layer.
Overrides the Keras application Activation layer specified by the
Overrides the Keras application Activation layer specified by the
...
@@ -163,7 +164,7 @@ class _LayersOverride(object):
...
@@ -163,7 +164,7 @@ class _LayersOverride(object):
else
:
else
:
return
tf
.
keras
.
layers
.
Lambda
(
tf
.
nn
.
relu
,
name
=
name
)
return
tf
.
keras
.
layers
.
Lambda
(
tf
.
nn
.
relu
,
name
=
name
)
def
BatchNormalization
(
self
,
**
kwargs
):
def
BatchNormalization
(
self
,
**
kwargs
):
# pylint: disable=invalid-name
"""Builds a normalization layer.
"""Builds a normalization layer.
Overrides the Keras application batch norm with the norm specified by the
Overrides the Keras application batch norm with the norm specified by the
...
@@ -191,7 +192,7 @@ class _LayersOverride(object):
...
@@ -191,7 +192,7 @@ class _LayersOverride(object):
momentum
=
self
.
_default_batchnorm_momentum
,
momentum
=
self
.
_default_batchnorm_momentum
,
**
kwargs
)
**
kwargs
)
def
Input
(
self
,
shape
):
def
Input
(
self
,
shape
):
# pylint: disable=invalid-name
"""Builds an Input layer.
"""Builds an Input layer.
Overrides the Keras application Input layer with one that uses a
Overrides the Keras application Input layer with one that uses a
...
@@ -219,7 +220,7 @@ class _LayersOverride(object):
...
@@ -219,7 +220,7 @@ class _LayersOverride(object):
input
=
input_tensor
,
shape
=
[
None
]
+
shape
)
input
=
input_tensor
,
shape
=
[
None
]
+
shape
)
return
model_utils
.
input_layer
(
shape
,
placeholder_with_default
)
return
model_utils
.
input_layer
(
shape
,
placeholder_with_default
)
def
MaxPooling2D
(
self
,
pool_size
,
**
kwargs
):
def
MaxPooling2D
(
self
,
pool_size
,
**
kwargs
):
# pylint: disable=invalid-name
"""Builds a MaxPooling2D layer with default padding as 'SAME'.
"""Builds a MaxPooling2D layer with default padding as 'SAME'.
This is specified by the default resnet arg_scope in slim.
This is specified by the default resnet arg_scope in slim.
...
@@ -237,7 +238,7 @@ class _LayersOverride(object):
...
@@ -237,7 +238,7 @@ class _LayersOverride(object):
# Add alias as Keras also has it.
# Add alias as Keras also has it.
MaxPool2D
=
MaxPooling2D
# pylint: disable=invalid-name
MaxPool2D
=
MaxPooling2D
# pylint: disable=invalid-name
def
ZeroPadding2D
(
self
,
padding
,
**
kwargs
):
# pylint: disable=unused-argument
def
ZeroPadding2D
(
self
,
padding
,
**
kwargs
):
# pylint: disable=unused-argument
,invalid-name
"""Replaces explicit padding in the Keras application with a no-op.
"""Replaces explicit padding in the Keras application with a no-op.
Args:
Args:
...
@@ -395,3 +396,146 @@ def resnet_v1_152(batchnorm_training,
...
@@ -395,3 +396,146 @@ def resnet_v1_152(batchnorm_training,
return
tf
.
keras
.
applications
.
resnet
.
ResNet152
(
return
tf
.
keras
.
applications
.
resnet
.
ResNet152
(
layers
=
layers_override
,
**
kwargs
)
layers
=
layers_override
,
**
kwargs
)
# pylint: enable=invalid-name
# pylint: enable=invalid-name
# The following codes are based on the existing keras ResNet model pattern:
# google3/third_party/tensorflow/python/keras/applications/resnet.py
def
block_basic
(
x
,
filters
,
kernel_size
=
3
,
stride
=
1
,
conv_shortcut
=
False
,
name
=
None
):
"""A residual block for ResNet18/34.
Arguments:
x: input tensor.
filters: integer, filters of the bottleneck layer.
kernel_size: default 3, kernel size of the bottleneck layer.
stride: default 1, stride of the first layer.
conv_shortcut: default False, use convolution shortcut if True, otherwise
identity shortcut.
name: string, block label.
Returns:
Output tensor for the residual block.
"""
layers
=
tf
.
keras
.
layers
bn_axis
=
3
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
else
1
preact
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
epsilon
=
1.001e-5
,
name
=
name
+
'_preact_bn'
)(
x
)
preact
=
layers
.
Activation
(
'relu'
,
name
=
name
+
'_preact_relu'
)(
preact
)
if
conv_shortcut
:
shortcut
=
layers
.
Conv2D
(
filters
,
1
,
strides
=
1
,
name
=
name
+
'_0_conv'
)(
preact
)
else
:
shortcut
=
layers
.
MaxPooling2D
(
1
,
strides
=
stride
)(
x
)
if
stride
>
1
else
x
x
=
layers
.
ZeroPadding2D
(
padding
=
((
1
,
1
),
(
1
,
1
)),
name
=
name
+
'_1_pad'
)(
preact
)
x
=
layers
.
Conv2D
(
filters
,
kernel_size
,
strides
=
1
,
use_bias
=
False
,
name
=
name
+
'_1_conv'
)(
x
)
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
epsilon
=
1.001e-5
,
name
=
name
+
'_1_bn'
)(
x
)
x
=
layers
.
Activation
(
'relu'
,
name
=
name
+
'_1_relu'
)(
x
)
x
=
layers
.
ZeroPadding2D
(
padding
=
((
1
,
1
),
(
1
,
1
)),
name
=
name
+
'_2_pad'
)(
x
)
x
=
layers
.
Conv2D
(
filters
,
kernel_size
,
strides
=
stride
,
use_bias
=
False
,
name
=
name
+
'_2_conv'
)(
x
)
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
epsilon
=
1.001e-5
,
name
=
name
+
'_2_bn'
)(
x
)
x
=
layers
.
Activation
(
'relu'
,
name
=
name
+
'_2_relu'
)(
x
)
x
=
layers
.
Add
(
name
=
name
+
'_out'
)([
shortcut
,
x
])
return
x
def
stack_basic
(
x
,
filters
,
blocks
,
stride1
=
2
,
name
=
None
):
"""A set of stacked residual blocks for ResNet18/34.
Arguments:
x: input tensor.
filters: integer, filters of the bottleneck layer in a block.
blocks: integer, blocks in the stacked blocks.
stride1: default 2, stride of the first layer in the first block.
name: string, stack label.
Returns:
Output tensor for the stacked blocks.
"""
x
=
block_basic
(
x
,
filters
,
conv_shortcut
=
True
,
name
=
name
+
'_block1'
)
for
i
in
range
(
2
,
blocks
):
x
=
block_basic
(
x
,
filters
,
name
=
name
+
'_block'
+
str
(
i
))
x
=
block_basic
(
x
,
filters
,
stride
=
stride1
,
name
=
name
+
'_block'
+
str
(
blocks
))
return
x
def
resnet_v1_18
(
include_top
=
True
,
weights
=
'imagenet'
,
input_tensor
=
None
,
input_shape
=
None
,
pooling
=
None
,
classes
=
1000
,
classifier_activation
=
'softmax'
):
"""Instantiates the ResNet18 architecture."""
def
stack_fn
(
x
):
x
=
stack_basic
(
x
,
64
,
2
,
stride1
=
1
,
name
=
'conv2'
)
x
=
stack_basic
(
x
,
128
,
2
,
name
=
'conv3'
)
x
=
stack_basic
(
x
,
256
,
2
,
name
=
'conv4'
)
return
stack_basic
(
x
,
512
,
2
,
name
=
'conv5'
)
return
resnet
.
ResNet
(
stack_fn
,
True
,
True
,
'resnet18'
,
include_top
,
weights
,
input_tensor
,
input_shape
,
pooling
,
classes
,
classifier_activation
=
classifier_activation
)
def
resnet_v1_34
(
include_top
=
True
,
weights
=
'imagenet'
,
input_tensor
=
None
,
input_shape
=
None
,
pooling
=
None
,
classes
=
1000
,
classifier_activation
=
'softmax'
):
"""Instantiates the ResNet34 architecture."""
def
stack_fn
(
x
):
x
=
stack_basic
(
x
,
64
,
3
,
stride1
=
1
,
name
=
'conv2'
)
x
=
stack_basic
(
x
,
128
,
4
,
name
=
'conv3'
)
x
=
stack_basic
(
x
,
256
,
6
,
name
=
'conv4'
)
return
stack_basic
(
x
,
512
,
3
,
name
=
'conv5'
)
return
resnet
.
ResNet
(
stack_fn
,
True
,
True
,
'resnet34'
,
include_top
,
weights
,
input_tensor
,
input_shape
,
pooling
,
classes
,
classifier_activation
=
classifier_activation
)
research/object_detection/models/keras_models/resnet_v1_tf2_test.py
View file @
4c83bd7f
...
@@ -20,12 +20,13 @@ object detection. To verify the consistency of the two models, we compare:
...
@@ -20,12 +20,13 @@ object detection. To verify the consistency of the two models, we compare:
2. Number of global variables.
2. Number of global variables.
"""
"""
import
unittest
import
unittest
from
absl.testing
import
parameterized
import
numpy
as
np
import
numpy
as
np
from
six.moves
import
zip
from
six.moves
import
zip
import
tensorflow.compat.v1
as
tf
import
tensorflow.compat.v1
as
tf
from
google.protobuf
import
text_format
from
google.protobuf
import
text_format
from
object_detection.builders
import
hyperparams_builder
from
object_detection.builders
import
hyperparams_builder
from
object_detection.models.keras_models
import
resnet_v1
from
object_detection.models.keras_models
import
resnet_v1
from
object_detection.protos
import
hyperparams_pb2
from
object_detection.protos
import
hyperparams_pb2
...
@@ -180,5 +181,46 @@ class ResnetV1Test(test_case.TestCase):
...
@@ -180,5 +181,46 @@ class ResnetV1Test(test_case.TestCase):
self
.
assertEqual
(
len
(
variables
),
var_num
)
self
.
assertEqual
(
len
(
variables
),
var_num
)
class
ResnetShapeTest
(
test_case
.
TestCase
,
parameterized
.
TestCase
):
@
unittest
.
skipIf
(
tf_version
.
is_tf1
(),
'Skipping TF2.X only test.'
)
@
parameterized
.
parameters
(
{
'resnet_type'
:
'resnet_v1_34'
,
'output_layer_names'
:
[
'conv2_block3_out'
,
'conv3_block4_out'
,
'conv4_block6_out'
,
'conv5_block3_out'
]
},
{
'resnet_type'
:
'resnet_v1_18'
,
'output_layer_names'
:
[
'conv2_block2_out'
,
'conv3_block2_out'
,
'conv4_block2_out'
,
'conv5_block2_out'
]
})
def
test_output_shapes
(
self
,
resnet_type
,
output_layer_names
):
if
resnet_type
==
'resnet_v1_34'
:
model
=
resnet_v1
.
resnet_v1_34
(
weights
=
None
)
else
:
model
=
resnet_v1
.
resnet_v1_18
(
weights
=
None
)
outputs
=
[
model
.
get_layer
(
output_layer_name
).
output
for
output_layer_name
in
output_layer_names
]
resnet_model
=
tf
.
keras
.
models
.
Model
(
inputs
=
model
.
input
,
outputs
=
outputs
)
outputs
=
resnet_model
(
np
.
zeros
((
2
,
64
,
64
,
3
),
dtype
=
np
.
float32
))
# Check the shape of 'conv2_block3_out':
self
.
assertEqual
(
outputs
[
0
].
shape
,
[
2
,
16
,
16
,
64
])
# Check the shape of 'conv3_block4_out':
self
.
assertEqual
(
outputs
[
1
].
shape
,
[
2
,
8
,
8
,
128
])
# Check the shape of 'conv4_block6_out':
self
.
assertEqual
(
outputs
[
2
].
shape
,
[
2
,
4
,
4
,
256
])
# Check the shape of 'conv5_block3_out':
self
.
assertEqual
(
outputs
[
3
].
shape
,
[
2
,
2
,
2
,
512
])
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment