Unverified Commit 7cc688ae authored by Yukun Zhu's avatar Yukun Zhu Committed by GitHub
Browse files

Merge pull request #5781 from linuxpolska/master

Improved performance of deeplab inference
parents a6494752 0958967d
...@@ -126,9 +126,10 @@ def main(unused_argv): ...@@ -126,9 +126,10 @@ def main(unused_argv):
eval_scales=FLAGS.inference_scales, eval_scales=FLAGS.inference_scales,
add_flipped_images=FLAGS.add_flipped_images) add_flipped_images=FLAGS.add_flipped_images)
predictions = tf.cast(predictions[common.OUTPUT_TYPE], tf.float32)
# Crop the valid regions from the predictions. # Crop the valid regions from the predictions.
semantic_predictions = tf.slice( semantic_predictions = tf.slice(
predictions[common.OUTPUT_TYPE], predictions,
[0, 0, 0], [0, 0, 0],
[1, resized_image_size[0], resized_image_size[1]]) [1, resized_image_size[0], resized_image_size[1]])
# Resize back the prediction to the original image size. # Resize back the prediction to the original image size.
...@@ -140,7 +141,7 @@ def main(unused_argv): ...@@ -140,7 +141,7 @@ def main(unused_argv):
label_size, label_size,
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR,
align_corners=True) align_corners=True)
return tf.squeeze(resized_label, 3) return tf.cast(tf.squeeze(resized_label, 3), tf.int32)
semantic_predictions = _resize_label(semantic_predictions, image_size) semantic_predictions = _resize_label(semantic_predictions, image_size)
semantic_predictions = tf.identity(semantic_predictions, name=_OUTPUT_NAME) semantic_predictions = tf.identity(semantic_predictions, name=_OUTPUT_NAME)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment