Commit 57a2de4d authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 339504167
parent 2b2e0b59
......@@ -112,6 +112,22 @@ class TrainTest(tf.test.TestCase, parameterized.TestCase):
run_post_eval=run_post_eval)
print(logs)
def test_parse_configuration(self):
model_dir = self.get_temp_dir()
flags_dict = dict(
experiment='mock',
mode='train',
model_dir=model_dir,
params_override=json.dumps(self._test_config))
with flagsaver.flagsaver(**flags_dict):
params = train_utils.parse_configuration(flags.FLAGS, lock_return=True)
with self.assertRaises(ValueError):
params.override({'task': {'init_checkpoint': 'Foo'}})
params = train_utils.parse_configuration(flags.FLAGS, lock_return=False)
params.override({'task': {'init_checkpoint': 'Bar'}})
self.assertEqual(params.task.init_checkpoint, 'Bar')
if __name__ == '__main__':
tf.test.main()
......@@ -64,7 +64,7 @@ class ParseConfigOptions:
params_override: str = ''
def parse_configuration(flags_obj):
def parse_configuration(flags_obj, lock_return=True, print_return=True):
"""Parses ExperimentConfig from flags."""
# 1. Get the default config from the registered experiment.
......@@ -106,10 +106,13 @@ def parse_configuration(flags_obj):
params, flags_obj.params_override, is_strict=True)
params.validate()
params.lock()
if lock_return:
params.lock()
pp = pprint.PrettyPrinter()
logging.info('Final experiment parameters: %s', pp.pformat(params.as_dict()))
if print_return:
pp = pprint.PrettyPrinter()
logging.info('Final experiment parameters: %s',
pp.pformat(params.as_dict()))
return params
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment