Unverified Commit 420a7253 authored by pkulzc's avatar pkulzc Committed by GitHub
Browse files

Refactor tests for Object Detection API. (#8688)

Internal changes

--

PiperOrigin-RevId: 316837667
parent d0ef3913
......@@ -24,10 +24,18 @@ import six
import tensorflow.compat.v1 as tf
from object_detection.dataset_tools import seq_example_util
from object_detection.utils import tf_version
class SeqExampleUtilTest(tf.test.TestCase):
def materialize_tensors(self, list_of_tensors):
if tf_version.is_tf2():
return [tensor.numpy() for tensor in list_of_tensors]
else:
with self.cached_session() as sess:
return sess.run(list_of_tensors)
def test_make_unlabeled_example(self):
num_frames = 5
image_height = 100
......@@ -41,8 +49,7 @@ class SeqExampleUtilTest(tf.test.TestCase):
image_source_ids = [str(idx) for idx in range(num_frames)]
images_list = tf.unstack(images, axis=0)
encoded_images_list = [tf.io.encode_jpeg(image) for image in images_list]
with tf.Session() as sess:
encoded_images = sess.run(encoded_images_list)
encoded_images = self.materialize_tensors(encoded_images_list)
seq_example = seq_example_util.make_sequence_example(
dataset_name=dataset_name,
video_id=video_id,
......@@ -109,8 +116,7 @@ class SeqExampleUtilTest(tf.test.TestCase):
dtype=tf.int32), dtype=tf.uint8)
images_list = tf.unstack(images, axis=0)
encoded_images_list = [tf.io.encode_jpeg(image) for image in images_list]
with tf.Session() as sess:
encoded_images = sess.run(encoded_images_list)
encoded_images = self.materialize_tensors(encoded_images_list)
timestamps = [100000, 110000]
is_annotated = [1, 0]
bboxes = [
......@@ -208,8 +214,7 @@ class SeqExampleUtilTest(tf.test.TestCase):
dtype=tf.int32), dtype=tf.uint8)
images_list = tf.unstack(images, axis=0)
encoded_images_list = [tf.io.encode_jpeg(image) for image in images_list]
with tf.Session() as sess:
encoded_images = sess.run(encoded_images_list)
encoded_images = self.materialize_tensors(encoded_images_list)
bboxes = [
np.array([[0., 0., 0.75, 0.75],
[0., 0., 1., 1.]], dtype=np.float32),
......
......@@ -52,6 +52,8 @@ EVAL_METRICS_CLASS_DICT = {
coco_evaluation.CocoKeypointEvaluator,
'coco_mask_metrics':
coco_evaluation.CocoMaskEvaluator,
'coco_panoptic_metrics':
coco_evaluation.CocoPanopticSegmentationEvaluator,
'oid_challenge_detection_metrics':
object_detection_evaluation.OpenImagesDetectionChallengeEvaluator,
'oid_challenge_segmentation_metrics':
......
This diff is collapsed.
......@@ -24,16 +24,19 @@ import tensorflow.compat.v1 as tf
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import types_pb2
from tensorflow.core.protobuf import saver_pb2
from tensorflow.tools.graph_transforms import TransformGraph
from object_detection import exporter
from object_detection.builders import graph_rewriter_builder
from object_detection.builders import model_builder
from object_detection.builders import post_processing_builder
from object_detection.core import box_list
from object_detection.utils import tf_version
_DEFAULT_NUM_CHANNELS = 3
_DEFAULT_NUM_COORD_BOX = 4
if tf_version.is_tf1():
from tensorflow.tools.graph_transforms import TransformGraph # pylint: disable=g-import-not-at-top
def get_const_center_size_encoded_anchors(anchors):
"""Exports center-size encoded anchors as a constant tensor.
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment