Commit 2e9bb539 authored by stephenwu's avatar stephenwu
Browse files

Merge branch 'master' of https://github.com/tensorflow/models into RTESuperGLUE

parents 7bae5317 8fba84f8
...@@ -1018,6 +1018,67 @@ class ConfigUtilTest(tf.test.TestCase): ...@@ -1018,6 +1018,67 @@ class ConfigUtilTest(tf.test.TestCase):
output_dict, output_dict,
config_util.remove_unecessary_ema(input_dict, no_ema_collection)) config_util.remove_unecessary_ema(input_dict, no_ema_collection))
def testUpdateRescoreInstances(self):
pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
kpt_task = pipeline_config.model.center_net.keypoint_estimation_task.add()
kpt_task.rescore_instances = True
_write_config(pipeline_config, pipeline_config_path)
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
cn_config = configs["model"].center_net
self.assertEqual(
True, cn_config.keypoint_estimation_task[0].rescore_instances)
config_util.merge_external_params_with_configs(
configs, kwargs_dict={"rescore_instances": False})
cn_config = configs["model"].center_net
self.assertEqual(
False, cn_config.keypoint_estimation_task[0].rescore_instances)
def testUpdateRescoreInstancesWithBooleanString(self):
pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
kpt_task = pipeline_config.model.center_net.keypoint_estimation_task.add()
kpt_task.rescore_instances = True
_write_config(pipeline_config, pipeline_config_path)
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
cn_config = configs["model"].center_net
self.assertEqual(
True, cn_config.keypoint_estimation_task[0].rescore_instances)
config_util.merge_external_params_with_configs(
configs, kwargs_dict={"rescore_instances": "False"})
cn_config = configs["model"].center_net
self.assertEqual(
False, cn_config.keypoint_estimation_task[0].rescore_instances)
def testUpdateRescoreInstancesWithMultipleTasks(self):
pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
kpt_task = pipeline_config.model.center_net.keypoint_estimation_task.add()
kpt_task.rescore_instances = True
kpt_task = pipeline_config.model.center_net.keypoint_estimation_task.add()
kpt_task.rescore_instances = True
_write_config(pipeline_config, pipeline_config_path)
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
cn_config = configs["model"].center_net
self.assertEqual(
True, cn_config.keypoint_estimation_task[0].rescore_instances)
config_util.merge_external_params_with_configs(
configs, kwargs_dict={"rescore_instances": False})
cn_config = configs["model"].center_net
self.assertEqual(
True, cn_config.keypoint_estimation_task[0].rescore_instances)
self.assertEqual(
True, cn_config.keypoint_estimation_task[1].rescore_instances)
if __name__ == "__main__": if __name__ == "__main__":
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment