Commit 440e7851 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Move common configs to official/modeling/hyperparams.

Move common configs inside base_config to config_definitions.

PiperOrigin-RevId: 314074592
parent 87ec3d2a
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Hyperparams package definition.""" """Hyperparams package definition."""
# pylint: disable=g-multiple-import
from official.modeling.hyperparams.base_config import * from official.modeling.hyperparams.base_config import *
from official.modeling.hyperparams.config_definitions import CallbacksConfig, RuntimeConfig, TensorboardConfig
from official.modeling.hyperparams.params_dict import * from official.modeling.hyperparams.params_dict import *
...@@ -246,81 +246,3 @@ class Config(params_dict.ParamsDict): ...@@ -246,81 +246,3 @@ class Config(params_dict.ParamsDict):
default_params = {a: p for a, p in zip(attributes, args)} default_params = {a: p for a, p in zip(attributes, args)}
default_params.update(kwargs) default_params.update(kwargs)
return cls(default_params) return cls(default_params)
@dataclasses.dataclass
class RuntimeConfig(Config):
"""High-level configurations for Runtime.
These include parameters that are not directly related to the experiment,
e.g. directories, accelerator type, etc.
Attributes:
distribution_strategy: e.g. 'mirrored', 'tpu', etc.
enable_xla: Whether or not to enable XLA.
per_gpu_thread_count: thread count per GPU.
gpu_thread_mode: Whether and how the GPU device uses its own threadpool.
dataset_num_private_threads: Number of threads for a private threadpool
created for all datasets computation.
tpu: The address of the TPU to use, if any.
num_gpus: The number of GPUs to use, if any.
worker_hosts: comma-separated list of worker ip:port pairs for running
multi-worker models with DistributionStrategy.
task_index: If multi-worker training, the task index of this worker.
all_reduce_alg: Defines the algorithm for performing all-reduce.
num_packs: Sets `num_packs` in the cross device ops used in
MirroredStrategy. For details, see tf.distribute.NcclAllReduce.
loss_scale: The type of loss scale. This is used when setting the mixed
precision policy.
run_eagerly: Whether or not to run the experiment eagerly.
batchnorm_spatial_persistent: Whether or not to enable the spatial
persistent mode for CuDNN batch norm kernel for improved GPU performance.
"""
distribution_strategy: str = 'mirrored'
enable_xla: bool = False
gpu_thread_mode: Optional[str] = None
dataset_num_private_threads: Optional[int] = None
per_gpu_thread_count: int = 0
tpu: Optional[str] = None
num_gpus: int = 0
worker_hosts: Optional[str] = None
task_index: int = -1
all_reduce_alg: Optional[str] = None
num_packs: int = 1
loss_scale: Optional[str] = None
run_eagerly: bool = False
batchnorm_spatial_persistent: bool = False
@dataclasses.dataclass
class TensorboardConfig(Config):
"""Configuration for Tensorboard.
Attributes:
track_lr: Whether or not to track the learning rate in Tensorboard. Defaults
to True.
write_model_weights: Whether or not to write the model weights as
images in Tensorboard. Defaults to False.
"""
track_lr: bool = True
write_model_weights: bool = False
@dataclasses.dataclass
class CallbacksConfig(Config):
"""Configuration for Callbacks.
Attributes:
enable_checkpoint_and_export: Whether or not to enable checkpoints as a
Callback. Defaults to True.
enable_tensorboard: Whether or not to enable Tensorboard as a Callback.
Defaults to True.
enable_time_history: Whether or not to enable TimeHistory Callbacks.
Defaults to True.
"""
enable_checkpoint_and_export: bool = True
enable_tensorboard: bool = True
enable_time_history: bool = True
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Common configuration settings."""
from typing import Optional
import dataclasses
from official.modeling import optimization
from official.modeling.hyperparams import base_config
from official.utils import registry
OptimizationConfig = optimization.OptimizationConfig
@dataclasses.dataclass
class DataConfig(base_config.Config):
"""The base configuration for building datasets.
Attributes:
input_path: The path to the input. It can be either (1) a file pattern, or
(2) multiple file patterns separated by comma.
global_batch_size: The global batch size across all replicas.
is_training: Whether this data is used for training or not.
drop_remainder: Whether the last batch should be dropped in the case it has
fewer than `global_batch_size` elements.
shuffle_buffer_size: The buffer size used for shuffling training data.
cache: Whether to cache dataset examples. Can be used to avoid re-reading
from disk on the second epoch. Requires significant memory overhead.
cycle_length: The number of files that will be processed concurrently when
interleaving files.
sharding: Whether sharding is used in the input pipeline.
examples_consume: An `integer` specifying the number of examples it will
produce. If positive, it only takes this number of examples and raises
tf.error.OutOfRangeError after that. Default is -1, meaning it will
exhaust all the examples in the dataset.
"""
input_path: str = ""
global_batch_size: int = 0
is_training: bool = None
drop_remainder: bool = True
shuffle_buffer_size: int = 100
cache: bool = False
cycle_length: int = 8
sharding: bool = True
examples_consume: int = -1
@dataclasses.dataclass
class RuntimeConfig(base_config.Config):
"""High-level configurations for Runtime.
These include parameters that are not directly related to the experiment,
e.g. directories, accelerator type, etc.
Attributes:
distribution_strategy: e.g. 'mirrored', 'tpu', etc.
enable_xla: Whether or not to enable XLA.
per_gpu_thread_count: thread count per GPU.
gpu_thread_mode: Whether and how the GPU device uses its own threadpool.
dataset_num_private_threads: Number of threads for a private threadpool
created for all datasets computation.
tpu: The address of the TPU to use, if any.
num_gpus: The number of GPUs to use, if any.
worker_hosts: comma-separated list of worker ip:port pairs for running
multi-worker models with DistributionStrategy.
task_index: If multi-worker training, the task index of this worker.
all_reduce_alg: Defines the algorithm for performing all-reduce.
num_packs: Sets `num_packs` in the cross device ops used in
MirroredStrategy. For details, see tf.distribute.NcclAllReduce.
loss_scale: The type of loss scale. This is used when setting the mixed
precision policy.
run_eagerly: Whether or not to run the experiment eagerly.
batchnorm_spatial_persistent: Whether or not to enable the spatial
persistent mode for CuDNN batch norm kernel for improved GPU performance.
"""
distribution_strategy: str = "mirrored"
enable_xla: bool = False
gpu_thread_mode: Optional[str] = None
dataset_num_private_threads: Optional[int] = None
per_gpu_thread_count: int = 0
tpu: Optional[str] = None
num_gpus: int = 0
worker_hosts: Optional[str] = None
task_index: int = -1
all_reduce_alg: Optional[str] = None
num_packs: int = 1
loss_scale: Optional[str] = None
run_eagerly: bool = False
batchnorm_spatial_persistent: bool = False
@dataclasses.dataclass
class TensorboardConfig(base_config.Config):
"""Configuration for Tensorboard.
Attributes:
track_lr: Whether or not to track the learning rate in Tensorboard. Defaults
to True.
write_model_weights: Whether or not to write the model weights as images in
Tensorboard. Defaults to False.
"""
track_lr: bool = True
write_model_weights: bool = False
@dataclasses.dataclass
class CallbacksConfig(base_config.Config):
"""Configuration for Callbacks.
Attributes:
enable_checkpoint_and_export: Whether or not to enable checkpoints as a
Callback. Defaults to True.
enable_tensorboard: Whether or not to enable Tensorboard as a Callback.
Defaults to True.
enable_time_history: Whether or not to enable TimeHistory Callbacks.
Defaults to True.
"""
enable_checkpoint_and_export: bool = True
enable_tensorboard: bool = True
enable_time_history: bool = True
@dataclasses.dataclass
class TrainerConfig(base_config.Config):
optimizer_config: OptimizationConfig = OptimizationConfig()
train_tf_while_loop: bool = True
train_tf_function: bool = True
eval_tf_function: bool = True
steps_per_loop: int = 1000
summary_interval: int = 1000
checkpoint_interval: int = 1000
max_to_keep: int = 5
@dataclasses.dataclass
class TaskConfig(base_config.Config):
network: base_config.Config = None
train_data: DataConfig = DataConfig()
validation_data: DataConfig = DataConfig()
@dataclasses.dataclass
class ExperimentConfig(base_config.Config):
"""Top-level configuration."""
mode: str = "train" # train, eval, train_and_eval.
task: TaskConfig = TaskConfig()
trainer: TrainerConfig = TrainerConfig()
runtime: RuntimeConfig = RuntimeConfig()
train_steps: int = 0
validation_steps: Optional[int] = None
validation_interval: int = 100
_REGISTERED_CONFIGS = {}
def register_config_factory(name):
"""Register ExperimentConfig factory method."""
return registry.register(_REGISTERED_CONFIGS, name)
def get_exp_config_creater(exp_name: str):
"""Looks up ExperimentConfig factory methods."""
exp_creater = registry.lookup(_REGISTERED_CONFIGS, exp_name)
return exp_creater
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Registry utility."""
def register(registered_collection, reg_key):
"""Register decorated function or class to collection.
Register decorated function or class into registered_collection, in a
hierarchical order. For example, when reg_key="my_model/my_exp/my_config_0"
the decorated function or class is stored under
registered_collection["my_model"]["my_exp"]["my_config_0"].
This decorator is supposed to be used together with the lookup() function in
this file.
Args:
registered_collection: a dictionary. The decorated function or class will be
put into this collection.
reg_key: The key for retrieving the registered function or class. If reg_key
is a string, it can be hierarchical like my_model/my_exp/my_config_0
Returns:
A decorator function
Raises:
KeyError: when function or class to register already exists.
"""
def decorator(fn_or_cls):
"""Put fn_or_cls in the dictionary."""
if isinstance(reg_key, str):
hierarchy = reg_key.split("/")
collection = registered_collection
for h_idx, entry_name in enumerate(hierarchy[:-1]):
if entry_name not in collection:
collection[entry_name] = {}
collection = collection[entry_name]
if not isinstance(collection, dict):
raise KeyError(
"Collection path {} at position {} already registered as "
"a function or class.".format(entry_name, h_idx))
leaf_reg_key = hierarchy[-1]
else:
collection = registered_collection
leaf_reg_key = reg_key
if leaf_reg_key in collection:
raise KeyError("Function or class {} registered multiple times.".format(
leaf_reg_key))
collection[leaf_reg_key] = fn_or_cls
return fn_or_cls
return decorator
def lookup(registered_collection, reg_key):
"""Lookup and return decorated function or class in the collection.
Lookup decorated function or class in registered_collection, in a
hierarchical order. For example, when
reg_key="my_model/my_exp/my_config_0",
this function will return
registered_collection["my_model"]["my_exp"]["my_config_0"].
Args:
registered_collection: a dictionary. The decorated function or class will be
retrieved from this collection.
reg_key: The key for retrieving the registered function or class. If reg_key
is a string, it can be hierarchical like my_model/my_exp/my_config_0
Returns:
The registered function or class.
Raises:
LookupError: when reg_key cannot be found.
"""
if isinstance(reg_key, str):
hierarchy = reg_key.split("/")
collection = registered_collection
for h_idx, entry_name in enumerate(hierarchy):
if entry_name not in collection:
raise LookupError(
"collection path {} at position {} never registered.".format(
entry_name, h_idx))
collection = collection[entry_name]
return collection
else:
if reg_key not in registered_collection:
raise LookupError("registration key {} never registered.".format(reg_key))
return registered_collection[reg_key]
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for registry."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from official.utils import registry
class RegistryTest(tf.test.TestCase):
def test_register(self):
collection = {}
@registry.register(collection, 'functions/func_0')
def func_test():
pass
self.assertEqual(
registry.lookup(collection, 'functions/func_0'), func_test)
@registry.register(collection, 'classes/cls_0')
class ClassRegistryKey:
pass
self.assertEqual(
registry.lookup(collection, 'classes/cls_0'), ClassRegistryKey)
@registry.register(collection, ClassRegistryKey)
class ClassRegistryValue:
pass
self.assertEqual(
registry.lookup(collection, ClassRegistryKey), ClassRegistryValue)
def test_register_hierarchy(self):
collection = {}
@registry.register(collection, 'functions/func_0')
def func_test0():
pass
@registry.register(collection, 'func_1')
def func_test1():
pass
@registry.register(collection, func_test1)
def func_test2():
pass
expected_collection = {
'functions': {
'func_0': func_test0,
},
'func_1': func_test1,
func_test1: func_test2,
}
self.assertEqual(collection, expected_collection)
def test_register_error(self):
collection = {}
@registry.register(collection, 'functions/func_0')
def func_test0(): # pylint: disable=unused-variable
pass
with self.assertRaises(KeyError):
@registry.register(collection, 'functions/func_0/sub_func')
def func_test1(): # pylint: disable=unused-variable
pass
with self.assertRaises(LookupError):
registry.lookup(collection, 'non-exist')
if __name__ == '__main__':
tf.test.main()
...@@ -23,54 +23,50 @@ from typing import Any, List, Mapping, Optional ...@@ -23,54 +23,50 @@ from typing import Any, List, Mapping, Optional
import dataclasses import dataclasses
from official.modeling.hyperparams import base_config from official.modeling import hyperparams
CallbacksConfig = hyperparams.CallbacksConfig
CallbacksConfig = base_config.CallbacksConfig TensorboardConfig = hyperparams.TensorboardConfig
TensorboardConfig = base_config.TensorboardConfig RuntimeConfig = hyperparams.RuntimeConfig
RuntimeConfig = base_config.RuntimeConfig
@dataclasses.dataclass @dataclasses.dataclass
class ExportConfig(base_config.Config): class ExportConfig(hyperparams.Config):
"""Configuration for exports. """Configuration for exports.
Attributes: Attributes:
checkpoint: the path to the checkpoint to export. checkpoint: the path to the checkpoint to export.
destination: the path to where the checkpoint should be exported. destination: the path to where the checkpoint should be exported.
""" """
checkpoint: str = None checkpoint: str = None
destination: str = None destination: str = None
@dataclasses.dataclass @dataclasses.dataclass
class MetricsConfig(base_config.Config): class MetricsConfig(hyperparams.Config):
"""Configuration for Metrics. """Configuration for Metrics.
Attributes: Attributes:
accuracy: Whether or not to track accuracy as a Callback. Defaults to None. accuracy: Whether or not to track accuracy as a Callback. Defaults to None.
top_5: Whether or not to track top_5_accuracy as a Callback. Defaults to top_5: Whether or not to track top_5_accuracy as a Callback. Defaults to
None. None.
""" """
accuracy: bool = None accuracy: bool = None
top_5: bool = None top_5: bool = None
@dataclasses.dataclass @dataclasses.dataclass
class TimeHistoryConfig(base_config.Config): class TimeHistoryConfig(hyperparams.Config):
"""Configuration for the TimeHistory callback. """Configuration for the TimeHistory callback.
Attributes: Attributes:
log_steps: Interval of steps between logging of batch level stats. log_steps: Interval of steps between logging of batch level stats.
""" """
log_steps: int = None log_steps: int = None
@dataclasses.dataclass @dataclasses.dataclass
class TrainConfig(base_config.Config): class TrainConfig(hyperparams.Config):
"""Configuration for training. """Configuration for training.
Attributes: Attributes:
...@@ -86,7 +82,6 @@ class TrainConfig(base_config.Config): ...@@ -86,7 +82,6 @@ class TrainConfig(base_config.Config):
equal the number of training steps in `model.compile`. This reduces the equal the number of training steps in `model.compile`. This reduces the
number of callbacks run per epoch which significantly improves end-to-end number of callbacks run per epoch which significantly improves end-to-end
TPU training time. TPU training time.
""" """
resume_checkpoint: bool = None resume_checkpoint: bool = None
epochs: int = None epochs: int = None
...@@ -99,7 +94,7 @@ class TrainConfig(base_config.Config): ...@@ -99,7 +94,7 @@ class TrainConfig(base_config.Config):
@dataclasses.dataclass @dataclasses.dataclass
class EvalConfig(base_config.Config): class EvalConfig(hyperparams.Config):
"""Configuration for evaluation. """Configuration for evaluation.
Attributes: Attributes:
...@@ -109,7 +104,6 @@ class EvalConfig(base_config.Config): ...@@ -109,7 +104,6 @@ class EvalConfig(base_config.Config):
be inferred based on the number of images and batch size. Defaults to be inferred based on the number of images and batch size. Defaults to
None. None.
skip_eval: Whether or not to skip evaluation. skip_eval: Whether or not to skip evaluation.
""" """
epochs_between_evals: int = None epochs_between_evals: int = None
steps: int = None steps: int = None
...@@ -117,21 +111,20 @@ class EvalConfig(base_config.Config): ...@@ -117,21 +111,20 @@ class EvalConfig(base_config.Config):
@dataclasses.dataclass @dataclasses.dataclass
class LossConfig(base_config.Config): class LossConfig(hyperparams.Config):
"""Configuration for Loss. """Configuration for Loss.
Attributes: Attributes:
name: The name of the loss. Defaults to None. name: The name of the loss. Defaults to None.
label_smoothing: Whether or not to apply label smoothing to the loss. This label_smoothing: Whether or not to apply label smoothing to the loss. This
only applies to 'categorical_cross_entropy'. only applies to 'categorical_cross_entropy'.
""" """
name: str = None name: str = None
label_smoothing: float = None label_smoothing: float = None
@dataclasses.dataclass @dataclasses.dataclass
class OptimizerConfig(base_config.Config): class OptimizerConfig(hyperparams.Config):
"""Configuration for Optimizers. """Configuration for Optimizers.
Attributes: Attributes:
...@@ -144,12 +137,11 @@ class OptimizerConfig(base_config.Config): ...@@ -144,12 +137,11 @@ class OptimizerConfig(base_config.Config):
exponential moving average is not used. Defaults to None. exponential moving average is not used. Defaults to None.
lookahead: Whether or not to apply the lookahead optimizer. Defaults to lookahead: Whether or not to apply the lookahead optimizer. Defaults to
None. None.
beta_1: The exponential decay rate for the 1st moment estimates. Used in beta_1: The exponential decay rate for the 1st moment estimates. Used in the
the Adam optimizers. Defaults to None. Adam optimizers. Defaults to None.
beta_2: The exponential decay rate for the 2nd moment estimates. Used in beta_2: The exponential decay rate for the 2nd moment estimates. Used in the
the Adam optimizers. Defaults to None. Adam optimizers. Defaults to None.
epsilon: Small value used to avoid 0 denominator. Defaults to 1e-7. epsilon: Small value used to avoid 0 denominator. Defaults to 1e-7.
""" """
name: str = None name: str = None
decay: float = None decay: float = None
...@@ -164,7 +156,7 @@ class OptimizerConfig(base_config.Config): ...@@ -164,7 +156,7 @@ class OptimizerConfig(base_config.Config):
@dataclasses.dataclass @dataclasses.dataclass
class LearningRateConfig(base_config.Config): class LearningRateConfig(hyperparams.Config):
"""Configuration for learning rates. """Configuration for learning rates.
Attributes: Attributes:
...@@ -173,16 +165,15 @@ class LearningRateConfig(base_config.Config): ...@@ -173,16 +165,15 @@ class LearningRateConfig(base_config.Config):
decay_epochs: The number of decay epochs. Defaults to None. decay_epochs: The number of decay epochs. Defaults to None.
decay_rate: The rate of decay. Defaults to None. decay_rate: The rate of decay. Defaults to None.
warmup_epochs: The number of warmup epochs. Defaults to None. warmup_epochs: The number of warmup epochs. Defaults to None.
batch_lr_multiplier: The multiplier to apply to the base learning rate, batch_lr_multiplier: The multiplier to apply to the base learning rate, if
if necessary. Defaults to None. necessary. Defaults to None.
examples_per_epoch: the number of examples in a single epoch. examples_per_epoch: the number of examples in a single epoch. Defaults to
Defaults to None. None.
boundaries: boundaries used in piecewise constant decay with warmup. boundaries: boundaries used in piecewise constant decay with warmup.
multipliers: multipliers used in piecewise constant decay with warmup. multipliers: multipliers used in piecewise constant decay with warmup.
scale_by_batch_size: Scale the learning rate by a fraction of the batch scale_by_batch_size: Scale the learning rate by a fraction of the batch
size. Set to 0 for no scaling (default). size. Set to 0 for no scaling (default).
staircase: Apply exponential decay at discrete values instead of continuous. staircase: Apply exponential decay at discrete values instead of continuous.
""" """
name: str = None name: str = None
initial_lr: float = None initial_lr: float = None
...@@ -197,7 +188,7 @@ class LearningRateConfig(base_config.Config): ...@@ -197,7 +188,7 @@ class LearningRateConfig(base_config.Config):
@dataclasses.dataclass @dataclasses.dataclass
class ModelConfig(base_config.Config): class ModelConfig(hyperparams.Config):
"""Configuration for Models. """Configuration for Models.
Attributes: Attributes:
...@@ -206,17 +197,16 @@ class ModelConfig(base_config.Config): ...@@ -206,17 +197,16 @@ class ModelConfig(base_config.Config):
num_classes: The number of classes in the model. Defaults to None. num_classes: The number of classes in the model. Defaults to None.
loss: A `LossConfig` instance. Defaults to None. loss: A `LossConfig` instance. Defaults to None.
optimizer: An `OptimizerConfig` instance. Defaults to None. optimizer: An `OptimizerConfig` instance. Defaults to None.
""" """
name: str = None name: str = None
model_params: base_config.Config = None model_params: hyperparams.Config = None
num_classes: int = None num_classes: int = None
loss: LossConfig = None loss: LossConfig = None
optimizer: OptimizerConfig = None optimizer: OptimizerConfig = None
@dataclasses.dataclass @dataclasses.dataclass
class ExperimentConfig(base_config.Config): class ExperimentConfig(hyperparams.Config):
"""Base configuration for an image classification experiment. """Base configuration for an image classification experiment.
Attributes: Attributes:
...@@ -227,7 +217,6 @@ class ExperimentConfig(base_config.Config): ...@@ -227,7 +217,6 @@ class ExperimentConfig(base_config.Config):
evaluation: An `EvalConfig` instance. evaluation: An `EvalConfig` instance.
model: A `ModelConfig` instance. model: A `ModelConfig` instance.
export: An `ExportConfig` instance. export: An `ExportConfig` instance.
""" """
model_dir: str = None model_dir: str = None
model_name: str = None model_name: str = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment