Commit dae3ba89 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 413972002
parent 7acb972a
...@@ -14,13 +14,13 @@ ...@@ -14,13 +14,13 @@
"""Training utils.""" """Training utils."""
import copy import copy
import dataclasses
import json import json
import os import os
import pprint import pprint
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
from absl import logging from absl import logging
import dataclasses
import gin import gin
import orbit import orbit
import tensorflow as tf import tensorflow as tf
...@@ -244,49 +244,87 @@ class ParseConfigOptions: ...@@ -244,49 +244,87 @@ class ParseConfigOptions:
return name in dataclasses.asdict(self) return name in dataclasses.asdict(self)
def parse_configuration(flags_obj, lock_return=True, print_return=True): class ExperimentParser:
"""Parses ExperimentConfig from flags.""" """Constructs the Experiment config from Flags or equivalent object.
Most of the cases, users only need to call the `parse()` function:
```
builder = ExperimentParser(FLAGS)
params = builder.parse()
```
if flags_obj.experiment is None: The advanced users can modify the flow by calling the parse_*() functions
raise ValueError('The flag --experiment must be specified.') separately.
"""
# 1. Get the default config from the registered experiment.
params = exp_factory.get_exp_config(flags_obj.experiment) def __init__(self, flags_obj):
self._flags_obj = flags_obj
# 2. Get the first level of override from `--config_file`.
# `--config_file` is typically used as a template that specifies the common def parse(self):
# override for a particular experiment. """Overrall process of constructing Experiment config."""
for config_file in flags_obj.config_file or []: params = self.base_experiment()
params = hyperparams.override_params_dict( params = self.parse_config_file(params)
params, config_file, is_strict=True) params = self.parse_runtime(params)
params = self.parse_data_service(params)
# 3. Override the TPU address and tf.data service address. params = self.parse_params_override(params)
params.override({ return params
'runtime': {
'tpu': flags_obj.tpu, def base_experiment(self):
}, """Get the base experiment config from --experiment field."""
}) if self._flags_obj.experiment is None:
if ('tf_data_service' in flags_obj and flags_obj.tf_data_service and raise ValueError('The flag --experiment must be specified.')
isinstance(params.task, config_definitions.TaskConfig)): return exp_factory.get_exp_config(self._flags_obj.experiment)
def parse_config_file(self, params):
"""Override the configs of params from the config_file."""
for config_file in self._flags_obj.config_file or []:
params = hyperparams.override_params_dict(
params, config_file, is_strict=True)
return params
def parse_runtime(self, params):
"""Override the runtime configs of params from flags."""
# Override the TPU address and tf.data service address.
params.override({ params.override({
'task': { 'runtime': {
'train_data': { 'tpu': self._flags_obj.tpu,
'tf_data_service_address': flags_obj.tf_data_service, },
},
'validation_data': {
'tf_data_service_address': flags_obj.tf_data_service,
}
}
}) })
return params
def parse_data_service(self, params):
"""Override the data service configs of params from flags."""
if ('tf_data_service' in self._flags_obj and
self._flags_obj.tf_data_service and
isinstance(params.task, config_definitions.TaskConfig)):
params.override({
'task': {
'train_data': {
'tf_data_service_address': self._flags_obj.tf_data_service,
},
'validation_data': {
'tf_data_service_address': self._flags_obj.tf_data_service,
}
}
})
return params
def parse_params_override(self, params):
# Get the second level of override from `--params_override`.
# `--params_override` is typically used as a further override over the
# template. For example, one may define a particular template for training
# ResNet50 on ImageNet in a config file and pass it via `--config_file`,
# then define different learning rates and pass it via `--params_override`.
if self._flags_obj.params_override:
params = hyperparams.override_params_dict(
params, self._flags_obj.params_override, is_strict=True)
return params
def parse_configuration(flags_obj, lock_return=True, print_return=True):
"""Parses ExperimentConfig from flags."""
# 4. Get the second level of override from `--params_override`. params = ExperimentParser(flags_obj).parse()
# `--params_override` is typically used as a further override over the
# template. For example, one may define a particular template for training
# ResNet50 on ImageNet in a config file and pass it via `--config_file`,
# then define different learning rates and pass it via `--params_override`.
if flags_obj.params_override:
params = hyperparams.override_params_dict(
params, flags_obj.params_override, is_strict=True)
params.validate() params.validate()
if lock_return: if lock_return:
......
...@@ -13,14 +13,37 @@ ...@@ -13,14 +13,37 @@
# limitations under the License. # limitations under the License.
"""Tests for official.core.train_utils.""" """Tests for official.core.train_utils."""
import os import os
import pprint
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from official.core import exp_factory
from official.core import test_utils from official.core import test_utils
from official.core import train_utils from official.core import train_utils
from official.modeling import hyperparams
@exp_factory.register_config_factory('foo')
def foo():
"""Multitask experiment for test."""
experiment_config = hyperparams.Config(
default_params={
'runtime': {
'tpu': 'fake',
},
'task': {
'model': {
'model_id': 'bar',
},
},
'trainer': {
'train_steps': -1,
'validation_steps': -1,
},
})
return experiment_config
class TrainUtilsTest(tf.test.TestCase): class TrainUtilsTest(tf.test.TestCase):
...@@ -93,6 +116,27 @@ class TrainUtilsTest(tf.test.TestCase): ...@@ -93,6 +116,27 @@ class TrainUtilsTest(tf.test.TestCase):
] ]
self.assertEqual(actual, expected) self.assertEqual(actual, expected)
def test_construct_experiment_from_flags(self):
options = train_utils.ParseConfigOptions(
experiment='foo',
config_file=[],
tpu='bar',
tf_data_service='',
params_override='task.model.model_id=new,'
'trainer.train_steps=10,'
'trainer.validation_steps=11')
builder = train_utils.ExperimentParser(options)
params_from_obj = builder.parse()
params_from_func = train_utils.parse_configuration(options)
pp = pprint.PrettyPrinter()
self.assertEqual(
pp.pformat(params_from_obj.as_dict()),
pp.pformat(params_from_func.as_dict()))
self.assertEqual(params_from_obj.runtime.tpu, 'bar')
self.assertEqual(params_from_obj.task.model.model_id, 'new')
self.assertEqual(params_from_obj.trainer.train_steps, 10)
self.assertEqual(params_from_obj.trainer.validation_steps, 11)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment