Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
07484704
Commit
07484704
authored
Jul 24, 2020
by
Vighnesh Birodkar
Committed by
TF Object Detection Team
Jul 24, 2020
Browse files
Better error message when loading a wrong checkpoint type in CenterNet.
PiperOrigin-RevId: 322967458
parent
bdaa525b
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
60 additions
and
26 deletions
+60
-26
research/object_detection/meta_architectures/center_net_meta_arch.py
...ject_detection/meta_architectures/center_net_meta_arch.py
+15
-14
research/object_detection/meta_architectures/center_net_meta_arch_tf2_test.py
...ction/meta_architectures/center_net_meta_arch_tf2_test.py
+11
-0
research/object_detection/models/center_net_hourglass_feature_extractor.py
...etection/models/center_net_hourglass_feature_extractor.py
+8
-2
research/object_detection/models/center_net_mobilenet_v2_feature_extractor.py
...ction/models/center_net_mobilenet_v2_feature_extractor.py
+8
-2
research/object_detection/models/center_net_resnet_feature_extractor.py
...t_detection/models/center_net_resnet_feature_extractor.py
+9
-4
research/object_detection/models/center_net_resnet_v1_fpn_feature_extractor.py
...tion/models/center_net_resnet_v1_fpn_feature_extractor.py
+9
-4
No files found.
research/object_detection/meta_architectures/center_net_meta_arch.py
View file @
07484704
...
...
@@ -118,6 +118,19 @@ class CenterNetFeatureExtractor(tf.keras.Model):
"""Ther number of feature outputs returned by the feature extractor."""
pass
@
abc
.
abstractmethod
def
get_sub_model
(
self
,
sub_model_type
):
"""Returns the underlying keras model for the given sub_model_type.
This function is useful when we only want to get a subset of weights to
be restored from a checkpoint.
Args:
sub_model_type: string, the type of sub model. Currently, CenterNet
feature extractors support 'detection' and 'classification'.
"""
pass
def
make_prediction_net
(
num_out_channels
,
kernel_size
=
3
,
num_filters
=
256
,
bias_fill
=
None
):
...
...
@@ -2762,20 +2775,8 @@ class CenterNetMetaArch(model.DetectionModel):
A dict mapping keys to Trackable objects (tf.Module or Checkpoint).
"""
if
fine_tune_checkpoint_type
==
'classification'
:
return
{
'feature_extractor'
:
self
.
_feature_extractor
.
get_base_model
()}
elif
fine_tune_checkpoint_type
==
'detection'
:
return
{
'feature_extractor'
:
self
.
_feature_extractor
.
get_model
()}
elif
fine_tune_checkpoint_type
==
'fine_tune'
:
feature_extractor_model
=
tf
.
train
.
Checkpoint
(
_feature_extractor
=
self
.
_feature_extractor
)
return
{
'model'
:
feature_extractor_model
}
else
:
raise
ValueError
(
'Not supported fine tune checkpoint type - {}'
.
format
(
fine_tune_checkpoint_type
))
sub_model
=
self
.
_feature_extractor
.
get_sub_model
(
fine_tune_checkpoint_type
)
return
{
'feature_extractor'
:
sub_model
}
def
updates
(
self
):
raise
RuntimeError
(
'This model is intended to be used with model_lib_v2 '
...
...
research/object_detection/meta_architectures/center_net_meta_arch_tf2_test.py
View file @
07484704
...
...
@@ -17,7 +17,9 @@
from
__future__
import
division
import
functools
import
re
import
unittest
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow.compat.v1
as
tf
...
...
@@ -1788,6 +1790,15 @@ class CenterNetMetaArchRestoreTest(test_case.TestCase):
self
.
assertIsInstance
(
restore_from_objects_map
[
'feature_extractor'
],
tf
.
keras
.
Model
)
def
test_retore_map_error
(
self
):
"""Test that restoring unsupported checkpoint type raises an error."""
model
=
build_center_net_meta_arch
(
build_resnet
=
True
)
msg
=
(
"Sub model detection is not defined for ResNet."
"Supported types are ['classification']."
)
with
self
.
assertRaisesRegex
(
ValueError
,
re
.
escape
(
msg
)):
model
.
restore_from_objects
(
'detection'
)
class
DummyFeatureExtractor
(
cnma
.
CenterNetFeatureExtractor
):
...
...
research/object_detection/models/center_net_hourglass_feature_extractor.py
View file @
07484704
...
...
@@ -62,8 +62,14 @@ class CenterNetHourglassFeatureExtractor(
"""Ther number of feature outputs returned by the feature extractor."""
return
self
.
_network
.
num_hourglasses
def
get_model
(
self
):
def
get_sub_model
(
self
,
sub_model_type
):
if
sub_model_type
==
'detection'
:
return
self
.
_network
else
:
supported_types
=
[
'detection'
]
raise
ValueError
(
(
'Sub model {} is not defined for Hourglass.'
.
format
(
sub_model_type
)
+
'Supported types are {}.'
.
format
(
supported_types
)))
def
hourglass_104
(
channel_means
,
channel_stds
,
bgr_ordering
):
...
...
research/object_detection/models/center_net_mobilenet_v2_feature_extractor.py
View file @
07484704
...
...
@@ -101,8 +101,14 @@ class CenterNetMobileNetV2FeatureExtractor(
"""The number of feature outputs returned by the feature extractor."""
return
1
def
get_model
(
self
):
def
get_sub_model
(
self
,
sub_model_type
):
if
sub_model_type
==
'detection'
:
return
self
.
_network
else
:
supported_types
=
[
'detection'
]
raise
ValueError
(
(
'Sub model {} is not defined for MobileNet.'
.
format
(
sub_model_type
)
+
'Supported types are {}.'
.
format
(
supported_types
)))
def
mobilenet_v2
(
channel_means
,
channel_stds
,
bgr_ordering
):
...
...
research/object_detection/models/center_net_resnet_feature_extractor.py
View file @
07484704
...
...
@@ -101,10 +101,6 @@ class CenterNetResnetFeatureExtractor(CenterNetFeatureExtractor):
def
load_feature_extractor_weights
(
self
,
path
):
self
.
_base_model
.
load_weights
(
path
)
def
get_base_model
(
self
):
"""Get base resnet model for inspection and testing."""
return
self
.
_base_model
def
call
(
self
,
inputs
):
"""Returns image features extracted by the backbone.
...
...
@@ -127,6 +123,15 @@ class CenterNetResnetFeatureExtractor(CenterNetFeatureExtractor):
def
out_stride
(
self
):
return
4
def
get_sub_model
(
self
,
sub_model_type
):
if
sub_model_type
==
'classification'
:
return
self
.
_base_model
else
:
supported_types
=
[
'classification'
]
raise
ValueError
(
(
'Sub model {} is not defined for ResNet.'
.
format
(
sub_model_type
)
+
'Supported types are {}.'
.
format
(
supported_types
)))
def
resnet_v2_101
(
channel_means
,
channel_stds
,
bgr_ordering
):
"""The ResNet v2 101 feature extractor."""
...
...
research/object_detection/models/center_net_resnet_v1_fpn_feature_extractor.py
View file @
07484704
...
...
@@ -137,10 +137,6 @@ class CenterNetResnetV1FpnFeatureExtractor(CenterNetFeatureExtractor):
def
load_feature_extractor_weights
(
self
,
path
):
self
.
_base_model
.
load_weights
(
path
)
def
get_base_model
(
self
):
"""Get base resnet model for inspection and testing."""
return
self
.
_base_model
def
call
(
self
,
inputs
):
"""Returns image features extracted by the backbone.
...
...
@@ -163,6 +159,15 @@ class CenterNetResnetV1FpnFeatureExtractor(CenterNetFeatureExtractor):
def
out_stride
(
self
):
return
4
def
get_sub_model
(
self
,
sub_model_type
):
if
sub_model_type
==
'classification'
:
return
self
.
_base_model
else
:
supported_types
=
[
'classification'
]
raise
ValueError
(
(
'Sub model {} is not defined for ResNet FPN.'
.
format
(
sub_model_type
)
+
'Supported types are {}.'
.
format
(
supported_types
)))
def
resnet_v1_101_fpn
(
channel_means
,
channel_stds
,
bgr_ordering
):
"""The ResNet v1 101 FPN feature extractor."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment