Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
fd6987fa
Commit
fd6987fa
authored
Aug 13, 2020
by
Vighnesh Birodkar
Committed by
TF Object Detection Team
Aug 13, 2020
Browse files
Support fine_tune in all CenterNet feature extractors.
PiperOrigin-RevId: 326528933
parent
f41f14e6
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
63 additions
and
31 deletions
+63
-31
research/object_detection/meta_architectures/center_net_meta_arch.py
...ject_detection/meta_architectures/center_net_meta_arch.py
+40
-9
research/object_detection/meta_architectures/center_net_meta_arch_tf2_test.py
...ction/meta_architectures/center_net_meta_arch_tf2_test.py
+3
-2
research/object_detection/models/center_net_hourglass_feature_extractor.py
...etection/models/center_net_hourglass_feature_extractor.py
+5
-4
research/object_detection/models/center_net_mobilenet_v2_feature_extractor.py
...ction/models/center_net_mobilenet_v2_feature_extractor.py
+5
-4
research/object_detection/models/center_net_resnet_feature_extractor.py
...t_detection/models/center_net_resnet_feature_extractor.py
+5
-6
research/object_detection/models/center_net_resnet_v1_fpn_feature_extractor.py
...tion/models/center_net_resnet_v1_fpn_feature_extractor.py
+5
-6
No files found.
research/object_detection/meta_architectures/center_net_meta_arch.py
View file @
fd6987fa
...
...
@@ -118,6 +118,12 @@ class CenterNetFeatureExtractor(tf.keras.Model):
"""Ther number of feature outputs returned by the feature extractor."""
pass
@
property
@
abc
.
abstractmethod
def
supported_sub_model_types
(
self
):
"""Valid sub model types supported by the get_sub_model function."""
pass
@
abc
.
abstractmethod
def
get_sub_model
(
self
,
sub_model_type
):
"""Returns the underlying keras model for the given sub_model_type.
...
...
@@ -2974,22 +2980,47 @@ class CenterNetMetaArch(model.DetectionModel):
fine_tune_checkpoint_type: whether to restore from a full detection
checkpoint (with compatible variable names) or to restore from a
classification checkpoint for initialization prior to training.
Valid values: `detection`, `classification`. Default 'detection'.
'detection': used when loading in the Hourglass model pre-trained on
other detection task.
'classification': used when loading in the ResNet model pre-trained on
image classification task. Note that only the image feature encoding
part is loaded but not those upsampling layers.
Valid values: `detection`, `classification`, `fine_tune`.
Default 'detection'.
'detection': used when loading models pre-trained on other detection
tasks. With this checkpoint type the weights of the feature extractor
are expected under the attribute 'feature_extractor'.
'classification': used when loading models pre-trained on an image
classification task. Note that only the encoder section of the network
is loaded and not the upsampling layers. With this checkpoint type,
the weights of only the encoder section are expected under the
attribute 'feature_extractor'.
'fine_tune': used when loading the entire CenterNet feature extractor
pre-trained on other tasks. The checkpoints saved during CenterNet
model training can be directly loaded using this mode.
model training can be directly loaded using this type. With this
checkpoint type, the weights of the feature extractor are expected
under the attribute 'model._feature_extractor'.
For more details, see the tensorflow section on Loading mechanics.
https://www.tensorflow.org/guide/checkpoint#loading_mechanics
Returns:
A dict mapping keys to Trackable objects (tf.Module or Checkpoint).
"""
sub_model
=
self
.
_feature_extractor
.
get_sub_model
(
fine_tune_checkpoint_type
)
return
{
'feature_extractor'
:
sub_model
}
supported_types
=
self
.
_feature_extractor
.
supported_sub_model_types
supported_types
+=
[
'fine_tune'
]
if
fine_tune_checkpoint_type
not
in
supported_types
:
message
=
(
'Checkpoint type "{}" not supported for {}. '
'Supported types are {}'
)
raise
ValueError
(
message
.
format
(
fine_tune_checkpoint_type
,
self
.
_feature_extractor
.
__class__
.
__name__
,
supported_types
))
elif
fine_tune_checkpoint_type
==
'fine_tune'
:
feature_extractor_model
=
tf
.
train
.
Checkpoint
(
_feature_extractor
=
self
.
_feature_extractor
)
return
{
'model'
:
feature_extractor_model
}
else
:
return
{
'feature_extractor'
:
self
.
_feature_extractor
.
get_sub_model
(
fine_tune_checkpoint_type
)}
def
updates
(
self
):
raise
RuntimeError
(
'This model is intended to be used with model_lib_v2 '
...
...
research/object_detection/meta_architectures/center_net_meta_arch_tf2_test.py
View file @
fd6987fa
...
...
@@ -1917,8 +1917,9 @@ class CenterNetMetaArchRestoreTest(test_case.TestCase):
"""Test that restoring unsupported checkpoint type raises an error."""
model
=
build_center_net_meta_arch
(
build_resnet
=
True
)
msg
=
(
"Sub model detection is not defined for ResNet."
"Supported types are ['classification']."
)
msg
=
(
"Checkpoint type
\"
detection
\"
not supported for "
"CenterNetResnetFeatureExtractor. Supported types are "
"['classification', 'fine_tune']"
)
with
self
.
assertRaisesRegex
(
ValueError
,
re
.
escape
(
msg
)):
model
.
restore_from_objects
(
'detection'
)
...
...
research/object_detection/models/center_net_hourglass_feature_extractor.py
View file @
fd6987fa
...
...
@@ -62,14 +62,15 @@ class CenterNetHourglassFeatureExtractor(
"""Ther number of feature outputs returned by the feature extractor."""
return
self
.
_network
.
num_hourglasses
@
property
def
supported_sub_model_types
(
self
):
return
[
'detection'
]
def
get_sub_model
(
self
,
sub_model_type
):
if
sub_model_type
==
'detection'
:
return
self
.
_network
else
:
supported_types
=
[
'detection'
]
raise
ValueError
(
(
'Sub model {} is not defined for Hourglass.'
.
format
(
sub_model_type
)
+
'Supported types are {}.'
.
format
(
supported_types
)))
ValueError
(
'Sub model type "{}" not supported.'
.
format
(
sub_model_type
))
def
hourglass_104
(
channel_means
,
channel_stds
,
bgr_ordering
):
...
...
research/object_detection/models/center_net_mobilenet_v2_feature_extractor.py
View file @
fd6987fa
...
...
@@ -101,14 +101,15 @@ class CenterNetMobileNetV2FeatureExtractor(
"""The number of feature outputs returned by the feature extractor."""
return
1
@
property
def
supported_sub_model_types
(
self
):
return
[
'detection'
]
def
get_sub_model
(
self
,
sub_model_type
):
if
sub_model_type
==
'detection'
:
return
self
.
_network
else
:
supported_types
=
[
'detection'
]
raise
ValueError
(
(
'Sub model {} is not defined for MobileNet.'
.
format
(
sub_model_type
)
+
'Supported types are {}.'
.
format
(
supported_types
)))
ValueError
(
'Sub model type "{}" not supported.'
.
format
(
sub_model_type
))
def
mobilenet_v2
(
channel_means
,
channel_stds
,
bgr_ordering
):
...
...
research/object_detection/models/center_net_resnet_feature_extractor.py
View file @
fd6987fa
...
...
@@ -123,16 +123,15 @@ class CenterNetResnetFeatureExtractor(CenterNetFeatureExtractor):
def
out_stride
(
self
):
return
4
@
property
def
supported_sub_model_types
(
self
):
return
[
'classification'
]
def
get_sub_model
(
self
,
sub_model_type
):
if
sub_model_type
==
'classification'
:
return
self
.
_base_model
else
:
supported_types
=
[
'classification'
]
raise
ValueError
(
(
'Sub model {} is not defined for ResNet.'
.
format
(
sub_model_type
)
+
'Supported types are {}.'
.
format
(
supported_types
)
+
'Use the script convert_keras_models.py to create your own '
+
'classification checkpoints.'
))
ValueError
(
'Sub model type "{}" not supported.'
.
format
(
sub_model_type
))
def
resnet_v2_101
(
channel_means
,
channel_stds
,
bgr_ordering
):
...
...
research/object_detection/models/center_net_resnet_v1_fpn_feature_extractor.py
View file @
fd6987fa
...
...
@@ -159,16 +159,15 @@ class CenterNetResnetV1FpnFeatureExtractor(CenterNetFeatureExtractor):
def
out_stride
(
self
):
return
4
@
property
def
supported_sub_model_types
(
self
):
return
[
'classification'
]
def
get_sub_model
(
self
,
sub_model_type
):
if
sub_model_type
==
'classification'
:
return
self
.
_base_model
else
:
supported_types
=
[
'classification'
]
raise
ValueError
(
(
'Sub model {} is not defined for ResNet FPN.'
.
format
(
sub_model_type
)
+
'Supported types are {}.'
.
format
(
supported_types
))
+
'Use the script convert_keras_models.py to create your own '
+
'classification checkpoints.'
)
ValueError
(
'Sub model type "{}" not supported.'
.
format
(
sub_model_type
))
def
resnet_v1_101_fpn
(
channel_means
,
channel_stds
,
bgr_ordering
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment