Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
1ea5e1f6
Commit
1ea5e1f6
authored
Aug 03, 2020
by
TF Object Detection Team
Browse files
Merge pull request #8893 from syiming:move_to_keraslayers_fasterrcnn_fpn_keras_feature_extractor
PiperOrigin-RevId: 324632246
parents
507a8d3c
ea8cc8cf
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
99 additions
and
39 deletions
+99
-39
research/object_detection/models/faster_rcnn_resnet_v1_fpn_keras_feature_extractor.py
...dels/faster_rcnn_resnet_v1_fpn_keras_feature_extractor.py
+96
-36
research/object_detection/models/faster_rcnn_resnet_v1_fpn_keras_feature_extractor_tf2_test.py
...er_rcnn_resnet_v1_fpn_keras_feature_extractor_tf2_test.py
+1
-1
research/object_detection/utils/ops.py
research/object_detection/utils/ops.py
+2
-2
No files found.
research/object_detection/models/faster_rcnn_resnet_v1_fpn_keras_feature_extractor.py
View file @
1ea5e1f6
...
@@ -20,6 +20,7 @@ import tensorflow.compat.v1 as tf
...
@@ -20,6 +20,7 @@ import tensorflow.compat.v1 as tf
from
object_detection.meta_architectures
import
faster_rcnn_meta_arch
from
object_detection.meta_architectures
import
faster_rcnn_meta_arch
from
object_detection.models
import
feature_map_generators
from
object_detection.models
import
feature_map_generators
from
object_detection.models.keras_models
import
resnet_v1
from
object_detection.models.keras_models
import
resnet_v1
from
object_detection.utils
import
ops
_RESNET_MODEL_OUTPUT_LAYERS
=
{
_RESNET_MODEL_OUTPUT_LAYERS
=
{
...
@@ -32,6 +33,78 @@ _RESNET_MODEL_OUTPUT_LAYERS = {
...
@@ -32,6 +33,78 @@ _RESNET_MODEL_OUTPUT_LAYERS = {
}
}
class
_ResnetFPN
(
tf
.
keras
.
layers
.
Layer
):
"""Construct Resnet FPN layer."""
def
__init__
(
self
,
backbone_classifier
,
fpn_features_generator
,
coarse_feature_layers
,
pad_to_multiple
,
fpn_min_level
,
resnet_block_names
,
base_fpn_max_level
):
"""Constructor.
Args:
backbone_classifier: Classifier backbone. Should be one of 'resnet_v1_50',
'resnet_v1_101', 'resnet_v1_152'.
fpn_features_generator: KerasFpnTopDownFeatureMaps that accepts a
dictionary of features and returns a ordered dictionary of fpn features.
coarse_feature_layers: Coarse feature layers for fpn.
pad_to_multiple: An integer multiple to pad input image.
fpn_min_level: the highest resolution feature map to use in FPN. The valid
values are {2, 3, 4, 5} which map to Resnet v1 layers.
resnet_block_names: a list of block names of resnet.
base_fpn_max_level: maximum level of fpn without coarse feature layers.
"""
super
(
_ResnetFPN
,
self
).
__init__
()
self
.
classification_backbone
=
backbone_classifier
self
.
fpn_features_generator
=
fpn_features_generator
self
.
coarse_feature_layers
=
coarse_feature_layers
self
.
pad_to_multiple
=
pad_to_multiple
self
.
_fpn_min_level
=
fpn_min_level
self
.
_resnet_block_names
=
resnet_block_names
self
.
_base_fpn_max_level
=
base_fpn_max_level
def
call
(
self
,
inputs
):
"""Create internal Resnet FPN layer.
Args:
inputs: A [batch, height_out, width_out, channels] float32 tensor
representing a batch of images.
Returns:
feature_maps: A list of tensors with shape [batch, height, width, depth]
represent extracted features.
"""
inputs
=
ops
.
pad_to_multiple
(
inputs
,
self
.
pad_to_multiple
)
backbone_outputs
=
self
.
classification_backbone
(
inputs
)
feature_block_list
=
[]
for
level
in
range
(
self
.
_fpn_min_level
,
self
.
_base_fpn_max_level
+
1
):
feature_block_list
.
append
(
'block{}'
.
format
(
level
-
1
))
feature_block_map
=
dict
(
list
(
zip
(
self
.
_resnet_block_names
,
backbone_outputs
)))
fpn_input_image_features
=
[
(
feature_block
,
feature_block_map
[
feature_block
])
for
feature_block
in
feature_block_list
]
fpn_features
=
self
.
fpn_features_generator
(
fpn_input_image_features
)
feature_maps
=
[]
for
level
in
range
(
self
.
_fpn_min_level
,
self
.
_base_fpn_max_level
+
1
):
feature_maps
.
append
(
fpn_features
[
'top_down_block{}'
.
format
(
level
-
1
)])
last_feature_map
=
fpn_features
[
'top_down_block{}'
.
format
(
self
.
_base_fpn_max_level
-
1
)]
for
coarse_feature_layers
in
self
.
coarse_feature_layers
:
for
layer
in
coarse_feature_layers
:
last_feature_map
=
layer
(
last_feature_map
)
feature_maps
.
append
(
last_feature_map
)
return
feature_maps
class
FasterRCNNResnetV1FpnKerasFeatureExtractor
(
class
FasterRCNNResnetV1FpnKerasFeatureExtractor
(
faster_rcnn_meta_arch
.
FasterRCNNKerasFeatureExtractor
):
faster_rcnn_meta_arch
.
FasterRCNNKerasFeatureExtractor
):
"""Faster RCNN Feature Extractor using Keras-based Resnet V1 FPN features."""
"""Faster RCNN Feature Extractor using Keras-based Resnet V1 FPN features."""
...
@@ -42,7 +115,8 @@ class FasterRCNNResnetV1FpnKerasFeatureExtractor(
...
@@ -42,7 +115,8 @@ class FasterRCNNResnetV1FpnKerasFeatureExtractor(
resnet_v1_base_model_name
,
resnet_v1_base_model_name
,
first_stage_features_stride
,
first_stage_features_stride
,
conv_hyperparams
,
conv_hyperparams
,
batch_norm_trainable
=
False
,
batch_norm_trainable
=
True
,
pad_to_multiple
=
32
,
weight_decay
=
0.0
,
weight_decay
=
0.0
,
fpn_min_level
=
2
,
fpn_min_level
=
2
,
fpn_max_level
=
6
,
fpn_max_level
=
6
,
...
@@ -60,6 +134,7 @@ class FasterRCNNResnetV1FpnKerasFeatureExtractor(
...
@@ -60,6 +134,7 @@ class FasterRCNNResnetV1FpnKerasFeatureExtractor(
containing convolution hyperparameters for the layers added on top of
containing convolution hyperparameters for the layers added on top of
the base feature extractor.
the base feature extractor.
batch_norm_trainable: See base class.
batch_norm_trainable: See base class.
pad_to_multiple: An integer multiple to pad input image.
weight_decay: See base class.
weight_decay: See base class.
fpn_min_level: the highest resolution feature map to use in FPN. The valid
fpn_min_level: the highest resolution feature map to use in FPN. The valid
values are {2, 3, 4, 5} which map to Resnet v1 layers.
values are {2, 3, 4, 5} which map to Resnet v1 layers.
...
@@ -93,6 +168,8 @@ class FasterRCNNResnetV1FpnKerasFeatureExtractor(
...
@@ -93,6 +168,8 @@ class FasterRCNNResnetV1FpnKerasFeatureExtractor(
self
.
_fpn_max_level
=
fpn_max_level
self
.
_fpn_max_level
=
fpn_max_level
self
.
_additional_layer_depth
=
additional_layer_depth
self
.
_additional_layer_depth
=
additional_layer_depth
self
.
_freeze_batchnorm
=
(
not
batch_norm_trainable
)
self
.
_freeze_batchnorm
=
(
not
batch_norm_trainable
)
self
.
_pad_to_multiple
=
pad_to_multiple
self
.
_override_base_feature_extractor_hyperparams
=
\
self
.
_override_base_feature_extractor_hyperparams
=
\
override_base_feature_extractor_hyperparams
override_base_feature_extractor_hyperparams
self
.
_resnet_block_names
=
[
'block1'
,
'block2'
,
'block3'
,
'block4'
]
self
.
_resnet_block_names
=
[
'block1'
,
'block2'
,
'block3'
,
'block4'
]
...
@@ -156,10 +233,7 @@ class FasterRCNNResnetV1FpnKerasFeatureExtractor(
...
@@ -156,10 +233,7 @@ class FasterRCNNResnetV1FpnKerasFeatureExtractor(
self
.
classification_backbone
=
tf
.
keras
.
Model
(
self
.
classification_backbone
=
tf
.
keras
.
Model
(
inputs
=
full_resnet_v1_model
.
inputs
,
inputs
=
full_resnet_v1_model
.
inputs
,
outputs
=
outputs
)
outputs
=
outputs
)
backbone_outputs
=
self
.
classification_backbone
(
full_resnet_v1_model
.
inputs
)
# construct FPN feature generator
self
.
_base_fpn_max_level
=
min
(
self
.
_fpn_max_level
,
5
)
self
.
_base_fpn_max_level
=
min
(
self
.
_fpn_max_level
,
5
)
self
.
_num_levels
=
self
.
_base_fpn_max_level
+
1
-
self
.
_fpn_min_level
self
.
_num_levels
=
self
.
_base_fpn_max_level
+
1
-
self
.
_fpn_min_level
self
.
_fpn_features_generator
=
(
self
.
_fpn_features_generator
=
(
...
@@ -171,16 +245,6 @@ class FasterRCNNResnetV1FpnKerasFeatureExtractor(
...
@@ -171,16 +245,6 @@ class FasterRCNNResnetV1FpnKerasFeatureExtractor(
freeze_batchnorm
=
self
.
_freeze_batchnorm
,
freeze_batchnorm
=
self
.
_freeze_batchnorm
,
name
=
'FeatureMaps'
))
name
=
'FeatureMaps'
))
feature_block_list
=
[]
for
level
in
range
(
self
.
_fpn_min_level
,
self
.
_base_fpn_max_level
+
1
):
feature_block_list
.
append
(
'block{}'
.
format
(
level
-
1
))
feature_block_map
=
dict
(
list
(
zip
(
self
.
_resnet_block_names
,
backbone_outputs
)))
fpn_input_image_features
=
[
(
feature_block
,
feature_block_map
[
feature_block
])
for
feature_block
in
feature_block_list
]
fpn_features
=
self
.
_fpn_features_generator
(
fpn_input_image_features
)
# Construct coarse feature layers
# Construct coarse feature layers
for
i
in
range
(
self
.
_base_fpn_max_level
,
self
.
_fpn_max_level
):
for
i
in
range
(
self
.
_base_fpn_max_level
,
self
.
_fpn_max_level
):
layers
=
[]
layers
=
[]
...
@@ -202,19 +266,13 @@ class FasterRCNNResnetV1FpnKerasFeatureExtractor(
...
@@ -202,19 +266,13 @@ class FasterRCNNResnetV1FpnKerasFeatureExtractor(
name
=
layer_name
))
name
=
layer_name
))
self
.
_coarse_feature_layers
.
append
(
layers
)
self
.
_coarse_feature_layers
.
append
(
layers
)
feature_maps
=
[]
feature_extractor_model
=
_ResnetFPN
(
self
.
classification_backbone
,
for
level
in
range
(
self
.
_fpn_min_level
,
self
.
_base_fpn_max_level
+
1
):
self
.
_fpn_features_generator
,
feature_maps
.
append
(
fpn_features
[
'top_down_block{}'
.
format
(
level
-
1
)])
self
.
_coarse_feature_layers
,
last_feature_map
=
fpn_features
[
'top_down_block{}'
.
format
(
self
.
_pad_to_multiple
,
self
.
_base_fpn_max_level
-
1
)]
self
.
_fpn_min_level
,
self
.
_resnet_block_names
,
for
coarse_feature_layers
in
self
.
_coarse_feature_layers
:
self
.
_base_fpn_max_level
)
for
layer
in
coarse_feature_layers
:
last_feature_map
=
layer
(
last_feature_map
)
feature_maps
.
append
(
last_feature_map
)
feature_extractor_model
=
tf
.
keras
.
models
.
Model
(
inputs
=
full_resnet_v1_model
.
inputs
,
outputs
=
feature_maps
)
return
feature_extractor_model
return
feature_extractor_model
def
get_box_classifier_feature_extractor_model
(
self
,
name
=
None
):
def
get_box_classifier_feature_extractor_model
(
self
,
name
=
None
):
...
@@ -233,16 +291,18 @@ class FasterRCNNResnetV1FpnKerasFeatureExtractor(
...
@@ -233,16 +291,18 @@ class FasterRCNNResnetV1FpnKerasFeatureExtractor(
And returns proposal_classifier_features:
And returns proposal_classifier_features:
A 4-D float tensor with shape
A 4-D float tensor with shape
[batch_size * self.max_num_proposals, 1024]
[batch_size * self.max_num_proposals,
1, 1,
1024]
representing box classifier features for each proposal.
representing box classifier features for each proposal.
"""
"""
with
tf
.
name_scope
(
name
):
with
tf
.
name_scope
(
name
):
with
tf
.
name_scope
(
'ResnetV1FPN'
):
with
tf
.
name_scope
(
'ResnetV1FPN'
):
# TODO(yiming): Add a batchnorm layer between two fc layers.
feature_extractor_model
=
tf
.
keras
.
models
.
Sequential
([
feature_extractor_model
=
tf
.
keras
.
models
.
Sequential
([
tf
.
keras
.
layers
.
Flatten
(),
tf
.
keras
.
layers
.
Flatten
(),
tf
.
keras
.
layers
.
Dense
(
units
=
1024
,
activation
=
'relu'
),
tf
.
keras
.
layers
.
Dense
(
units
=
1024
,
activation
=
'relu'
),
tf
.
keras
.
layers
.
Dense
(
units
=
1024
,
activation
=
'relu'
)
self
.
_conv_hyperparams
.
build_batch_norm
(
training
=
(
self
.
_is_training
and
not
self
.
_freeze_batchnorm
)),
tf
.
keras
.
layers
.
Dense
(
units
=
1024
,
activation
=
'relu'
),
tf
.
keras
.
layers
.
Reshape
((
1
,
1
,
1024
))
])
])
return
feature_extractor_model
return
feature_extractor_model
...
@@ -254,8 +314,8 @@ class FasterRCNNResnet50FpnKerasFeatureExtractor(
...
@@ -254,8 +314,8 @@ class FasterRCNNResnet50FpnKerasFeatureExtractor(
def
__init__
(
self
,
def
__init__
(
self
,
is_training
,
is_training
,
first_stage_features_stride
=
16
,
first_stage_features_stride
=
16
,
batch_norm_trainable
=
True
,
conv_hyperparams
=
None
,
conv_hyperparams
=
None
,
batch_norm_trainable
=
False
,
weight_decay
=
0.0
,
weight_decay
=
0.0
,
fpn_min_level
=
2
,
fpn_min_level
=
2
,
fpn_max_level
=
6
,
fpn_max_level
=
6
,
...
@@ -266,8 +326,8 @@ class FasterRCNNResnet50FpnKerasFeatureExtractor(
...
@@ -266,8 +326,8 @@ class FasterRCNNResnet50FpnKerasFeatureExtractor(
Args:
Args:
is_training: See base class.
is_training: See base class.
first_stage_features_stride: See base class.
first_stage_features_stride: See base class.
conv_hyperparams: See base class.
batch_norm_trainable: See base class.
batch_norm_trainable: See base class.
conv_hyperparams: See base class.
weight_decay: See base class.
weight_decay: See base class.
fpn_min_level: See base class.
fpn_min_level: See base class.
fpn_max_level: See base class.
fpn_max_level: See base class.
...
@@ -297,8 +357,8 @@ class FasterRCNNResnet101FpnKerasFeatureExtractor(
...
@@ -297,8 +357,8 @@ class FasterRCNNResnet101FpnKerasFeatureExtractor(
def
__init__
(
self
,
def
__init__
(
self
,
is_training
,
is_training
,
first_stage_features_stride
=
16
,
first_stage_features_stride
=
16
,
batch_norm_trainable
=
True
,
conv_hyperparams
=
None
,
conv_hyperparams
=
None
,
batch_norm_trainable
=
False
,
weight_decay
=
0.0
,
weight_decay
=
0.0
,
fpn_min_level
=
2
,
fpn_min_level
=
2
,
fpn_max_level
=
6
,
fpn_max_level
=
6
,
...
@@ -309,8 +369,8 @@ class FasterRCNNResnet101FpnKerasFeatureExtractor(
...
@@ -309,8 +369,8 @@ class FasterRCNNResnet101FpnKerasFeatureExtractor(
Args:
Args:
is_training: See base class.
is_training: See base class.
first_stage_features_stride: See base class.
first_stage_features_stride: See base class.
conv_hyperparams: See base class.
batch_norm_trainable: See base class.
batch_norm_trainable: See base class.
conv_hyperparams: See base class.
weight_decay: See base class.
weight_decay: See base class.
fpn_min_level: See base class.
fpn_min_level: See base class.
fpn_max_level: See base class.
fpn_max_level: See base class.
...
@@ -339,8 +399,8 @@ class FasterRCNNResnet152FpnKerasFeatureExtractor(
...
@@ -339,8 +399,8 @@ class FasterRCNNResnet152FpnKerasFeatureExtractor(
def
__init__
(
self
,
def
__init__
(
self
,
is_training
,
is_training
,
first_stage_features_stride
=
16
,
first_stage_features_stride
=
16
,
batch_norm_trainable
=
True
,
conv_hyperparams
=
None
,
conv_hyperparams
=
None
,
batch_norm_trainable
=
False
,
weight_decay
=
0.0
,
weight_decay
=
0.0
,
fpn_min_level
=
2
,
fpn_min_level
=
2
,
fpn_max_level
=
6
,
fpn_max_level
=
6
,
...
@@ -351,8 +411,8 @@ class FasterRCNNResnet152FpnKerasFeatureExtractor(
...
@@ -351,8 +411,8 @@ class FasterRCNNResnet152FpnKerasFeatureExtractor(
Args:
Args:
is_training: See base class.
is_training: See base class.
first_stage_features_stride: See base class.
first_stage_features_stride: See base class.
conv_hyperparams: See base class.
batch_norm_trainable: See base class.
batch_norm_trainable: See base class.
conv_hyperparams: See base class.
weight_decay: See base class.
weight_decay: See base class.
fpn_min_level: See base class.
fpn_min_level: See base class.
fpn_max_level: See base class.
fpn_max_level: See base class.
...
...
research/object_detection/models/faster_rcnn_resnet_v1_fpn_keras_feature_extractor_tf2_test.py
View file @
1ea5e1f6
...
@@ -91,4 +91,4 @@ class FasterRCNNResnetV1FpnKerasFeatureExtractorTest(tf.test.TestCase):
...
@@ -91,4 +91,4 @@ class FasterRCNNResnetV1FpnKerasFeatureExtractorTest(tf.test.TestCase):
model
(
proposal_feature_maps
))
model
(
proposal_feature_maps
))
features_shape
=
tf
.
shape
(
proposal_classifier_features
)
features_shape
=
tf
.
shape
(
proposal_classifier_features
)
self
.
assertAllEqual
(
features_shape
.
numpy
(),
[
3
,
1024
])
self
.
assertAllEqual
(
features_shape
.
numpy
(),
[
3
,
1
,
1
,
1024
])
research/object_detection/utils/ops.py
View file @
1ea5e1f6
...
@@ -216,13 +216,13 @@ def pad_to_multiple(tensor, multiple):
...
@@ -216,13 +216,13 @@ def pad_to_multiple(tensor, multiple):
height_pad
=
tf
.
zeros
([
height_pad
=
tf
.
zeros
([
batch_size
,
padded_tensor_height
-
tensor_height
,
tensor_width
,
batch_size
,
padded_tensor_height
-
tensor_height
,
tensor_width
,
tensor_depth
tensor_depth
])
]
,
dtype
=
tensor
.
dtype
)
tensor
=
tf
.
concat
([
tensor
,
height_pad
],
1
)
tensor
=
tf
.
concat
([
tensor
,
height_pad
],
1
)
if
padded_tensor_width
!=
tensor_width
:
if
padded_tensor_width
!=
tensor_width
:
width_pad
=
tf
.
zeros
([
width_pad
=
tf
.
zeros
([
batch_size
,
padded_tensor_height
,
padded_tensor_width
-
tensor_width
,
batch_size
,
padded_tensor_height
,
padded_tensor_width
-
tensor_width
,
tensor_depth
tensor_depth
])
]
,
dtype
=
tensor
.
dtype
)
tensor
=
tf
.
concat
([
tensor
,
width_pad
],
2
)
tensor
=
tf
.
concat
([
tensor
,
width_pad
],
2
)
return
tensor
return
tensor
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment