Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
78f0e355
Commit
78f0e355
authored
Dec 04, 2020
by
Vighnesh Birodkar
Committed by
TF Object Detection Team
Dec 04, 2020
Browse files
Add checkpoint type 'full' to SSD meta arch.
PiperOrigin-RevId: 345620862
parent
fd6b24c1
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
59 additions
and
9 deletions
+59
-9
research/object_detection/meta_architectures/ssd_meta_arch.py
...arch/object_detection/meta_architectures/ssd_meta_arch.py
+15
-4
research/object_detection/meta_architectures/ssd_meta_arch_test.py
...object_detection/meta_architectures/ssd_meta_arch_test.py
+33
-1
research/object_detection/protos/train.proto
research/object_detection/protos/train.proto
+7
-4
research/object_detection/utils/test_utils.py
research/object_detection/utils/test_utils.py
+4
-0
No files found.
research/object_detection/meta_architectures/ssd_meta_arch.py
View file @
78f0e355
...
...
@@ -1308,10 +1308,17 @@ class SSDMetaArch(model.DetectionModel):
to be used to restore Slim-based models when running Tensorflow 1.x.
Args:
fine_tune_checkpoint_type: whether to restore from a full detection
checkpoint (with compatible variable names) or to restore from a
classification checkpoint for initialization prior to training.
Valid values: `detection`, `classification`. Default 'detection'.
fine_tune_checkpoint_type: A string inidicating the subset of variables
to load. Valid values: `detection`, `classification`, `full`. Default
`detection`.
An SSD checkpoint has three parts:
1) Classification Network (like ResNet)
2) DeConv layers (for FPN)
3) Box/Class prediction parameters
The parameters will be loaded using the following strategy:
`classification` - will load #1
`detection` - will load #1, #2
`full` - will load #1, #2, #3
Returns:
A dict mapping keys to Trackable objects (tf.Module or Checkpoint).
...
...
@@ -1325,6 +1332,10 @@ class SSDMetaArch(model.DetectionModel):
fake_model
=
tf
.
train
.
Checkpoint
(
_feature_extractor
=
self
.
_feature_extractor
)
return
{
'model'
:
fake_model
}
elif
fine_tune_checkpoint_type
==
'full'
:
return
{
'model'
:
self
}
else
:
raise
ValueError
(
'Not supported fine_tune_checkpoint_type: {}'
.
format
(
fine_tune_checkpoint_type
))
...
...
research/object_detection/meta_architectures/ssd_meta_arch_test.py
View file @
78f0e355
...
...
@@ -615,7 +615,6 @@ class SsdMetaArchTest(ssd_meta_arch_test_lib.SSDMetaArchTestBase,
self
.
assertNotIn
(
six
.
ensure_binary
(
'FeatureExtractor'
),
var
)
def
test_load_all_det_checkpoint_vars
(
self
):
# TODO(rathodv): Support TF2.X
if
self
.
is_tf2
():
return
test_graph_detection
=
tf
.
Graph
()
with
test_graph_detection
.
as_default
():
...
...
@@ -634,6 +633,39 @@ class SsdMetaArchTest(ssd_meta_arch_test_lib.SSDMetaArchTestBase,
self
.
assertIsInstance
(
var_map
,
dict
)
self
.
assertIn
(
'another_variable'
,
var_map
)
def
test_load_checkpoint_vars_tf2
(
self
):
if
not
self
.
is_tf2
():
self
.
skipTest
(
'Not running TF2 checkpoint test with TF1.'
)
model
,
_
,
_
,
_
=
self
.
_create_model
()
inputs_shape
=
[
2
,
2
,
2
,
3
]
inputs
=
tf
.
cast
(
tf
.
random_uniform
(
inputs_shape
,
minval
=
0
,
maxval
=
255
,
dtype
=
tf
.
int32
),
dtype
=
tf
.
float32
)
model
(
inputs
)
detection_var_names
=
sorted
([
var
.
name
for
var
in
model
.
restore_from_objects
(
'detection'
)[
'model'
].
_feature_extractor
.
weights
])
expected_detection_names
=
[
'ssd_meta_arch/fake_ssd_keras_feature_extractor/mock_model/layer1/bias:0'
,
'ssd_meta_arch/fake_ssd_keras_feature_extractor/mock_model/layer1/kernel:0'
]
self
.
assertEqual
(
detection_var_names
,
expected_detection_names
)
full_var_names
=
sorted
([
var
.
name
for
var
in
model
.
restore_from_objects
(
'full'
)[
'model'
].
weights
])
exepcted_full_names
=
[
'box_predictor_var:0'
]
+
expected_detection_names
self
.
assertEqual
(
exepcted_full_names
,
full_var_names
)
# TODO(vighneshb) Add similar test for classification checkpoint type.
# TODO(vighneshb) Test loading a checkpoint from disk to verify that
# checkpoints are loaded correctly.
def
test_loss_results_are_correct_with_random_example_sampling
(
self
):
with
test_utils
.
GraphContextOrNone
()
as
g
:
model
,
num_classes
,
_
,
_
=
self
.
_create_model
(
...
...
research/object_detection/protos/train.proto
View file @
78f0e355
...
...
@@ -40,9 +40,12 @@ message TrainConfig {
// extractor variables trained outside of object detection.
optional
string
fine_tune_checkpoint
=
7
[
default
=
""
];
// Type of checkpoint to restore variables from, e.g. 'classification' or
// 'detection'. Provides extensibility to from_detection_checkpoint.
// Typically used to load feature extractor variables from trained models.
// Type of checkpoint to restore variables from, e.g. 'classification'
// 'detection', `fine_tune`, `full`. Controls which variables are restored
// from the pre-trained checkpoint. For meta architecture specific valid
// values of this parameter, see the restore_map (TF1) or
// restore_from_object (TF2) function documentation in the
// /meta_architectures/*meta_arch.py files
optional
string
fine_tune_checkpoint_type
=
22
[
default
=
""
];
// Either "v1" or "v2". If v1, restores the checkpoint using the tensorflow
...
...
@@ -60,7 +63,7 @@ message TrainConfig {
// Whether to load all checkpoint vars that match model variable names and
// sizes. This option is only available if `from_detection_checkpoint` is
// True. This option is *not* supported for TF2 --- setting it to true
// will raise an error.
// will raise an error.
Instead, set fine_tune_checkpoint_type: 'full'.
optional
bool
load_all_detection_checkpoint_vars
=
19
[
default
=
false
];
// Number of steps to train the DetectionModel for. If 0, will train the model
...
...
research/object_detection/utils/test_utils.py
View file @
78f0e355
...
...
@@ -101,6 +101,10 @@ class MockKerasBoxPredictor(box_predictor.KerasBoxPredictor):
is_training
,
num_classes
,
False
,
False
)
self
.
_add_background_class
=
add_background_class
# Dummy variable so that box predictor registers some variables.
self
.
_dummy_var
=
tf
.
Variable
(
0.0
,
trainable
=
True
,
name
=
'box_predictor_var'
)
def
_predict
(
self
,
image_features
,
**
kwargs
):
image_feature
=
image_features
[
0
]
combined_feature_shape
=
shape_utils
.
combined_static_and_dynamic_shape
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment