Commit b1025b3b authored by syiming's avatar syiming
Browse files

Merge remote-tracking branch 'upstream/master' into fasterrcnn_fpn_keras_feature_extractor

parents 69ce1c45 e9df75ab
...@@ -13,22 +13,21 @@ ...@@ -13,22 +13,21 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Tests for graph_rewriter_builder.""" """Tests for graph_rewriter_builder."""
import unittest
import mock import mock
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
import tf_slim as slim import tf_slim as slim
from object_detection.builders import graph_rewriter_builder from object_detection.builders import graph_rewriter_builder
from object_detection.protos import graph_rewriter_pb2 from object_detection.protos import graph_rewriter_pb2
from object_detection.utils import tf_version
# pylint: disable=g-import-not-at-top
try:
from tensorflow.contrib import quantize as contrib_quantize
except ImportError:
# TF 2.0 doesn't ship with contrib.
pass
# pylint: enable=g-import-not-at-top
if tf_version.is_tf1():
from tensorflow.contrib import quantize as contrib_quantize # pylint: disable=g-import-not-at-top
@unittest.skipIf(tf_version.is_tf2(), 'Skipping TF1.X only test.')
class QuantizationBuilderTest(tf.test.TestCase): class QuantizationBuilderTest(tf.test.TestCase):
def testQuantizationBuilderSetsUpCorrectTrainArguments(self): def testQuantizationBuilderSetsUpCorrectTrainArguments(self):
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
"""Tests for input_reader_builder.""" """Tests for input_reader_builder."""
import os import os
import unittest
import numpy as np import numpy as np
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
...@@ -26,6 +27,7 @@ from object_detection.core import standard_fields as fields ...@@ -26,6 +27,7 @@ from object_detection.core import standard_fields as fields
from object_detection.dataset_tools import seq_example_util from object_detection.dataset_tools import seq_example_util
from object_detection.protos import input_reader_pb2 from object_detection.protos import input_reader_pb2
from object_detection.utils import dataset_util from object_detection.utils import dataset_util
from object_detection.utils import tf_version
def _get_labelmap_path(): def _get_labelmap_path():
...@@ -35,6 +37,7 @@ def _get_labelmap_path(): ...@@ -35,6 +37,7 @@ def _get_labelmap_path():
'pet_label_map.pbtxt') 'pet_label_map.pbtxt')
@unittest.skipIf(tf_version.is_tf2(), 'Skipping TF1.X only test.')
class InputReaderBuilderTest(tf.test.TestCase): class InputReaderBuilderTest(tf.test.TestCase):
def create_tf_record(self): def create_tf_record(self):
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment