Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
fd6987fa
Commit
fd6987fa
authored
Aug 13, 2020
by
Vighnesh Birodkar
Committed by
TF Object Detection Team
Aug 13, 2020
Browse files
Support fine_tune in all CenterNet feature extractors.
PiperOrigin-RevId: 326528933
parent
f41f14e6
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
63 additions
and
31 deletions
+63
-31
research/object_detection/meta_architectures/center_net_meta_arch.py
...ject_detection/meta_architectures/center_net_meta_arch.py
+40
-9
research/object_detection/meta_architectures/center_net_meta_arch_tf2_test.py
...ction/meta_architectures/center_net_meta_arch_tf2_test.py
+3
-2
research/object_detection/models/center_net_hourglass_feature_extractor.py
...etection/models/center_net_hourglass_feature_extractor.py
+5
-4
research/object_detection/models/center_net_mobilenet_v2_feature_extractor.py
...ction/models/center_net_mobilenet_v2_feature_extractor.py
+5
-4
research/object_detection/models/center_net_resnet_feature_extractor.py
...t_detection/models/center_net_resnet_feature_extractor.py
+5
-6
research/object_detection/models/center_net_resnet_v1_fpn_feature_extractor.py
...tion/models/center_net_resnet_v1_fpn_feature_extractor.py
+5
-6
No files found.
research/object_detection/meta_architectures/center_net_meta_arch.py
View file @
fd6987fa
...
@@ -118,6 +118,12 @@ class CenterNetFeatureExtractor(tf.keras.Model):
...
@@ -118,6 +118,12 @@ class CenterNetFeatureExtractor(tf.keras.Model):
"""Ther number of feature outputs returned by the feature extractor."""
"""Ther number of feature outputs returned by the feature extractor."""
pass
pass
@
property
@
abc
.
abstractmethod
def
supported_sub_model_types
(
self
):
"""Valid sub model types supported by the get_sub_model function."""
pass
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
get_sub_model
(
self
,
sub_model_type
):
def
get_sub_model
(
self
,
sub_model_type
):
"""Returns the underlying keras model for the given sub_model_type.
"""Returns the underlying keras model for the given sub_model_type.
...
@@ -2974,22 +2980,47 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -2974,22 +2980,47 @@ class CenterNetMetaArch(model.DetectionModel):
fine_tune_checkpoint_type: whether to restore from a full detection
fine_tune_checkpoint_type: whether to restore from a full detection
checkpoint (with compatible variable names) or to restore from a
checkpoint (with compatible variable names) or to restore from a
classification checkpoint for initialization prior to training.
classification checkpoint for initialization prior to training.
Valid values: `detection`, `classification`. Default 'detection'.
Valid values: `detection`, `classification`, `fine_tune`.
'detection': used when loading in the Hourglass model pre-trained on
Default 'detection'.
other detection task.
'detection': used when loading models pre-trained on other detection
'classification': used when loading in the ResNet model pre-trained on
tasks. With this checkpoint type the weights of the feature extractor
image classification task. Note that only the image feature encoding
are expected under the attribute 'feature_extractor'.
part is loaded but not those upsampling layers.
'classification': used when loading models pre-trained on an image
classification task. Note that only the encoder section of the network
is loaded and not the upsampling layers. With this checkpoint type,
the weights of only the encoder section are expected under the
attribute 'feature_extractor'.
'fine_tune': used when loading the entire CenterNet feature extractor
'fine_tune': used when loading the entire CenterNet feature extractor
pre-trained on other tasks. The checkpoints saved during CenterNet
pre-trained on other tasks. The checkpoints saved during CenterNet
model training can be directly loaded using this mode.
model training can be directly loaded using this type. With this
checkpoint type, the weights of the feature extractor are expected
under the attribute 'model._feature_extractor'.
For more details, see the tensorflow section on Loading mechanics.
https://www.tensorflow.org/guide/checkpoint#loading_mechanics
Returns:
Returns:
A dict mapping keys to Trackable objects (tf.Module or Checkpoint).
A dict mapping keys to Trackable objects (tf.Module or Checkpoint).
"""
"""
sub_model
=
self
.
_feature_extractor
.
get_sub_model
(
fine_tune_checkpoint_type
)
supported_types
=
self
.
_feature_extractor
.
supported_sub_model_types
return
{
'feature_extractor'
:
sub_model
}
supported_types
+=
[
'fine_tune'
]
if
fine_tune_checkpoint_type
not
in
supported_types
:
message
=
(
'Checkpoint type "{}" not supported for {}. '
'Supported types are {}'
)
raise
ValueError
(
message
.
format
(
fine_tune_checkpoint_type
,
self
.
_feature_extractor
.
__class__
.
__name__
,
supported_types
))
elif
fine_tune_checkpoint_type
==
'fine_tune'
:
feature_extractor_model
=
tf
.
train
.
Checkpoint
(
_feature_extractor
=
self
.
_feature_extractor
)
return
{
'model'
:
feature_extractor_model
}
else
:
return
{
'feature_extractor'
:
self
.
_feature_extractor
.
get_sub_model
(
fine_tune_checkpoint_type
)}
def
updates
(
self
):
def
updates
(
self
):
raise
RuntimeError
(
'This model is intended to be used with model_lib_v2 '
raise
RuntimeError
(
'This model is intended to be used with model_lib_v2 '
...
...
research/object_detection/meta_architectures/center_net_meta_arch_tf2_test.py
View file @
fd6987fa
...
@@ -1917,8 +1917,9 @@ class CenterNetMetaArchRestoreTest(test_case.TestCase):
...
@@ -1917,8 +1917,9 @@ class CenterNetMetaArchRestoreTest(test_case.TestCase):
"""Test that restoring unsupported checkpoint type raises an error."""
"""Test that restoring unsupported checkpoint type raises an error."""
model
=
build_center_net_meta_arch
(
build_resnet
=
True
)
model
=
build_center_net_meta_arch
(
build_resnet
=
True
)
msg
=
(
"Sub model detection is not defined for ResNet."
msg
=
(
"Checkpoint type
\"
detection
\"
not supported for "
"Supported types are ['classification']."
)
"CenterNetResnetFeatureExtractor. Supported types are "
"['classification', 'fine_tune']"
)
with
self
.
assertRaisesRegex
(
ValueError
,
re
.
escape
(
msg
)):
with
self
.
assertRaisesRegex
(
ValueError
,
re
.
escape
(
msg
)):
model
.
restore_from_objects
(
'detection'
)
model
.
restore_from_objects
(
'detection'
)
...
...
research/object_detection/models/center_net_hourglass_feature_extractor.py
View file @
fd6987fa
...
@@ -62,14 +62,15 @@ class CenterNetHourglassFeatureExtractor(
...
@@ -62,14 +62,15 @@ class CenterNetHourglassFeatureExtractor(
"""Ther number of feature outputs returned by the feature extractor."""
"""Ther number of feature outputs returned by the feature extractor."""
return
self
.
_network
.
num_hourglasses
return
self
.
_network
.
num_hourglasses
@
property
def
supported_sub_model_types
(
self
):
return
[
'detection'
]
def
get_sub_model
(
self
,
sub_model_type
):
def
get_sub_model
(
self
,
sub_model_type
):
if
sub_model_type
==
'detection'
:
if
sub_model_type
==
'detection'
:
return
self
.
_network
return
self
.
_network
else
:
else
:
supported_types
=
[
'detection'
]
ValueError
(
'Sub model type "{}" not supported.'
.
format
(
sub_model_type
))
raise
ValueError
(
(
'Sub model {} is not defined for Hourglass.'
.
format
(
sub_model_type
)
+
'Supported types are {}.'
.
format
(
supported_types
)))
def
hourglass_104
(
channel_means
,
channel_stds
,
bgr_ordering
):
def
hourglass_104
(
channel_means
,
channel_stds
,
bgr_ordering
):
...
...
research/object_detection/models/center_net_mobilenet_v2_feature_extractor.py
View file @
fd6987fa
...
@@ -101,14 +101,15 @@ class CenterNetMobileNetV2FeatureExtractor(
...
@@ -101,14 +101,15 @@ class CenterNetMobileNetV2FeatureExtractor(
"""The number of feature outputs returned by the feature extractor."""
"""The number of feature outputs returned by the feature extractor."""
return
1
return
1
@
property
def
supported_sub_model_types
(
self
):
return
[
'detection'
]
def
get_sub_model
(
self
,
sub_model_type
):
def
get_sub_model
(
self
,
sub_model_type
):
if
sub_model_type
==
'detection'
:
if
sub_model_type
==
'detection'
:
return
self
.
_network
return
self
.
_network
else
:
else
:
supported_types
=
[
'detection'
]
ValueError
(
'Sub model type "{}" not supported.'
.
format
(
sub_model_type
))
raise
ValueError
(
(
'Sub model {} is not defined for MobileNet.'
.
format
(
sub_model_type
)
+
'Supported types are {}.'
.
format
(
supported_types
)))
def
mobilenet_v2
(
channel_means
,
channel_stds
,
bgr_ordering
):
def
mobilenet_v2
(
channel_means
,
channel_stds
,
bgr_ordering
):
...
...
research/object_detection/models/center_net_resnet_feature_extractor.py
View file @
fd6987fa
...
@@ -123,16 +123,15 @@ class CenterNetResnetFeatureExtractor(CenterNetFeatureExtractor):
...
@@ -123,16 +123,15 @@ class CenterNetResnetFeatureExtractor(CenterNetFeatureExtractor):
def
out_stride
(
self
):
def
out_stride
(
self
):
return
4
return
4
@
property
def
supported_sub_model_types
(
self
):
return
[
'classification'
]
def
get_sub_model
(
self
,
sub_model_type
):
def
get_sub_model
(
self
,
sub_model_type
):
if
sub_model_type
==
'classification'
:
if
sub_model_type
==
'classification'
:
return
self
.
_base_model
return
self
.
_base_model
else
:
else
:
supported_types
=
[
'classification'
]
ValueError
(
'Sub model type "{}" not supported.'
.
format
(
sub_model_type
))
raise
ValueError
(
(
'Sub model {} is not defined for ResNet.'
.
format
(
sub_model_type
)
+
'Supported types are {}.'
.
format
(
supported_types
)
+
'Use the script convert_keras_models.py to create your own '
+
'classification checkpoints.'
))
def
resnet_v2_101
(
channel_means
,
channel_stds
,
bgr_ordering
):
def
resnet_v2_101
(
channel_means
,
channel_stds
,
bgr_ordering
):
...
...
research/object_detection/models/center_net_resnet_v1_fpn_feature_extractor.py
View file @
fd6987fa
...
@@ -159,16 +159,15 @@ class CenterNetResnetV1FpnFeatureExtractor(CenterNetFeatureExtractor):
...
@@ -159,16 +159,15 @@ class CenterNetResnetV1FpnFeatureExtractor(CenterNetFeatureExtractor):
def
out_stride
(
self
):
def
out_stride
(
self
):
return
4
return
4
@
property
def
supported_sub_model_types
(
self
):
return
[
'classification'
]
def
get_sub_model
(
self
,
sub_model_type
):
def
get_sub_model
(
self
,
sub_model_type
):
if
sub_model_type
==
'classification'
:
if
sub_model_type
==
'classification'
:
return
self
.
_base_model
return
self
.
_base_model
else
:
else
:
supported_types
=
[
'classification'
]
ValueError
(
'Sub model type "{}" not supported.'
.
format
(
sub_model_type
))
raise
ValueError
(
(
'Sub model {} is not defined for ResNet FPN.'
.
format
(
sub_model_type
)
+
'Supported types are {}.'
.
format
(
supported_types
))
+
'Use the script convert_keras_models.py to create your own '
+
'classification checkpoints.'
)
def
resnet_v1_101_fpn
(
channel_means
,
channel_stds
,
bgr_ordering
):
def
resnet_v1_101_fpn
(
channel_means
,
channel_stds
,
bgr_ordering
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment