Commit 57a2de4d authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 339504167
parent 2b2e0b59
...@@ -112,6 +112,22 @@ class TrainTest(tf.test.TestCase, parameterized.TestCase): ...@@ -112,6 +112,22 @@ class TrainTest(tf.test.TestCase, parameterized.TestCase):
run_post_eval=run_post_eval) run_post_eval=run_post_eval)
print(logs) print(logs)
def test_parse_configuration(self):
model_dir = self.get_temp_dir()
flags_dict = dict(
experiment='mock',
mode='train',
model_dir=model_dir,
params_override=json.dumps(self._test_config))
with flagsaver.flagsaver(**flags_dict):
params = train_utils.parse_configuration(flags.FLAGS, lock_return=True)
with self.assertRaises(ValueError):
params.override({'task': {'init_checkpoint': 'Foo'}})
params = train_utils.parse_configuration(flags.FLAGS, lock_return=False)
params.override({'task': {'init_checkpoint': 'Bar'}})
self.assertEqual(params.task.init_checkpoint, 'Bar')
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -64,7 +64,7 @@ class ParseConfigOptions: ...@@ -64,7 +64,7 @@ class ParseConfigOptions:
params_override: str = '' params_override: str = ''
def parse_configuration(flags_obj): def parse_configuration(flags_obj, lock_return=True, print_return=True):
"""Parses ExperimentConfig from flags.""" """Parses ExperimentConfig from flags."""
# 1. Get the default config from the registered experiment. # 1. Get the default config from the registered experiment.
...@@ -106,10 +106,13 @@ def parse_configuration(flags_obj): ...@@ -106,10 +106,13 @@ def parse_configuration(flags_obj):
params, flags_obj.params_override, is_strict=True) params, flags_obj.params_override, is_strict=True)
params.validate() params.validate()
params.lock() if lock_return:
params.lock()
pp = pprint.PrettyPrinter() if print_return:
logging.info('Final experiment parameters: %s', pp.pformat(params.as_dict())) pp = pprint.PrettyPrinter()
logging.info('Final experiment parameters: %s',
pp.pformat(params.as_dict()))
return params return params
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment