Unverified Commit 420a7253 authored by pkulzc's avatar pkulzc Committed by GitHub
Browse files

Refactor tests for Object Detection API. (#8688)

Internal changes

--

PiperOrigin-RevId: 316837667
parent d0ef3913
......@@ -20,13 +20,15 @@ from __future__ import division
from __future__ import print_function
import os
import unittest
import tensorflow.compat.v1 as tf
from object_detection.utils import test_case
from object_detection.utils import tf_version
from object_detection.utils import variables_helper
@unittest.skipIf(tf_version.is_tf2(), 'Skipping TF1.X only test.')
class FilterVariablesTest(test_case.TestCase):
def _create_variables(self):
......@@ -68,6 +70,7 @@ class FilterVariablesTest(test_case.TestCase):
self.assertCountEqual(out_variables, [variables[1], variables[3]])
@unittest.skipIf(tf_version.is_tf2(), 'Skipping TF1.X only test.')
class MultiplyGradientsMatchingRegexTest(tf.test.TestCase):
def _create_grads_and_vars(self):
......@@ -107,6 +110,7 @@ class MultiplyGradientsMatchingRegexTest(tf.test.TestCase):
self.assertCountEqual(output, exp_output)
@unittest.skipIf(tf_version.is_tf2(), 'Skipping TF1.X only test.')
class FreezeGradientsMatchingRegexTest(test_case.TestCase):
def _create_grads_and_vars(self):
......@@ -132,6 +136,7 @@ class FreezeGradientsMatchingRegexTest(test_case.TestCase):
self.assertCountEqual(output, exp_output)
@unittest.skipIf(tf_version.is_tf2(), 'Skipping TF1.X only test.')
class GetVariablesAvailableInCheckpointTest(test_case.TestCase):
def test_return_all_variables_from_checkpoint(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment