"...git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "a02a5490baab3b4745844b0f0752fe746a0cb7bc"
Commit e76ee699 authored by Derek Chow's avatar Derek Chow
Browse files

Exporter updates.

parent e5de97b4
...@@ -113,14 +113,19 @@ def freeze_graph_with_def_protos( ...@@ -113,14 +113,19 @@ def freeze_graph_with_def_protos(
def _image_tensor_input_placeholder(): def _image_tensor_input_placeholder():
"""Returns input node that accepts a batch of uint8 images.""" """Returns placeholder and input node that accepts a batch of uint8 images."""
return tf.placeholder(dtype=tf.uint8, input_tensor = tf.placeholder(dtype=tf.uint8,
shape=(None, None, None, 3), shape=(None, None, None, 3),
name='image_tensor') name='image_tensor')
return input_tensor, input_tensor
def _tf_example_input_placeholder(): def _tf_example_input_placeholder():
"""Returns input node that accepts a batch of strings with tf examples.""" """Returns input that accepts a batch of strings with tf examples.
Returns:
a tuple of placeholder and input nodes that output decoded images.
"""
batch_tf_example_placeholder = tf.placeholder( batch_tf_example_placeholder = tf.placeholder(
tf.string, shape=[None], name='tf_example') tf.string, shape=[None], name='tf_example')
def decode(tf_example_string_tensor): def decode(tf_example_string_tensor):
...@@ -128,15 +133,20 @@ def _tf_example_input_placeholder(): ...@@ -128,15 +133,20 @@ def _tf_example_input_placeholder():
tf_example_string_tensor) tf_example_string_tensor)
image_tensor = tensor_dict[fields.InputDataFields.image] image_tensor = tensor_dict[fields.InputDataFields.image]
return image_tensor return image_tensor
return tf.map_fn(decode, return (batch_tf_example_placeholder,
elems=batch_tf_example_placeholder, tf.map_fn(decode,
dtype=tf.uint8, elems=batch_tf_example_placeholder,
parallel_iterations=32, dtype=tf.uint8,
back_prop=False) parallel_iterations=32,
back_prop=False))
def _encoded_image_string_tensor_input_placeholder(): def _encoded_image_string_tensor_input_placeholder():
"""Returns input node that accepts a batch of PNG or JPEG strings.""" """Returns input that accepts a batch of PNG or JPEG strings.
Returns:
a tuple of placeholder and input nodes that output decoded images.
"""
batch_image_str_placeholder = tf.placeholder( batch_image_str_placeholder = tf.placeholder(
dtype=tf.string, dtype=tf.string,
shape=[None], shape=[None],
...@@ -146,11 +156,13 @@ def _encoded_image_string_tensor_input_placeholder(): ...@@ -146,11 +156,13 @@ def _encoded_image_string_tensor_input_placeholder():
channels=3) channels=3)
image_tensor.set_shape((None, None, 3)) image_tensor.set_shape((None, None, 3))
return image_tensor return image_tensor
return tf.map_fn(decode, return (batch_image_str_placeholder,
elems=batch_image_str_placeholder, tf.map_fn(
dtype=tf.uint8, decode,
parallel_iterations=32, elems=batch_image_str_placeholder,
back_prop=False) dtype=tf.uint8,
parallel_iterations=32,
back_prop=False))
input_placeholder_fn_map = { input_placeholder_fn_map = {
...@@ -262,7 +274,7 @@ def _write_saved_model(saved_model_path, ...@@ -262,7 +274,7 @@ def _write_saved_model(saved_model_path,
builder.add_meta_graph_and_variables( builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING], sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map={ signature_def_map={
'signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY': signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
detection_signature, detection_signature,
}, },
) )
...@@ -300,7 +312,8 @@ def _export_inference_graph(input_type, ...@@ -300,7 +312,8 @@ def _export_inference_graph(input_type,
if input_type not in input_placeholder_fn_map: if input_type not in input_placeholder_fn_map:
raise ValueError('Unknown input type: {}'.format(input_type)) raise ValueError('Unknown input type: {}'.format(input_type))
inputs = tf.to_float(input_placeholder_fn_map[input_type]()) placeholder_tensor, input_tensors = input_placeholder_fn_map[input_type]()
inputs = tf.to_float(input_tensors)
preprocessed_inputs = detection_model.preprocess(inputs) preprocessed_inputs = detection_model.preprocess(inputs)
output_tensors = detection_model.predict(preprocessed_inputs) output_tensors = detection_model.predict(preprocessed_inputs)
postprocessed_tensors = detection_model.postprocess(output_tensors) postprocessed_tensors = detection_model.postprocess(output_tensors)
...@@ -333,7 +346,8 @@ def _export_inference_graph(input_type, ...@@ -333,7 +346,8 @@ def _export_inference_graph(input_type,
optimize_graph=optimize_graph, optimize_graph=optimize_graph,
initializer_nodes='') initializer_nodes='')
_write_frozen_graph(frozen_graph_path, frozen_graph_def) _write_frozen_graph(frozen_graph_path, frozen_graph_def)
_write_saved_model(saved_model_path, frozen_graph_def, inputs, outputs) _write_saved_model(saved_model_path, frozen_graph_def, placeholder_tensor,
outputs)
def export_inference_graph(input_type, def export_inference_graph(input_type,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment