Commit f33ffcc2 authored by Vivek Rathod's avatar Vivek Rathod
Browse files

Add option to export graph with input node that accepts encoded jpeg or png string

parent 977160a9
...@@ -18,21 +18,27 @@ r"""Tool to export an object detection model for inference. ...@@ -18,21 +18,27 @@ r"""Tool to export an object detection model for inference.
Prepares an object detection tensorflow graph for inference using model Prepares an object detection tensorflow graph for inference using model
configuration and an optional trained checkpoint. configuration and an optional trained checkpoint.
The inference graph contains one of two input nodes depending on the user The inference graph contains one of three input nodes depending on the user
specified option. specified option.
* `image_tensor`: Accepts a uint8 4-D tensor of shape [1, None, None, 3] * `image_tensor`: Accepts a uint8 4-D tensor of shape [1, None, None, 3]
* `encoded_image_string_tensor`: Accepts a scalar string tensor of encoded PNG
or JPEG image.
* `tf_example`: Accepts a serialized TFExample proto. The batch size in this * `tf_example`: Accepts a serialized TFExample proto. The batch size in this
case is always 1. case is always 1.
and the following output nodes: and the following output nodes returned by the model.postprocess(..):
* `num_detections` : Outputs float32 tensors of the form [batch] * `num_detections`: Outputs float32 tensors of the form [batch]
that specifies the number of valid boxes per image in the batch. that specifies the number of valid boxes per image in the batch.
* `detection_boxes` : Outputs float32 tensors of the form * `detection_boxes`: Outputs float32 tensors of the form
[batch, num_boxes, 4] containing detected boxes. [batch, num_boxes, 4] containing detected boxes.
* `detection_scores` : Outputs float32 tensors of the form * `detection_scores`: Outputs float32 tensors of the form
[batch, num_boxes] containing class scores for the detections. [batch, num_boxes] containing class scores for the detections.
* `detection_classes`: Outputs float32 tensors of the form * `detection_classes`: Outputs float32 tensors of the form
[batch, num_boxes] containing classes for the detections. [batch, num_boxes] containing classes for the detections.
* `detection_masks`: Outputs float32 tensors of the form
[batch, num_boxes, mask_height, mask_width] containing predicted instance
masks for each box if its present in the dictionary of postprocessed
tensors returned by the model.
Note that currently `batch` is always 1, but we will support `batch` > 1 in Note that currently `batch` is always 1, but we will support `batch` > 1 in
the future. the future.
...@@ -61,7 +67,8 @@ slim = tf.contrib.slim ...@@ -61,7 +67,8 @@ slim = tf.contrib.slim
flags = tf.app.flags flags = tf.app.flags
flags.DEFINE_string('input_type', 'image_tensor', 'Type of input node. Can be ' flags.DEFINE_string('input_type', 'image_tensor', 'Type of input node. Can be '
'one of [`image_tensor` `tf_example_proto`]') 'one of [`image_tensor`, `encoded_image_string_tensor`, '
'`tf_example`]')
flags.DEFINE_string('pipeline_config_path', '', flags.DEFINE_string('pipeline_config_path', '',
'Path to a pipeline_pb2.TrainEvalPipelineConfig config ' 'Path to a pipeline_pb2.TrainEvalPipelineConfig config '
'file.') 'file.')
......
...@@ -30,8 +30,8 @@ from object_detection.data_decoders import tf_example_decoder ...@@ -30,8 +30,8 @@ from object_detection.data_decoders import tf_example_decoder
slim = tf.contrib.slim slim = tf.contrib.slim
# TODO: Replace with freeze_graph.freeze_graph_with_def_protos when newer # TODO: Replace with freeze_graph.freeze_graph_with_def_protos when
# version of Tensorflow becomes more common. # newer version of Tensorflow becomes more common.
def freeze_graph_with_def_protos( def freeze_graph_with_def_protos(
input_graph_def, input_graph_def,
input_saver_def, input_saver_def,
...@@ -48,12 +48,12 @@ def freeze_graph_with_def_protos( ...@@ -48,12 +48,12 @@ def freeze_graph_with_def_protos(
# 'input_checkpoint' may be a prefix if we're using Saver V2 format # 'input_checkpoint' may be a prefix if we're using Saver V2 format
if not saver_lib.checkpoint_exists(input_checkpoint): if not saver_lib.checkpoint_exists(input_checkpoint):
logging.info('Input checkpoint "' + input_checkpoint + '" does not exist!') raise ValueError(
return -1 'Input checkpoint "' + input_checkpoint + '" does not exist!')
if not output_node_names: if not output_node_names:
logging.info('You must supply the name of a node to --output_node_names.') raise ValueError(
return -1 'You must supply the name of a node to --output_node_names.')
# Remove all the explicit device specifications for this node. This helps to # Remove all the explicit device specifications for this node. This helps to
# make the graph more portable. # make the graph more portable.
...@@ -101,7 +101,7 @@ def freeze_graph_with_def_protos( ...@@ -101,7 +101,7 @@ def freeze_graph_with_def_protos(
def _tf_example_input_placeholder(): def _tf_example_input_placeholder():
tf_example_placeholder = tf.placeholder( tf_example_placeholder = tf.placeholder(
tf.string, shape=[], name='tf_example') tf.string, shape=[], name='tf_example')
tensor_dict = tf_example_decoder.TfExampleDecoder().Decode( tensor_dict = tf_example_decoder.TfExampleDecoder().decode(
tf_example_placeholder) tf_example_placeholder)
image = tensor_dict[fields.InputDataFields.image] image = tensor_dict[fields.InputDataFields.image]
return tf.expand_dims(image, axis=0) return tf.expand_dims(image, axis=0)
...@@ -112,9 +112,21 @@ def _image_tensor_input_placeholder(): ...@@ -112,9 +112,21 @@ def _image_tensor_input_placeholder():
shape=(1, None, None, 3), shape=(1, None, None, 3),
name='image_tensor') name='image_tensor')
def _encoded_image_string_tensor_input_placeholder():
image_str = tf.placeholder(dtype=tf.string,
shape=[],
name='encoded_image_string_tensor')
image_tensor = tf.image.decode_image(image_str, channels=3)
image_tensor.set_shape((None, None, 3))
return tf.expand_dims(image_tensor, axis=0)
input_placeholder_fn_map = { input_placeholder_fn_map = {
'image_tensor': _image_tensor_input_placeholder,
'encoded_image_string_tensor':
_encoded_image_string_tensor_input_placeholder,
'tf_example': _tf_example_input_placeholder, 'tf_example': _tf_example_input_placeholder,
'image_tensor': _image_tensor_input_placeholder
} }
...@@ -129,23 +141,31 @@ def _add_output_tensor_nodes(postprocessed_tensors): ...@@ -129,23 +141,31 @@ def _add_output_tensor_nodes(postprocessed_tensors):
containing scores for the detected boxes. containing scores for the detected boxes.
* detection_classes: float32 tensor of shape [batch_size, num_boxes] * detection_classes: float32 tensor of shape [batch_size, num_boxes]
containing class predictions for the detected boxes. containing class predictions for the detected boxes.
* detection_masks: (Optional) float32 tensor of shape
[batch_size, num_boxes, mask_height, mask_width] containing masks for each
detection box.
Args: Args:
postprocessed_tensors: a dictionary containing the following fields postprocessed_tensors: a dictionary containing the following fields
'detection_boxes': [batch, max_detections, 4] 'detection_boxes': [batch, max_detections, 4]
'detection_scores': [batch, max_detections] 'detection_scores': [batch, max_detections]
'detection_classes': [batch, max_detections] 'detection_classes': [batch, max_detections]
'detection_masks': [batch, max_detections, mask_height, mask_width]
(optional).
'num_detections': [batch] 'num_detections': [batch]
""" """
label_id_offset = 1 label_id_offset = 1
boxes = postprocessed_tensors.get('detection_boxes') boxes = postprocessed_tensors.get('detection_boxes')
scores = postprocessed_tensors.get('detection_scores') scores = postprocessed_tensors.get('detection_scores')
classes = postprocessed_tensors.get('detection_classes') + label_id_offset classes = postprocessed_tensors.get('detection_classes') + label_id_offset
masks = postprocessed_tensors.get('detection_masks')
num_detections = postprocessed_tensors.get('num_detections') num_detections = postprocessed_tensors.get('num_detections')
tf.identity(boxes, name='detection_boxes') tf.identity(boxes, name='detection_boxes')
tf.identity(scores, name='detection_scores') tf.identity(scores, name='detection_scores')
tf.identity(classes, name='detection_classes') tf.identity(classes, name='detection_classes')
tf.identity(num_detections, name='num_detections') tf.identity(num_detections, name='num_detections')
if masks is not None:
tf.identity(masks, name='detection_masks')
def _write_inference_graph(inference_graph_path, def _write_inference_graph(inference_graph_path,
...@@ -201,6 +221,7 @@ def _export_inference_graph(input_type, ...@@ -201,6 +221,7 @@ def _export_inference_graph(input_type,
use_moving_averages, use_moving_averages,
checkpoint_path, checkpoint_path,
inference_graph_path): inference_graph_path):
"""Export helper."""
if input_type not in input_placeholder_fn_map: if input_type not in input_placeholder_fn_map:
raise ValueError('Unknown input type: {}'.format(input_type)) raise ValueError('Unknown input type: {}'.format(input_type))
inputs = tf.to_float(input_placeholder_fn_map[input_type]()) inputs = tf.to_float(input_placeholder_fn_map[input_type]())
...@@ -208,8 +229,13 @@ def _export_inference_graph(input_type, ...@@ -208,8 +229,13 @@ def _export_inference_graph(input_type,
output_tensors = detection_model.predict(preprocessed_inputs) output_tensors = detection_model.predict(preprocessed_inputs)
postprocessed_tensors = detection_model.postprocess(output_tensors) postprocessed_tensors = detection_model.postprocess(output_tensors)
_add_output_tensor_nodes(postprocessed_tensors) _add_output_tensor_nodes(postprocessed_tensors)
out_node_names = ['num_detections', 'detection_scores,'
'detection_boxes', 'detection_classes']
if 'detection_masks' in postprocessed_tensors:
out_node_names.append('detection_masks')
_write_inference_graph(inference_graph_path, checkpoint_path, _write_inference_graph(inference_graph_path, checkpoint_path,
use_moving_averages) use_moving_averages,
output_node_names=','.join(out_node_names))
def export_inference_graph(input_type, pipeline_config, checkpoint_path, def export_inference_graph(input_type, pipeline_config, checkpoint_path,
......
...@@ -26,24 +26,28 @@ from object_detection.protos import pipeline_pb2 ...@@ -26,24 +26,28 @@ from object_detection.protos import pipeline_pb2
class FakeModel(model.DetectionModel): class FakeModel(model.DetectionModel):
def __init__(self, add_detection_masks=False):
self._add_detection_masks = add_detection_masks
def preprocess(self, inputs): def preprocess(self, inputs):
return (tf.identity(inputs) * return tf.identity(inputs)
tf.get_variable('dummy', shape=(),
initializer=tf.constant_initializer(2),
dtype=tf.float32))
def predict(self, preprocessed_inputs): def predict(self, preprocessed_inputs):
return {'image': tf.identity(preprocessed_inputs)} return {'image': tf.layers.conv2d(preprocessed_inputs, 3, 1)}
def postprocess(self, prediction_dict): def postprocess(self, prediction_dict):
with tf.control_dependencies(prediction_dict.values()): with tf.control_dependencies(prediction_dict.values()):
return { postprocessed_tensors = {
'detection_boxes': tf.constant([[0.0, 0.0, 0.5, 0.5], 'detection_boxes': tf.constant([[0.0, 0.0, 0.5, 0.5],
[0.5, 0.5, 0.8, 0.8]], tf.float32), [0.5, 0.5, 0.8, 0.8]], tf.float32),
'detection_scores': tf.constant([[0.7, 0.6]], tf.float32), 'detection_scores': tf.constant([[0.7, 0.6]], tf.float32),
'detection_classes': tf.constant([[0, 1]], tf.float32), 'detection_classes': tf.constant([[0, 1]], tf.float32),
'num_detections': tf.constant([2], tf.float32) 'num_detections': tf.constant([2], tf.float32)
} }
if self._add_detection_masks:
postprocessed_tensors['detection_masks'] = tf.constant(
np.arange(32).reshape([2, 4, 4]), tf.float32)
return postprocessed_tensors
def restore_fn(self, checkpoint_path, from_detection_checkpoint): def restore_fn(self, checkpoint_path, from_detection_checkpoint):
pass pass
...@@ -58,8 +62,11 @@ class ExportInferenceGraphTest(tf.test.TestCase): ...@@ -58,8 +62,11 @@ class ExportInferenceGraphTest(tf.test.TestCase):
use_moving_averages): use_moving_averages):
g = tf.Graph() g = tf.Graph()
with g.as_default(): with g.as_default():
mock_model = FakeModel(num_classes=1) mock_model = FakeModel()
mock_model.preprocess(tf.constant([1, 3, 4, 3], tf.float32)) preprocessed_inputs = mock_model.preprocess(
tf.ones([1, 3, 4, 3], tf.float32))
predictions = mock_model.predict(preprocessed_inputs)
mock_model.postprocess(predictions)
if use_moving_averages: if use_moving_averages:
tf.train.ExponentialMovingAverage(0.0).apply() tf.train.ExponentialMovingAverage(0.0).apply()
saver = tf.train.Saver() saver = tf.train.Saver()
...@@ -93,7 +100,7 @@ class ExportInferenceGraphTest(tf.test.TestCase): ...@@ -93,7 +100,7 @@ class ExportInferenceGraphTest(tf.test.TestCase):
def test_export_graph_with_image_tensor_input(self): def test_export_graph_with_image_tensor_input(self):
with mock.patch.object( with mock.patch.object(
model_builder, 'build', autospec=True) as mock_builder: model_builder, 'build', autospec=True) as mock_builder:
mock_builder.return_value = FakeModel(num_classes=1) mock_builder.return_value = FakeModel()
inference_graph_path = os.path.join(self.get_temp_dir(), inference_graph_path = os.path.join(self.get_temp_dir(),
'exported_graph.pbtxt') 'exported_graph.pbtxt')
...@@ -108,7 +115,7 @@ class ExportInferenceGraphTest(tf.test.TestCase): ...@@ -108,7 +115,7 @@ class ExportInferenceGraphTest(tf.test.TestCase):
def test_export_graph_with_tf_example_input(self): def test_export_graph_with_tf_example_input(self):
with mock.patch.object( with mock.patch.object(
model_builder, 'build', autospec=True) as mock_builder: model_builder, 'build', autospec=True) as mock_builder:
mock_builder.return_value = FakeModel(num_classes=1) mock_builder.return_value = FakeModel()
inference_graph_path = os.path.join(self.get_temp_dir(), inference_graph_path = os.path.join(self.get_temp_dir(),
'exported_graph.pbtxt') 'exported_graph.pbtxt')
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
...@@ -119,6 +126,20 @@ class ExportInferenceGraphTest(tf.test.TestCase): ...@@ -119,6 +126,20 @@ class ExportInferenceGraphTest(tf.test.TestCase):
checkpoint_path=None, checkpoint_path=None,
inference_graph_path=inference_graph_path) inference_graph_path=inference_graph_path)
def test_export_graph_with_encoded_image_string_input(self):
with mock.patch.object(
model_builder, 'build', autospec=True) as mock_builder:
mock_builder.return_value = FakeModel()
inference_graph_path = os.path.join(self.get_temp_dir(),
'exported_graph.pbtxt')
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
pipeline_config.eval_config.use_moving_averages = False
exporter.export_inference_graph(
input_type='encoded_image_string_tensor',
pipeline_config=pipeline_config,
checkpoint_path=None,
inference_graph_path=inference_graph_path)
def test_export_frozen_graph(self): def test_export_frozen_graph(self):
checkpoint_path = os.path.join(self.get_temp_dir(), 'model-ckpt') checkpoint_path = os.path.join(self.get_temp_dir(), 'model-ckpt')
self._save_checkpoint_from_mock_model(checkpoint_path, self._save_checkpoint_from_mock_model(checkpoint_path,
...@@ -127,7 +148,7 @@ class ExportInferenceGraphTest(tf.test.TestCase): ...@@ -127,7 +148,7 @@ class ExportInferenceGraphTest(tf.test.TestCase):
'exported_graph.pb') 'exported_graph.pb')
with mock.patch.object( with mock.patch.object(
model_builder, 'build', autospec=True) as mock_builder: model_builder, 'build', autospec=True) as mock_builder:
mock_builder.return_value = FakeModel(num_classes=1) mock_builder.return_value = FakeModel()
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
pipeline_config.eval_config.use_moving_averages = False pipeline_config.eval_config.use_moving_averages = False
exporter.export_inference_graph( exporter.export_inference_graph(
...@@ -144,7 +165,7 @@ class ExportInferenceGraphTest(tf.test.TestCase): ...@@ -144,7 +165,7 @@ class ExportInferenceGraphTest(tf.test.TestCase):
'exported_graph.pb') 'exported_graph.pb')
with mock.patch.object( with mock.patch.object(
model_builder, 'build', autospec=True) as mock_builder: model_builder, 'build', autospec=True) as mock_builder:
mock_builder.return_value = FakeModel(num_classes=1) mock_builder.return_value = FakeModel()
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
pipeline_config.eval_config.use_moving_averages = True pipeline_config.eval_config.use_moving_averages = True
exporter.export_inference_graph( exporter.export_inference_graph(
...@@ -153,6 +174,55 @@ class ExportInferenceGraphTest(tf.test.TestCase): ...@@ -153,6 +174,55 @@ class ExportInferenceGraphTest(tf.test.TestCase):
checkpoint_path=checkpoint_path, checkpoint_path=checkpoint_path,
inference_graph_path=inference_graph_path) inference_graph_path=inference_graph_path)
def test_export_model_with_all_output_nodes(self):
checkpoint_path = os.path.join(self.get_temp_dir(), 'model-ckpt')
self._save_checkpoint_from_mock_model(checkpoint_path,
use_moving_averages=False)
inference_graph_path = os.path.join(self.get_temp_dir(),
'exported_graph.pb')
with mock.patch.object(
model_builder, 'build', autospec=True) as mock_builder:
mock_builder.return_value = FakeModel(add_detection_masks=True)
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
exporter.export_inference_graph(
input_type='image_tensor',
pipeline_config=pipeline_config,
checkpoint_path=checkpoint_path,
inference_graph_path=inference_graph_path)
inference_graph = self._load_inference_graph(inference_graph_path)
with self.test_session(graph=inference_graph):
inference_graph.get_tensor_by_name('image_tensor:0')
inference_graph.get_tensor_by_name('detection_boxes:0')
inference_graph.get_tensor_by_name('detection_scores:0')
inference_graph.get_tensor_by_name('detection_classes:0')
inference_graph.get_tensor_by_name('detection_masks:0')
inference_graph.get_tensor_by_name('num_detections:0')
def test_export_model_with_detection_only_nodes(self):
checkpoint_path = os.path.join(self.get_temp_dir(), 'model-ckpt')
self._save_checkpoint_from_mock_model(checkpoint_path,
use_moving_averages=False)
inference_graph_path = os.path.join(self.get_temp_dir(),
'exported_graph.pb')
with mock.patch.object(
model_builder, 'build', autospec=True) as mock_builder:
mock_builder.return_value = FakeModel(add_detection_masks=False)
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
exporter.export_inference_graph(
input_type='image_tensor',
pipeline_config=pipeline_config,
checkpoint_path=checkpoint_path,
inference_graph_path=inference_graph_path)
inference_graph = self._load_inference_graph(inference_graph_path)
with self.test_session(graph=inference_graph):
inference_graph.get_tensor_by_name('image_tensor:0')
inference_graph.get_tensor_by_name('detection_boxes:0')
inference_graph.get_tensor_by_name('detection_scores:0')
inference_graph.get_tensor_by_name('detection_classes:0')
inference_graph.get_tensor_by_name('num_detections:0')
with self.assertRaises(KeyError):
inference_graph.get_tensor_by_name('detection_masks:0')
def test_export_and_run_inference_with_image_tensor(self): def test_export_and_run_inference_with_image_tensor(self):
checkpoint_path = os.path.join(self.get_temp_dir(), 'model-ckpt') checkpoint_path = os.path.join(self.get_temp_dir(), 'model-ckpt')
self._save_checkpoint_from_mock_model(checkpoint_path, self._save_checkpoint_from_mock_model(checkpoint_path,
...@@ -161,7 +231,7 @@ class ExportInferenceGraphTest(tf.test.TestCase): ...@@ -161,7 +231,7 @@ class ExportInferenceGraphTest(tf.test.TestCase):
'exported_graph.pb') 'exported_graph.pb')
with mock.patch.object( with mock.patch.object(
model_builder, 'build', autospec=True) as mock_builder: model_builder, 'build', autospec=True) as mock_builder:
mock_builder.return_value = FakeModel(num_classes=1) mock_builder.return_value = FakeModel(add_detection_masks=True)
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
pipeline_config.eval_config.use_moving_averages = False pipeline_config.eval_config.use_moving_averages = False
exporter.export_inference_graph( exporter.export_inference_graph(
...@@ -176,16 +246,72 @@ class ExportInferenceGraphTest(tf.test.TestCase): ...@@ -176,16 +246,72 @@ class ExportInferenceGraphTest(tf.test.TestCase):
boxes = inference_graph.get_tensor_by_name('detection_boxes:0') boxes = inference_graph.get_tensor_by_name('detection_boxes:0')
scores = inference_graph.get_tensor_by_name('detection_scores:0') scores = inference_graph.get_tensor_by_name('detection_scores:0')
classes = inference_graph.get_tensor_by_name('detection_classes:0') classes = inference_graph.get_tensor_by_name('detection_classes:0')
masks = inference_graph.get_tensor_by_name('detection_masks:0')
num_detections = inference_graph.get_tensor_by_name('num_detections:0') num_detections = inference_graph.get_tensor_by_name('num_detections:0')
(boxes, scores, classes, num_detections) = sess.run( (boxes, scores, classes, masks, num_detections) = sess.run(
[boxes, scores, classes, num_detections], [boxes, scores, classes, masks, num_detections],
feed_dict={image_tensor: np.ones((1, 4, 4, 3)).astype(np.uint8)}) feed_dict={image_tensor: np.ones((1, 4, 4, 3)).astype(np.uint8)})
self.assertAllClose(boxes, [[0.0, 0.0, 0.5, 0.5], self.assertAllClose(boxes, [[0.0, 0.0, 0.5, 0.5],
[0.5, 0.5, 0.8, 0.8]]) [0.5, 0.5, 0.8, 0.8]])
self.assertAllClose(scores, [[0.7, 0.6]]) self.assertAllClose(scores, [[0.7, 0.6]])
self.assertAllClose(classes, [[1, 2]]) self.assertAllClose(classes, [[1, 2]])
self.assertAllClose(masks, np.arange(32).reshape([2, 4, 4]))
self.assertAllClose(num_detections, [2]) self.assertAllClose(num_detections, [2])
def _create_encoded_image_string(self, image_array_np, encoding_format):
od_graph = tf.Graph()
with od_graph.as_default():
if encoding_format == 'jpg':
encoded_string = tf.image.encode_jpeg(image_array_np)
elif encoding_format == 'png':
encoded_string = tf.image.encode_png(image_array_np)
else:
raise ValueError('Supports only the following formats: `jpg`, `png`')
with self.test_session(graph=od_graph):
return encoded_string.eval()
def test_export_and_run_inference_with_encoded_image_string_tensor(self):
checkpoint_path = os.path.join(self.get_temp_dir(), 'model-ckpt')
self._save_checkpoint_from_mock_model(checkpoint_path,
use_moving_averages=False)
inference_graph_path = os.path.join(self.get_temp_dir(),
'exported_graph.pb')
with mock.patch.object(
model_builder, 'build', autospec=True) as mock_builder:
mock_builder.return_value = FakeModel(add_detection_masks=True)
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
pipeline_config.eval_config.use_moving_averages = False
exporter.export_inference_graph(
input_type='encoded_image_string_tensor',
pipeline_config=pipeline_config,
checkpoint_path=checkpoint_path,
inference_graph_path=inference_graph_path)
inference_graph = self._load_inference_graph(inference_graph_path)
jpg_image_str = self._create_encoded_image_string(
np.ones((4, 4, 3)).astype(np.uint8), 'jpg')
png_image_str = self._create_encoded_image_string(
np.ones((4, 4, 3)).astype(np.uint8), 'png')
with self.test_session(graph=inference_graph) as sess:
image_str_tensor = inference_graph.get_tensor_by_name(
'encoded_image_string_tensor:0')
boxes = inference_graph.get_tensor_by_name('detection_boxes:0')
scores = inference_graph.get_tensor_by_name('detection_scores:0')
classes = inference_graph.get_tensor_by_name('detection_classes:0')
masks = inference_graph.get_tensor_by_name('detection_masks:0')
num_detections = inference_graph.get_tensor_by_name('num_detections:0')
for image_str in [jpg_image_str, png_image_str]:
(boxes_np, scores_np, classes_np, masks_np,
num_detections_np) = sess.run(
[boxes, scores, classes, masks, num_detections],
feed_dict={image_str_tensor: image_str})
self.assertAllClose(boxes_np, [[0.0, 0.0, 0.5, 0.5],
[0.5, 0.5, 0.8, 0.8]])
self.assertAllClose(scores_np, [[0.7, 0.6]])
self.assertAllClose(classes_np, [[1, 2]])
self.assertAllClose(masks_np, np.arange(32).reshape([2, 4, 4]))
self.assertAllClose(num_detections_np, [2])
def test_export_and_run_inference_with_tf_example(self): def test_export_and_run_inference_with_tf_example(self):
checkpoint_path = os.path.join(self.get_temp_dir(), 'model-ckpt') checkpoint_path = os.path.join(self.get_temp_dir(), 'model-ckpt')
self._save_checkpoint_from_mock_model(checkpoint_path, self._save_checkpoint_from_mock_model(checkpoint_path,
...@@ -194,7 +320,7 @@ class ExportInferenceGraphTest(tf.test.TestCase): ...@@ -194,7 +320,7 @@ class ExportInferenceGraphTest(tf.test.TestCase):
'exported_graph.pb') 'exported_graph.pb')
with mock.patch.object( with mock.patch.object(
model_builder, 'build', autospec=True) as mock_builder: model_builder, 'build', autospec=True) as mock_builder:
mock_builder.return_value = FakeModel(num_classes=1) mock_builder.return_value = FakeModel(add_detection_masks=True)
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
pipeline_config.eval_config.use_moving_averages = False pipeline_config.eval_config.use_moving_averages = False
exporter.export_inference_graph( exporter.export_inference_graph(
...@@ -209,15 +335,17 @@ class ExportInferenceGraphTest(tf.test.TestCase): ...@@ -209,15 +335,17 @@ class ExportInferenceGraphTest(tf.test.TestCase):
boxes = inference_graph.get_tensor_by_name('detection_boxes:0') boxes = inference_graph.get_tensor_by_name('detection_boxes:0')
scores = inference_graph.get_tensor_by_name('detection_scores:0') scores = inference_graph.get_tensor_by_name('detection_scores:0')
classes = inference_graph.get_tensor_by_name('detection_classes:0') classes = inference_graph.get_tensor_by_name('detection_classes:0')
masks = inference_graph.get_tensor_by_name('detection_masks:0')
num_detections = inference_graph.get_tensor_by_name('num_detections:0') num_detections = inference_graph.get_tensor_by_name('num_detections:0')
(boxes, scores, classes, num_detections) = sess.run( (boxes, scores, classes, masks, num_detections) = sess.run(
[boxes, scores, classes, num_detections], [boxes, scores, classes, masks, num_detections],
feed_dict={tf_example: self._create_tf_example( feed_dict={tf_example: self._create_tf_example(
np.ones((4, 4, 3)).astype(np.uint8))}) np.ones((4, 4, 3)).astype(np.uint8))})
self.assertAllClose(boxes, [[0.0, 0.0, 0.5, 0.5], self.assertAllClose(boxes, [[0.0, 0.0, 0.5, 0.5],
[0.5, 0.5, 0.8, 0.8]]) [0.5, 0.5, 0.8, 0.8]])
self.assertAllClose(scores, [[0.7, 0.6]]) self.assertAllClose(scores, [[0.7, 0.6]])
self.assertAllClose(classes, [[1, 2]]) self.assertAllClose(classes, [[1, 2]])
self.assertAllClose(masks, np.arange(32).reshape([2, 4, 4]))
self.assertAllClose(num_detections, [2]) self.assertAllClose(num_detections, [2])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment