Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
705dbf33
Commit
705dbf33
authored
Jul 10, 2020
by
syiming
Browse files
merge
parent
4faea59a
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
52 additions
and
27 deletions
+52
-27
research/object_detection/models/faster_rcnn_resnet_v1_fpn_keras_feature_extractor.py
...dels/faster_rcnn_resnet_v1_fpn_keras_feature_extractor.py
+52
-27
No files found.
research/object_detection/models/faster_rcnn_resnet_v1_fpn_keras_feature_extractor.py
View file @
705dbf33
...
...
@@ -20,6 +20,7 @@ import tensorflow.compat.v1 as tf
from
object_detection.meta_architectures
import
faster_rcnn_meta_arch
from
object_detection.models
import
feature_map_generators
from
object_detection.models.keras_models
import
resnet_v1
from
object_detection.utils
import
ops
_RESNET_MODEL_OUTPUT_LAYERS
=
{
...
...
@@ -31,6 +32,49 @@ _RESNET_MODEL_OUTPUT_LAYERS = {
'conv4_block36_out'
,
'conv5_block3_out'
],
}
class
ResnetFPN
(
tf
.
keras
.
layers
.
Layer
):
def
__init__
(
self
,
backbone_classifier
,
fpn_features_generator
,
coarse_feature_layers
,
fpn_min_level
,
resnet_block_names
,
base_fpn_max_level
):
super
(
ResnetFPN
,
self
).
__init__
()
self
.
classification_backbone
=
backbone_classifier
self
.
fpn_features_generator
=
fpn_features_generator
self
.
coarse_feature_layers
=
coarse_feature_layers
self
.
_fpn_min_level
=
fpn_min_level
self
.
_resnet_block_names
=
resnet_block_names
self
.
_base_fpn_max_level
=
base_fpn_max_level
def
call
(
self
,
inputs
):
inputs
=
ops
.
pad_to_multiple
(
inputs
,
32
)
backbone_outputs
=
self
.
classification_backbone
(
inputs
)
feature_block_list
=
[]
for
level
in
range
(
self
.
_fpn_min_level
,
self
.
_base_fpn_max_level
+
1
):
feature_block_list
.
append
(
'block{}'
.
format
(
level
-
1
))
feature_block_map
=
dict
(
list
(
zip
(
self
.
_resnet_block_names
,
backbone_outputs
)))
fpn_input_image_features
=
[
(
feature_block
,
feature_block_map
[
feature_block
])
for
feature_block
in
feature_block_list
]
fpn_features
=
self
.
fpn_features_generator
(
fpn_input_image_features
)
feature_maps
=
[]
for
level
in
range
(
self
.
_fpn_min_level
,
self
.
_base_fpn_max_level
+
1
):
feature_maps
.
append
(
fpn_features
[
'top_down_block{}'
.
format
(
level
-
1
)])
last_feature_map
=
fpn_features
[
'top_down_block{}'
.
format
(
self
.
_base_fpn_max_level
-
1
)]
for
coarse_feature_layers
in
self
.
coarse_feature_layers
:
for
layer
in
coarse_feature_layers
:
last_feature_map
=
layer
(
last_feature_map
)
feature_maps
.
append
(
last_feature_map
)
return
feature_maps
class
FasterRCNNResnetV1FpnKerasFeatureExtractor
(
faster_rcnn_meta_arch
.
FasterRCNNKerasFeatureExtractor
):
...
...
@@ -155,9 +199,7 @@ class FasterRCNNResnetV1FpnKerasFeatureExtractor(
self
.
classification_backbone
=
tf
.
keras
.
Model
(
inputs
=
full_resnet_v1_model
.
inputs
,
outputs
=
outputs
)
backbone_outputs
=
self
.
classification_backbone
(
full_resnet_v1_model
.
inputs
)
# construct FPN feature generator
self
.
_base_fpn_max_level
=
min
(
self
.
_fpn_max_level
,
5
)
self
.
_num_levels
=
self
.
_base_fpn_max_level
+
1
-
self
.
_fpn_min_level
self
.
_fpn_features_generator
=
(
...
...
@@ -169,16 +211,6 @@ class FasterRCNNResnetV1FpnKerasFeatureExtractor(
freeze_batchnorm
=
self
.
_freeze_batchnorm
,
name
=
'FeatureMaps'
))
feature_block_list
=
[]
for
level
in
range
(
self
.
_fpn_min_level
,
self
.
_base_fpn_max_level
+
1
):
feature_block_list
.
append
(
'block{}'
.
format
(
level
-
1
))
feature_block_map
=
dict
(
list
(
zip
(
self
.
_resnet_block_names
,
backbone_outputs
)))
fpn_input_image_features
=
[
(
feature_block
,
feature_block_map
[
feature_block
])
for
feature_block
in
feature_block_list
]
fpn_features
=
self
.
_fpn_features_generator
(
fpn_input_image_features
)
# Construct coarse feature layers
for
i
in
range
(
self
.
_base_fpn_max_level
,
self
.
_fpn_max_level
):
layers
=
[]
...
...
@@ -200,19 +232,12 @@ class FasterRCNNResnetV1FpnKerasFeatureExtractor(
name
=
layer_name
))
self
.
_coarse_feature_layers
.
append
(
layers
)
feature_maps
=
[]
for
level
in
range
(
self
.
_fpn_min_level
,
self
.
_base_fpn_max_level
+
1
):
feature_maps
.
append
(
fpn_features
[
'top_down_block{}'
.
format
(
level
-
1
)])
last_feature_map
=
fpn_features
[
'top_down_block{}'
.
format
(
self
.
_base_fpn_max_level
-
1
)]
for
coarse_feature_layers
in
self
.
_coarse_feature_layers
:
for
layer
in
coarse_feature_layers
:
last_feature_map
=
layer
(
last_feature_map
)
feature_maps
.
append
(
last_feature_map
)
feature_extractor_model
=
tf
.
keras
.
models
.
Model
(
inputs
=
full_resnet_v1_model
.
inputs
,
outputs
=
feature_maps
)
feature_extractor_model
=
ResnetFPN
(
self
.
classification_backbone
,
self
.
_fpn_features_generator
,
self
.
_coarse_feature_layers
,
self
.
_fpn_min_level
,
self
.
_resnet_block_names
,
self
.
_base_fpn_max_level
)
return
feature_extractor_model
def
get_box_classifier_feature_extractor_model
(
self
,
name
=
None
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment