Commit 78f0e355 authored by Vighnesh Birodkar's avatar Vighnesh Birodkar Committed by TF Object Detection Team
Browse files

Add checkpoint type 'full' to SSD meta arch.

PiperOrigin-RevId: 345620862
parent fd6b24c1
...@@ -1308,10 +1308,17 @@ class SSDMetaArch(model.DetectionModel): ...@@ -1308,10 +1308,17 @@ class SSDMetaArch(model.DetectionModel):
to be used to restore Slim-based models when running Tensorflow 1.x. to be used to restore Slim-based models when running Tensorflow 1.x.
Args: Args:
fine_tune_checkpoint_type: whether to restore from a full detection fine_tune_checkpoint_type: A string inidicating the subset of variables
checkpoint (with compatible variable names) or to restore from a to load. Valid values: `detection`, `classification`, `full`. Default
classification checkpoint for initialization prior to training. `detection`.
Valid values: `detection`, `classification`. Default 'detection'. An SSD checkpoint has three parts:
1) Classification Network (like ResNet)
2) DeConv layers (for FPN)
3) Box/Class prediction parameters
The parameters will be loaded using the following strategy:
`classification` - will load #1
`detection` - will load #1, #2
`full` - will load #1, #2, #3
Returns: Returns:
A dict mapping keys to Trackable objects (tf.Module or Checkpoint). A dict mapping keys to Trackable objects (tf.Module or Checkpoint).
...@@ -1325,6 +1332,10 @@ class SSDMetaArch(model.DetectionModel): ...@@ -1325,6 +1332,10 @@ class SSDMetaArch(model.DetectionModel):
fake_model = tf.train.Checkpoint( fake_model = tf.train.Checkpoint(
_feature_extractor=self._feature_extractor) _feature_extractor=self._feature_extractor)
return {'model': fake_model} return {'model': fake_model}
elif fine_tune_checkpoint_type == 'full':
return {'model': self}
else: else:
raise ValueError('Not supported fine_tune_checkpoint_type: {}'.format( raise ValueError('Not supported fine_tune_checkpoint_type: {}'.format(
fine_tune_checkpoint_type)) fine_tune_checkpoint_type))
......
...@@ -615,7 +615,6 @@ class SsdMetaArchTest(ssd_meta_arch_test_lib.SSDMetaArchTestBase, ...@@ -615,7 +615,6 @@ class SsdMetaArchTest(ssd_meta_arch_test_lib.SSDMetaArchTestBase,
self.assertNotIn(six.ensure_binary('FeatureExtractor'), var) self.assertNotIn(six.ensure_binary('FeatureExtractor'), var)
def test_load_all_det_checkpoint_vars(self): def test_load_all_det_checkpoint_vars(self):
# TODO(rathodv): Support TF2.X
if self.is_tf2(): return if self.is_tf2(): return
test_graph_detection = tf.Graph() test_graph_detection = tf.Graph()
with test_graph_detection.as_default(): with test_graph_detection.as_default():
...@@ -634,6 +633,39 @@ class SsdMetaArchTest(ssd_meta_arch_test_lib.SSDMetaArchTestBase, ...@@ -634,6 +633,39 @@ class SsdMetaArchTest(ssd_meta_arch_test_lib.SSDMetaArchTestBase,
self.assertIsInstance(var_map, dict) self.assertIsInstance(var_map, dict)
self.assertIn('another_variable', var_map) self.assertIn('another_variable', var_map)
def test_load_checkpoint_vars_tf2(self):
if not self.is_tf2():
self.skipTest('Not running TF2 checkpoint test with TF1.')
model, _, _, _ = self._create_model()
inputs_shape = [2, 2, 2, 3]
inputs = tf.cast(
tf.random_uniform(inputs_shape, minval=0, maxval=255, dtype=tf.int32),
dtype=tf.float32)
model(inputs)
detection_var_names = sorted([
var.name for var in model.restore_from_objects('detection')[
'model']._feature_extractor.weights
])
expected_detection_names = [
'ssd_meta_arch/fake_ssd_keras_feature_extractor/mock_model/layer1/bias:0',
'ssd_meta_arch/fake_ssd_keras_feature_extractor/mock_model/layer1/kernel:0'
]
self.assertEqual(detection_var_names, expected_detection_names)
full_var_names = sorted([
var.name for var in
model.restore_from_objects('full')['model'].weights
])
exepcted_full_names = ['box_predictor_var:0'] + expected_detection_names
self.assertEqual(exepcted_full_names, full_var_names)
# TODO(vighneshb) Add similar test for classification checkpoint type.
# TODO(vighneshb) Test loading a checkpoint from disk to verify that
# checkpoints are loaded correctly.
def test_loss_results_are_correct_with_random_example_sampling(self): def test_loss_results_are_correct_with_random_example_sampling(self):
with test_utils.GraphContextOrNone() as g: with test_utils.GraphContextOrNone() as g:
model, num_classes, _, _ = self._create_model( model, num_classes, _, _ = self._create_model(
......
...@@ -40,9 +40,12 @@ message TrainConfig { ...@@ -40,9 +40,12 @@ message TrainConfig {
// extractor variables trained outside of object detection. // extractor variables trained outside of object detection.
optional string fine_tune_checkpoint = 7 [default=""]; optional string fine_tune_checkpoint = 7 [default=""];
// Type of checkpoint to restore variables from, e.g. 'classification' or // Type of checkpoint to restore variables from, e.g. 'classification'
// 'detection'. Provides extensibility to from_detection_checkpoint. // 'detection', `fine_tune`, `full`. Controls which variables are restored
// Typically used to load feature extractor variables from trained models. // from the pre-trained checkpoint. For meta architecture specific valid
// values of this parameter, see the restore_map (TF1) or
// restore_from_object (TF2) function documentation in the
// /meta_architectures/*meta_arch.py files
optional string fine_tune_checkpoint_type = 22 [default=""]; optional string fine_tune_checkpoint_type = 22 [default=""];
// Either "v1" or "v2". If v1, restores the checkpoint using the tensorflow // Either "v1" or "v2". If v1, restores the checkpoint using the tensorflow
...@@ -60,7 +63,7 @@ message TrainConfig { ...@@ -60,7 +63,7 @@ message TrainConfig {
// Whether to load all checkpoint vars that match model variable names and // Whether to load all checkpoint vars that match model variable names and
// sizes. This option is only available if `from_detection_checkpoint` is // sizes. This option is only available if `from_detection_checkpoint` is
// True. This option is *not* supported for TF2 --- setting it to true // True. This option is *not* supported for TF2 --- setting it to true
// will raise an error. // will raise an error. Instead, set fine_tune_checkpoint_type: 'full'.
optional bool load_all_detection_checkpoint_vars = 19 [default = false]; optional bool load_all_detection_checkpoint_vars = 19 [default = false];
// Number of steps to train the DetectionModel for. If 0, will train the model // Number of steps to train the DetectionModel for. If 0, will train the model
......
...@@ -101,6 +101,10 @@ class MockKerasBoxPredictor(box_predictor.KerasBoxPredictor): ...@@ -101,6 +101,10 @@ class MockKerasBoxPredictor(box_predictor.KerasBoxPredictor):
is_training, num_classes, False, False) is_training, num_classes, False, False)
self._add_background_class = add_background_class self._add_background_class = add_background_class
# Dummy variable so that box predictor registers some variables.
self._dummy_var = tf.Variable(0.0, trainable=True,
name='box_predictor_var')
def _predict(self, image_features, **kwargs): def _predict(self, image_features, **kwargs):
image_feature = image_features[0] image_feature = image_features[0]
combined_feature_shape = shape_utils.combined_static_and_dynamic_shape( combined_feature_shape = shape_utils.combined_static_and_dynamic_shape(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment