Commit dae3ba89 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 413972002
parent 7acb972a
...@@ -14,13 +14,13 @@ ...@@ -14,13 +14,13 @@
"""Training utils.""" """Training utils."""
import copy import copy
import dataclasses
import json import json
import os import os
import pprint import pprint
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
from absl import logging from absl import logging
import dataclasses
import gin import gin
import orbit import orbit
import tensorflow as tf import tensorflow as tf
...@@ -244,49 +244,87 @@ class ParseConfigOptions: ...@@ -244,49 +244,87 @@ class ParseConfigOptions:
return name in dataclasses.asdict(self) return name in dataclasses.asdict(self)
def parse_configuration(flags_obj, lock_return=True, print_return=True): class ExperimentParser:
"""Parses ExperimentConfig from flags.""" """Constructs the Experiment config from Flags or equivalent object.
if flags_obj.experiment is None: Most of the cases, users only need to call the `parse()` function:
raise ValueError('The flag --experiment must be specified.') ```
builder = ExperimentParser(FLAGS)
params = builder.parse()
```
# 1. Get the default config from the registered experiment. The advanced users can modify the flow by calling the parse_*() functions
params = exp_factory.get_exp_config(flags_obj.experiment) separately.
"""
def __init__(self, flags_obj):
self._flags_obj = flags_obj
def parse(self):
"""Overrall process of constructing Experiment config."""
params = self.base_experiment()
params = self.parse_config_file(params)
params = self.parse_runtime(params)
params = self.parse_data_service(params)
params = self.parse_params_override(params)
return params
def base_experiment(self):
"""Get the base experiment config from --experiment field."""
if self._flags_obj.experiment is None:
raise ValueError('The flag --experiment must be specified.')
return exp_factory.get_exp_config(self._flags_obj.experiment)
# 2. Get the first level of override from `--config_file`. def parse_config_file(self, params):
# `--config_file` is typically used as a template that specifies the common """Override the configs of params from the config_file."""
# override for a particular experiment. for config_file in self._flags_obj.config_file or []:
for config_file in flags_obj.config_file or []:
params = hyperparams.override_params_dict( params = hyperparams.override_params_dict(
params, config_file, is_strict=True) params, config_file, is_strict=True)
return params
# 3. Override the TPU address and tf.data service address. def parse_runtime(self, params):
"""Override the runtime configs of params from flags."""
# Override the TPU address and tf.data service address.
params.override({ params.override({
'runtime': { 'runtime': {
'tpu': flags_obj.tpu, 'tpu': self._flags_obj.tpu,
}, },
}) })
if ('tf_data_service' in flags_obj and flags_obj.tf_data_service and return params
def parse_data_service(self, params):
"""Override the data service configs of params from flags."""
if ('tf_data_service' in self._flags_obj and
self._flags_obj.tf_data_service and
isinstance(params.task, config_definitions.TaskConfig)): isinstance(params.task, config_definitions.TaskConfig)):
params.override({ params.override({
'task': { 'task': {
'train_data': { 'train_data': {
'tf_data_service_address': flags_obj.tf_data_service, 'tf_data_service_address': self._flags_obj.tf_data_service,
}, },
'validation_data': { 'validation_data': {
'tf_data_service_address': flags_obj.tf_data_service, 'tf_data_service_address': self._flags_obj.tf_data_service,
} }
} }
}) })
return params
# 4. Get the second level of override from `--params_override`. def parse_params_override(self, params):
# Get the second level of override from `--params_override`.
# `--params_override` is typically used as a further override over the # `--params_override` is typically used as a further override over the
# template. For example, one may define a particular template for training # template. For example, one may define a particular template for training
# ResNet50 on ImageNet in a config file and pass it via `--config_file`, # ResNet50 on ImageNet in a config file and pass it via `--config_file`,
# then define different learning rates and pass it via `--params_override`. # then define different learning rates and pass it via `--params_override`.
if flags_obj.params_override: if self._flags_obj.params_override:
params = hyperparams.override_params_dict( params = hyperparams.override_params_dict(
params, flags_obj.params_override, is_strict=True) params, self._flags_obj.params_override, is_strict=True)
return params
def parse_configuration(flags_obj, lock_return=True, print_return=True):
"""Parses ExperimentConfig from flags."""
params = ExperimentParser(flags_obj).parse()
params.validate() params.validate()
if lock_return: if lock_return:
......
...@@ -13,14 +13,37 @@ ...@@ -13,14 +13,37 @@
# limitations under the License. # limitations under the License.
"""Tests for official.core.train_utils.""" """Tests for official.core.train_utils."""
import os import os
import pprint
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from official.core import exp_factory
from official.core import test_utils from official.core import test_utils
from official.core import train_utils from official.core import train_utils
from official.modeling import hyperparams
@exp_factory.register_config_factory('foo')
def foo():
"""Multitask experiment for test."""
experiment_config = hyperparams.Config(
default_params={
'runtime': {
'tpu': 'fake',
},
'task': {
'model': {
'model_id': 'bar',
},
},
'trainer': {
'train_steps': -1,
'validation_steps': -1,
},
})
return experiment_config
class TrainUtilsTest(tf.test.TestCase): class TrainUtilsTest(tf.test.TestCase):
...@@ -93,6 +116,27 @@ class TrainUtilsTest(tf.test.TestCase): ...@@ -93,6 +116,27 @@ class TrainUtilsTest(tf.test.TestCase):
] ]
self.assertEqual(actual, expected) self.assertEqual(actual, expected)
def test_construct_experiment_from_flags(self):
options = train_utils.ParseConfigOptions(
experiment='foo',
config_file=[],
tpu='bar',
tf_data_service='',
params_override='task.model.model_id=new,'
'trainer.train_steps=10,'
'trainer.validation_steps=11')
builder = train_utils.ExperimentParser(options)
params_from_obj = builder.parse()
params_from_func = train_utils.parse_configuration(options)
pp = pprint.PrettyPrinter()
self.assertEqual(
pp.pformat(params_from_obj.as_dict()),
pp.pformat(params_from_func.as_dict()))
self.assertEqual(params_from_obj.runtime.tpu, 'bar')
self.assertEqual(params_from_obj.task.model.model_id, 'new')
self.assertEqual(params_from_obj.trainer.train_steps, 10)
self.assertEqual(params_from_obj.trainer.validation_steps, 11)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment