"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "1ac38ec89c5899f44f84e44ee461c714544c6af0"
Commit fa45b626 authored by Derek Chow's avatar Derek Chow
Browse files

Change model.restore_fn to return a variable map instead of init_fn.

parent a57a00f6
...@@ -228,25 +228,24 @@ class DetectionModel(object): ...@@ -228,25 +228,24 @@ class DetectionModel(object):
fields.BoxListFields.keypoints] = groundtruth_keypoints_list fields.BoxListFields.keypoints] = groundtruth_keypoints_list
@abstractmethod @abstractmethod
def restore_fn(self, checkpoint_path, from_detection_checkpoint=True): def restore_map(self, from_detection_checkpoint=True):
"""Return callable for loading a foreign checkpoint into tensorflow graph. """Returns a map of variables to load from a foreign checkpoint.
Loads variables from a different tensorflow graph (typically feature Returns a map of variable names to load from a checkpoint to variables in
extractor variables). This enables the model to initialize based on weights the model graph. This enables the model to initialize based on weights from
from another task. For example, the feature extractor variables from a another task. For example, the feature extractor variables from a
classification model can be used to bootstrap training of an object classification model can be used to bootstrap training of an object
detector. When loading from an object detection model, the checkpoint model detector. When loading from an object detection model, the checkpoint model
should have the same parameters as this detection model with exception of should have the same parameters as this detection model with exception of
the num_classes parameter. the num_classes parameter.
Args: Args:
checkpoint_path: path to checkpoint to restore.
from_detection_checkpoint: whether to restore from a full detection from_detection_checkpoint: whether to restore from a full detection
checkpoint (with compatible variable names) or to restore from a checkpoint (with compatible variable names) or to restore from a
classification checkpoint for initialization prior to training. classification checkpoint for initialization prior to training.
Returns: Returns:
a callable which takes a tf.Session as input and loads a checkpoint when A dict mapping variable names (to load from a checkpoint) to variables in
run. the model graph.
""" """
pass pass
...@@ -54,7 +54,7 @@ class FakeModel(model.DetectionModel): ...@@ -54,7 +54,7 @@ class FakeModel(model.DetectionModel):
np.arange(32).reshape([2, 4, 4]), tf.float32) np.arange(32).reshape([2, 4, 4]), tf.float32)
return postprocessed_tensors return postprocessed_tensors
def restore_fn(self, checkpoint_path, from_detection_checkpoint): def restore_map(self, checkpoint_path, from_detection_checkpoint):
pass pass
def loss(self, prediction_dict): def loss(self, prediction_dict):
......
...@@ -56,7 +56,6 @@ py_library( ...@@ -56,7 +56,6 @@ py_library(
"//tensorflow_models/object_detection/core:standard_fields", "//tensorflow_models/object_detection/core:standard_fields",
"//tensorflow_models/object_detection/core:target_assigner", "//tensorflow_models/object_detection/core:target_assigner",
"//tensorflow_models/object_detection/utils:ops", "//tensorflow_models/object_detection/utils:ops",
"//tensorflow_models/object_detection/utils:variables_helper",
], ],
) )
......
...@@ -80,7 +80,6 @@ from object_detection.core import post_processing ...@@ -80,7 +80,6 @@ from object_detection.core import post_processing
from object_detection.core import standard_fields as fields from object_detection.core import standard_fields as fields
from object_detection.core import target_assigner from object_detection.core import target_assigner
from object_detection.utils import ops from object_detection.utils import ops
from object_detection.utils import variables_helper
slim = tf.contrib.slim slim = tf.contrib.slim
...@@ -159,21 +158,19 @@ class FasterRCNNFeatureExtractor(object): ...@@ -159,21 +158,19 @@ class FasterRCNNFeatureExtractor(object):
def restore_from_classification_checkpoint_fn( def restore_from_classification_checkpoint_fn(
self, self,
checkpoint_path,
first_stage_feature_extractor_scope, first_stage_feature_extractor_scope,
second_stage_feature_extractor_scope): second_stage_feature_extractor_scope):
"""Returns callable for loading a checkpoint into the tensorflow graph. """Returns a map of variables to load from a foreign checkpoint.
Args: Args:
checkpoint_path: path to checkpoint to restore.
first_stage_feature_extractor_scope: A scope name for the first stage first_stage_feature_extractor_scope: A scope name for the first stage
feature extractor. feature extractor.
second_stage_feature_extractor_scope: A scope name for the second stage second_stage_feature_extractor_scope: A scope name for the second stage
feature extractor. feature extractor.
Returns: Returns:
a callable which takes a tf.Session as input and loads a checkpoint when A dict mapping variable names (to load from a checkpoint) to variables in
run. the model graph.
""" """
variables_to_restore = {} variables_to_restore = {}
for variable in tf.global_variables(): for variable in tf.global_variables():
...@@ -182,13 +179,7 @@ class FasterRCNNFeatureExtractor(object): ...@@ -182,13 +179,7 @@ class FasterRCNNFeatureExtractor(object):
if variable.op.name.startswith(scope_name): if variable.op.name.startswith(scope_name):
var_name = variable.op.name.replace(scope_name + '/', '') var_name = variable.op.name.replace(scope_name + '/', '')
variables_to_restore[var_name] = variable variables_to_restore[var_name] = variable
variables_to_restore = ( return variables_to_restore
variables_helper.get_variables_available_in_checkpoint(
variables_to_restore, checkpoint_path))
saver = tf.train.Saver(variables_to_restore)
def restore(sess):
saver.restore(sess, checkpoint_path)
return restore
class FasterRCNNMetaArch(model.DetectionModel): class FasterRCNNMetaArch(model.DetectionModel):
...@@ -1413,25 +1404,22 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -1413,25 +1404,22 @@ class FasterRCNNMetaArch(model.DetectionModel):
cls_losses=tf.expand_dims(single_image_cls_loss, 0), cls_losses=tf.expand_dims(single_image_cls_loss, 0),
decoded_boxlist_list=[proposal_boxlist]) decoded_boxlist_list=[proposal_boxlist])
def restore_fn(self, checkpoint_path, from_detection_checkpoint=True): def restore_map(self, from_detection_checkpoint=True):
"""Returns callable for loading a checkpoint into the tensorflow graph. """Returns a map of variables to load from a foreign checkpoint.
See parent class for details.
Args: Args:
checkpoint_path: path to checkpoint to restore. from_detection_checkpoint: whether to restore from a full detection
from_detection_checkpoint: whether to restore from a detection checkpoint checkpoint (with compatible variable names) or to restore from a
(with compatible variable names) or to restore from a classification classification checkpoint for initialization prior to training.
checkpoint for initialization prior to training. Note that when
from_detection_checkpoint=True, the current implementation only
supports restoration from an (exactly) identical model (with exception
of the num_classes parameter).
Returns: Returns:
a callable which takes a tf.Session as input and loads a checkpoint when A dict mapping variable names (to load from a checkpoint) to variables in
run. the model graph.
""" """
if not from_detection_checkpoint: if not from_detection_checkpoint:
return self._feature_extractor.restore_from_classification_checkpoint_fn( return self._feature_extractor.restore_from_classification_checkpoint_fn(
checkpoint_path,
self.first_stage_feature_extractor_scope, self.first_stage_feature_extractor_scope,
self.second_stage_feature_extractor_scope) self.second_stage_feature_extractor_scope)
...@@ -1439,13 +1427,9 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -1439,13 +1427,9 @@ class FasterRCNNMetaArch(model.DetectionModel):
variables_to_restore.append(slim.get_or_create_global_step()) variables_to_restore.append(slim.get_or_create_global_step())
# Only load feature extractor variables to be consistent with loading from # Only load feature extractor variables to be consistent with loading from
# a classification checkpoint. # a classification checkpoint.
first_stage_variables = tf.contrib.framework.filter_variables( feature_extractor_variables = tf.contrib.framework.filter_variables(
variables_to_restore, variables_to_restore,
include_patterns=[self.first_stage_feature_extractor_scope, include_patterns=[self.first_stage_feature_extractor_scope,
self.second_stage_feature_extractor_scope]) self.second_stage_feature_extractor_scope])
return {var.op.name: var for var in feature_extractor_variables}
saver = tf.train.Saver(first_stage_variables)
def restore(sess):
saver.restore(sess, checkpoint_path)
return restore
...@@ -957,7 +957,7 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase): ...@@ -957,7 +957,7 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase):
exp_loc_loss) exp_loc_loss)
self.assertAllClose(loss_dict_out['second_stage_classification_loss'], 0) self.assertAllClose(loss_dict_out['second_stage_classification_loss'], 0)
def test_restore_fn_classification(self): def test_restore_map_for_classification_ckpt(self):
# Define mock tensorflow classification graph and save variables. # Define mock tensorflow classification graph and save variables.
test_graph_classification = tf.Graph() test_graph_classification = tf.Graph()
with test_graph_classification.as_default(): with test_graph_classification.as_default():
...@@ -986,12 +986,17 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase): ...@@ -986,12 +986,17 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase):
preprocessed_inputs = model.preprocess(inputs) preprocessed_inputs = model.preprocess(inputs)
prediction_dict = model.predict(preprocessed_inputs) prediction_dict = model.predict(preprocessed_inputs)
model.postprocess(prediction_dict) model.postprocess(prediction_dict)
restore_fn = model.restore_fn(saved_model_path, var_map = model.restore_map(from_detection_checkpoint=False)
from_detection_checkpoint=False) self.assertIsInstance(var_map, dict)
saver = tf.train.Saver(var_map)
with self.test_session() as sess: with self.test_session() as sess:
restore_fn(sess) saver.restore(sess, saved_model_path)
for var in sess.run(tf.report_uninitialized_variables()):
self.assertNotIn(model.first_stage_feature_extractor_scope, var.name)
self.assertNotIn(model.second_stage_feature_extractor_scope,
var.name)
def test_restore_fn_detection(self): def test_restore_map_for_detection_ckpt(self):
# Define first detection graph and save variables. # Define first detection graph and save variables.
test_graph_detection1 = tf.Graph() test_graph_detection1 = tf.Graph()
with test_graph_detection1.as_default(): with test_graph_detection1.as_default():
...@@ -1022,10 +1027,11 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase): ...@@ -1022,10 +1027,11 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase):
preprocessed_inputs2 = model2.preprocess(inputs2) preprocessed_inputs2 = model2.preprocess(inputs2)
prediction_dict2 = model2.predict(preprocessed_inputs2) prediction_dict2 = model2.predict(preprocessed_inputs2)
model2.postprocess(prediction_dict2) model2.postprocess(prediction_dict2)
restore_fn = model2.restore_fn(saved_model_path, var_map = model2.restore_map(from_detection_checkpoint=True)
from_detection_checkpoint=True) self.assertIsInstance(var_map, dict)
saver = tf.train.Saver(var_map)
with self.test_session() as sess: with self.test_session() as sess:
restore_fn(sess) saver.restore(sess, saved_model_path)
for var in sess.run(tf.report_uninitialized_variables()): for var in sess.run(tf.report_uninitialized_variables()):
self.assertNotIn(model2.first_stage_feature_extractor_scope, var.name) self.assertNotIn(model2.first_stage_feature_extractor_scope, var.name)
self.assertNotIn(model2.second_stage_feature_extractor_scope, self.assertNotIn(model2.second_stage_feature_extractor_scope,
......
...@@ -29,7 +29,6 @@ from object_detection.core import box_predictor as bpredictor ...@@ -29,7 +29,6 @@ from object_detection.core import box_predictor as bpredictor
from object_detection.core import model from object_detection.core import model
from object_detection.core import standard_fields as fields from object_detection.core import standard_fields as fields
from object_detection.core import target_assigner from object_detection.core import target_assigner
from object_detection.utils import variables_helper
slim = tf.contrib.slim slim = tf.contrib.slim
...@@ -562,33 +561,26 @@ class SSDMetaArch(model.DetectionModel): ...@@ -562,33 +561,26 @@ class SSDMetaArch(model.DetectionModel):
decoded_boxlist_list=decoded_boxlist_list, decoded_boxlist_list=decoded_boxlist_list,
match_list=match_list) match_list=match_list)
def restore_fn(self, checkpoint_path, from_detection_checkpoint=True): def restore_map(self, from_detection_checkpoint=True):
"""Return callable for loading a checkpoint into the tensorflow graph. """Returns a map of variables to load from a foreign checkpoint.
See parent class for details.
Args: Args:
checkpoint_path: path to checkpoint to restore.
from_detection_checkpoint: whether to restore from a full detection from_detection_checkpoint: whether to restore from a full detection
checkpoint (with compatible variable names) or to restore from a checkpoint (with compatible variable names) or to restore from a
classification checkpoint for initialization prior to training. classification checkpoint for initialization prior to training.
Returns: Returns:
a callable which takes a tf.Session as input and loads a checkpoint when A dict mapping variable names (to load from a checkpoint) to variables in
run. the model graph.
""" """
variables_to_restore = {} variables_to_restore = {}
for variable in tf.all_variables(): for variable in tf.all_variables():
if variable.op.name.startswith(self._extract_features_scope): if variable.op.name.startswith(self._extract_features_scope):
var_name = variable.op.name var_name = variable.op.name
if not from_detection_checkpoint: if not from_detection_checkpoint:
var_name = ( var_name = (re.split('^' + self._extract_features_scope + '/',
re.split('^' + self._extract_features_scope + '/', var_name)[-1]) var_name)[-1])
variables_to_restore[var_name] = variable variables_to_restore[var_name] = variable
# TODO: Load variables selectively using scopes. return variables_to_restore
variables_to_restore = (
variables_helper.get_variables_available_in_checkpoint(
variables_to_restore, checkpoint_path))
saver = tf.train.Saver(variables_to_restore)
def restore(sess):
saver.restore(sess, checkpoint_path)
return restore
...@@ -207,20 +207,21 @@ class SsdMetaArchTest(tf.test.TestCase): ...@@ -207,20 +207,21 @@ class SsdMetaArchTest(tf.test.TestCase):
self.assertAllClose(losses_out['classification_loss'], self.assertAllClose(losses_out['classification_loss'],
expected_classification_loss) expected_classification_loss)
def test_restore_fn_detection(self): def test_restore_map_for_detection_ckpt(self):
init_op = tf.global_variables_initializer() init_op = tf.global_variables_initializer()
saver = tf_saver.Saver() saver = tf_saver.Saver()
save_path = self.get_temp_dir() save_path = self.get_temp_dir()
with self.test_session() as sess: with self.test_session() as sess:
sess.run(init_op) sess.run(init_op)
saved_model_path = saver.save(sess, save_path) saved_model_path = saver.save(sess, save_path)
restore_fn = self._model.restore_fn(saved_model_path, var_map = self._model.restore_map(from_detection_checkpoint=True)
from_detection_checkpoint=True) self.assertIsInstance(var_map, dict)
restore_fn(sess) saver = tf.train.Saver(var_map)
saver.restore(sess, saved_model_path)
for var in sess.run(tf.report_uninitialized_variables()): for var in sess.run(tf.report_uninitialized_variables()):
self.assertNotIn('FeatureExtractor', var.name) self.assertNotIn('FeatureExtractor', var.name)
def test_restore_fn_classification(self): def test_restore_map_for_classification_ckpt(self):
# Define mock tensorflow classification graph and save variables. # Define mock tensorflow classification graph and save variables.
test_graph_classification = tf.Graph() test_graph_classification = tf.Graph()
with test_graph_classification.as_default(): with test_graph_classification.as_default():
...@@ -246,10 +247,11 @@ class SsdMetaArchTest(tf.test.TestCase): ...@@ -246,10 +247,11 @@ class SsdMetaArchTest(tf.test.TestCase):
preprocessed_inputs = self._model.preprocess(inputs) preprocessed_inputs = self._model.preprocess(inputs)
prediction_dict = self._model.predict(preprocessed_inputs) prediction_dict = self._model.predict(preprocessed_inputs)
self._model.postprocess(prediction_dict) self._model.postprocess(prediction_dict)
restore_fn = self._model.restore_fn(saved_model_path, var_map = self._model.restore_map(from_detection_checkpoint=False)
from_detection_checkpoint=False) self.assertIsInstance(var_map, dict)
saver = tf.train.Saver(var_map)
with self.test_session() as sess: with self.test_session() as sess:
restore_fn(sess) saver.restore(sess, saved_model_path)
for var in sess.run(tf.report_uninitialized_variables()): for var in sess.run(tf.report_uninitialized_variables()):
self.assertNotIn('FeatureExtractor', var.name) self.assertNotIn('FeatureExtractor', var.name)
......
...@@ -94,7 +94,6 @@ py_library( ...@@ -94,7 +94,6 @@ py_library(
deps = [ deps = [
"//tensorflow", "//tensorflow",
"//tensorflow_models/object_detection/meta_architectures:faster_rcnn_meta_arch", "//tensorflow_models/object_detection/meta_architectures:faster_rcnn_meta_arch",
"//tensorflow_models/object_detection/utils:variables_helper",
"//tensorflow_models/slim:inception_resnet_v2", "//tensorflow_models/slim:inception_resnet_v2",
], ],
) )
......
...@@ -25,7 +25,6 @@ Huang et al. (https://arxiv.org/abs/1611.10012) ...@@ -25,7 +25,6 @@ Huang et al. (https://arxiv.org/abs/1611.10012)
import tensorflow as tf import tensorflow as tf
from object_detection.meta_architectures import faster_rcnn_meta_arch from object_detection.meta_architectures import faster_rcnn_meta_arch
from object_detection.utils import variables_helper
from nets import inception_resnet_v2 from nets import inception_resnet_v2
slim = tf.contrib.slim slim = tf.contrib.slim
...@@ -168,30 +167,30 @@ class FasterRCNNInceptionResnetV2FeatureExtractor( ...@@ -168,30 +167,30 @@ class FasterRCNNInceptionResnetV2FeatureExtractor(
def restore_from_classification_checkpoint_fn( def restore_from_classification_checkpoint_fn(
self, self,
checkpoint_path,
first_stage_feature_extractor_scope, first_stage_feature_extractor_scope,
second_stage_feature_extractor_scope): second_stage_feature_extractor_scope):
"""Returns callable for loading a checkpoint into the tensorflow graph. """Returns a map of variables to load from a foreign checkpoint.
Note that this overrides the default implementation in Note that this overrides the default implementation in
faster_rcnn_meta_arch.FasterRCNNFeatureExtractor which does not work for faster_rcnn_meta_arch.FasterRCNNFeatureExtractor which does not work for
InceptionResnetV2 checkpoints. InceptionResnetV2 checkpoints.
TODO: revisit whether it's possible to force the `Repeat` namescope as TODO(jonathanhuang,rathodv): revisit whether it's possible to force the
created in `_extract_box_classifier_features` to start counting at 2 (e.g. `Repeat` namescope as created in `_extract_box_classifier_features` to
`Repeat_2`) so that the default restore_fn can be used. start counting at 2 (e.g. `Repeat_2`) so that the default restore_fn can
be used.
Args: Args:
checkpoint_path: Path to checkpoint to restore.
first_stage_feature_extractor_scope: A scope name for the first stage first_stage_feature_extractor_scope: A scope name for the first stage
feature extractor. feature extractor.
second_stage_feature_extractor_scope: A scope name for the second stage second_stage_feature_extractor_scope: A scope name for the second stage
feature extractor. feature extractor.
Returns: Returns:
a callable which takes a tf.Session as input and loads a checkpoint when A dict mapping variable names (to load from a checkpoint) to variables in
run. the model graph.
""" """
variables_to_restore = {} variables_to_restore = {}
for variable in tf.global_variables(): for variable in tf.global_variables():
if variable.op.name.startswith( if variable.op.name.startswith(
...@@ -207,10 +206,4 @@ class FasterRCNNInceptionResnetV2FeatureExtractor( ...@@ -207,10 +206,4 @@ class FasterRCNNInceptionResnetV2FeatureExtractor(
var_name = var_name.replace( var_name = var_name.replace(
second_stage_feature_extractor_scope + '/', '') second_stage_feature_extractor_scope + '/', '')
variables_to_restore[var_name] = variable variables_to_restore[var_name] = variable
variables_to_restore = ( return variables_to_restore
variables_helper.get_variables_available_in_checkpoint(
variables_to_restore, checkpoint_path))
saver = tf.train.Saver(variables_to_restore)
def restore(sess):
saver.restore(sess, checkpoint_path)
return restore
...@@ -211,9 +211,14 @@ def train(create_tensor_dict_fn, create_model_fn, train_config, master, task, ...@@ -211,9 +211,14 @@ def train(create_tensor_dict_fn, create_model_fn, train_config, master, task,
# Create ops required to initialize the model from a given checkpoint. # Create ops required to initialize the model from a given checkpoint.
init_fn = None init_fn = None
if train_config.fine_tune_checkpoint: if train_config.fine_tune_checkpoint:
init_fn = detection_model.restore_fn( var_map = detection_model.restore_map(
train_config.fine_tune_checkpoint,
from_detection_checkpoint=train_config.from_detection_checkpoint) from_detection_checkpoint=train_config.from_detection_checkpoint)
var_map = variables_helper.get_variables_available_in_checkpoint(
var_map, train_config.fine_tune_checkpoint)
saver = tf.train.Saver(var_map)
def initializer_fn(sess):
saver.restore(sess, train_config.fine_tune_checkpoint)
init_fn = initializer_fn
with tf.device(deploy_config.optimizer_device()): with tf.device(deploy_config.optimizer_device()):
total_loss, grads_and_vars = model_deploy.optimize_clones( total_loss, grads_and_vars = model_deploy.optimize_clones(
......
...@@ -139,21 +139,18 @@ class FakeDetectionModel(model.DetectionModel): ...@@ -139,21 +139,18 @@ class FakeDetectionModel(model.DetectionModel):
} }
return loss_dict return loss_dict
def restore_fn(self, checkpoint_path, from_detection_checkpoint=True): def restore_map(self, from_detection_checkpoint=True):
"""Return callable for loading a checkpoint into the tensorflow graph. """Returns a map of variables to load from a foreign checkpoint.
Args: Args:
checkpoint_path: path to checkpoint to restore.
from_detection_checkpoint: whether to restore from a full detection from_detection_checkpoint: whether to restore from a full detection
checkpoint (with compatible variable names) or to restore from a checkpoint (with compatible variable names) or to restore from a
classification checkpoint for initialization prior to training. classification checkpoint for initialization prior to training.
Returns: Returns:
a callable which takes a tf.Session and does nothing. A dict mapping variable names to variables.
""" """
def restore(unused_sess): return {var.op.name: var for var in tf.global_variables()}
return
return restore
class TrainerTest(tf.test.TestCase): class TrainerTest(tf.test.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment