Unverified Commit c46caa56 authored by Jonathan Huang's avatar Jonathan Huang Committed by GitHub
Browse files

Merge pull request #2624 from tombstone/exporter_update

Updates to exporter modules.
parents 9a811d95 78cf0ae0
...@@ -77,6 +77,13 @@ flags = tf.app.flags ...@@ -77,6 +77,13 @@ flags = tf.app.flags
flags.DEFINE_string('input_type', 'image_tensor', 'Type of input node. Can be ' flags.DEFINE_string('input_type', 'image_tensor', 'Type of input node. Can be '
'one of [`image_tensor`, `encoded_image_string_tensor`, ' 'one of [`image_tensor`, `encoded_image_string_tensor`, '
'`tf_example`]') '`tf_example`]')
flags.DEFINE_list('input_shape', None,
'If input_type is `image_tensor`, this can explicitly set '
'the shape of this input tensor to a fixed size. The '
'dimensions are to be provided as a comma-separated list of '
'integers. A value of -1 can be used for unknown dimensions. '
'If not specified, for an `image_tensor, the default shape '
'will be partially specified as `[None, None, None, 3]`.')
flags.DEFINE_string('pipeline_config_path', None, flags.DEFINE_string('pipeline_config_path', None,
'Path to a pipeline_pb2.TrainEvalPipelineConfig config ' 'Path to a pipeline_pb2.TrainEvalPipelineConfig config '
'file.') 'file.')
...@@ -85,21 +92,25 @@ flags.DEFINE_string('trained_checkpoint_prefix', None, ...@@ -85,21 +92,25 @@ flags.DEFINE_string('trained_checkpoint_prefix', None,
'path/to/model.ckpt') 'path/to/model.ckpt')
flags.DEFINE_string('output_directory', None, 'Path to write outputs.') flags.DEFINE_string('output_directory', None, 'Path to write outputs.')
tf.app.flags.mark_flag_as_required('pipeline_config_path')
tf.app.flags.mark_flag_as_required('trained_checkpoint_prefix')
tf.app.flags.mark_flag_as_required('output_directory')
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
def main(_): def main(_):
assert FLAGS.pipeline_config_path, '`pipeline_config_path` is missing'
assert FLAGS.trained_checkpoint_prefix, (
'`trained_checkpoint_prefix` is missing')
assert FLAGS.output_directory, '`output_directory` is missing'
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
with tf.gfile.GFile(FLAGS.pipeline_config_path, 'r') as f: with tf.gfile.GFile(FLAGS.pipeline_config_path, 'r') as f:
text_format.Merge(f.read(), pipeline_config) text_format.Merge(f.read(), pipeline_config)
exporter.export_inference_graph( if FLAGS.input_shape:
FLAGS.input_type, pipeline_config, FLAGS.trained_checkpoint_prefix, input_shape = [
FLAGS.output_directory) int(dim) if dim != '-1' else None for dim in FLAGS.input_shape
]
else:
input_shape = None
exporter.export_inference_graph(FLAGS.input_type, pipeline_config,
FLAGS.trained_checkpoint_prefix,
FLAGS.output_directory, input_shape)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -16,12 +16,12 @@ ...@@ -16,12 +16,12 @@
"""Functions to export object detection inference graph.""" """Functions to export object detection inference graph."""
import logging import logging
import os import os
import tempfile
import tensorflow as tf import tensorflow as tf
from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python import pywrap_tensorflow from tensorflow.python import pywrap_tensorflow
from tensorflow.python.client import session from tensorflow.python.client import session
from tensorflow.python.framework import graph_util from tensorflow.python.framework import graph_util
from tensorflow.python.framework import importer
from tensorflow.python.platform import gfile from tensorflow.python.platform import gfile
from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import signature_constants
from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import saver as saver_lib
...@@ -43,7 +43,7 @@ def freeze_graph_with_def_protos( ...@@ -43,7 +43,7 @@ def freeze_graph_with_def_protos(
filename_tensor_name, filename_tensor_name,
clear_devices, clear_devices,
initializer_nodes, initializer_nodes,
optimize_graph=False, optimize_graph=True,
variable_names_blacklist=''): variable_names_blacklist=''):
"""Converts all variables in a graph and checkpoint into constants.""" """Converts all variables in a graph and checkpoint into constants."""
del restore_op_name, filename_tensor_name # Unused by updated loading code. del restore_op_name, filename_tensor_name # Unused by updated loading code.
...@@ -111,12 +111,37 @@ def freeze_graph_with_def_protos( ...@@ -111,12 +111,37 @@ def freeze_graph_with_def_protos(
return output_graph_def return output_graph_def
def replace_variable_values_with_moving_averages(graph,
current_checkpoint_file,
new_checkpoint_file):
"""Replaces variable values in the checkpoint with their moving averages.
def _image_tensor_input_placeholder(): If the current checkpoint has shadow variables maintaining moving averages of
"""Returns placeholder and input node that accepts a batch of uint8 images.""" the variables defined in the graph, this function generates a new checkpoint
input_tensor = tf.placeholder(dtype=tf.uint8, where the variables contain the values of their moving averages.
shape=(None, None, None, 3),
name='image_tensor') Args:
graph: a tf.Graph object.
current_checkpoint_file: a checkpoint containing both original variables and
their moving averages.
new_checkpoint_file: file path to write a new checkpoint.
"""
with graph.as_default():
variable_averages = tf.train.ExponentialMovingAverage(0.0)
ema_variables_to_restore = variable_averages.variables_to_restore()
with tf.Session() as sess:
read_saver = tf.train.Saver(ema_variables_to_restore)
read_saver.restore(sess, current_checkpoint_file)
write_saver = tf.train.Saver()
write_saver.save(sess, new_checkpoint_file)
def _image_tensor_input_placeholder(input_shape=None):
"""Returns input placeholder and a 4-D uint8 image tensor."""
if input_shape is None:
input_shape = (None, None, None, 3)
input_tensor = tf.placeholder(
dtype=tf.uint8, shape=input_shape, name='image_tensor')
return input_tensor, input_tensor return input_tensor, input_tensor
...@@ -124,7 +149,7 @@ def _tf_example_input_placeholder(): ...@@ -124,7 +149,7 @@ def _tf_example_input_placeholder():
"""Returns input that accepts a batch of strings with tf examples. """Returns input that accepts a batch of strings with tf examples.
Returns: Returns:
a tuple of placeholder and input nodes that output decoded images. a tuple of input placeholder and the output decoded images.
""" """
batch_tf_example_placeholder = tf.placeholder( batch_tf_example_placeholder = tf.placeholder(
tf.string, shape=[None], name='tf_example') tf.string, shape=[None], name='tf_example')
...@@ -145,7 +170,7 @@ def _encoded_image_string_tensor_input_placeholder(): ...@@ -145,7 +170,7 @@ def _encoded_image_string_tensor_input_placeholder():
"""Returns input that accepts a batch of PNG or JPEG strings. """Returns input that accepts a batch of PNG or JPEG strings.
Returns: Returns:
a tuple of placeholder and input nodes that output decoded images. a tuple of input placeholder and the output decoded images.
""" """
batch_image_str_placeholder = tf.placeholder( batch_image_str_placeholder = tf.placeholder(
dtype=tf.string, dtype=tf.string,
...@@ -301,7 +326,9 @@ def _export_inference_graph(input_type, ...@@ -301,7 +326,9 @@ def _export_inference_graph(input_type,
use_moving_averages, use_moving_averages,
trained_checkpoint_prefix, trained_checkpoint_prefix,
output_directory, output_directory,
optimize_graph=False, additional_output_tensor_names=None,
input_shape=None,
optimize_graph=True,
output_collection_name='inference_op'): output_collection_name='inference_op'):
"""Export helper.""" """Export helper."""
tf.gfile.MakeDirs(output_directory) tf.gfile.MakeDirs(output_directory)
...@@ -312,50 +339,69 @@ def _export_inference_graph(input_type, ...@@ -312,50 +339,69 @@ def _export_inference_graph(input_type,
if input_type not in input_placeholder_fn_map: if input_type not in input_placeholder_fn_map:
raise ValueError('Unknown input type: {}'.format(input_type)) raise ValueError('Unknown input type: {}'.format(input_type))
placeholder_tensor, input_tensors = input_placeholder_fn_map[input_type]() placeholder_args = {}
if input_shape is not None:
if input_type != 'image_tensor':
raise ValueError('Can only specify input shape for `image_tensor` '
'inputs.')
placeholder_args['input_shape'] = input_shape
placeholder_tensor, input_tensors = input_placeholder_fn_map[input_type](
**placeholder_args)
inputs = tf.to_float(input_tensors) inputs = tf.to_float(input_tensors)
preprocessed_inputs = detection_model.preprocess(inputs) preprocessed_inputs = detection_model.preprocess(inputs)
output_tensors = detection_model.predict(preprocessed_inputs) output_tensors = detection_model.predict(preprocessed_inputs)
postprocessed_tensors = detection_model.postprocess(output_tensors) postprocessed_tensors = detection_model.postprocess(output_tensors)
outputs = _add_output_tensor_nodes(postprocessed_tensors, outputs = _add_output_tensor_nodes(postprocessed_tensors,
output_collection_name) output_collection_name)
# Add global step to the graph.
slim.get_or_create_global_step()
saver = None
if use_moving_averages: if use_moving_averages:
variable_averages = tf.train.ExponentialMovingAverage(0.0) temp_checkpoint_file = tempfile.NamedTemporaryFile()
variables_to_restore = variable_averages.variables_to_restore() replace_variable_values_with_moving_averages(
saver = tf.train.Saver(variables_to_restore) tf.get_default_graph(), trained_checkpoint_prefix,
temp_checkpoint_file.name)
checkpoint_to_use = temp_checkpoint_file.name
else: else:
saver = tf.train.Saver() checkpoint_to_use = trained_checkpoint_prefix
saver = tf.train.Saver()
input_saver_def = saver.as_saver_def() input_saver_def = saver.as_saver_def()
_write_graph_and_checkpoint( _write_graph_and_checkpoint(
inference_graph_def=tf.get_default_graph().as_graph_def(), inference_graph_def=tf.get_default_graph().as_graph_def(),
model_path=model_path, model_path=model_path,
input_saver_def=input_saver_def, input_saver_def=input_saver_def,
trained_checkpoint_prefix=trained_checkpoint_prefix) trained_checkpoint_prefix=checkpoint_to_use)
if additional_output_tensor_names is not None:
output_node_names = ','.join(outputs.keys()+additional_output_tensor_names)
else:
output_node_names = ','.join(outputs.keys())
frozen_graph_def = freeze_graph_with_def_protos( frozen_graph_def = freeze_graph_with_def_protos(
input_graph_def=tf.get_default_graph().as_graph_def(), input_graph_def=tf.get_default_graph().as_graph_def(),
input_saver_def=input_saver_def, input_saver_def=input_saver_def,
input_checkpoint=trained_checkpoint_prefix, input_checkpoint=checkpoint_to_use,
output_node_names=','.join(outputs.keys()), output_node_names=output_node_names,
restore_op_name='save/restore_all', restore_op_name='save/restore_all',
filename_tensor_name='save/Const:0', filename_tensor_name='save/Const:0',
clear_devices=True, clear_devices=True,
optimize_graph=optimize_graph, optimize_graph=optimize_graph,
initializer_nodes='') initializer_nodes='')
_write_frozen_graph(frozen_graph_path, frozen_graph_def) _write_frozen_graph(frozen_graph_path, frozen_graph_def)
_write_saved_model(saved_model_path, frozen_graph_def, placeholder_tensor, _write_saved_model(saved_model_path, frozen_graph_def,
outputs) placeholder_tensor, outputs)
def export_inference_graph(input_type, def export_inference_graph(input_type,
pipeline_config, pipeline_config,
trained_checkpoint_prefix, trained_checkpoint_prefix,
output_directory, output_directory,
optimize_graph=False, input_shape=None,
output_collection_name='inference_op'): optimize_graph=True,
output_collection_name='inference_op',
additional_output_tensor_names=None):
"""Exports inference graph for the model specified in the pipeline config. """Exports inference graph for the model specified in the pipeline config.
Args: Args:
...@@ -364,13 +410,18 @@ def export_inference_graph(input_type, ...@@ -364,13 +410,18 @@ def export_inference_graph(input_type,
pipeline_config: pipeline_pb2.TrainAndEvalPipelineConfig proto. pipeline_config: pipeline_pb2.TrainAndEvalPipelineConfig proto.
trained_checkpoint_prefix: Path to the trained checkpoint file. trained_checkpoint_prefix: Path to the trained checkpoint file.
output_directory: Path to write outputs. output_directory: Path to write outputs.
input_shape: Sets a fixed shape for an `image_tensor` input. If not
specified, will default to [None, None, None, 3].
optimize_graph: Whether to optimize graph using Grappler. optimize_graph: Whether to optimize graph using Grappler.
output_collection_name: Name of collection to add output tensors to. output_collection_name: Name of collection to add output tensors to.
If None, does not add output tensors to a collection. If None, does not add output tensors to a collection.
additional_output_tensor_names: list of additional output
tensors to include in the frozen graph.
""" """
detection_model = model_builder.build(pipeline_config.model, detection_model = model_builder.build(pipeline_config.model,
is_training=False) is_training=False)
_export_inference_graph(input_type, detection_model, _export_inference_graph(input_type, detection_model,
pipeline_config.eval_config.use_moving_averages, pipeline_config.eval_config.use_moving_averages,
trained_checkpoint_prefix, output_directory, trained_checkpoint_prefix,
optimize_graph, output_collection_name) output_directory, additional_output_tensor_names,
input_shape, optimize_graph, output_collection_name)
...@@ -28,6 +28,8 @@ if six.PY2: ...@@ -28,6 +28,8 @@ if six.PY2:
else: else:
from unittest import mock # pylint: disable=g-import-not-at-top from unittest import mock # pylint: disable=g-import-not-at-top
slim = tf.contrib.slim
class FakeModel(model.DetectionModel): class FakeModel(model.DetectionModel):
...@@ -78,6 +80,7 @@ class ExportInferenceGraphTest(tf.test.TestCase): ...@@ -78,6 +80,7 @@ class ExportInferenceGraphTest(tf.test.TestCase):
mock_model.postprocess(predictions) mock_model.postprocess(predictions)
if use_moving_averages: if use_moving_averages:
tf.train.ExponentialMovingAverage(0.0).apply() tf.train.ExponentialMovingAverage(0.0).apply()
slim.get_or_create_global_step()
saver = tf.train.Saver() saver = tf.train.Saver()
init = tf.global_variables_initializer() init = tf.global_variables_initializer()
with self.test_session() as sess: with self.test_session() as sess:
...@@ -122,6 +125,41 @@ class ExportInferenceGraphTest(tf.test.TestCase): ...@@ -122,6 +125,41 @@ class ExportInferenceGraphTest(tf.test.TestCase):
pipeline_config=pipeline_config, pipeline_config=pipeline_config,
trained_checkpoint_prefix=trained_checkpoint_prefix, trained_checkpoint_prefix=trained_checkpoint_prefix,
output_directory=output_directory) output_directory=output_directory)
self.assertTrue(os.path.exists(os.path.join(
output_directory, 'saved_model', 'saved_model.pb')))
def test_export_graph_with_fixed_size_image_tensor_input(self):
input_shape = [1, 320, 320, 3]
tmp_dir = self.get_temp_dir()
trained_checkpoint_prefix = os.path.join(tmp_dir, 'model.ckpt')
self._save_checkpoint_from_mock_model(
trained_checkpoint_prefix, use_moving_averages=False)
with mock.patch.object(
model_builder, 'build', autospec=True) as mock_builder:
mock_builder.return_value = FakeModel()
output_directory = os.path.join(tmp_dir, 'output')
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
pipeline_config.eval_config.use_moving_averages = False
exporter.export_inference_graph(
input_type='image_tensor',
pipeline_config=pipeline_config,
trained_checkpoint_prefix=trained_checkpoint_prefix,
output_directory=output_directory,
input_shape=input_shape)
saved_model_path = os.path.join(output_directory, 'saved_model')
self.assertTrue(
os.path.exists(os.path.join(saved_model_path, 'saved_model.pb')))
with tf.Graph().as_default() as od_graph:
with self.test_session(graph=od_graph) as sess:
meta_graph = tf.saved_model.loader.load(
sess, [tf.saved_model.tag_constants.SERVING], saved_model_path)
signature = meta_graph.signature_def['serving_default']
input_tensor_name = signature.inputs['inputs'].name
image_tensor = od_graph.get_tensor_by_name(input_tensor_name)
self.assertSequenceEqual(image_tensor.get_shape().as_list(),
input_shape)
def test_export_graph_with_tf_example_input(self): def test_export_graph_with_tf_example_input(self):
tmp_dir = self.get_temp_dir() tmp_dir = self.get_temp_dir()
...@@ -139,6 +177,8 @@ class ExportInferenceGraphTest(tf.test.TestCase): ...@@ -139,6 +177,8 @@ class ExportInferenceGraphTest(tf.test.TestCase):
pipeline_config=pipeline_config, pipeline_config=pipeline_config,
trained_checkpoint_prefix=trained_checkpoint_prefix, trained_checkpoint_prefix=trained_checkpoint_prefix,
output_directory=output_directory) output_directory=output_directory)
self.assertTrue(os.path.exists(os.path.join(
output_directory, 'saved_model', 'saved_model.pb')))
def test_export_graph_with_encoded_image_string_input(self): def test_export_graph_with_encoded_image_string_input(self):
tmp_dir = self.get_temp_dir() tmp_dir = self.get_temp_dir()
...@@ -156,6 +196,44 @@ class ExportInferenceGraphTest(tf.test.TestCase): ...@@ -156,6 +196,44 @@ class ExportInferenceGraphTest(tf.test.TestCase):
pipeline_config=pipeline_config, pipeline_config=pipeline_config,
trained_checkpoint_prefix=trained_checkpoint_prefix, trained_checkpoint_prefix=trained_checkpoint_prefix,
output_directory=output_directory) output_directory=output_directory)
self.assertTrue(os.path.exists(os.path.join(
output_directory, 'saved_model', 'saved_model.pb')))
def _get_variables_in_checkpoint(self, checkpoint_file):
return set([
var_name
for var_name, _ in tf.train.list_variables(checkpoint_file)])
def test_replace_variable_values_with_moving_averages(self):
tmp_dir = self.get_temp_dir()
trained_checkpoint_prefix = os.path.join(tmp_dir, 'model.ckpt')
new_checkpoint_prefix = os.path.join(tmp_dir, 'new.ckpt')
self._save_checkpoint_from_mock_model(trained_checkpoint_prefix,
use_moving_averages=True)
graph = tf.Graph()
with graph.as_default():
fake_model = FakeModel()
preprocessed_inputs = fake_model.preprocess(
tf.placeholder(dtype=tf.float32, shape=[None, None, None, 3]))
predictions = fake_model.predict(preprocessed_inputs)
fake_model.postprocess(predictions)
exporter.replace_variable_values_with_moving_averages(
graph, trained_checkpoint_prefix, new_checkpoint_prefix)
expected_variables = set(['conv2d/bias', 'conv2d/kernel'])
variables_in_old_ckpt = self._get_variables_in_checkpoint(
trained_checkpoint_prefix)
self.assertIn('conv2d/bias/ExponentialMovingAverage',
variables_in_old_ckpt)
self.assertIn('conv2d/kernel/ExponentialMovingAverage',
variables_in_old_ckpt)
variables_in_new_ckpt = self._get_variables_in_checkpoint(
new_checkpoint_prefix)
self.assertTrue(expected_variables.issubset(variables_in_new_ckpt))
self.assertNotIn('conv2d/bias/ExponentialMovingAverage',
variables_in_new_ckpt)
self.assertNotIn('conv2d/kernel/ExponentialMovingAverage',
variables_in_new_ckpt)
def test_export_graph_with_moving_averages(self): def test_export_graph_with_moving_averages(self):
tmp_dir = self.get_temp_dir() tmp_dir = self.get_temp_dir()
...@@ -173,6 +251,12 @@ class ExportInferenceGraphTest(tf.test.TestCase): ...@@ -173,6 +251,12 @@ class ExportInferenceGraphTest(tf.test.TestCase):
pipeline_config=pipeline_config, pipeline_config=pipeline_config,
trained_checkpoint_prefix=trained_checkpoint_prefix, trained_checkpoint_prefix=trained_checkpoint_prefix,
output_directory=output_directory) output_directory=output_directory)
self.assertTrue(os.path.exists(os.path.join(
output_directory, 'saved_model', 'saved_model.pb')))
expected_variables = set(['conv2d/bias', 'conv2d/kernel', 'global_step'])
actual_variables = set(
[var_name for var_name, _ in tf.train.list_variables(output_directory)])
self.assertTrue(expected_variables.issubset(actual_variables))
def test_export_model_with_all_output_nodes(self): def test_export_model_with_all_output_nodes(self):
tmp_dir = self.get_temp_dir() tmp_dir = self.get_temp_dir()
...@@ -434,14 +518,24 @@ class ExportInferenceGraphTest(tf.test.TestCase): ...@@ -434,14 +518,24 @@ class ExportInferenceGraphTest(tf.test.TestCase):
np.ones((4, 4, 3)).astype(np.uint8))] * 2) np.ones((4, 4, 3)).astype(np.uint8))] * 2)
with tf.Graph().as_default() as od_graph: with tf.Graph().as_default() as od_graph:
with self.test_session(graph=od_graph) as sess: with self.test_session(graph=od_graph) as sess:
tf.saved_model.loader.load( meta_graph = tf.saved_model.loader.load(
sess, [tf.saved_model.tag_constants.SERVING], saved_model_path) sess, [tf.saved_model.tag_constants.SERVING], saved_model_path)
tf_example = od_graph.get_tensor_by_name('tf_example:0')
boxes = od_graph.get_tensor_by_name('detection_boxes:0') signature = meta_graph.signature_def['serving_default']
scores = od_graph.get_tensor_by_name('detection_scores:0') input_tensor_name = signature.inputs['inputs'].name
classes = od_graph.get_tensor_by_name('detection_classes:0') tf_example = od_graph.get_tensor_by_name(input_tensor_name)
masks = od_graph.get_tensor_by_name('detection_masks:0')
num_detections = od_graph.get_tensor_by_name('num_detections:0') boxes = od_graph.get_tensor_by_name(
signature.outputs['detection_boxes'].name)
scores = od_graph.get_tensor_by_name(
signature.outputs['detection_scores'].name)
classes = od_graph.get_tensor_by_name(
signature.outputs['detection_classes'].name)
masks = od_graph.get_tensor_by_name(
signature.outputs['detection_masks'].name)
num_detections = od_graph.get_tensor_by_name(
signature.outputs['num_detections'].name)
(boxes_np, scores_np, classes_np, masks_np, (boxes_np, scores_np, classes_np, masks_np,
num_detections_np) = sess.run( num_detections_np) = sess.run(
[boxes, scores, classes, masks, num_detections], [boxes, scores, classes, masks, num_detections],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment