Commit d9d10fbb authored by Derek Chow's avatar Derek Chow
Browse files

Add capability to export as SavedModel in exporter script.

parent 5196648b
...@@ -16,7 +16,8 @@ ...@@ -16,7 +16,8 @@
r"""Tool to export an object detection model for inference. r"""Tool to export an object detection model for inference.
Prepares an object detection tensorflow graph for inference using model Prepares an object detection tensorflow graph for inference using model
configuration and an optional trained checkpoint. configuration and an optional trained checkpoint. Outputs either an inference
graph or a SavedModel (https://tensorflow.github.io/serving/serving_basic.html).
The inference graph contains one of three input nodes depending on the user The inference graph contains one of three input nodes depending on the user
specified option. specified option.
...@@ -77,6 +78,8 @@ flags.DEFINE_string('checkpoint_path', '', 'Optional path to checkpoint file. ' ...@@ -77,6 +78,8 @@ flags.DEFINE_string('checkpoint_path', '', 'Optional path to checkpoint file. '
'the graph.') 'the graph.')
flags.DEFINE_string('inference_graph_path', '', 'Path to write the output ' flags.DEFINE_string('inference_graph_path', '', 'Path to write the output '
'inference graph.') 'inference graph.')
flags.DEFINE_bool('export_as_saved_model', False, 'Whether the exported graph '
'should be saved as a SavedModel')
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -90,7 +93,8 @@ def main(_): ...@@ -90,7 +93,8 @@ def main(_):
text_format.Merge(f.read(), pipeline_config) text_format.Merge(f.read(), pipeline_config)
exporter.export_inference_graph(FLAGS.input_type, pipeline_config, exporter.export_inference_graph(FLAGS.input_type, pipeline_config,
FLAGS.checkpoint_path, FLAGS.checkpoint_path,
FLAGS.inference_graph_path) FLAGS.inference_graph_path,
FLAGS.export_as_saved_model)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -22,6 +22,7 @@ from tensorflow.python.client import session ...@@ -22,6 +22,7 @@ from tensorflow.python.client import session
from tensorflow.python.framework import graph_util from tensorflow.python.framework import graph_util
from tensorflow.python.framework import importer from tensorflow.python.framework import importer
from tensorflow.python.platform import gfile from tensorflow.python.platform import gfile
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import saver as saver_lib
from object_detection.builders import model_builder from object_detection.builders import model_builder
from object_detection.core import standard_fields as fields from object_detection.core import standard_fields as fields
...@@ -39,7 +40,6 @@ def freeze_graph_with_def_protos( ...@@ -39,7 +40,6 @@ def freeze_graph_with_def_protos(
output_node_names, output_node_names,
restore_op_name, restore_op_name,
filename_tensor_name, filename_tensor_name,
output_graph,
clear_devices, clear_devices,
initializer_nodes, initializer_nodes,
variable_names_blacklist=''): variable_names_blacklist=''):
...@@ -92,9 +92,30 @@ def freeze_graph_with_def_protos( ...@@ -92,9 +92,30 @@ def freeze_graph_with_def_protos(
output_node_names.split(','), output_node_names.split(','),
variable_names_blacklist=variable_names_blacklist) variable_names_blacklist=variable_names_blacklist)
with gfile.GFile(output_graph, 'wb') as f: return output_graph_def
f.write(output_graph_def.SerializeToString())
logging.info('%d ops in the final graph.', len(output_graph_def.node))
def get_frozen_graph_def(inference_graph_def, use_moving_averages,
input_checkpoint, output_node_names):
"""Freezes all variables in a graph definition."""
saver = None
if use_moving_averages:
variable_averages = tf.train.ExponentialMovingAverage(0.0)
variables_to_restore = variable_averages.variables_to_restore()
saver = tf.train.Saver(variables_to_restore)
else:
saver = tf.train.Saver()
frozen_graph_def = freeze_graph_with_def_protos(
input_graph_def=inference_graph_def,
input_saver_def=saver.as_saver_def(),
input_checkpoint=input_checkpoint,
output_node_names=output_node_names,
restore_op_name='save/restore_all',
filename_tensor_name='save/Const:0',
clear_devices=True,
initializer_nodes='')
return frozen_graph_def
# TODO: Support batch tf example inputs. # TODO: Support batch tf example inputs.
...@@ -153,6 +174,9 @@ def _add_output_tensor_nodes(postprocessed_tensors): ...@@ -153,6 +174,9 @@ def _add_output_tensor_nodes(postprocessed_tensors):
'detection_masks': [batch, max_detections, mask_height, mask_width] 'detection_masks': [batch, max_detections, mask_height, mask_width]
(optional). (optional).
'num_detections': [batch] 'num_detections': [batch]
Returns:
A tensor dict containing the added output tensor nodes.
""" """
label_id_offset = 1 label_id_offset = 1
boxes = postprocessed_tensors.get('detection_boxes') boxes = postprocessed_tensors.get('detection_boxes')
...@@ -160,12 +184,14 @@ def _add_output_tensor_nodes(postprocessed_tensors): ...@@ -160,12 +184,14 @@ def _add_output_tensor_nodes(postprocessed_tensors):
classes = postprocessed_tensors.get('detection_classes') + label_id_offset classes = postprocessed_tensors.get('detection_classes') + label_id_offset
masks = postprocessed_tensors.get('detection_masks') masks = postprocessed_tensors.get('detection_masks')
num_detections = postprocessed_tensors.get('num_detections') num_detections = postprocessed_tensors.get('num_detections')
tf.identity(boxes, name='detection_boxes') outputs = {}
tf.identity(scores, name='detection_scores') outputs['detection_boxes'] = tf.identity(boxes, name='detection_boxes')
tf.identity(classes, name='detection_classes') outputs['detection_scores'] = tf.identity(scores, name='detection_scores')
tf.identity(num_detections, name='num_detections') outputs['detection_classes'] = tf.identity(classes, name='detection_classes')
outputs['num_detections'] = tf.identity(num_detections, name='num_detections')
if masks is not None: if masks is not None:
tf.identity(masks, name='detection_masks') outputs['detection_masks'] = tf.identity(masks, name='detection_masks')
return outputs
def _write_inference_graph(inference_graph_path, def _write_inference_graph(inference_graph_path,
...@@ -192,23 +218,17 @@ def _write_inference_graph(inference_graph_path, ...@@ -192,23 +218,17 @@ def _write_inference_graph(inference_graph_path,
""" """
inference_graph_def = tf.get_default_graph().as_graph_def() inference_graph_def = tf.get_default_graph().as_graph_def()
if checkpoint_path: if checkpoint_path:
saver = None output_graph_def = get_frozen_graph_def(
if use_moving_averages: inference_graph_def=inference_graph_def,
variable_averages = tf.train.ExponentialMovingAverage(0.0) use_moving_averages=use_moving_averages,
variables_to_restore = variable_averages.variables_to_restore()
saver = tf.train.Saver(variables_to_restore)
else:
saver = tf.train.Saver()
freeze_graph_with_def_protos(
input_graph_def=inference_graph_def,
input_saver_def=saver.as_saver_def(),
input_checkpoint=checkpoint_path, input_checkpoint=checkpoint_path,
output_node_names=output_node_names, output_node_names=output_node_names,
restore_op_name='save/restore_all', )
filename_tensor_name='save/Const:0',
output_graph=inference_graph_path, with gfile.GFile(inference_graph_path, 'wb') as f:
clear_devices=True, f.write(output_graph_def.SerializeToString())
initializer_nodes='') logging.info('%d ops in the final graph.', len(output_graph_def.node))
return return
tf.train.write_graph(inference_graph_def, tf.train.write_graph(inference_graph_def,
os.path.dirname(inference_graph_path), os.path.dirname(inference_graph_path),
...@@ -216,11 +236,70 @@ def _write_inference_graph(inference_graph_path, ...@@ -216,11 +236,70 @@ def _write_inference_graph(inference_graph_path,
as_text=False) as_text=False)
def _write_saved_model(inference_graph_path, inputs, outputs,
checkpoint_path=None, use_moving_averages=False):
"""Writes SavedModel to disk.
If checkpoint_path is not None bakes the weights into the graph thereby
eliminating the need of checkpoint files during inference. If the model
was trained with moving averages, setting use_moving_averages to true
restores the moving averages, otherwise the original set of variables
is restored.
Args:
inference_graph_path: Path to write inference graph.
inputs: The input image tensor to use for detection.
outputs: A tensor dictionary containing the outputs of a DetectionModel.
checkpoint_path: Optional path to the checkpoint file.
use_moving_averages: Whether to export the original or the moving averages
of the trainable variables from the checkpoint.
"""
inference_graph_def = tf.get_default_graph().as_graph_def()
checkpoint_graph_def = None
if checkpoint_path:
output_node_names = ','.join(outputs.keys())
checkpoint_graph_def = get_frozen_graph_def(
inference_graph_def=inference_graph_def,
use_moving_averages=use_moving_averages,
input_checkpoint=checkpoint_path,
output_node_names=output_node_names
)
with tf.Graph().as_default():
with session.Session() as sess:
tf.import_graph_def(checkpoint_graph_def)
builder = tf.saved_model.builder.SavedModelBuilder(inference_graph_path)
tensor_info_inputs = {
'inputs': tf.saved_model.utils.build_tensor_info(inputs)}
tensor_info_outputs = {}
for k, v in outputs.items():
tensor_info_outputs[k] = tf.saved_model.utils.build_tensor_info(v)
detection_signature = (
tf.saved_model.signature_def_utils.build_signature_def(
inputs=tensor_info_inputs,
outputs=tensor_info_outputs,
method_name=signature_constants.PREDICT_METHOD_NAME))
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map={
'signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY':
detection_signature,
},
)
builder.save()
def _export_inference_graph(input_type, def _export_inference_graph(input_type,
detection_model, detection_model,
use_moving_averages, use_moving_averages,
checkpoint_path, checkpoint_path,
inference_graph_path): inference_graph_path,
export_as_saved_model=False):
"""Export helper.""" """Export helper."""
if input_type not in input_placeholder_fn_map: if input_type not in input_placeholder_fn_map:
raise ValueError('Unknown input type: {}'.format(input_type)) raise ValueError('Unknown input type: {}'.format(input_type))
...@@ -228,18 +307,19 @@ def _export_inference_graph(input_type, ...@@ -228,18 +307,19 @@ def _export_inference_graph(input_type,
preprocessed_inputs = detection_model.preprocess(inputs) preprocessed_inputs = detection_model.preprocess(inputs)
output_tensors = detection_model.predict(preprocessed_inputs) output_tensors = detection_model.predict(preprocessed_inputs)
postprocessed_tensors = detection_model.postprocess(output_tensors) postprocessed_tensors = detection_model.postprocess(output_tensors)
_add_output_tensor_nodes(postprocessed_tensors) outputs = _add_output_tensor_nodes(postprocessed_tensors)
out_node_names = ['num_detections', 'detection_scores,' out_node_names = list(outputs.keys())
'detection_boxes', 'detection_classes'] if export_as_saved_model:
if 'detection_masks' in postprocessed_tensors: _write_saved_model(inference_graph_path, inputs, outputs, checkpoint_path,
out_node_names.append('detection_masks') use_moving_averages)
_write_inference_graph(inference_graph_path, checkpoint_path, else:
use_moving_averages, _write_inference_graph(inference_graph_path, checkpoint_path,
output_node_names=','.join(out_node_names)) use_moving_averages,
output_node_names=','.join(out_node_names))
def export_inference_graph(input_type, pipeline_config, checkpoint_path, def export_inference_graph(input_type, pipeline_config, checkpoint_path,
inference_graph_path): inference_graph_path, export_as_saved_model=False):
"""Exports inference graph for the model specified in the pipeline config. """Exports inference graph for the model specified in the pipeline config.
Args: Args:
...@@ -248,9 +328,12 @@ def export_inference_graph(input_type, pipeline_config, checkpoint_path, ...@@ -248,9 +328,12 @@ def export_inference_graph(input_type, pipeline_config, checkpoint_path,
pipeline_config: pipeline_pb2.TrainAndEvalPipelineConfig proto. pipeline_config: pipeline_pb2.TrainAndEvalPipelineConfig proto.
checkpoint_path: Path to the checkpoint file to freeze. checkpoint_path: Path to the checkpoint file to freeze.
inference_graph_path: Path to write inference graph to. inference_graph_path: Path to write inference graph to.
export_as_saved_model: If the model should be exported as a SavedModel. If
false, it is saved as an inference graph.
""" """
detection_model = model_builder.build(pipeline_config.model, detection_model = model_builder.build(pipeline_config.model,
is_training=False) is_training=False)
_export_inference_graph(input_type, detection_model, _export_inference_graph(input_type, detection_model,
pipeline_config.eval_config.use_moving_averages, pipeline_config.eval_config.use_moving_averages,
checkpoint_path, inference_graph_path) checkpoint_path, inference_graph_path,
export_as_saved_model)
...@@ -15,14 +15,19 @@ ...@@ -15,14 +15,19 @@
"""Tests for object_detection.export_inference_graph.""" """Tests for object_detection.export_inference_graph."""
import os import os
import mock
import numpy as np import numpy as np
import six
import tensorflow as tf import tensorflow as tf
from object_detection import exporter from object_detection import exporter
from object_detection.builders import model_builder from object_detection.builders import model_builder
from object_detection.core import model from object_detection.core import model
from object_detection.protos import pipeline_pb2 from object_detection.protos import pipeline_pb2
if six.PY2:
import mock # pylint: disable=g-import-not-at-top
else:
from unittest import mock # pylint: disable=g-import-not-at-top
class FakeModel(model.DetectionModel): class FakeModel(model.DetectionModel):
...@@ -348,6 +353,45 @@ class ExportInferenceGraphTest(tf.test.TestCase): ...@@ -348,6 +353,45 @@ class ExportInferenceGraphTest(tf.test.TestCase):
self.assertAllClose(masks, np.arange(32).reshape([2, 4, 4])) self.assertAllClose(masks, np.arange(32).reshape([2, 4, 4]))
self.assertAllClose(num_detections, [2]) self.assertAllClose(num_detections, [2])
def test_export_saved_model_and_run_inference(self):
checkpoint_path = os.path.join(self.get_temp_dir(), 'model-ckpt')
self._save_checkpoint_from_mock_model(checkpoint_path,
use_moving_averages=False)
inference_graph_path = os.path.join(self.get_temp_dir(),
'saved_model')
with mock.patch.object(
model_builder, 'build', autospec=True) as mock_builder:
mock_builder.return_value = FakeModel(add_detection_masks=True)
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
pipeline_config.eval_config.use_moving_averages = False
exporter.export_inference_graph(
input_type='tf_example',
pipeline_config=pipeline_config,
checkpoint_path=checkpoint_path,
inference_graph_path=inference_graph_path,
export_as_saved_model=True)
with tf.Graph().as_default() as od_graph:
with self.test_session(graph=od_graph) as sess:
tf.saved_model.loader.load(
sess, [tf.saved_model.tag_constants.SERVING], inference_graph_path)
tf_example = od_graph.get_tensor_by_name('import/tf_example:0')
boxes = od_graph.get_tensor_by_name('import/detection_boxes:0')
scores = od_graph.get_tensor_by_name('import/detection_scores:0')
classes = od_graph.get_tensor_by_name('import/detection_classes:0')
masks = od_graph.get_tensor_by_name('import/detection_masks:0')
num_detections = od_graph.get_tensor_by_name('import/num_detections:0')
(boxes, scores, classes, masks, num_detections) = sess.run(
[boxes, scores, classes, masks, num_detections],
feed_dict={tf_example: self._create_tf_example(
np.ones((4, 4, 3)).astype(np.uint8))})
self.assertAllClose(boxes, [[0.0, 0.0, 0.5, 0.5],
[0.5, 0.5, 0.8, 0.8]])
self.assertAllClose(scores, [[0.7, 0.6]])
self.assertAllClose(classes, [[1, 2]])
self.assertAllClose(masks, np.arange(32).reshape([2, 4, 4]))
self.assertAllClose(num_detections, [2])
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment