Commit a6df5573 authored by derekjchow's avatar derekjchow Committed by GitHub
Browse files

Merge pull request #1758 from bmabey/py3-compat-object_detection-eval

updates object_detection/eval.py to be python3 compatible
parents dac4fbab 65f0fa2e
...@@ -154,7 +154,7 @@ def evaluate(create_input_dict_fn, create_model_fn, eval_config, categories, ...@@ -154,7 +154,7 @@ def evaluate(create_input_dict_fn, create_model_fn, eval_config, categories,
""" """
if batch_index >= eval_config.num_visualizations: if batch_index >= eval_config.num_visualizations:
if 'original_image' in tensor_dict: if 'original_image' in tensor_dict:
tensor_dict = {k: v for (k, v) in tensor_dict.iteritems() tensor_dict = {k: v for (k, v) in tensor_dict.items()
if k != 'original_image'} if k != 'original_image'}
try: try:
(result_dict, _) = sess.run([tensor_dict, update_op]) (result_dict, _) = sess.run([tensor_dict, update_op])
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
from __future__ import division from __future__ import division
import numpy as np import numpy as np
from six import moves
def compute_precision_recall(scores, labels, num_gt): def compute_precision_recall(scores, labels, num_gt):
...@@ -103,7 +104,7 @@ def compute_average_precision(precision, recall): ...@@ -103,7 +104,7 @@ def compute_average_precision(precision, recall):
raise ValueError("Precision must be in the range of [0, 1].") raise ValueError("Precision must be in the range of [0, 1].")
if np.amin(recall) < 0 or np.amax(recall) > 1: if np.amin(recall) < 0 or np.amax(recall) > 1:
raise ValueError("recall must be in the range of [0, 1].") raise ValueError("recall must be in the range of [0, 1].")
if not all(recall[i] <= recall[i + 1] for i in xrange(len(recall) - 1)): if not all(recall[i] <= recall[i + 1] for i in moves.range(len(recall) - 1)):
raise ValueError("recall must be a non-decreasing array") raise ValueError("recall must be a non-decreasing array")
recall = np.concatenate([[0], recall, [1]]) recall = np.concatenate([[0], recall, [1]])
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
"""Numpy BoxList classes and functions.""" """Numpy BoxList classes and functions."""
import numpy as np import numpy as np
from six import moves
class BoxList(object): class BoxList(object):
...@@ -127,7 +128,7 @@ class BoxList(object): ...@@ -127,7 +128,7 @@ class BoxList(object):
ymin, and all xmax of boxes are equal or greater than xmin. ymin, and all xmax of boxes are equal or greater than xmin.
""" """
if data.shape[0] > 0: if data.shape[0] > 0:
for i in xrange(data.shape[0]): for i in moves.range(data.shape[0]):
if data[i, 0] > data[i, 2] or data[i, 1] > data[i, 3]: if data[i, 0] > data[i, 2] or data[i, 1] > data[i, 3]:
return False return False
return True return True
...@@ -80,7 +80,7 @@ def encode_image_array_as_png_str(image): ...@@ -80,7 +80,7 @@ def encode_image_array_as_png_str(image):
PNG encoded image string. PNG encoded image string.
""" """
image_pil = Image.fromarray(np.uint8(image)) image_pil = Image.fromarray(np.uint8(image))
output = six.StringIO() output = six.BytesIO()
image_pil.save(output, format='PNG') image_pil.save(output, format='PNG')
png_string = output.getvalue() png_string = output.getvalue()
output.close() output.close()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment