Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
aa3e639f
Commit
aa3e639f
authored
May 06, 2021
by
Vighnesh Birodkar
Committed by
TF Object Detection Team
May 06, 2021
Browse files
Standardize fine tune checkpoints across all TF2 models.
PiperOrigin-RevId: 372426423
parent
f006521b
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
109 additions
and
81 deletions
+109
-81
research/object_detection/meta_architectures/center_net_meta_arch.py
...ject_detection/meta_architectures/center_net_meta_arch.py
+20
-31
research/object_detection/meta_architectures/center_net_meta_arch_tf2_test.py
...ction/meta_architectures/center_net_meta_arch_tf2_test.py
+6
-8
research/object_detection/meta_architectures/faster_rcnn_meta_arch.py
...ect_detection/meta_architectures/faster_rcnn_meta_arch.py
+2
-0
research/object_detection/model_lib_v2.py
research/object_detection/model_lib_v2.py
+4
-0
research/object_detection/models/center_net_hourglass_feature_extractor.py
...etection/models/center_net_hourglass_feature_extractor.py
+0
-10
research/object_detection/models/center_net_mobilenet_v2_feature_extractor.py
...ction/models/center_net_mobilenet_v2_feature_extractor.py
+2
-11
research/object_detection/models/center_net_resnet_feature_extractor.py
...t_detection/models/center_net_resnet_feature_extractor.py
+2
-8
research/object_detection/models/center_net_resnet_v1_fpn_feature_extractor.py
...tion/models/center_net_resnet_v1_fpn_feature_extractor.py
+2
-8
research/object_detection/protos/train.proto
research/object_detection/protos/train.proto
+19
-5
research/object_detection/utils/variables_helper.py
research/object_detection/utils/variables_helper.py
+52
-0
No files found.
research/object_detection/meta_architectures/center_net_meta_arch.py
View file @
aa3e639f
...
...
@@ -117,23 +117,9 @@ class CenterNetFeatureExtractor(tf.keras.Model):
pass
@
property
@
abc
.
abstractmethod
def
supported_sub_model_types
(
self
):
"""Valid sub model types supported by the get_sub_model function."""
pass
@
abc
.
abstractmethod
def
get_sub_model
(
self
,
sub_model_type
):
"""Returns the underlying keras model for the given sub_model_type.
This function is useful when we only want to get a subset of weights to
be restored from a checkpoint.
Args:
sub_model_type: string, the type of sub model. Currently, CenterNet
feature extractors support 'detection' and 'classification'.
"""
pass
def
classification_backbone
(
self
):
raise
NotImplementedError
(
'Classification backbone not supported for {}'
.
format
(
type
(
self
)))
def
make_prediction_net
(
num_out_channels
,
kernel_sizes
=
(
3
),
num_filters
=
(
256
),
...
...
@@ -4200,25 +4186,28 @@ class CenterNetMetaArch(model.DetectionModel):
A dict mapping keys to Trackable objects (tf.Module or Checkpoint).
"""
supported_types
=
self
.
_feature_extractor
.
supported_sub_model_types
supported_types
+=
[
'fine_tune'
]
if
fine_tune_checkpoint_type
not
in
supported_types
:
message
=
(
'Checkpoint type "{}" not supported for {}. '
'Supported types are {}'
)
raise
ValueError
(
message
.
format
(
fine_tune_checkpoint_type
,
self
.
_feature_extractor
.
__class__
.
__name__
,
supported_types
))
elif
fine_tune_checkpoint_type
==
'fine_tune'
:
if
fine_tune_checkpoint_type
==
'detection'
:
feature_extractor_model
=
tf
.
train
.
Checkpoint
(
_feature_extractor
=
self
.
_feature_extractor
)
return
{
'model'
:
feature_extractor_model
}
elif
fine_tune_checkpoint_type
==
'classification'
:
return
{
'feature_extractor'
:
self
.
_feature_extractor
.
classification_backbone
}
elif
fine_tune_checkpoint_type
==
'full'
:
return
{
'model'
:
self
}
elif
fine_tune_checkpoint_type
==
'fine_tune'
:
raise
ValueError
((
'"fine_tune" is no longer supported for CenterNet. '
'Please set fine_tune_checkpoint_type to "detection"'
' which has the same functionality. If you are using'
' the ExtremeNet checkpoint, download the new version'
' from the model zoo.'
))
else
:
r
eturn
{
'feature_extractor'
:
self
.
_feature_extractor
.
get_sub_model
(
fine_tune_checkpoint_type
)
}
r
aise
ValueError
(
'Unknown fine tune checkpoint type {}'
.
format
(
fine_tune_checkpoint_type
)
)
def
updates
(
self
):
if
tf_version
.
is_tf2
():
...
...
research/object_detection/meta_architectures/center_net_meta_arch_tf2_test.py
View file @
aa3e639f
...
...
@@ -17,7 +17,6 @@
from
__future__
import
division
import
functools
import
re
import
unittest
from
absl.testing
import
parameterized
...
...
@@ -2887,15 +2886,14 @@ class CenterNetMetaArchRestoreTest(test_case.TestCase):
self
.
assertIsInstance
(
restore_from_objects_map
[
'feature_extractor'
],
tf
.
keras
.
Model
)
def
test_retore_map_
error
(
self
):
"""Test that
restoring unsupported checkpoint type raise
s an
error
."""
def
test_retore_map_
detection
(
self
):
"""Test that
detection checkpoint
s
c
an
be restored
."""
model
=
build_center_net_meta_arch
(
build_resnet
=
True
)
msg
=
(
"Checkpoint type
\"
detection
\"
not supported for "
"CenterNetResnetFeatureExtractor. Supported types are "
"['classification', 'fine_tune']"
)
with
self
.
assertRaisesRegex
(
ValueError
,
re
.
escape
(
msg
)):
model
.
restore_from_objects
(
'detection'
)
restore_from_objects_map
=
model
.
restore_from_objects
(
'detection'
)
self
.
assertIsInstance
(
restore_from_objects_map
[
'model'
].
_feature_extractor
,
tf
.
keras
.
Model
)
class
DummyFeatureExtractor
(
cnma
.
CenterNetFeatureExtractor
):
...
...
research/object_detection/meta_architectures/faster_rcnn_meta_arch.py
View file @
aa3e639f
...
...
@@ -2896,6 +2896,8 @@ class FasterRCNNMetaArch(model.DetectionModel):
_feature_extractor_for_proposal_features
=
self
.
_feature_extractor_for_proposal_features
)
return
{
'model'
:
fake_model
}
elif
fine_tune_checkpoint_type
==
'full'
:
return
{
'model'
:
self
}
else
:
raise
ValueError
(
'Not supported fine_tune_checkpoint_type: {}'
.
format
(
fine_tune_checkpoint_type
))
...
...
research/object_detection/model_lib_v2.py
View file @
aa3e639f
...
...
@@ -35,6 +35,7 @@ from object_detection.protos import train_pb2
from
object_detection.utils
import
config_util
from
object_detection.utils
import
label_map_util
from
object_detection.utils
import
ops
from
object_detection.utils
import
variables_helper
from
object_detection.utils
import
visualization_utils
as
vutils
...
...
@@ -587,6 +588,9 @@ def train_loop(
lambda
:
global_step
%
num_steps_per_iteration
==
0
):
# Load a fine-tuning checkpoint.
if
train_config
.
fine_tune_checkpoint
:
variables_helper
.
ensure_checkpoint_supported
(
train_config
.
fine_tune_checkpoint
,
fine_tune_checkpoint_type
,
model_dir
)
load_fine_tune_checkpoint
(
detection_model
,
train_config
.
fine_tune_checkpoint
,
fine_tune_checkpoint_type
,
fine_tune_checkpoint_version
,
...
...
research/object_detection/models/center_net_hourglass_feature_extractor.py
View file @
aa3e639f
...
...
@@ -62,16 +62,6 @@ class CenterNetHourglassFeatureExtractor(
"""Ther number of feature outputs returned by the feature extractor."""
return
self
.
_network
.
num_hourglasses
@
property
def
supported_sub_model_types
(
self
):
return
[
'detection'
]
def
get_sub_model
(
self
,
sub_model_type
):
if
sub_model_type
==
'detection'
:
return
self
.
_network
else
:
ValueError
(
'Sub model type "{}" not supported.'
.
format
(
sub_model_type
))
def
hourglass_10
(
channel_means
,
channel_stds
,
bgr_ordering
,
**
kwargs
):
"""The Hourglass-10 backbone for CenterNet."""
...
...
research/object_detection/models/center_net_mobilenet_v2_feature_extractor.py
View file @
aa3e639f
...
...
@@ -83,9 +83,6 @@ class CenterNetMobileNetV2FeatureExtractor(
def
load_feature_extractor_weights
(
self
,
path
):
self
.
_network
.
load_weights
(
path
)
def
get_base_model
(
self
):
return
self
.
_network
def
call
(
self
,
inputs
):
return
[
self
.
_network
(
inputs
)]
...
...
@@ -100,14 +97,8 @@ class CenterNetMobileNetV2FeatureExtractor(
return
1
@
property
def
supported_sub_model_types
(
self
):
return
[
'detection'
]
def
get_sub_model
(
self
,
sub_model_type
):
if
sub_model_type
==
'detection'
:
return
self
.
_network
else
:
ValueError
(
'Sub model type "{}" not supported.'
.
format
(
sub_model_type
))
def
classification_backbone
(
self
):
return
self
.
_network
def
mobilenet_v2
(
channel_means
,
channel_stds
,
bgr_ordering
,
...
...
research/object_detection/models/center_net_resnet_feature_extractor.py
View file @
aa3e639f
...
...
@@ -126,14 +126,8 @@ class CenterNetResnetFeatureExtractor(CenterNetFeatureExtractor):
return
4
@
property
def
supported_sub_model_types
(
self
):
return
[
'classification'
]
def
get_sub_model
(
self
,
sub_model_type
):
if
sub_model_type
==
'classification'
:
return
self
.
_base_model
else
:
ValueError
(
'Sub model type "{}" not supported.'
.
format
(
sub_model_type
))
def
classification_backbone
(
self
):
return
self
.
_base_model
def
resnet_v2_101
(
channel_means
,
channel_stds
,
bgr_ordering
,
**
kwargs
):
...
...
research/object_detection/models/center_net_resnet_v1_fpn_feature_extractor.py
View file @
aa3e639f
...
...
@@ -162,14 +162,8 @@ class CenterNetResnetV1FpnFeatureExtractor(CenterNetFeatureExtractor):
return
4
@
property
def
supported_sub_model_types
(
self
):
return
[
'classification'
]
def
get_sub_model
(
self
,
sub_model_type
):
if
sub_model_type
==
'classification'
:
return
self
.
_base_model
else
:
ValueError
(
'Sub model type "{}" not supported.'
.
format
(
sub_model_type
))
def
classification_backbone
(
self
):
return
self
.
_base_model
def
resnet_v1_101_fpn
(
channel_means
,
channel_stds
,
bgr_ordering
,
**
kwargs
):
...
...
research/object_detection/protos/train.proto
View file @
aa3e639f
...
...
@@ -40,12 +40,26 @@ message TrainConfig {
// extractor variables trained outside of object detection.
optional
string
fine_tune_checkpoint
=
7
[
default
=
""
];
// Type of checkpoint to restore variables from, e.g. 'classification'
// 'detection', `fine_tune`, `full`. Controls which variables are restored
// from the pre-trained checkpoint. For meta architecture specific valid
// values of this parameter, see the restore_map (TF1) or
// This option controls how variables are restored from the (pre-trained)
// fine_tune_checkpoint. For TF2 models, 3 different types are supported:
// 1. "classification": Restores only the classification backbone part of
// the feature extractor. This option is typically used when you want
// to train a detection model starting from a pre-trained image
// classification model, e.g. a ResNet model pre-trained on ImageNet.
// 2. "detection": Restores the entire feature extractor. The only parts
// of the full detection model that are not restored are the box and
// class prediction heads. This option is typically used when you want
// to use a pre-trained detection model and train on a new dataset or
// task which requires different box and class prediction heads.
// 3. "full": Restores the entire detection model, including the
// feature extractor, its classification backbone, and the prediction
// heads. This option should only be used when the pre-training and
// fine-tuning tasks are the same. Otherwise, the model's parameters
// may have incompatible shapes, which will cause errors when
// attempting to restore the checkpoint.
// For more details about this parameter, see the restore_map (TF1) or
// restore_from_object (TF2) function documentation in the
// /meta_architectures/*meta_arch.py files
// /meta_architectures/*meta_arch.py files
.
optional
string
fine_tune_checkpoint_type
=
22
[
default
=
""
];
// Either "v1" or "v2". If v1, restores the checkpoint using the tensorflow
...
...
research/object_detection/utils/variables_helper.py
View file @
aa3e639f
...
...
@@ -21,6 +21,7 @@ from __future__ import division
from
__future__
import
print_function
import
logging
import
os
import
re
import
tensorflow.compat.v1
as
tf
...
...
@@ -29,6 +30,19 @@ import tf_slim as slim
from
tensorflow.python.ops
import
variables
as
tf_variables
# Maps checkpoint types to variable name prefixes that are no longer
# supported
DETECTION_FEATURE_EXTRACTOR_MSG
=
"""
\
The checkpoint type 'detection' is not supported when it contains variable
names with 'feature_extractor'. Please download the new checkpoint file
from model zoo.
"""
DEPRECATED_CHECKPOINT_MAP
=
{
'detection'
:
(
'feature_extractor'
,
DETECTION_FEATURE_EXTRACTOR_MSG
)
}
# TODO(derekjchow): Consider replacing with tf.contrib.filter_variables in
# tensorflow/contrib/framework/python/ops/variables.py
def
filter_variables
(
variables
,
filter_regex_list
,
invert
=
False
):
...
...
@@ -176,3 +190,41 @@ def get_global_variables_safely():
"executing eagerly. Use a Keras model's `.variables` "
"attribute instead."
)
return
tf
.
global_variables
()
def
ensure_checkpoint_supported
(
checkpoint_path
,
checkpoint_type
,
model_dir
):
"""Ensures that the given checkpoint can be properly loaded.
Performs the following checks
1. Raises an error if checkpoint_path and model_dir are same.
2. Checks that checkpoint_path does not contain a deprecated checkpoint file
by inspecting its variables.
Args:
checkpoint_path: str, path to checkpoint.
checkpoint_type: str, denotes the type of checkpoint.
model_dir: The model directory to store intermediate training checkpoints.
Raises:
RuntimeError: If
1. We detect an deprecated checkpoint file.
2. model_dir and checkpoint_path are in the same directory.
"""
variables
=
tf
.
train
.
list_variables
(
checkpoint_path
)
if
checkpoint_type
in
DEPRECATED_CHECKPOINT_MAP
:
blocked_prefix
,
msg
=
DEPRECATED_CHECKPOINT_MAP
[
checkpoint_type
]
for
var_name
,
_
in
variables
:
if
var_name
.
startswith
(
blocked_prefix
):
tf
.
logging
.
error
(
'Found variable name - %s with prefix %s'
,
var_name
,
blocked_prefix
)
raise
RuntimeError
(
msg
)
checkpoint_path_dir
=
os
.
path
.
abspath
(
os
.
path
.
dirname
(
checkpoint_path
))
model_dir
=
os
.
path
.
abspath
(
model_dir
)
if
model_dir
==
checkpoint_path_dir
:
raise
RuntimeError
(
(
'Checkpoint dir ({}) and model_dir ({}) cannot be same.'
.
format
(
checkpoint_path_dir
,
model_dir
)
+
(
' Please set model_dir to a different path.'
)))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment