Commit fa45b626 authored by Derek Chow's avatar Derek Chow
Browse files

Change model.restore_fn to return a variable map instead of init_fn.

parent a57a00f6
......@@ -228,25 +228,24 @@ class DetectionModel(object):
fields.BoxListFields.keypoints] = groundtruth_keypoints_list
@abstractmethod
def restore_fn(self, checkpoint_path, from_detection_checkpoint=True):
"""Return callable for loading a foreign checkpoint into tensorflow graph.
def restore_map(self, from_detection_checkpoint=True):
"""Returns a map of variables to load from a foreign checkpoint.
Loads variables from a different tensorflow graph (typically feature
extractor variables). This enables the model to initialize based on weights
from another task. For example, the feature extractor variables from a
Returns a map of variable names to load from a checkpoint to variables in
the model graph. This enables the model to initialize based on weights from
another task. For example, the feature extractor variables from a
classification model can be used to bootstrap training of an object
detector. When loading from an object detection model, the checkpoint model
should have the same parameters as this detection model with exception of
the num_classes parameter.
Args:
checkpoint_path: path to checkpoint to restore.
from_detection_checkpoint: whether to restore from a full detection
checkpoint (with compatible variable names) or to restore from a
classification checkpoint for initialization prior to training.
Returns:
a callable which takes a tf.Session as input and loads a checkpoint when
run.
A dict mapping variable names (to load from a checkpoint) to variables in
the model graph.
"""
pass
......@@ -54,7 +54,7 @@ class FakeModel(model.DetectionModel):
np.arange(32).reshape([2, 4, 4]), tf.float32)
return postprocessed_tensors
def restore_fn(self, checkpoint_path, from_detection_checkpoint):
def restore_map(self, checkpoint_path, from_detection_checkpoint):
pass
def loss(self, prediction_dict):
......
......@@ -56,7 +56,6 @@ py_library(
"//tensorflow_models/object_detection/core:standard_fields",
"//tensorflow_models/object_detection/core:target_assigner",
"//tensorflow_models/object_detection/utils:ops",
"//tensorflow_models/object_detection/utils:variables_helper",
],
)
......
......@@ -80,7 +80,6 @@ from object_detection.core import post_processing
from object_detection.core import standard_fields as fields
from object_detection.core import target_assigner
from object_detection.utils import ops
from object_detection.utils import variables_helper
slim = tf.contrib.slim
......@@ -159,21 +158,19 @@ class FasterRCNNFeatureExtractor(object):
def restore_from_classification_checkpoint_fn(
self,
checkpoint_path,
first_stage_feature_extractor_scope,
second_stage_feature_extractor_scope):
"""Returns callable for loading a checkpoint into the tensorflow graph.
"""Returns a map of variables to load from a foreign checkpoint.
Args:
checkpoint_path: path to checkpoint to restore.
first_stage_feature_extractor_scope: A scope name for the first stage
feature extractor.
second_stage_feature_extractor_scope: A scope name for the second stage
feature extractor.
Returns:
a callable which takes a tf.Session as input and loads a checkpoint when
run.
A dict mapping variable names (to load from a checkpoint) to variables in
the model graph.
"""
variables_to_restore = {}
for variable in tf.global_variables():
......@@ -182,13 +179,7 @@ class FasterRCNNFeatureExtractor(object):
if variable.op.name.startswith(scope_name):
var_name = variable.op.name.replace(scope_name + '/', '')
variables_to_restore[var_name] = variable
variables_to_restore = (
variables_helper.get_variables_available_in_checkpoint(
variables_to_restore, checkpoint_path))
saver = tf.train.Saver(variables_to_restore)
def restore(sess):
saver.restore(sess, checkpoint_path)
return restore
return variables_to_restore
class FasterRCNNMetaArch(model.DetectionModel):
......@@ -1413,25 +1404,22 @@ class FasterRCNNMetaArch(model.DetectionModel):
cls_losses=tf.expand_dims(single_image_cls_loss, 0),
decoded_boxlist_list=[proposal_boxlist])
def restore_fn(self, checkpoint_path, from_detection_checkpoint=True):
"""Returns callable for loading a checkpoint into the tensorflow graph.
def restore_map(self, from_detection_checkpoint=True):
"""Returns a map of variables to load from a foreign checkpoint.
See parent class for details.
Args:
checkpoint_path: path to checkpoint to restore.
from_detection_checkpoint: whether to restore from a detection checkpoint
(with compatible variable names) or to restore from a classification
checkpoint for initialization prior to training. Note that when
from_detection_checkpoint=True, the current implementation only
supports restoration from an (exactly) identical model (with exception
of the num_classes parameter).
from_detection_checkpoint: whether to restore from a full detection
checkpoint (with compatible variable names) or to restore from a
classification checkpoint for initialization prior to training.
Returns:
a callable which takes a tf.Session as input and loads a checkpoint when
run.
A dict mapping variable names (to load from a checkpoint) to variables in
the model graph.
"""
if not from_detection_checkpoint:
return self._feature_extractor.restore_from_classification_checkpoint_fn(
checkpoint_path,
self.first_stage_feature_extractor_scope,
self.second_stage_feature_extractor_scope)
......@@ -1439,13 +1427,9 @@ class FasterRCNNMetaArch(model.DetectionModel):
variables_to_restore.append(slim.get_or_create_global_step())
# Only load feature extractor variables to be consistent with loading from
# a classification checkpoint.
first_stage_variables = tf.contrib.framework.filter_variables(
feature_extractor_variables = tf.contrib.framework.filter_variables(
variables_to_restore,
include_patterns=[self.first_stage_feature_extractor_scope,
self.second_stage_feature_extractor_scope])
return {var.op.name: var for var in feature_extractor_variables}
saver = tf.train.Saver(first_stage_variables)
def restore(sess):
saver.restore(sess, checkpoint_path)
return restore
......@@ -957,7 +957,7 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase):
exp_loc_loss)
self.assertAllClose(loss_dict_out['second_stage_classification_loss'], 0)
def test_restore_fn_classification(self):
def test_restore_map_for_classification_ckpt(self):
# Define mock tensorflow classification graph and save variables.
test_graph_classification = tf.Graph()
with test_graph_classification.as_default():
......@@ -986,12 +986,17 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase):
preprocessed_inputs = model.preprocess(inputs)
prediction_dict = model.predict(preprocessed_inputs)
model.postprocess(prediction_dict)
restore_fn = model.restore_fn(saved_model_path,
from_detection_checkpoint=False)
var_map = model.restore_map(from_detection_checkpoint=False)
self.assertIsInstance(var_map, dict)
saver = tf.train.Saver(var_map)
with self.test_session() as sess:
restore_fn(sess)
saver.restore(sess, saved_model_path)
for var in sess.run(tf.report_uninitialized_variables()):
self.assertNotIn(model.first_stage_feature_extractor_scope, var.name)
self.assertNotIn(model.second_stage_feature_extractor_scope,
var.name)
def test_restore_fn_detection(self):
def test_restore_map_for_detection_ckpt(self):
# Define first detection graph and save variables.
test_graph_detection1 = tf.Graph()
with test_graph_detection1.as_default():
......@@ -1022,10 +1027,11 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase):
preprocessed_inputs2 = model2.preprocess(inputs2)
prediction_dict2 = model2.predict(preprocessed_inputs2)
model2.postprocess(prediction_dict2)
restore_fn = model2.restore_fn(saved_model_path,
from_detection_checkpoint=True)
var_map = model2.restore_map(from_detection_checkpoint=True)
self.assertIsInstance(var_map, dict)
saver = tf.train.Saver(var_map)
with self.test_session() as sess:
restore_fn(sess)
saver.restore(sess, saved_model_path)
for var in sess.run(tf.report_uninitialized_variables()):
self.assertNotIn(model2.first_stage_feature_extractor_scope, var.name)
self.assertNotIn(model2.second_stage_feature_extractor_scope,
......
......@@ -29,7 +29,6 @@ from object_detection.core import box_predictor as bpredictor
from object_detection.core import model
from object_detection.core import standard_fields as fields
from object_detection.core import target_assigner
from object_detection.utils import variables_helper
slim = tf.contrib.slim
......@@ -562,33 +561,26 @@ class SSDMetaArch(model.DetectionModel):
decoded_boxlist_list=decoded_boxlist_list,
match_list=match_list)
def restore_fn(self, checkpoint_path, from_detection_checkpoint=True):
"""Return callable for loading a checkpoint into the tensorflow graph.
def restore_map(self, from_detection_checkpoint=True):
"""Returns a map of variables to load from a foreign checkpoint.
See parent class for details.
Args:
checkpoint_path: path to checkpoint to restore.
from_detection_checkpoint: whether to restore from a full detection
checkpoint (with compatible variable names) or to restore from a
classification checkpoint for initialization prior to training.
Returns:
a callable which takes a tf.Session as input and loads a checkpoint when
run.
A dict mapping variable names (to load from a checkpoint) to variables in
the model graph.
"""
variables_to_restore = {}
for variable in tf.all_variables():
if variable.op.name.startswith(self._extract_features_scope):
var_name = variable.op.name
if not from_detection_checkpoint:
var_name = (
re.split('^' + self._extract_features_scope + '/', var_name)[-1])
var_name = (re.split('^' + self._extract_features_scope + '/',
var_name)[-1])
variables_to_restore[var_name] = variable
# TODO: Load variables selectively using scopes.
variables_to_restore = (
variables_helper.get_variables_available_in_checkpoint(
variables_to_restore, checkpoint_path))
saver = tf.train.Saver(variables_to_restore)
def restore(sess):
saver.restore(sess, checkpoint_path)
return restore
return variables_to_restore
......@@ -207,20 +207,21 @@ class SsdMetaArchTest(tf.test.TestCase):
self.assertAllClose(losses_out['classification_loss'],
expected_classification_loss)
def test_restore_fn_detection(self):
def test_restore_map_for_detection_ckpt(self):
init_op = tf.global_variables_initializer()
saver = tf_saver.Saver()
save_path = self.get_temp_dir()
with self.test_session() as sess:
sess.run(init_op)
saved_model_path = saver.save(sess, save_path)
restore_fn = self._model.restore_fn(saved_model_path,
from_detection_checkpoint=True)
restore_fn(sess)
var_map = self._model.restore_map(from_detection_checkpoint=True)
self.assertIsInstance(var_map, dict)
saver = tf.train.Saver(var_map)
saver.restore(sess, saved_model_path)
for var in sess.run(tf.report_uninitialized_variables()):
self.assertNotIn('FeatureExtractor', var.name)
def test_restore_fn_classification(self):
def test_restore_map_for_classification_ckpt(self):
# Define mock tensorflow classification graph and save variables.
test_graph_classification = tf.Graph()
with test_graph_classification.as_default():
......@@ -246,10 +247,11 @@ class SsdMetaArchTest(tf.test.TestCase):
preprocessed_inputs = self._model.preprocess(inputs)
prediction_dict = self._model.predict(preprocessed_inputs)
self._model.postprocess(prediction_dict)
restore_fn = self._model.restore_fn(saved_model_path,
from_detection_checkpoint=False)
var_map = self._model.restore_map(from_detection_checkpoint=False)
self.assertIsInstance(var_map, dict)
saver = tf.train.Saver(var_map)
with self.test_session() as sess:
restore_fn(sess)
saver.restore(sess, saved_model_path)
for var in sess.run(tf.report_uninitialized_variables()):
self.assertNotIn('FeatureExtractor', var.name)
......
......@@ -94,7 +94,6 @@ py_library(
deps = [
"//tensorflow",
"//tensorflow_models/object_detection/meta_architectures:faster_rcnn_meta_arch",
"//tensorflow_models/object_detection/utils:variables_helper",
"//tensorflow_models/slim:inception_resnet_v2",
],
)
......
......@@ -25,7 +25,6 @@ Huang et al. (https://arxiv.org/abs/1611.10012)
import tensorflow as tf
from object_detection.meta_architectures import faster_rcnn_meta_arch
from object_detection.utils import variables_helper
from nets import inception_resnet_v2
slim = tf.contrib.slim
......@@ -168,30 +167,30 @@ class FasterRCNNInceptionResnetV2FeatureExtractor(
def restore_from_classification_checkpoint_fn(
self,
checkpoint_path,
first_stage_feature_extractor_scope,
second_stage_feature_extractor_scope):
"""Returns callable for loading a checkpoint into the tensorflow graph.
"""Returns a map of variables to load from a foreign checkpoint.
Note that this overrides the default implementation in
faster_rcnn_meta_arch.FasterRCNNFeatureExtractor which does not work for
InceptionResnetV2 checkpoints.
TODO: revisit whether it's possible to force the `Repeat` namescope as
created in `_extract_box_classifier_features` to start counting at 2 (e.g.
`Repeat_2`) so that the default restore_fn can be used.
TODO(jonathanhuang,rathodv): revisit whether it's possible to force the
`Repeat` namescope as created in `_extract_box_classifier_features` to
start counting at 2 (e.g. `Repeat_2`) so that the default restore_fn can
be used.
Args:
checkpoint_path: Path to checkpoint to restore.
first_stage_feature_extractor_scope: A scope name for the first stage
feature extractor.
second_stage_feature_extractor_scope: A scope name for the second stage
feature extractor.
Returns:
a callable which takes a tf.Session as input and loads a checkpoint when
run.
A dict mapping variable names (to load from a checkpoint) to variables in
the model graph.
"""
variables_to_restore = {}
for variable in tf.global_variables():
if variable.op.name.startswith(
......@@ -207,10 +206,4 @@ class FasterRCNNInceptionResnetV2FeatureExtractor(
var_name = var_name.replace(
second_stage_feature_extractor_scope + '/', '')
variables_to_restore[var_name] = variable
variables_to_restore = (
variables_helper.get_variables_available_in_checkpoint(
variables_to_restore, checkpoint_path))
saver = tf.train.Saver(variables_to_restore)
def restore(sess):
saver.restore(sess, checkpoint_path)
return restore
return variables_to_restore
......@@ -211,9 +211,14 @@ def train(create_tensor_dict_fn, create_model_fn, train_config, master, task,
# Create ops required to initialize the model from a given checkpoint.
init_fn = None
if train_config.fine_tune_checkpoint:
init_fn = detection_model.restore_fn(
train_config.fine_tune_checkpoint,
var_map = detection_model.restore_map(
from_detection_checkpoint=train_config.from_detection_checkpoint)
var_map = variables_helper.get_variables_available_in_checkpoint(
var_map, train_config.fine_tune_checkpoint)
saver = tf.train.Saver(var_map)
def initializer_fn(sess):
saver.restore(sess, train_config.fine_tune_checkpoint)
init_fn = initializer_fn
with tf.device(deploy_config.optimizer_device()):
total_loss, grads_and_vars = model_deploy.optimize_clones(
......
......@@ -139,21 +139,18 @@ class FakeDetectionModel(model.DetectionModel):
}
return loss_dict
def restore_fn(self, checkpoint_path, from_detection_checkpoint=True):
"""Return callable for loading a checkpoint into the tensorflow graph.
def restore_map(self, from_detection_checkpoint=True):
"""Returns a map of variables to load from a foreign checkpoint.
Args:
checkpoint_path: path to checkpoint to restore.
from_detection_checkpoint: whether to restore from a full detection
checkpoint (with compatible variable names) or to restore from a
classification checkpoint for initialization prior to training.
Returns:
a callable which takes a tf.Session and does nothing.
A dict mapping variable names to variables.
"""
def restore(unused_sess):
return
return restore
return {var.op.name: var for var in tf.global_variables()}
class TrainerTest(tf.test.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment