Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
9733eeb0
Commit
9733eeb0
authored
Jan 13, 2021
by
Soroosh Yazdani
Committed by
TF Object Detection Team
Jan 13, 2021
Browse files
Updating center_net_mobilenet_v2_fpn_feature_extractor to support classification finetuning.
PiperOrigin-RevId: 351619683
parent
63ec7359
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
14 deletions
+18
-14
research/object_detection/models/center_net_mobilenet_v2_fpn_feature_extractor.py
...n/models/center_net_mobilenet_v2_fpn_feature_extractor.py
+18
-14
No files found.
research/object_detection/models/center_net_mobilenet_v2_fpn_feature_extractor.py
View file @
9733eeb0
...
@@ -58,18 +58,18 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
...
@@ -58,18 +58,18 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
channel_means
=
channel_means
,
channel_means
=
channel_means
,
channel_stds
=
channel_stds
,
channel_stds
=
channel_stds
,
bgr_ordering
=
bgr_ordering
)
bgr_ordering
=
bgr_ordering
)
self
.
_
network
=
mobilenet_v2_net
self
.
_
base_model
=
mobilenet_v2_net
output
=
self
.
_
network
(
self
.
_network
.
input
)
output
=
self
.
_
base_model
(
self
.
_base_model
.
input
)
# Add pyramid feature network on every layer that has stride 2.
# Add pyramid feature network on every layer that has stride 2.
skip_outputs
=
[
skip_outputs
=
[
self
.
_
network
.
get_layer
(
skip_layer_name
).
output
self
.
_
base_model
.
get_layer
(
skip_layer_name
).
output
for
skip_layer_name
in
_MOBILENET_V2_FPN_SKIP_LAYERS
for
skip_layer_name
in
_MOBILENET_V2_FPN_SKIP_LAYERS
]
]
self
.
_fpn_model
=
tf
.
keras
.
models
.
Model
(
self
.
_fpn_model
=
tf
.
keras
.
models
.
Model
(
inputs
=
self
.
_
network
.
input
,
outputs
=
skip_outputs
)
inputs
=
self
.
_
base_model
.
input
,
outputs
=
skip_outputs
)
fpn_outputs
=
self
.
_fpn_model
(
self
.
_
network
.
input
)
fpn_outputs
=
self
.
_fpn_model
(
self
.
_
base_model
.
input
)
# Construct the top-down feature maps -- we start with an output of
# Construct the top-down feature maps -- we start with an output of
# 7x7x1280, which we continually upsample, apply a residual on and merge.
# 7x7x1280, which we continually upsample, apply a residual on and merge.
...
@@ -108,8 +108,8 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
...
@@ -108,8 +108,8 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
output
=
top_down
output
=
top_down
self
.
_
network
=
tf
.
keras
.
models
.
Model
(
self
.
_
feature_extractor_model
=
tf
.
keras
.
models
.
Model
(
inputs
=
self
.
_
network
.
input
,
outputs
=
output
)
inputs
=
self
.
_
base_model
.
input
,
outputs
=
output
)
def
preprocess
(
self
,
resized_inputs
):
def
preprocess
(
self
,
resized_inputs
):
resized_inputs
=
super
(
CenterNetMobileNetV2FPNFeatureExtractor
,
resized_inputs
=
super
(
CenterNetMobileNetV2FPNFeatureExtractor
,
...
@@ -117,13 +117,20 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
...
@@ -117,13 +117,20 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
return
tf
.
keras
.
applications
.
mobilenet_v2
.
preprocess_input
(
resized_inputs
)
return
tf
.
keras
.
applications
.
mobilenet_v2
.
preprocess_input
(
resized_inputs
)
def
load_feature_extractor_weights
(
self
,
path
):
def
load_feature_extractor_weights
(
self
,
path
):
self
.
_
network
.
load_weights
(
path
)
self
.
_
base_model
.
load_weights
(
path
)
def
get_base_model
(
self
):
@
property
return
self
.
_network
def
supported_sub_model_types
(
self
):
return
[
'classification'
]
def
get_sub_model
(
self
,
sub_model_type
):
if
sub_model_type
==
'classification'
:
return
self
.
_base_model
else
:
ValueError
(
'Sub model type "{}" not supported.'
.
format
(
sub_model_type
))
def
call
(
self
,
inputs
):
def
call
(
self
,
inputs
):
return
[
self
.
_
network
(
inputs
)]
return
[
self
.
_
feature_extractor_model
(
inputs
)]
@
property
@
property
def
out_stride
(
self
):
def
out_stride
(
self
):
...
@@ -135,9 +142,6 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
...
@@ -135,9 +142,6 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
"""The number of feature outputs returned by the feature extractor."""
"""The number of feature outputs returned by the feature extractor."""
return
1
return
1
def
get_model
(
self
):
return
self
.
_network
def
mobilenet_v2_fpn
(
channel_means
,
channel_stds
,
bgr_ordering
):
def
mobilenet_v2_fpn
(
channel_means
,
channel_stds
,
bgr_ordering
):
"""The MobileNetV2+FPN backbone for CenterNet."""
"""The MobileNetV2+FPN backbone for CenterNet."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment