"vscode:/vscode.git/clone" did not exist on "c578113aa111c70a3126de5ac579278d0865283d"
Commit 78f0e355 authored by Vighnesh Birodkar's avatar Vighnesh Birodkar Committed by TF Object Detection Team
Browse files

Add checkpoint type 'full' to SSD meta arch.

PiperOrigin-RevId: 345620862
parent fd6b24c1
......@@ -1308,10 +1308,17 @@ class SSDMetaArch(model.DetectionModel):
to be used to restore Slim-based models when running Tensorflow 1.x.
Args:
fine_tune_checkpoint_type: whether to restore from a full detection
checkpoint (with compatible variable names) or to restore from a
classification checkpoint for initialization prior to training.
Valid values: `detection`, `classification`. Default 'detection'.
fine_tune_checkpoint_type: A string inidicating the subset of variables
to load. Valid values: `detection`, `classification`, `full`. Default
`detection`.
An SSD checkpoint has three parts:
1) Classification Network (like ResNet)
2) DeConv layers (for FPN)
3) Box/Class prediction parameters
The parameters will be loaded using the following strategy:
`classification` - will load #1
`detection` - will load #1, #2
`full` - will load #1, #2, #3
Returns:
A dict mapping keys to Trackable objects (tf.Module or Checkpoint).
......@@ -1325,6 +1332,10 @@ class SSDMetaArch(model.DetectionModel):
fake_model = tf.train.Checkpoint(
_feature_extractor=self._feature_extractor)
return {'model': fake_model}
elif fine_tune_checkpoint_type == 'full':
return {'model': self}
else:
raise ValueError('Not supported fine_tune_checkpoint_type: {}'.format(
fine_tune_checkpoint_type))
......
......@@ -615,7 +615,6 @@ class SsdMetaArchTest(ssd_meta_arch_test_lib.SSDMetaArchTestBase,
self.assertNotIn(six.ensure_binary('FeatureExtractor'), var)
def test_load_all_det_checkpoint_vars(self):
# TODO(rathodv): Support TF2.X
if self.is_tf2(): return
test_graph_detection = tf.Graph()
with test_graph_detection.as_default():
......@@ -634,6 +633,39 @@ class SsdMetaArchTest(ssd_meta_arch_test_lib.SSDMetaArchTestBase,
self.assertIsInstance(var_map, dict)
self.assertIn('another_variable', var_map)
def test_load_checkpoint_vars_tf2(self):
if not self.is_tf2():
self.skipTest('Not running TF2 checkpoint test with TF1.')
model, _, _, _ = self._create_model()
inputs_shape = [2, 2, 2, 3]
inputs = tf.cast(
tf.random_uniform(inputs_shape, minval=0, maxval=255, dtype=tf.int32),
dtype=tf.float32)
model(inputs)
detection_var_names = sorted([
var.name for var in model.restore_from_objects('detection')[
'model']._feature_extractor.weights
])
expected_detection_names = [
'ssd_meta_arch/fake_ssd_keras_feature_extractor/mock_model/layer1/bias:0',
'ssd_meta_arch/fake_ssd_keras_feature_extractor/mock_model/layer1/kernel:0'
]
self.assertEqual(detection_var_names, expected_detection_names)
full_var_names = sorted([
var.name for var in
model.restore_from_objects('full')['model'].weights
])
exepcted_full_names = ['box_predictor_var:0'] + expected_detection_names
self.assertEqual(exepcted_full_names, full_var_names)
# TODO(vighneshb) Add similar test for classification checkpoint type.
# TODO(vighneshb) Test loading a checkpoint from disk to verify that
# checkpoints are loaded correctly.
def test_loss_results_are_correct_with_random_example_sampling(self):
with test_utils.GraphContextOrNone() as g:
model, num_classes, _, _ = self._create_model(
......
......@@ -40,9 +40,12 @@ message TrainConfig {
// extractor variables trained outside of object detection.
optional string fine_tune_checkpoint = 7 [default=""];
// Type of checkpoint to restore variables from, e.g. 'classification' or
// 'detection'. Provides extensibility to from_detection_checkpoint.
// Typically used to load feature extractor variables from trained models.
// Type of checkpoint to restore variables from, e.g. 'classification'
// 'detection', `fine_tune`, `full`. Controls which variables are restored
// from the pre-trained checkpoint. For meta architecture specific valid
// values of this parameter, see the restore_map (TF1) or
// restore_from_object (TF2) function documentation in the
// /meta_architectures/*meta_arch.py files
optional string fine_tune_checkpoint_type = 22 [default=""];
// Either "v1" or "v2". If v1, restores the checkpoint using the tensorflow
......@@ -60,7 +63,7 @@ message TrainConfig {
// Whether to load all checkpoint vars that match model variable names and
// sizes. This option is only available if `from_detection_checkpoint` is
// True. This option is *not* supported for TF2 --- setting it to true
// will raise an error.
// will raise an error. Instead, set fine_tune_checkpoint_type: 'full'.
optional bool load_all_detection_checkpoint_vars = 19 [default = false];
// Number of steps to train the DetectionModel for. If 0, will train the model
......
......@@ -101,6 +101,10 @@ class MockKerasBoxPredictor(box_predictor.KerasBoxPredictor):
is_training, num_classes, False, False)
self._add_background_class = add_background_class
# Dummy variable so that box predictor registers some variables.
self._dummy_var = tf.Variable(0.0, trainable=True,
name='box_predictor_var')
def _predict(self, image_features, **kwargs):
image_feature = image_features[0]
combined_feature_shape = shape_utils.combined_static_and_dynamic_shape(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment