"vscode:/vscode.git/clone" did not exist on "89374d38700a67e000c7dc56efd83d21cf88c8ee"
Commit 7417c721 authored by Xianzhi Du's avatar Xianzhi Du Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 373393602
parent 39f1fc8e
...@@ -39,6 +39,7 @@ class ImageClassificationConfigTest(tf.test.TestCase, parameterized.TestCase): ...@@ -39,6 +39,7 @@ class ImageClassificationConfigTest(tf.test.TestCase, parameterized.TestCase):
self.assertIsInstance(config.task.model, self.assertIsInstance(config.task.model,
exp_cfg.ImageClassificationModel) exp_cfg.ImageClassificationModel)
self.assertIsInstance(config.task.train_data, exp_cfg.DataConfig) self.assertIsInstance(config.task.train_data, exp_cfg.DataConfig)
config.validate()
config.task.train_data.is_training = None config.task.train_data.is_training = None
with self.assertRaises(KeyError): with self.assertRaises(KeyError):
config.validate() config.validate()
......
...@@ -29,6 +29,7 @@ class MaskRCNNConfigTest(tf.test.TestCase, parameterized.TestCase): ...@@ -29,6 +29,7 @@ class MaskRCNNConfigTest(tf.test.TestCase, parameterized.TestCase):
('fasterrcnn_resnetfpn_coco',), ('fasterrcnn_resnetfpn_coco',),
('maskrcnn_resnetfpn_coco',), ('maskrcnn_resnetfpn_coco',),
('maskrcnn_spinenet_coco',), ('maskrcnn_spinenet_coco',),
('cascadercnn_resnetfpn_coco',),
) )
def test_maskrcnn_configs(self, config_name): def test_maskrcnn_configs(self, config_name):
config = exp_factory.get_exp_config(config_name) config = exp_factory.get_exp_config(config_name)
...@@ -36,6 +37,7 @@ class MaskRCNNConfigTest(tf.test.TestCase, parameterized.TestCase): ...@@ -36,6 +37,7 @@ class MaskRCNNConfigTest(tf.test.TestCase, parameterized.TestCase):
self.assertIsInstance(config.task, exp_cfg.MaskRCNNTask) self.assertIsInstance(config.task, exp_cfg.MaskRCNNTask)
self.assertIsInstance(config.task.model, exp_cfg.MaskRCNN) self.assertIsInstance(config.task.model, exp_cfg.MaskRCNN)
self.assertIsInstance(config.task.train_data, exp_cfg.DataConfig) self.assertIsInstance(config.task.train_data, exp_cfg.DataConfig)
config.validate()
config.task.train_data.is_training = None config.task.train_data.is_training = None
with self.assertRaisesRegex(KeyError, 'Found inconsistncy between key'): with self.assertRaisesRegex(KeyError, 'Found inconsistncy between key'):
config.validate() config.validate()
......
...@@ -23,19 +23,20 @@ from official.vision import beta ...@@ -23,19 +23,20 @@ from official.vision import beta
from official.vision.beta.configs import retinanet as exp_cfg from official.vision.beta.configs import retinanet as exp_cfg
class MaskRCNNConfigTest(tf.test.TestCase, parameterized.TestCase): class RetinaNetConfigTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters( @parameterized.parameters(
('retinanet_resnetfpn_coco',), ('retinanet_resnetfpn_coco',),
('retinanet_spinenet_coco',), ('retinanet_spinenet_coco',),
('retinanet_spinenet_mobile_coco',), ('retinanet_spinenet_mobile_coco',),
) )
def test_maskrcnn_configs(self, config_name): def test_retinanet_configs(self, config_name):
config = exp_factory.get_exp_config(config_name) config = exp_factory.get_exp_config(config_name)
self.assertIsInstance(config, cfg.ExperimentConfig) self.assertIsInstance(config, cfg.ExperimentConfig)
self.assertIsInstance(config.task, exp_cfg.RetinaNetTask) self.assertIsInstance(config.task, exp_cfg.RetinaNetTask)
self.assertIsInstance(config.task.model, exp_cfg.RetinaNet) self.assertIsInstance(config.task.model, exp_cfg.RetinaNet)
self.assertIsInstance(config.task.train_data, exp_cfg.DataConfig) self.assertIsInstance(config.task.train_data, exp_cfg.DataConfig)
config.validate()
config.task.train_data.is_training = None config.task.train_data.is_training = None
with self.assertRaisesRegex(KeyError, 'Found inconsistncy between key'): with self.assertRaisesRegex(KeyError, 'Found inconsistncy between key'):
config.validate() config.validate()
......
...@@ -36,6 +36,7 @@ class ImageSegmentationConfigTest(tf.test.TestCase, parameterized.TestCase): ...@@ -36,6 +36,7 @@ class ImageSegmentationConfigTest(tf.test.TestCase, parameterized.TestCase):
self.assertIsInstance(config.task.model, self.assertIsInstance(config.task.model,
exp_cfg.SemanticSegmentationModel) exp_cfg.SemanticSegmentationModel)
self.assertIsInstance(config.task.train_data, exp_cfg.DataConfig) self.assertIsInstance(config.task.train_data, exp_cfg.DataConfig)
config.validate()
config.task.train_data.is_training = None config.task.train_data.is_training = None
with self.assertRaises(KeyError): with self.assertRaises(KeyError):
config.validate() config.validate()
......
...@@ -35,6 +35,7 @@ class VideoClassificationConfigTest(tf.test.TestCase, parameterized.TestCase): ...@@ -35,6 +35,7 @@ class VideoClassificationConfigTest(tf.test.TestCase, parameterized.TestCase):
self.assertIsInstance(config.task, exp_cfg.VideoClassificationTask) self.assertIsInstance(config.task, exp_cfg.VideoClassificationTask)
self.assertIsInstance(config.task.model, exp_cfg.VideoClassificationModel) self.assertIsInstance(config.task.model, exp_cfg.VideoClassificationModel)
self.assertIsInstance(config.task.train_data, exp_cfg.DataConfig) self.assertIsInstance(config.task.train_data, exp_cfg.DataConfig)
config.validate()
config.task.train_data.is_training = None config.task.train_data.is_training = None
with self.assertRaises(KeyError): with self.assertRaises(KeyError):
config.validate() config.validate()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment