Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
07484704
Commit
07484704
authored
Jul 24, 2020
by
Vighnesh Birodkar
Committed by
TF Object Detection Team
Jul 24, 2020
Browse files
Better error message when loading a wrong checkpoint type in CenterNet.
PiperOrigin-RevId: 322967458
parent
bdaa525b
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
60 additions
and
26 deletions
+60
-26
research/object_detection/meta_architectures/center_net_meta_arch.py
...ject_detection/meta_architectures/center_net_meta_arch.py
+15
-14
research/object_detection/meta_architectures/center_net_meta_arch_tf2_test.py
...ction/meta_architectures/center_net_meta_arch_tf2_test.py
+11
-0
research/object_detection/models/center_net_hourglass_feature_extractor.py
...etection/models/center_net_hourglass_feature_extractor.py
+8
-2
research/object_detection/models/center_net_mobilenet_v2_feature_extractor.py
...ction/models/center_net_mobilenet_v2_feature_extractor.py
+8
-2
research/object_detection/models/center_net_resnet_feature_extractor.py
...t_detection/models/center_net_resnet_feature_extractor.py
+9
-4
research/object_detection/models/center_net_resnet_v1_fpn_feature_extractor.py
...tion/models/center_net_resnet_v1_fpn_feature_extractor.py
+9
-4
No files found.
research/object_detection/meta_architectures/center_net_meta_arch.py
View file @
07484704
...
@@ -118,6 +118,19 @@ class CenterNetFeatureExtractor(tf.keras.Model):
...
@@ -118,6 +118,19 @@ class CenterNetFeatureExtractor(tf.keras.Model):
"""Ther number of feature outputs returned by the feature extractor."""
"""Ther number of feature outputs returned by the feature extractor."""
pass
pass
@
abc
.
abstractmethod
def
get_sub_model
(
self
,
sub_model_type
):
"""Returns the underlying keras model for the given sub_model_type.
This function is useful when we only want to get a subset of weights to
be restored from a checkpoint.
Args:
sub_model_type: string, the type of sub model. Currently, CenterNet
feature extractors support 'detection' and 'classification'.
"""
pass
def
make_prediction_net
(
num_out_channels
,
kernel_size
=
3
,
num_filters
=
256
,
def
make_prediction_net
(
num_out_channels
,
kernel_size
=
3
,
num_filters
=
256
,
bias_fill
=
None
):
bias_fill
=
None
):
...
@@ -2762,20 +2775,8 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -2762,20 +2775,8 @@ class CenterNetMetaArch(model.DetectionModel):
A dict mapping keys to Trackable objects (tf.Module or Checkpoint).
A dict mapping keys to Trackable objects (tf.Module or Checkpoint).
"""
"""
if
fine_tune_checkpoint_type
==
'classification'
:
sub_model
=
self
.
_feature_extractor
.
get_sub_model
(
fine_tune_checkpoint_type
)
return
{
'feature_extractor'
:
self
.
_feature_extractor
.
get_base_model
()}
return
{
'feature_extractor'
:
sub_model
}
elif
fine_tune_checkpoint_type
==
'detection'
:
return
{
'feature_extractor'
:
self
.
_feature_extractor
.
get_model
()}
elif
fine_tune_checkpoint_type
==
'fine_tune'
:
feature_extractor_model
=
tf
.
train
.
Checkpoint
(
_feature_extractor
=
self
.
_feature_extractor
)
return
{
'model'
:
feature_extractor_model
}
else
:
raise
ValueError
(
'Not supported fine tune checkpoint type - {}'
.
format
(
fine_tune_checkpoint_type
))
def
updates
(
self
):
def
updates
(
self
):
raise
RuntimeError
(
'This model is intended to be used with model_lib_v2 '
raise
RuntimeError
(
'This model is intended to be used with model_lib_v2 '
...
...
research/object_detection/meta_architectures/center_net_meta_arch_tf2_test.py
View file @
07484704
...
@@ -17,7 +17,9 @@
...
@@ -17,7 +17,9 @@
from
__future__
import
division
from
__future__
import
division
import
functools
import
functools
import
re
import
unittest
import
unittest
from
absl.testing
import
parameterized
from
absl.testing
import
parameterized
import
numpy
as
np
import
numpy
as
np
import
tensorflow.compat.v1
as
tf
import
tensorflow.compat.v1
as
tf
...
@@ -1788,6 +1790,15 @@ class CenterNetMetaArchRestoreTest(test_case.TestCase):
...
@@ -1788,6 +1790,15 @@ class CenterNetMetaArchRestoreTest(test_case.TestCase):
self
.
assertIsInstance
(
restore_from_objects_map
[
'feature_extractor'
],
self
.
assertIsInstance
(
restore_from_objects_map
[
'feature_extractor'
],
tf
.
keras
.
Model
)
tf
.
keras
.
Model
)
def
test_retore_map_error
(
self
):
"""Test that restoring unsupported checkpoint type raises an error."""
model
=
build_center_net_meta_arch
(
build_resnet
=
True
)
msg
=
(
"Sub model detection is not defined for ResNet."
"Supported types are ['classification']."
)
with
self
.
assertRaisesRegex
(
ValueError
,
re
.
escape
(
msg
)):
model
.
restore_from_objects
(
'detection'
)
class
DummyFeatureExtractor
(
cnma
.
CenterNetFeatureExtractor
):
class
DummyFeatureExtractor
(
cnma
.
CenterNetFeatureExtractor
):
...
...
research/object_detection/models/center_net_hourglass_feature_extractor.py
View file @
07484704
...
@@ -62,8 +62,14 @@ class CenterNetHourglassFeatureExtractor(
...
@@ -62,8 +62,14 @@ class CenterNetHourglassFeatureExtractor(
"""Ther number of feature outputs returned by the feature extractor."""
"""Ther number of feature outputs returned by the feature extractor."""
return
self
.
_network
.
num_hourglasses
return
self
.
_network
.
num_hourglasses
def
get_model
(
self
):
def
get_sub_model
(
self
,
sub_model_type
):
return
self
.
_network
if
sub_model_type
==
'detection'
:
return
self
.
_network
else
:
supported_types
=
[
'detection'
]
raise
ValueError
(
(
'Sub model {} is not defined for Hourglass.'
.
format
(
sub_model_type
)
+
'Supported types are {}.'
.
format
(
supported_types
)))
def
hourglass_104
(
channel_means
,
channel_stds
,
bgr_ordering
):
def
hourglass_104
(
channel_means
,
channel_stds
,
bgr_ordering
):
...
...
research/object_detection/models/center_net_mobilenet_v2_feature_extractor.py
View file @
07484704
...
@@ -101,8 +101,14 @@ class CenterNetMobileNetV2FeatureExtractor(
...
@@ -101,8 +101,14 @@ class CenterNetMobileNetV2FeatureExtractor(
"""The number of feature outputs returned by the feature extractor."""
"""The number of feature outputs returned by the feature extractor."""
return
1
return
1
def
get_model
(
self
):
def
get_sub_model
(
self
,
sub_model_type
):
return
self
.
_network
if
sub_model_type
==
'detection'
:
return
self
.
_network
else
:
supported_types
=
[
'detection'
]
raise
ValueError
(
(
'Sub model {} is not defined for MobileNet.'
.
format
(
sub_model_type
)
+
'Supported types are {}.'
.
format
(
supported_types
)))
def
mobilenet_v2
(
channel_means
,
channel_stds
,
bgr_ordering
):
def
mobilenet_v2
(
channel_means
,
channel_stds
,
bgr_ordering
):
...
...
research/object_detection/models/center_net_resnet_feature_extractor.py
View file @
07484704
...
@@ -101,10 +101,6 @@ class CenterNetResnetFeatureExtractor(CenterNetFeatureExtractor):
...
@@ -101,10 +101,6 @@ class CenterNetResnetFeatureExtractor(CenterNetFeatureExtractor):
def
load_feature_extractor_weights
(
self
,
path
):
def
load_feature_extractor_weights
(
self
,
path
):
self
.
_base_model
.
load_weights
(
path
)
self
.
_base_model
.
load_weights
(
path
)
def
get_base_model
(
self
):
"""Get base resnet model for inspection and testing."""
return
self
.
_base_model
def
call
(
self
,
inputs
):
def
call
(
self
,
inputs
):
"""Returns image features extracted by the backbone.
"""Returns image features extracted by the backbone.
...
@@ -127,6 +123,15 @@ class CenterNetResnetFeatureExtractor(CenterNetFeatureExtractor):
...
@@ -127,6 +123,15 @@ class CenterNetResnetFeatureExtractor(CenterNetFeatureExtractor):
def
out_stride
(
self
):
def
out_stride
(
self
):
return
4
return
4
def
get_sub_model
(
self
,
sub_model_type
):
if
sub_model_type
==
'classification'
:
return
self
.
_base_model
else
:
supported_types
=
[
'classification'
]
raise
ValueError
(
(
'Sub model {} is not defined for ResNet.'
.
format
(
sub_model_type
)
+
'Supported types are {}.'
.
format
(
supported_types
)))
def
resnet_v2_101
(
channel_means
,
channel_stds
,
bgr_ordering
):
def
resnet_v2_101
(
channel_means
,
channel_stds
,
bgr_ordering
):
"""The ResNet v2 101 feature extractor."""
"""The ResNet v2 101 feature extractor."""
...
...
research/object_detection/models/center_net_resnet_v1_fpn_feature_extractor.py
View file @
07484704
...
@@ -137,10 +137,6 @@ class CenterNetResnetV1FpnFeatureExtractor(CenterNetFeatureExtractor):
...
@@ -137,10 +137,6 @@ class CenterNetResnetV1FpnFeatureExtractor(CenterNetFeatureExtractor):
def
load_feature_extractor_weights
(
self
,
path
):
def
load_feature_extractor_weights
(
self
,
path
):
self
.
_base_model
.
load_weights
(
path
)
self
.
_base_model
.
load_weights
(
path
)
def
get_base_model
(
self
):
"""Get base resnet model for inspection and testing."""
return
self
.
_base_model
def
call
(
self
,
inputs
):
def
call
(
self
,
inputs
):
"""Returns image features extracted by the backbone.
"""Returns image features extracted by the backbone.
...
@@ -163,6 +159,15 @@ class CenterNetResnetV1FpnFeatureExtractor(CenterNetFeatureExtractor):
...
@@ -163,6 +159,15 @@ class CenterNetResnetV1FpnFeatureExtractor(CenterNetFeatureExtractor):
def
out_stride
(
self
):
def
out_stride
(
self
):
return
4
return
4
def
get_sub_model
(
self
,
sub_model_type
):
if
sub_model_type
==
'classification'
:
return
self
.
_base_model
else
:
supported_types
=
[
'classification'
]
raise
ValueError
(
(
'Sub model {} is not defined for ResNet FPN.'
.
format
(
sub_model_type
)
+
'Supported types are {}.'
.
format
(
supported_types
)))
def
resnet_v1_101_fpn
(
channel_means
,
channel_stds
,
bgr_ordering
):
def
resnet_v1_101_fpn
(
channel_means
,
channel_stds
,
bgr_ordering
):
"""The ResNet v1 101 FPN feature extractor."""
"""The ResNet v1 101 FPN feature extractor."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment