Unverified Commit 420a7253 authored by pkulzc's avatar pkulzc Committed by GitHub
Browse files

Refactor tests for Object Detection API. (#8688)

Internal changes

--

PiperOrigin-RevId: 316837667
parent d0ef3913
...@@ -24,10 +24,18 @@ import six ...@@ -24,10 +24,18 @@ import six
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
from object_detection.dataset_tools import seq_example_util from object_detection.dataset_tools import seq_example_util
from object_detection.utils import tf_version
class SeqExampleUtilTest(tf.test.TestCase): class SeqExampleUtilTest(tf.test.TestCase):
def materialize_tensors(self, list_of_tensors):
if tf_version.is_tf2():
return [tensor.numpy() for tensor in list_of_tensors]
else:
with self.cached_session() as sess:
return sess.run(list_of_tensors)
def test_make_unlabeled_example(self): def test_make_unlabeled_example(self):
num_frames = 5 num_frames = 5
image_height = 100 image_height = 100
...@@ -41,8 +49,7 @@ class SeqExampleUtilTest(tf.test.TestCase): ...@@ -41,8 +49,7 @@ class SeqExampleUtilTest(tf.test.TestCase):
image_source_ids = [str(idx) for idx in range(num_frames)] image_source_ids = [str(idx) for idx in range(num_frames)]
images_list = tf.unstack(images, axis=0) images_list = tf.unstack(images, axis=0)
encoded_images_list = [tf.io.encode_jpeg(image) for image in images_list] encoded_images_list = [tf.io.encode_jpeg(image) for image in images_list]
with tf.Session() as sess: encoded_images = self.materialize_tensors(encoded_images_list)
encoded_images = sess.run(encoded_images_list)
seq_example = seq_example_util.make_sequence_example( seq_example = seq_example_util.make_sequence_example(
dataset_name=dataset_name, dataset_name=dataset_name,
video_id=video_id, video_id=video_id,
...@@ -109,8 +116,7 @@ class SeqExampleUtilTest(tf.test.TestCase): ...@@ -109,8 +116,7 @@ class SeqExampleUtilTest(tf.test.TestCase):
dtype=tf.int32), dtype=tf.uint8) dtype=tf.int32), dtype=tf.uint8)
images_list = tf.unstack(images, axis=0) images_list = tf.unstack(images, axis=0)
encoded_images_list = [tf.io.encode_jpeg(image) for image in images_list] encoded_images_list = [tf.io.encode_jpeg(image) for image in images_list]
with tf.Session() as sess: encoded_images = self.materialize_tensors(encoded_images_list)
encoded_images = sess.run(encoded_images_list)
timestamps = [100000, 110000] timestamps = [100000, 110000]
is_annotated = [1, 0] is_annotated = [1, 0]
bboxes = [ bboxes = [
...@@ -208,8 +214,7 @@ class SeqExampleUtilTest(tf.test.TestCase): ...@@ -208,8 +214,7 @@ class SeqExampleUtilTest(tf.test.TestCase):
dtype=tf.int32), dtype=tf.uint8) dtype=tf.int32), dtype=tf.uint8)
images_list = tf.unstack(images, axis=0) images_list = tf.unstack(images, axis=0)
encoded_images_list = [tf.io.encode_jpeg(image) for image in images_list] encoded_images_list = [tf.io.encode_jpeg(image) for image in images_list]
with tf.Session() as sess: encoded_images = self.materialize_tensors(encoded_images_list)
encoded_images = sess.run(encoded_images_list)
bboxes = [ bboxes = [
np.array([[0., 0., 0.75, 0.75], np.array([[0., 0., 0.75, 0.75],
[0., 0., 1., 1.]], dtype=np.float32), [0., 0., 1., 1.]], dtype=np.float32),
......
...@@ -52,6 +52,8 @@ EVAL_METRICS_CLASS_DICT = { ...@@ -52,6 +52,8 @@ EVAL_METRICS_CLASS_DICT = {
coco_evaluation.CocoKeypointEvaluator, coco_evaluation.CocoKeypointEvaluator,
'coco_mask_metrics': 'coco_mask_metrics':
coco_evaluation.CocoMaskEvaluator, coco_evaluation.CocoMaskEvaluator,
'coco_panoptic_metrics':
coco_evaluation.CocoPanopticSegmentationEvaluator,
'oid_challenge_detection_metrics': 'oid_challenge_detection_metrics':
object_detection_evaluation.OpenImagesDetectionChallengeEvaluator, object_detection_evaluation.OpenImagesDetectionChallengeEvaluator,
'oid_challenge_segmentation_metrics': 'oid_challenge_segmentation_metrics':
......
This diff is collapsed.
...@@ -24,16 +24,19 @@ import tensorflow.compat.v1 as tf ...@@ -24,16 +24,19 @@ import tensorflow.compat.v1 as tf
from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import types_pb2 from tensorflow.core.framework import types_pb2
from tensorflow.core.protobuf import saver_pb2 from tensorflow.core.protobuf import saver_pb2
from tensorflow.tools.graph_transforms import TransformGraph
from object_detection import exporter from object_detection import exporter
from object_detection.builders import graph_rewriter_builder from object_detection.builders import graph_rewriter_builder
from object_detection.builders import model_builder from object_detection.builders import model_builder
from object_detection.builders import post_processing_builder from object_detection.builders import post_processing_builder
from object_detection.core import box_list from object_detection.core import box_list
from object_detection.utils import tf_version
_DEFAULT_NUM_CHANNELS = 3 _DEFAULT_NUM_CHANNELS = 3
_DEFAULT_NUM_COORD_BOX = 4 _DEFAULT_NUM_COORD_BOX = 4
if tf_version.is_tf1():
from tensorflow.tools.graph_transforms import TransformGraph # pylint: disable=g-import-not-at-top
def get_const_center_size_encoded_anchors(anchors): def get_const_center_size_encoded_anchors(anchors):
"""Exports center-size encoded anchors as a constant tensor. """Exports center-size encoded anchors as a constant tensor.
......
...@@ -18,6 +18,7 @@ from __future__ import absolute_import ...@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
import unittest
import numpy as np import numpy as np
import six import six
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
...@@ -32,6 +33,7 @@ from object_detection.core import model ...@@ -32,6 +33,7 @@ from object_detection.core import model
from object_detection.protos import graph_rewriter_pb2 from object_detection.protos import graph_rewriter_pb2
from object_detection.protos import pipeline_pb2 from object_detection.protos import pipeline_pb2
from object_detection.protos import post_processing_pb2 from object_detection.protos import post_processing_pb2
from object_detection.utils import tf_version
# pylint: disable=g-import-not-at-top # pylint: disable=g-import-not-at-top
...@@ -82,6 +84,7 @@ class FakeModel(model.DetectionModel): ...@@ -82,6 +84,7 @@ class FakeModel(model.DetectionModel):
pass pass
@unittest.skipIf(tf_version.is_tf2(), 'Skipping TF1.X only test.')
class ExportTfliteGraphTest(tf.test.TestCase): class ExportTfliteGraphTest(tf.test.TestCase):
def _save_checkpoint_from_mock_model(self, def _save_checkpoint_from_mock_model(self,
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment