Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
afd5579f
Commit
afd5579f
authored
Jul 22, 2020
by
Kaushik Shivakumar
Browse files
Merge remote-tracking branch 'upstream/master' into context_tf2
parents
dcd96e02
567bd18d
Changes
89
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
549 additions
and
9 deletions
+549
-9
research/object_detection/models/center_net_mobilenet_v2_feature_extractor.py
...ction/models/center_net_mobilenet_v2_feature_extractor.py
+117
-0
research/object_detection/models/center_net_mobilenet_v2_feature_extractor_tf2_test.py
...els/center_net_mobilenet_v2_feature_extractor_tf2_test.py
+46
-0
research/object_detection/models/center_net_resnet_v1_fpn_feature_extractor.py
...tion/models/center_net_resnet_v1_fpn_feature_extractor.py
+30
-0
research/object_detection/models/center_net_resnet_v1_fpn_feature_extractor_tf2_test.py
...ls/center_net_resnet_v1_fpn_feature_extractor_tf2_test.py
+2
-0
research/object_detection/models/keras_models/resnet_v1.py
research/object_detection/models/keras_models/resnet_v1.py
+151
-7
research/object_detection/models/keras_models/resnet_v1_tf2_test.py
...bject_detection/models/keras_models/resnet_v1_tf2_test.py
+43
-1
research/object_detection/protos/input_reader.proto
research/object_detection/protos/input_reader.proto
+5
-1
research/object_detection/utils/spatial_transform_ops.py
research/object_detection/utils/spatial_transform_ops.py
+94
-0
research/object_detection/utils/spatial_transform_ops_test.py
...arch/object_detection/utils/spatial_transform_ops_test.py
+61
-0
No files found.
research/object_detection/models/center_net_mobilenet_v2_feature_extractor.py
0 → 100644
View file @
afd5579f
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""MobileNet V2[1] feature extractor for CenterNet[2] meta architecture.
[1]: https://arxiv.org/abs/1801.04381
[2]: https://arxiv.org/abs/1904.07850
"""
import
tensorflow.compat.v1
as
tf
from
object_detection.meta_architectures
import
center_net_meta_arch
from
object_detection.models.keras_models
import
mobilenet_v2
as
mobilenetv2
class
CenterNetMobileNetV2FeatureExtractor
(
center_net_meta_arch
.
CenterNetFeatureExtractor
):
"""The MobileNet V2 feature extractor for CenterNet."""
def
__init__
(
self
,
mobilenet_v2_net
,
channel_means
=
(
0.
,
0.
,
0.
),
channel_stds
=
(
1.
,
1.
,
1.
),
bgr_ordering
=
False
):
"""Intializes the feature extractor.
Args:
mobilenet_v2_net: The underlying mobilenet_v2 network to use.
channel_means: A tuple of floats, denoting the mean of each channel
which will be subtracted from it.
channel_stds: A tuple of floats, denoting the standard deviation of each
channel. Each channel will be divided by its standard deviation value.
bgr_ordering: bool, if set will change the channel ordering to be in the
[blue, red, green] order.
"""
super
(
CenterNetMobileNetV2FeatureExtractor
,
self
).
__init__
(
channel_means
=
channel_means
,
channel_stds
=
channel_stds
,
bgr_ordering
=
bgr_ordering
)
self
.
_network
=
mobilenet_v2_net
output
=
self
.
_network
(
self
.
_network
.
input
)
# TODO(nkhadke): Try out MobileNet+FPN next (skip connections are cheap and
# should help with performance).
# MobileNet by itself transforms a 224x224x3 volume into a 7x7x1280, which
# leads to a stride of 32. We perform upsampling to get it to a target
# stride of 4.
for
num_filters
in
[
256
,
128
,
64
]:
# 1. We use a simple convolution instead of a deformable convolution
conv
=
tf
.
keras
.
layers
.
Conv2D
(
filters
=
num_filters
,
kernel_size
=
1
,
strides
=
1
,
padding
=
'same'
)
output
=
conv
(
output
)
output
=
tf
.
keras
.
layers
.
BatchNormalization
()(
output
)
output
=
tf
.
keras
.
layers
.
ReLU
()(
output
)
# 2. We use the default initialization for the convolution layers
# instead of initializing it to do bilinear upsampling.
conv_transpose
=
tf
.
keras
.
layers
.
Conv2DTranspose
(
filters
=
num_filters
,
kernel_size
=
3
,
strides
=
2
,
padding
=
'same'
)
output
=
conv_transpose
(
output
)
output
=
tf
.
keras
.
layers
.
BatchNormalization
()(
output
)
output
=
tf
.
keras
.
layers
.
ReLU
()(
output
)
self
.
_network
=
tf
.
keras
.
models
.
Model
(
inputs
=
self
.
_network
.
input
,
outputs
=
output
)
def
preprocess
(
self
,
resized_inputs
):
resized_inputs
=
super
(
CenterNetMobileNetV2FeatureExtractor
,
self
).
preprocess
(
resized_inputs
)
return
tf
.
keras
.
applications
.
mobilenet_v2
.
preprocess_input
(
resized_inputs
)
def
load_feature_extractor_weights
(
self
,
path
):
self
.
_network
.
load_weights
(
path
)
def
get_base_model
(
self
):
return
self
.
_network
def
call
(
self
,
inputs
):
return
[
self
.
_network
(
inputs
)]
@
property
def
out_stride
(
self
):
"""The stride in the output image of the network."""
return
4
@
property
def
num_feature_outputs
(
self
):
"""The number of feature outputs returned by the feature extractor."""
return
1
def
get_model
(
self
):
return
self
.
_network
def
mobilenet_v2
(
channel_means
,
channel_stds
,
bgr_ordering
):
"""The MobileNetV2 backbone for CenterNet."""
# We set 'is_training' to True for now.
network
=
mobilenetv2
.
mobilenet_v2
(
True
,
include_top
=
False
)
return
CenterNetMobileNetV2FeatureExtractor
(
network
,
channel_means
=
channel_means
,
channel_stds
=
channel_stds
,
bgr_ordering
=
bgr_ordering
)
research/object_detection/models/center_net_mobilenet_v2_feature_extractor_tf2_test.py
0 → 100644
View file @
afd5579f
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Testing mobilenet_v2 feature extractor for CenterNet."""
import
unittest
import
numpy
as
np
import
tensorflow.compat.v1
as
tf
from
object_detection.models
import
center_net_mobilenet_v2_feature_extractor
from
object_detection.models.keras_models
import
mobilenet_v2
from
object_detection.utils
import
test_case
from
object_detection.utils
import
tf_version
@
unittest
.
skipIf
(
tf_version
.
is_tf1
(),
'Skipping TF2.X only test.'
)
class
CenterNetMobileNetV2FeatureExtractorTest
(
test_case
.
TestCase
):
def
test_center_net_mobilenet_v2_feature_extractor
(
self
):
net
=
mobilenet_v2
.
mobilenet_v2
(
True
,
include_top
=
False
)
model
=
center_net_mobilenet_v2_feature_extractor
.
CenterNetMobileNetV2FeatureExtractor
(
net
)
def
graph_fn
():
img
=
np
.
zeros
((
8
,
224
,
224
,
3
),
dtype
=
np
.
float32
)
processed_img
=
model
.
preprocess
(
img
)
return
model
(
processed_img
)
outputs
=
self
.
execute
(
graph_fn
,
[])
self
.
assertEqual
(
outputs
.
shape
,
(
8
,
56
,
56
,
64
))
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
research/object_detection/models/center_net_resnet_v1_fpn_feature_extractor.py
View file @
afd5579f
...
@@ -21,9 +21,14 @@
...
@@ -21,9 +21,14 @@
import
tensorflow.compat.v1
as
tf
import
tensorflow.compat.v1
as
tf
from
object_detection.meta_architectures.center_net_meta_arch
import
CenterNetFeatureExtractor
from
object_detection.meta_architectures.center_net_meta_arch
import
CenterNetFeatureExtractor
from
object_detection.models.keras_models
import
resnet_v1
_RESNET_MODEL_OUTPUT_LAYERS
=
{
_RESNET_MODEL_OUTPUT_LAYERS
=
{
'resnet_v1_18'
:
[
'conv2_block2_out'
,
'conv3_block2_out'
,
'conv4_block2_out'
,
'conv5_block2_out'
],
'resnet_v1_34'
:
[
'conv2_block3_out'
,
'conv3_block4_out'
,
'conv4_block6_out'
,
'conv5_block3_out'
],
'resnet_v1_50'
:
[
'conv2_block3_out'
,
'conv3_block4_out'
,
'resnet_v1_50'
:
[
'conv2_block3_out'
,
'conv3_block4_out'
,
'conv4_block6_out'
,
'conv5_block3_out'
],
'conv4_block6_out'
,
'conv5_block3_out'
],
'resnet_v1_101'
:
[
'conv2_block3_out'
,
'conv3_block4_out'
,
'resnet_v1_101'
:
[
'conv2_block3_out'
,
'conv3_block4_out'
,
...
@@ -69,6 +74,10 @@ class CenterNetResnetV1FpnFeatureExtractor(CenterNetFeatureExtractor):
...
@@ -69,6 +74,10 @@ class CenterNetResnetV1FpnFeatureExtractor(CenterNetFeatureExtractor):
self
.
_base_model
=
tf
.
keras
.
applications
.
ResNet50
(
weights
=
None
)
self
.
_base_model
=
tf
.
keras
.
applications
.
ResNet50
(
weights
=
None
)
elif
resnet_type
==
'resnet_v1_101'
:
elif
resnet_type
==
'resnet_v1_101'
:
self
.
_base_model
=
tf
.
keras
.
applications
.
ResNet101
(
weights
=
None
)
self
.
_base_model
=
tf
.
keras
.
applications
.
ResNet101
(
weights
=
None
)
elif
resnet_type
==
'resnet_v1_18'
:
self
.
_base_model
=
resnet_v1
.
resnet_v1_18
(
weights
=
None
)
elif
resnet_type
==
'resnet_v1_34'
:
self
.
_base_model
=
resnet_v1
.
resnet_v1_34
(
weights
=
None
)
else
:
else
:
raise
ValueError
(
'Unknown Resnet Model {}'
.
format
(
resnet_type
))
raise
ValueError
(
'Unknown Resnet Model {}'
.
format
(
resnet_type
))
output_layers
=
_RESNET_MODEL_OUTPUT_LAYERS
[
resnet_type
]
output_layers
=
_RESNET_MODEL_OUTPUT_LAYERS
[
resnet_type
]
...
@@ -174,3 +183,24 @@ def resnet_v1_50_fpn(channel_means, channel_stds, bgr_ordering):
...
@@ -174,3 +183,24 @@ def resnet_v1_50_fpn(channel_means, channel_stds, bgr_ordering):
channel_means
=
channel_means
,
channel_means
=
channel_means
,
channel_stds
=
channel_stds
,
channel_stds
=
channel_stds
,
bgr_ordering
=
bgr_ordering
)
bgr_ordering
=
bgr_ordering
)
def
resnet_v1_34_fpn
(
channel_means
,
channel_stds
,
bgr_ordering
):
"""The ResNet v1 34 FPN feature extractor."""
return
CenterNetResnetV1FpnFeatureExtractor
(
resnet_type
=
'resnet_v1_34'
,
channel_means
=
channel_means
,
channel_stds
=
channel_stds
,
bgr_ordering
=
bgr_ordering
)
def
resnet_v1_18_fpn
(
channel_means
,
channel_stds
,
bgr_ordering
):
"""The ResNet v1 18 FPN feature extractor."""
return
CenterNetResnetV1FpnFeatureExtractor
(
resnet_type
=
'resnet_v1_18'
,
channel_means
=
channel_means
,
channel_stds
=
channel_stds
,
bgr_ordering
=
bgr_ordering
)
research/object_detection/models/center_net_resnet_v1_fpn_feature_extractor_tf2_test.py
View file @
afd5579f
...
@@ -31,6 +31,8 @@ class CenterNetResnetV1FpnFeatureExtractorTest(test_case.TestCase,
...
@@ -31,6 +31,8 @@ class CenterNetResnetV1FpnFeatureExtractorTest(test_case.TestCase,
@
parameterized
.
parameters
(
@
parameterized
.
parameters
(
{
'resnet_type'
:
'resnet_v1_50'
},
{
'resnet_type'
:
'resnet_v1_50'
},
{
'resnet_type'
:
'resnet_v1_101'
},
{
'resnet_type'
:
'resnet_v1_101'
},
{
'resnet_type'
:
'resnet_v1_18'
},
{
'resnet_type'
:
'resnet_v1_34'
},
)
)
def
test_correct_output_size
(
self
,
resnet_type
):
def
test_correct_output_size
(
self
,
resnet_type
):
"""Verify that shape of features returned by the backbone is correct."""
"""Verify that shape of features returned by the backbone is correct."""
...
...
research/object_detection/models/keras_models/resnet_v1.py
View file @
afd5579f
...
@@ -21,6 +21,7 @@ from __future__ import print_function
...
@@ -21,6 +21,7 @@ from __future__ import print_function
import
tensorflow.compat.v1
as
tf
import
tensorflow.compat.v1
as
tf
from
tensorflow.python.keras.applications
import
resnet
from
object_detection.core
import
freezable_batch_norm
from
object_detection.core
import
freezable_batch_norm
from
object_detection.models.keras_models
import
model_utils
from
object_detection.models.keras_models
import
model_utils
...
@@ -95,11 +96,11 @@ class _LayersOverride(object):
...
@@ -95,11 +96,11 @@ class _LayersOverride(object):
self
.
regularizer
=
tf
.
keras
.
regularizers
.
l2
(
weight_decay
)
self
.
regularizer
=
tf
.
keras
.
regularizers
.
l2
(
weight_decay
)
self
.
initializer
=
tf
.
variance_scaling_initializer
()
self
.
initializer
=
tf
.
variance_scaling_initializer
()
def
_FixedPaddingLayer
(
self
,
kernel_size
,
rate
=
1
):
def
_FixedPaddingLayer
(
self
,
kernel_size
,
rate
=
1
):
# pylint: disable=invalid-name
return
tf
.
keras
.
layers
.
Lambda
(
return
tf
.
keras
.
layers
.
Lambda
(
lambda
x
:
_fixed_padding
(
x
,
kernel_size
,
rate
))
lambda
x
:
_fixed_padding
(
x
,
kernel_size
,
rate
))
def
Conv2D
(
self
,
filters
,
kernel_size
,
**
kwargs
):
def
Conv2D
(
self
,
filters
,
kernel_size
,
**
kwargs
):
# pylint: disable=invalid-name
"""Builds a Conv2D layer according to the current Object Detection config.
"""Builds a Conv2D layer according to the current Object Detection config.
Overrides the Keras Resnet application's convolutions with ones that
Overrides the Keras Resnet application's convolutions with ones that
...
@@ -141,7 +142,7 @@ class _LayersOverride(object):
...
@@ -141,7 +142,7 @@ class _LayersOverride(object):
else
:
else
:
return
tf
.
keras
.
layers
.
Conv2D
(
filters
,
kernel_size
,
**
kwargs
)
return
tf
.
keras
.
layers
.
Conv2D
(
filters
,
kernel_size
,
**
kwargs
)
def
Activation
(
self
,
*
args
,
**
kwargs
):
# pylint: disable=unused-argument
def
Activation
(
self
,
*
args
,
**
kwargs
):
# pylint: disable=unused-argument
,invalid-name
"""Builds an activation layer.
"""Builds an activation layer.
Overrides the Keras application Activation layer specified by the
Overrides the Keras application Activation layer specified by the
...
@@ -163,7 +164,7 @@ class _LayersOverride(object):
...
@@ -163,7 +164,7 @@ class _LayersOverride(object):
else
:
else
:
return
tf
.
keras
.
layers
.
Lambda
(
tf
.
nn
.
relu
,
name
=
name
)
return
tf
.
keras
.
layers
.
Lambda
(
tf
.
nn
.
relu
,
name
=
name
)
def
BatchNormalization
(
self
,
**
kwargs
):
def
BatchNormalization
(
self
,
**
kwargs
):
# pylint: disable=invalid-name
"""Builds a normalization layer.
"""Builds a normalization layer.
Overrides the Keras application batch norm with the norm specified by the
Overrides the Keras application batch norm with the norm specified by the
...
@@ -191,7 +192,7 @@ class _LayersOverride(object):
...
@@ -191,7 +192,7 @@ class _LayersOverride(object):
momentum
=
self
.
_default_batchnorm_momentum
,
momentum
=
self
.
_default_batchnorm_momentum
,
**
kwargs
)
**
kwargs
)
def
Input
(
self
,
shape
):
def
Input
(
self
,
shape
):
# pylint: disable=invalid-name
"""Builds an Input layer.
"""Builds an Input layer.
Overrides the Keras application Input layer with one that uses a
Overrides the Keras application Input layer with one that uses a
...
@@ -219,7 +220,7 @@ class _LayersOverride(object):
...
@@ -219,7 +220,7 @@ class _LayersOverride(object):
input
=
input_tensor
,
shape
=
[
None
]
+
shape
)
input
=
input_tensor
,
shape
=
[
None
]
+
shape
)
return
model_utils
.
input_layer
(
shape
,
placeholder_with_default
)
return
model_utils
.
input_layer
(
shape
,
placeholder_with_default
)
def
MaxPooling2D
(
self
,
pool_size
,
**
kwargs
):
def
MaxPooling2D
(
self
,
pool_size
,
**
kwargs
):
# pylint: disable=invalid-name
"""Builds a MaxPooling2D layer with default padding as 'SAME'.
"""Builds a MaxPooling2D layer with default padding as 'SAME'.
This is specified by the default resnet arg_scope in slim.
This is specified by the default resnet arg_scope in slim.
...
@@ -237,7 +238,7 @@ class _LayersOverride(object):
...
@@ -237,7 +238,7 @@ class _LayersOverride(object):
# Add alias as Keras also has it.
# Add alias as Keras also has it.
MaxPool2D
=
MaxPooling2D
# pylint: disable=invalid-name
MaxPool2D
=
MaxPooling2D
# pylint: disable=invalid-name
def
ZeroPadding2D
(
self
,
padding
,
**
kwargs
):
# pylint: disable=unused-argument
def
ZeroPadding2D
(
self
,
padding
,
**
kwargs
):
# pylint: disable=unused-argument
,invalid-name
"""Replaces explicit padding in the Keras application with a no-op.
"""Replaces explicit padding in the Keras application with a no-op.
Args:
Args:
...
@@ -395,3 +396,146 @@ def resnet_v1_152(batchnorm_training,
...
@@ -395,3 +396,146 @@ def resnet_v1_152(batchnorm_training,
return
tf
.
keras
.
applications
.
resnet
.
ResNet152
(
return
tf
.
keras
.
applications
.
resnet
.
ResNet152
(
layers
=
layers_override
,
**
kwargs
)
layers
=
layers_override
,
**
kwargs
)
# pylint: enable=invalid-name
# pylint: enable=invalid-name
# The following codes are based on the existing keras ResNet model pattern:
# google3/third_party/tensorflow/python/keras/applications/resnet.py
def
block_basic
(
x
,
filters
,
kernel_size
=
3
,
stride
=
1
,
conv_shortcut
=
False
,
name
=
None
):
"""A residual block for ResNet18/34.
Arguments:
x: input tensor.
filters: integer, filters of the bottleneck layer.
kernel_size: default 3, kernel size of the bottleneck layer.
stride: default 1, stride of the first layer.
conv_shortcut: default False, use convolution shortcut if True, otherwise
identity shortcut.
name: string, block label.
Returns:
Output tensor for the residual block.
"""
layers
=
tf
.
keras
.
layers
bn_axis
=
3
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
else
1
preact
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
epsilon
=
1.001e-5
,
name
=
name
+
'_preact_bn'
)(
x
)
preact
=
layers
.
Activation
(
'relu'
,
name
=
name
+
'_preact_relu'
)(
preact
)
if
conv_shortcut
:
shortcut
=
layers
.
Conv2D
(
filters
,
1
,
strides
=
1
,
name
=
name
+
'_0_conv'
)(
preact
)
else
:
shortcut
=
layers
.
MaxPooling2D
(
1
,
strides
=
stride
)(
x
)
if
stride
>
1
else
x
x
=
layers
.
ZeroPadding2D
(
padding
=
((
1
,
1
),
(
1
,
1
)),
name
=
name
+
'_1_pad'
)(
preact
)
x
=
layers
.
Conv2D
(
filters
,
kernel_size
,
strides
=
1
,
use_bias
=
False
,
name
=
name
+
'_1_conv'
)(
x
)
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
epsilon
=
1.001e-5
,
name
=
name
+
'_1_bn'
)(
x
)
x
=
layers
.
Activation
(
'relu'
,
name
=
name
+
'_1_relu'
)(
x
)
x
=
layers
.
ZeroPadding2D
(
padding
=
((
1
,
1
),
(
1
,
1
)),
name
=
name
+
'_2_pad'
)(
x
)
x
=
layers
.
Conv2D
(
filters
,
kernel_size
,
strides
=
stride
,
use_bias
=
False
,
name
=
name
+
'_2_conv'
)(
x
)
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
epsilon
=
1.001e-5
,
name
=
name
+
'_2_bn'
)(
x
)
x
=
layers
.
Activation
(
'relu'
,
name
=
name
+
'_2_relu'
)(
x
)
x
=
layers
.
Add
(
name
=
name
+
'_out'
)([
shortcut
,
x
])
return
x
def
stack_basic
(
x
,
filters
,
blocks
,
stride1
=
2
,
name
=
None
):
"""A set of stacked residual blocks for ResNet18/34.
Arguments:
x: input tensor.
filters: integer, filters of the bottleneck layer in a block.
blocks: integer, blocks in the stacked blocks.
stride1: default 2, stride of the first layer in the first block.
name: string, stack label.
Returns:
Output tensor for the stacked blocks.
"""
x
=
block_basic
(
x
,
filters
,
conv_shortcut
=
True
,
name
=
name
+
'_block1'
)
for
i
in
range
(
2
,
blocks
):
x
=
block_basic
(
x
,
filters
,
name
=
name
+
'_block'
+
str
(
i
))
x
=
block_basic
(
x
,
filters
,
stride
=
stride1
,
name
=
name
+
'_block'
+
str
(
blocks
))
return
x
def
resnet_v1_18
(
include_top
=
True
,
weights
=
'imagenet'
,
input_tensor
=
None
,
input_shape
=
None
,
pooling
=
None
,
classes
=
1000
,
classifier_activation
=
'softmax'
):
"""Instantiates the ResNet18 architecture."""
def
stack_fn
(
x
):
x
=
stack_basic
(
x
,
64
,
2
,
stride1
=
1
,
name
=
'conv2'
)
x
=
stack_basic
(
x
,
128
,
2
,
name
=
'conv3'
)
x
=
stack_basic
(
x
,
256
,
2
,
name
=
'conv4'
)
return
stack_basic
(
x
,
512
,
2
,
name
=
'conv5'
)
return
resnet
.
ResNet
(
stack_fn
,
True
,
True
,
'resnet18'
,
include_top
,
weights
,
input_tensor
,
input_shape
,
pooling
,
classes
,
classifier_activation
=
classifier_activation
)
def
resnet_v1_34
(
include_top
=
True
,
weights
=
'imagenet'
,
input_tensor
=
None
,
input_shape
=
None
,
pooling
=
None
,
classes
=
1000
,
classifier_activation
=
'softmax'
):
"""Instantiates the ResNet34 architecture."""
def
stack_fn
(
x
):
x
=
stack_basic
(
x
,
64
,
3
,
stride1
=
1
,
name
=
'conv2'
)
x
=
stack_basic
(
x
,
128
,
4
,
name
=
'conv3'
)
x
=
stack_basic
(
x
,
256
,
6
,
name
=
'conv4'
)
return
stack_basic
(
x
,
512
,
3
,
name
=
'conv5'
)
return
resnet
.
ResNet
(
stack_fn
,
True
,
True
,
'resnet34'
,
include_top
,
weights
,
input_tensor
,
input_shape
,
pooling
,
classes
,
classifier_activation
=
classifier_activation
)
research/object_detection/models/keras_models/resnet_v1_tf2_test.py
View file @
afd5579f
...
@@ -20,12 +20,13 @@ object detection. To verify the consistency of the two models, we compare:
...
@@ -20,12 +20,13 @@ object detection. To verify the consistency of the two models, we compare:
2. Number of global variables.
2. Number of global variables.
"""
"""
import
unittest
import
unittest
from
absl.testing
import
parameterized
import
numpy
as
np
import
numpy
as
np
from
six.moves
import
zip
from
six.moves
import
zip
import
tensorflow.compat.v1
as
tf
import
tensorflow.compat.v1
as
tf
from
google.protobuf
import
text_format
from
google.protobuf
import
text_format
from
object_detection.builders
import
hyperparams_builder
from
object_detection.builders
import
hyperparams_builder
from
object_detection.models.keras_models
import
resnet_v1
from
object_detection.models.keras_models
import
resnet_v1
from
object_detection.protos
import
hyperparams_pb2
from
object_detection.protos
import
hyperparams_pb2
...
@@ -180,5 +181,46 @@ class ResnetV1Test(test_case.TestCase):
...
@@ -180,5 +181,46 @@ class ResnetV1Test(test_case.TestCase):
self
.
assertEqual
(
len
(
variables
),
var_num
)
self
.
assertEqual
(
len
(
variables
),
var_num
)
class
ResnetShapeTest
(
test_case
.
TestCase
,
parameterized
.
TestCase
):
@
unittest
.
skipIf
(
tf_version
.
is_tf1
(),
'Skipping TF2.X only test.'
)
@
parameterized
.
parameters
(
{
'resnet_type'
:
'resnet_v1_34'
,
'output_layer_names'
:
[
'conv2_block3_out'
,
'conv3_block4_out'
,
'conv4_block6_out'
,
'conv5_block3_out'
]
},
{
'resnet_type'
:
'resnet_v1_18'
,
'output_layer_names'
:
[
'conv2_block2_out'
,
'conv3_block2_out'
,
'conv4_block2_out'
,
'conv5_block2_out'
]
})
def
test_output_shapes
(
self
,
resnet_type
,
output_layer_names
):
if
resnet_type
==
'resnet_v1_34'
:
model
=
resnet_v1
.
resnet_v1_34
(
weights
=
None
)
else
:
model
=
resnet_v1
.
resnet_v1_18
(
weights
=
None
)
outputs
=
[
model
.
get_layer
(
output_layer_name
).
output
for
output_layer_name
in
output_layer_names
]
resnet_model
=
tf
.
keras
.
models
.
Model
(
inputs
=
model
.
input
,
outputs
=
outputs
)
outputs
=
resnet_model
(
np
.
zeros
((
2
,
64
,
64
,
3
),
dtype
=
np
.
float32
))
# Check the shape of 'conv2_block3_out':
self
.
assertEqual
(
outputs
[
0
].
shape
,
[
2
,
16
,
16
,
64
])
# Check the shape of 'conv3_block4_out':
self
.
assertEqual
(
outputs
[
1
].
shape
,
[
2
,
8
,
8
,
128
])
# Check the shape of 'conv4_block6_out':
self
.
assertEqual
(
outputs
[
2
].
shape
,
[
2
,
4
,
4
,
256
])
# Check the shape of 'conv5_block3_out':
self
.
assertEqual
(
outputs
[
3
].
shape
,
[
2
,
2
,
2
,
512
])
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
tf
.
test
.
main
()
research/object_detection/protos/input_reader.proto
View file @
afd5579f
...
@@ -31,7 +31,7 @@ enum InputType {
...
@@ -31,7 +31,7 @@ enum InputType {
TF_SEQUENCE_EXAMPLE
=
2
;
// TfSequenceExample Input
TF_SEQUENCE_EXAMPLE
=
2
;
// TfSequenceExample Input
}
}
// Next id: 3
2
// Next id: 3
3
message
InputReader
{
message
InputReader
{
// Name of input reader. Typically used to describe the dataset that is read
// Name of input reader. Typically used to describe the dataset that is read
// by this input reader.
// by this input reader.
...
@@ -133,6 +133,10 @@ message InputReader {
...
@@ -133,6 +133,10 @@ message InputReader {
// Whether input data type is tf.Examples or tf.SequenceExamples
// Whether input data type is tf.Examples or tf.SequenceExamples
optional
InputType
input_type
=
30
[
default
=
TF_EXAMPLE
];
optional
InputType
input_type
=
30
[
default
=
TF_EXAMPLE
];
// Which frame to choose from the input if Sequence Example. -1 indicates
// random choice.
optional
int32
frame_index
=
32
[
default
=
-
1
];
oneof
input_reader
{
oneof
input_reader
{
TFRecordInputReader
tf_record_input_reader
=
8
;
TFRecordInputReader
tf_record_input_reader
=
8
;
ExternalInputReader
external_input_reader
=
9
;
ExternalInputReader
external_input_reader
=
9
;
...
...
research/object_detection/utils/spatial_transform_ops.py
View file @
afd5579f
...
@@ -411,6 +411,56 @@ def multilevel_roi_align(features, boxes, box_levels, output_size,
...
@@ -411,6 +411,56 @@ def multilevel_roi_align(features, boxes, box_levels, output_size,
return
features_per_box
return
features_per_box
def
multilevel_native_crop_and_resize
(
images
,
boxes
,
box_levels
,
crop_size
,
scope
=
None
):
"""Multilevel native crop and resize.
Same as `multilevel_matmul_crop_and_resize` but uses tf.image.crop_and_resize.
Args:
images: A list of 4-D tensor of shape
[batch, image_height, image_width, depth] representing features of
different size.
boxes: A `Tensor` of type `float32`.
A 3-D tensor of shape `[batch, num_boxes, 4]`. The boxes are specified in
normalized coordinates and are of the form `[y1, x1, y2, x2]`. A
normalized coordinate value of `y` is mapped to the image coordinate at
`y * (image_height - 1)`, so as the `[0, 1]` interval of normalized image
height is mapped to `[0, image_height - 1] in image height coordinates.
We do allow y1 > y2, in which case the sampled crop is an up-down flipped
version of the original image. The width dimension is treated similarly.
Normalized coordinates outside the `[0, 1]` range are allowed, in which
case we use `extrapolation_value` to extrapolate the input image values.
box_levels: A 2-D tensor of shape [batch, num_boxes] representing the level
of the box.
crop_size: A list of two integers `[crop_height, crop_width]`. All
cropped image patches are resized to this size. The aspect ratio of the
image content is not preserved. Both `crop_height` and `crop_width` need
to be positive.
scope: A name for the operation (optional).
Returns:
A 5-D float tensor of shape `[batch, num_boxes, crop_height, crop_width,
depth]`
"""
if
box_levels
is
None
:
return
native_crop_and_resize
(
images
[
0
],
boxes
,
crop_size
,
scope
)
with
tf
.
name_scope
(
'MultiLevelNativeCropAndResize'
):
cropped_feature_list
=
[]
for
level
,
image
in
enumerate
(
images
):
# For each level, crop the feature according to all boxes
# set the cropped feature not at this level to 0 tensor.
# Consider more efficient way of computing cropped features.
cropped
=
native_crop_and_resize
(
image
,
boxes
,
crop_size
,
scope
)
cond
=
tf
.
tile
(
tf
.
equal
(
box_levels
,
level
)[:,
:,
tf
.
newaxis
],
[
1
,
1
]
+
[
tf
.
math
.
reduce_prod
(
cropped
.
shape
.
as_list
()[
2
:])])
cond
=
tf
.
reshape
(
cond
,
cropped
.
shape
)
cropped_final
=
tf
.
where
(
cond
,
cropped
,
tf
.
zeros_like
(
cropped
))
cropped_feature_list
.
append
(
cropped_final
)
return
tf
.
math
.
reduce_sum
(
cropped_feature_list
,
axis
=
0
)
def
native_crop_and_resize
(
image
,
boxes
,
crop_size
,
scope
=
None
):
def
native_crop_and_resize
(
image
,
boxes
,
crop_size
,
scope
=
None
):
"""Same as `matmul_crop_and_resize` but uses tf.image.crop_and_resize."""
"""Same as `matmul_crop_and_resize` but uses tf.image.crop_and_resize."""
def
get_box_inds
(
proposals
):
def
get_box_inds
(
proposals
):
...
@@ -431,6 +481,50 @@ def native_crop_and_resize(image, boxes, crop_size, scope=None):
...
@@ -431,6 +481,50 @@ def native_crop_and_resize(image, boxes, crop_size, scope=None):
return
tf
.
reshape
(
cropped_regions
,
final_shape
)
return
tf
.
reshape
(
cropped_regions
,
final_shape
)
def
multilevel_matmul_crop_and_resize
(
images
,
boxes
,
box_levels
,
crop_size
,
extrapolation_value
=
0.0
,
scope
=
None
):
"""Multilevel matmul crop and resize.
Same as `matmul_crop_and_resize` but crop images according to box levels.
Args:
images: A list of 4-D tensor of shape
[batch, image_height, image_width, depth] representing features of
different size.
boxes: A `Tensor` of type `float32` or 'bfloat16'.
A 3-D tensor of shape `[batch, num_boxes, 4]`. The boxes are specified in
normalized coordinates and are of the form `[y1, x1, y2, x2]`. A
normalized coordinate value of `y` is mapped to the image coordinate at
`y * (image_height - 1)`, so as the `[0, 1]` interval of normalized image
height is mapped to `[0, image_height - 1] in image height coordinates.
We do allow y1 > y2, in which case the sampled crop is an up-down flipped
version of the original image. The width dimension is treated similarly.
Normalized coordinates outside the `[0, 1]` range are allowed, in which
case we use `extrapolation_value` to extrapolate the input image values.
box_levels: A 2-D tensor of shape [batch, num_boxes] representing the level
of the box.
crop_size: A list of two integers `[crop_height, crop_width]`. All
cropped image patches are resized to this size. The aspect ratio of the
image content is not preserved. Both `crop_height` and `crop_width` need
to be positive.
extrapolation_value: A float value to use for extrapolation.
scope: A name for the operation (optional).
Returns:
A 5-D float tensor of shape `[batch, num_boxes, crop_height, crop_width,
depth]`
"""
with
tf
.
name_scope
(
scope
,
'MultiLevelMatMulCropAndResize'
):
if
box_levels
is
None
:
box_levels
=
tf
.
zeros
(
tf
.
shape
(
boxes
)[:
2
],
dtype
=
tf
.
int32
)
return
multilevel_roi_align
(
images
,
boxes
,
box_levels
,
crop_size
,
align_corners
=
True
,
extrapolation_value
=
extrapolation_value
)
def
matmul_crop_and_resize
(
image
,
boxes
,
crop_size
,
extrapolation_value
=
0.0
,
def
matmul_crop_and_resize
(
image
,
boxes
,
crop_size
,
extrapolation_value
=
0.0
,
scope
=
None
):
scope
=
None
):
"""Matrix multiplication based implementation of the crop and resize op.
"""Matrix multiplication based implementation of the crop and resize op.
...
...
research/object_detection/utils/spatial_transform_ops_test.py
View file @
afd5579f
...
@@ -512,6 +512,38 @@ class MatMulCropAndResizeTest(test_case.TestCase):
...
@@ -512,6 +512,38 @@ class MatMulCropAndResizeTest(test_case.TestCase):
crop_output
=
self
.
execute
(
graph_fn
,
[
image
,
boxes
])
crop_output
=
self
.
execute
(
graph_fn
,
[
image
,
boxes
])
self
.
assertAllClose
(
crop_output
,
expected_output
)
self
.
assertAllClose
(
crop_output
,
expected_output
)
def
testMultilevelMatMulCropAndResize
(
self
):
def
graph_fn
(
image1
,
image2
,
boxes
,
box_levels
):
return
spatial_ops
.
multilevel_matmul_crop_and_resize
([
image1
,
image2
],
boxes
,
box_levels
,
crop_size
=
[
2
,
2
])
image
=
[
np
.
array
([[[[
1
,
0
],
[
2
,
0
],
[
3
,
0
]],
[[
4
,
0
],
[
5
,
0
],
[
6
,
0
]],
[[
7
,
0
],
[
8
,
0
],
[
9
,
0
]]],
[[[
1
,
0
],
[
2
,
0
],
[
3
,
0
]],
[[
4
,
0
],
[
5
,
0
],
[
6
,
0
]],
[[
7
,
0
],
[
8
,
0
],
[
9
,
0
]]]],
dtype
=
np
.
float32
),
np
.
array
([[[[
1
,
0
],
[
2
,
1
],
[
3
,
2
]],
[[
4
,
3
],
[
5
,
4
],
[
6
,
5
]],
[[
7
,
6
],
[
8
,
7
],
[
9
,
8
]]],
[[[
1
,
0
],
[
2
,
1
],
[
3
,
2
]],
[[
4
,
3
],
[
5
,
4
],
[
6
,
5
]],
[[
7
,
6
],
[
8
,
7
],
[
9
,
8
]]]],
dtype
=
np
.
float32
)]
boxes
=
np
.
array
([[[
1
,
1
,
0
,
0
],
[.
5
,
.
5
,
0
,
0
]],
[[
0
,
0
,
1
,
1
],
[
0
,
0
,
.
5
,
.
5
]]],
dtype
=
np
.
float32
)
box_levels
=
np
.
array
([[
0
,
1
],
[
1
,
1
]],
dtype
=
np
.
int32
)
expected_output
=
[[[[[
9
,
0
],
[
7
,
0
]],
[[
3
,
0
],
[
1
,
0
]]],
[[[
5
,
4
],
[
4
,
3
]],
[[
2
,
1
],
[
1
,
0
]]]],
[[[[
1
,
0
],
[
3
,
2
]],
[[
7
,
6
],
[
9
,
8
]]],
[[[
1
,
0
],
[
2
,
1
]],
[[
4
,
3
],
[
5
,
4
]]]]]
crop_output
=
self
.
execute
(
graph_fn
,
image
+
[
boxes
,
box_levels
])
self
.
assertAllClose
(
crop_output
,
expected_output
)
class
NativeCropAndResizeTest
(
test_case
.
TestCase
):
class
NativeCropAndResizeTest
(
test_case
.
TestCase
):
...
@@ -537,6 +569,35 @@ class NativeCropAndResizeTest(test_case.TestCase):
...
@@ -537,6 +569,35 @@ class NativeCropAndResizeTest(test_case.TestCase):
crop_output
=
self
.
execute_cpu
(
graph_fn
,
[
image
,
boxes
])
crop_output
=
self
.
execute_cpu
(
graph_fn
,
[
image
,
boxes
])
self
.
assertAllClose
(
crop_output
,
expected_output
)
self
.
assertAllClose
(
crop_output
,
expected_output
)
def
testMultilevelBatchCropAndResize3x3To2x2_2Channels
(
self
):
def
graph_fn
(
image1
,
image2
,
boxes
,
box_levels
):
return
spatial_ops
.
multilevel_native_crop_and_resize
([
image1
,
image2
],
boxes
,
box_levels
,
crop_size
=
[
2
,
2
])
image
=
[
np
.
array
([[[[
1
,
0
],
[
2
,
1
],
[
3
,
2
]],
[[
4
,
3
],
[
5
,
4
],
[
6
,
5
]],
[[
7
,
6
],
[
8
,
7
],
[
9
,
8
]]],
[[[
1
,
0
],
[
2
,
1
],
[
3
,
2
]],
[[
4
,
3
],
[
5
,
4
],
[
6
,
5
]],
[[
7
,
6
],
[
8
,
7
],
[
9
,
8
]]]],
dtype
=
np
.
float32
),
np
.
array
([[[[
1
,
0
],
[
2
,
1
]],
[[
4
,
3
],
[
5
,
4
]]],
[[[
1
,
0
],
[
2
,
1
]],
[[
4
,
3
],
[
5
,
4
]]]],
dtype
=
np
.
float32
)]
boxes
=
np
.
array
([[[
0
,
0
,
1
,
1
],
[
0
,
0
,
.
5
,
.
5
]],
[[
1
,
1
,
0
,
0
],
[.
5
,
.
5
,
0
,
0
]]],
dtype
=
np
.
float32
)
box_levels
=
np
.
array
([[
0
,
1
],
[
0
,
0
]],
dtype
=
np
.
float32
)
expected_output
=
[[[[[
1
,
0
],
[
3
,
2
]],
[[
7
,
6
],
[
9
,
8
]]],
[[[
1
,
0
],
[
1.5
,
0.5
]],
[[
2.5
,
1.5
],
[
3
,
2
]]]],
[[[[
9
,
8
],
[
7
,
6
]],
[[
3
,
2
],
[
1
,
0
]]],
[[[
5
,
4
],
[
4
,
3
]],
[[
2
,
1
],
[
1
,
0
]]]]]
crop_output
=
self
.
execute_cpu
(
graph_fn
,
image
+
[
boxes
,
box_levels
])
self
.
assertAllClose
(
crop_output
,
expected_output
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
tf
.
test
.
main
()
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment