Commit d933decc authored by Yoni Ben-Meshulam's avatar Yoni Ben-Meshulam Committed by TF Object Detection Team
Browse files

Support for config override of sample_from_dataset_weights.

PiperOrigin-RevId: 345586686
parent b4499b90
......@@ -621,6 +621,8 @@ def _maybe_update_config_with_key_value(configs, key, value):
value)
elif field_name == "num_classes":
_update_num_classes(configs["model"], value)
elif field_name == "sample_from_datasets_weights":
_update_sample_from_datasets_weights(configs["train_input_config"], value)
else:
return False
return True
......@@ -1073,3 +1075,17 @@ def _update_num_classes(model_config, num_classes):
model_config.faster_rcnn.num_classes = num_classes
if meta_architecture == "ssd":
model_config.ssd.num_classes = num_classes
def _update_sample_from_datasets_weights(input_reader_config, weights):
"""Updated sample_from_datasets_weights with overrides."""
if len(weights) != len(input_reader_config.sample_from_datasets_weights):
raise ValueError(
"sample_from_datasets_weights override has a different number of values"
" ({}) than the configured dataset weights ({})."
.format(
len(input_reader_config.sample_from_datasets_weights),
len(weights)))
del input_reader_config.sample_from_datasets_weights[:]
input_reader_config.sample_from_datasets_weights.extend(weights)
......@@ -377,6 +377,45 @@ class ConfigUtilTest(tf.test.TestCase):
new_batch_size = configs["train_config"].batch_size
self.assertEqual(10, new_batch_size)
@unittest.skipIf(tf_version.is_tf2(), "Skipping TF1.X only test.")
def testOverwriteSampleFromDatasetWeights(self):
"""Tests config override for sample_from_datasets_weights."""
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
pipeline_config.train_input_reader.sample_from_datasets_weights.extend(
[1, 2])
pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
_write_config(pipeline_config, pipeline_config_path)
# Override parameters:
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
hparams = contrib_training.HParams(sample_from_datasets_weights=[0.5, 0.5])
configs = config_util.merge_external_params_with_configs(configs, hparams)
# Ensure that the parameters have the overridden values:
self.assertListEqual(
[0.5, 0.5],
list(configs["train_input_config"].sample_from_datasets_weights))
@unittest.skipIf(tf_version.is_tf2(), "Skipping TF1.X only test.")
def testOverwriteSampleFromDatasetWeightsWrongLength(self):
"""Tests config override for sample_from_datasets_weights."""
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
pipeline_config.train_input_reader.sample_from_datasets_weights.extend(
[1, 2])
pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
_write_config(pipeline_config, pipeline_config_path)
# Try to override parameter with too many weights:
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
hparams = contrib_training.HParams(
sample_from_datasets_weights=[0.5, 0.5, 0.5])
with self.assertRaises(
ValueError,
msg="sample_from_datasets_weights override has a different number of"
" values (3) than the configured dataset weights (2)."
):
config_util.merge_external_params_with_configs(configs, hparams)
@unittest.skipIf(tf_version.is_tf2(), "Skipping TF1.X only test.")
def testKeyValueOverrideBadKey(self):
"""Tests that overwriting with a bad key causes an exception."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment