Unverified Commit f16a7b5b authored by vedanshu's avatar vedanshu Committed by GitHub
Browse files

Merge pull request #1 from tensorflow/master

new pull
parents 8e9296ff 8f58f396
# Lint as: python3 # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,7 +11,7 @@ ...@@ -12,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Hyperparams package definition.""" """Hyperparams package definition."""
# pylint: disable=g-multiple-import # pylint: disable=g-multiple-import
from official.modeling.hyperparams.base_config import * from official.modeling.hyperparams.base_config import *
......
# Lint as: python3 # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,17 +11,13 @@ ...@@ -12,17 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Base configurations to standardize experiments."""
from __future__ import absolute_import """Base configurations to standardize experiments."""
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import copy import copy
import functools import functools
from typing import Any, List, Mapping, Optional, Type from typing import Any, List, Mapping, Optional, Type
from absl import logging
import dataclasses import dataclasses
import tensorflow as tf import tensorflow as tf
...@@ -35,11 +30,15 @@ from official.modeling.hyperparams import params_dict ...@@ -35,11 +30,15 @@ from official.modeling.hyperparams import params_dict
class Config(params_dict.ParamsDict): class Config(params_dict.ParamsDict):
"""The base configuration class that supports YAML/JSON based overrides. """The base configuration class that supports YAML/JSON based overrides.
* It recursively enforces a whitelist of basic types and container types, so Because of YAML/JSON serialization limitations, some semantics of dataclass
are not supported:
* It recursively enforces a allowlist of basic types and container types, so
it avoids surprises with copy and reuse caused by unanticipated types. it avoids surprises with copy and reuse caused by unanticipated types.
* It converts dict to Config even within sequences, * Warning: it converts Dict to `Config` even within sequences,
e.g. for config = Config({'key': [([{'a': 42}],)]), e.g. for config = Config({'key': [([{'a': 42}],)]),
type(config.key[0][0][0]) is Config rather than dict. type(config.key[0][0][0]) is Config rather than dict.
If you define/annotate some field as Dict, the field will convert to a
`Config` instance and lose the dictionary type.
""" """
# It's safe to add bytes and other immutable types here. # It's safe to add bytes and other immutable types here.
...@@ -142,10 +141,11 @@ class Config(params_dict.ParamsDict): ...@@ -142,10 +141,11 @@ class Config(params_dict.ParamsDict):
return subconfig_type return subconfig_type
def __post_init__(self, default_params, restrictions, *args, **kwargs): def __post_init__(self, default_params, restrictions, *args, **kwargs):
super().__init__(default_params=default_params, super().__init__(
restrictions=restrictions, default_params=default_params,
*args, restrictions=restrictions,
**kwargs) *args,
**kwargs)
def _set(self, k, v): def _set(self, k, v):
"""Overrides same method in ParamsDict. """Overrides same method in ParamsDict.
...@@ -160,13 +160,32 @@ class Config(params_dict.ParamsDict): ...@@ -160,13 +160,32 @@ class Config(params_dict.ParamsDict):
RuntimeError RuntimeError
""" """
subconfig_type = self._get_subconfig_type(k) subconfig_type = self._get_subconfig_type(k)
if isinstance(v, dict):
def is_null(k):
if k not in self.__dict__ or not self.__dict__[k]: if k not in self.__dict__ or not self.__dict__[k]:
return True
return False
if isinstance(v, dict):
if is_null(k):
# If the key not exist or the value is None, a new Config-family object # If the key not exist or the value is None, a new Config-family object
# sould be created for the key. # sould be created for the key.
self.__dict__[k] = subconfig_type(v) self.__dict__[k] = subconfig_type(v)
else: else:
self.__dict__[k].override(v) self.__dict__[k].override(v)
elif not is_null(k) and isinstance(v, self.SEQUENCE_TYPES) and all(
[not isinstance(e, self.IMMUTABLE_TYPES) for e in v]):
if len(self.__dict__[k]) == len(v):
for i in range(len(v)):
self.__dict__[k][i].override(v[i])
elif not all([isinstance(e, self.IMMUTABLE_TYPES) for e in v]):
logging.warning(
"The list/tuple don't match the value dictionaries provided. Thus, "
'the list/tuple is determined by the type annotation and '
'values provided. This is error-prone.')
self.__dict__[k] = self._import_config(v, subconfig_type)
else:
self.__dict__[k] = self._import_config(v, subconfig_type)
else: else:
self.__dict__[k] = self._import_config(v, subconfig_type) self.__dict__[k] = self._import_config(v, subconfig_type)
...@@ -220,16 +239,19 @@ class Config(params_dict.ParamsDict): ...@@ -220,16 +239,19 @@ class Config(params_dict.ParamsDict):
} }
def replace(self, **kwargs): def replace(self, **kwargs):
"""Like `override`, but returns a copy with the current config unchanged.""" """Overrides/returns a unlocked copy with the current config unchanged."""
params = self.__class__(self) # pylint: disable=protected-access
params.override(kwargs, is_strict=True) params = copy.deepcopy(self)
params._locked = False
params._override(kwargs, is_strict=True)
# pylint: enable=protected-access
return params return params
@classmethod @classmethod
def from_yaml(cls, file_path: str): def from_yaml(cls, file_path: str):
# Note: This only works if the Config has all default values. # Note: This only works if the Config has all default values.
with tf.io.gfile.GFile(file_path, 'r') as f: with tf.io.gfile.GFile(file_path, 'r') as f:
loaded = yaml.load(f) loaded = yaml.load(f, Loader=yaml.FullLoader)
config = cls() config = cls()
config.override(loaded) config.override(loaded)
return config return config
......
# Lint as: python3 # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,7 +11,6 @@ ...@@ -12,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
import pprint import pprint
from typing import List, Tuple from typing import List, Tuple
...@@ -45,6 +43,17 @@ class DumpConfig3(DumpConfig2): ...@@ -45,6 +43,17 @@ class DumpConfig3(DumpConfig2):
g: Tuple[DumpConfig1, ...] = (DumpConfig1(),) g: Tuple[DumpConfig1, ...] = (DumpConfig1(),)
@dataclasses.dataclass
class DumpConfig4(DumpConfig2):
x: int = 3
@dataclasses.dataclass
class DummyConfig5(base_config.Config):
y: Tuple[DumpConfig2, ...] = (DumpConfig2(), DumpConfig4())
z: Tuple[str] = ('a',)
class BaseConfigTest(parameterized.TestCase, tf.test.TestCase): class BaseConfigTest(parameterized.TestCase, tf.test.TestCase):
def assertHasSameTypes(self, c, d, msg=''): def assertHasSameTypes(self, c, d, msg=''):
...@@ -106,6 +115,22 @@ class BaseConfigTest(parameterized.TestCase, tf.test.TestCase): ...@@ -106,6 +115,22 @@ class BaseConfigTest(parameterized.TestCase, tf.test.TestCase):
self.assertEqual(config.g[0].a, 4) self.assertEqual(config.g[0].a, 4)
self.assertEqual(config.g[0].b, 'new text 3') self.assertEqual(config.g[0].b, 'new text 3')
def test_replace(self):
config = DumpConfig2()
new_config = config.replace(e={'a': 2})
self.assertEqual(new_config.e.a, 2)
self.assertIsInstance(new_config.e, DumpConfig1)
config = DumpConfig2(e=DumpConfig2())
new_config = config.replace(e={'c': 4})
self.assertEqual(new_config.e.c, 4)
self.assertIsInstance(new_config.e, DumpConfig2)
config = DumpConfig3()
new_config = config.replace(g=[{'a': 4, 'b': 'new text 3'}])
self.assertIsInstance(new_config.g[0], DumpConfig1)
self.assertEqual(new_config.g[0].a, 4)
@parameterized.parameters( @parameterized.parameters(
('_locked', "The key '_locked' is internally reserved."), ('_locked', "The key '_locked' is internally reserved."),
('_restrictions', "The key '_restrictions' is internally reserved."), ('_restrictions', "The key '_restrictions' is internally reserved."),
...@@ -147,10 +172,8 @@ class BaseConfigTest(parameterized.TestCase, tf.test.TestCase): ...@@ -147,10 +172,8 @@ class BaseConfigTest(parameterized.TestCase, tf.test.TestCase):
params.override({'c': {'c3': 30}}, is_strict=True) params.override({'c': {'c3': 30}}, is_strict=True)
config = base_config.Config({'key': [{'a': 42}]}) config = base_config.Config({'key': [{'a': 42}]})
config.override({'key': [{'b': 43}]}) with self.assertRaisesRegex(KeyError, "The key 'b' does not exist"):
self.assertEqual(config.key[0].b, 43) config.override({'key': [{'b': 43}]})
with self.assertRaisesRegex(AttributeError, 'The key `a` does not exist'):
_ = config.key[0].a
@parameterized.parameters( @parameterized.parameters(
(lambda x: x, 'Unknown type'), (lambda x: x, 'Unknown type'),
...@@ -294,6 +317,44 @@ class BaseConfigTest(parameterized.TestCase, tf.test.TestCase): ...@@ -294,6 +317,44 @@ class BaseConfigTest(parameterized.TestCase, tf.test.TestCase):
]), ]),
"['s', 1, 1.0, True, None, {}, [], (), {8: 9, (2,): (3, [4], {6: 7})}]") "['s', 1, 1.0, True, None, {}, [], (), {8: 9, (2,): (3, [4], {6: 7})}]")
def test_with_restrictions(self):
restrictions = ['e.a<c']
config = DumpConfig2(restrictions=restrictions)
config.validate()
def test_nested_tuple(self):
config = DummyConfig5()
config.override({
'y': [{
'c': 4,
'd': 'new text 3',
'e': {
'a': 2
}
}, {
'c': 0,
'd': 'new text 3',
'e': {
'a': 2
}
}],
'z': ['a', 'b', 'c'],
})
self.assertEqual(config.y[0].c, 4)
self.assertEqual(config.y[1].c, 0)
self.assertIsInstance(config.y[0], DumpConfig2)
self.assertIsInstance(config.y[1], DumpConfig4)
self.assertSameElements(config.z, ['a', 'b', 'c'])
def test_override_by_empty_sequence(self):
config = DummyConfig5()
config.override({
'y': [],
'z': (),
}, is_strict=True)
self.assertEmpty(config.y)
self.assertEmpty(config.z)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
# Lint as: python3 # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,124 +11,18 @@ ...@@ -12,124 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Common configuration settings."""
from typing import Optional, Union
"""Common configuration settings."""
# pylint:disable=wildcard-import
import dataclasses import dataclasses
from official.core.config_definitions import *
from official.modeling.hyperparams import base_config from official.modeling.hyperparams import base_config
from official.modeling.optimization.configs import optimization_config
from official.utils import registry
OptimizationConfig = optimization_config.OptimizationConfig
@dataclasses.dataclass
class DataConfig(base_config.Config):
"""The base configuration for building datasets.
Attributes:
input_path: The path to the input. It can be either (1) a file pattern, or
(2) multiple file patterns separated by comma. It should not be specified
when the following `tfds_name` is specified.
tfds_name: The name of the tensorflow dataset (TFDS). It should not be
specified when the above `input_path` is specified.
tfds_split: A str indicating which split of the data to load from TFDS. It
is required when above `tfds_name` is specified.
global_batch_size: The global batch size across all replicas.
is_training: Whether this data is used for training or not.
drop_remainder: Whether the last batch should be dropped in the case it has
fewer than `global_batch_size` elements.
shuffle_buffer_size: The buffer size used for shuffling training data.
cache: Whether to cache dataset examples. Can be used to avoid re-reading
from disk on the second epoch. Requires significant memory overhead.
cycle_length: The number of files that will be processed concurrently when
interleaving files.
block_length: The number of consecutive elements to produce from each input
element before cycling to another input element when interleaving files.
sharding: Whether sharding is used in the input pipeline.
examples_consume: An `integer` specifying the number of examples it will
produce. If positive, it only takes this number of examples and raises
tf.error.OutOfRangeError after that. Default is -1, meaning it will
exhaust all the examples in the dataset.
tfds_data_dir: A str specifying the directory to read/write TFDS data.
tfds_download: A bool to indicate whether to download data using TFDS.
tfds_as_supervised: A bool. When loading dataset from TFDS, if True,
the returned tf.data.Dataset will have a 2-tuple structure (input, label)
according to builder.info.supervised_keys; if False, the default,
the returned tf.data.Dataset will have a dictionary with all the features.
tfds_skip_decoding_feature: A str to indicate which features are skipped
for decoding when loading dataset from TFDS. Use comma to separate
multiple features. The main use case is to skip the image/video decoding
for better performance.
"""
input_path: str = ""
tfds_name: str = ""
tfds_split: str = ""
global_batch_size: int = 0
is_training: bool = None
drop_remainder: bool = True
shuffle_buffer_size: int = 100
cache: bool = False
cycle_length: int = 8
block_length: int = 1
sharding: bool = True
examples_consume: int = -1
tfds_data_dir: str = ""
tfds_download: bool = False
tfds_as_supervised: bool = False
tfds_skip_decoding_feature: str = ""
@dataclasses.dataclass
class RuntimeConfig(base_config.Config):
"""High-level configurations for Runtime.
These include parameters that are not directly related to the experiment,
e.g. directories, accelerator type, etc.
Attributes:
distribution_strategy: e.g. 'mirrored', 'tpu', etc.
enable_xla: Whether or not to enable XLA.
per_gpu_thread_count: thread count per GPU.
gpu_thread_mode: Whether and how the GPU device uses its own threadpool.
dataset_num_private_threads: Number of threads for a private threadpool
created for all datasets computation.
tpu: The address of the TPU to use, if any.
num_gpus: The number of GPUs to use, if any.
worker_hosts: comma-separated list of worker ip:port pairs for running
multi-worker models with DistributionStrategy.
task_index: If multi-worker training, the task index of this worker.
all_reduce_alg: Defines the algorithm for performing all-reduce.
num_packs: Sets `num_packs` in the cross device ops used in
MirroredStrategy. For details, see tf.distribute.NcclAllReduce.
mixed_precision_dtype: dtype of mixed precision policy. It can be 'float32',
'float16', or 'bfloat16'.
loss_scale: The type of loss scale, or 'float' value. This is used when
setting the mixed precision policy.
run_eagerly: Whether or not to run the experiment eagerly.
batchnorm_spatial_persistent: Whether or not to enable the spatial
persistent mode for CuDNN batch norm kernel for improved GPU performance.
"""
distribution_strategy: str = "mirrored"
enable_xla: bool = False
gpu_thread_mode: Optional[str] = None
dataset_num_private_threads: Optional[int] = None
per_gpu_thread_count: int = 0
tpu: Optional[str] = None
num_gpus: int = 0
worker_hosts: Optional[str] = None
task_index: int = -1
all_reduce_alg: Optional[str] = None
num_packs: int = 1
mixed_precision_dtype: Optional[str] = None
loss_scale: Optional[Union[str, float]] = None
run_eagerly: bool = False
batchnorm_spatial_persistent: bool = False
# TODO(hongkuny): These configs are used in models that are going to deprecate.
# Once those models are removed, we should delete this file to avoid confusion.
# Users should not use this file anymore.
@dataclasses.dataclass @dataclasses.dataclass
class TensorboardConfig(base_config.Config): class TensorboardConfig(base_config.Config):
"""Configuration for Tensorboard. """Configuration for Tensorboard.
...@@ -151,75 +44,14 @@ class CallbacksConfig(base_config.Config): ...@@ -151,75 +44,14 @@ class CallbacksConfig(base_config.Config):
Attributes: Attributes:
enable_checkpoint_and_export: Whether or not to enable checkpoints as a enable_checkpoint_and_export: Whether or not to enable checkpoints as a
Callback. Defaults to True. Callback. Defaults to True.
enable_backup_and_restore: Whether or not to add BackupAndRestore
callback. Defaults to True.
enable_tensorboard: Whether or not to enable Tensorboard as a Callback. enable_tensorboard: Whether or not to enable Tensorboard as a Callback.
Defaults to True. Defaults to True.
enable_time_history: Whether or not to enable TimeHistory Callbacks. enable_time_history: Whether or not to enable TimeHistory Callbacks.
Defaults to True. Defaults to True.
""" """
enable_checkpoint_and_export: bool = True enable_checkpoint_and_export: bool = True
enable_backup_and_restore: bool = False
enable_tensorboard: bool = True enable_tensorboard: bool = True
enable_time_history: bool = True enable_time_history: bool = True
@dataclasses.dataclass
class TrainerConfig(base_config.Config):
"""Configuration for trainer.
Attributes:
optimizer_config: optimizer config, it includes optimizer, learning rate,
and warmup schedule configs.
train_tf_while_loop: whether or not to use tf while loop.
train_tf_function: whether or not to use tf_function for training loop.
eval_tf_function: whether or not to use tf_function for eval.
steps_per_loop: number of steps per loop.
summary_interval: number of steps between each summary.
checkpoint_interval: number of steps between checkpoints.
max_to_keep: max checkpoints to keep.
continuous_eval_timeout: maximum number of seconds to wait between
checkpoints, if set to None, continuous eval will wait indefinitely.
train_steps: number of train steps.
validation_steps: number of eval steps. If `None`, the entire eval dataset
is used.
validation_interval: number of training steps to run between evaluations.
"""
optimizer_config: OptimizationConfig = OptimizationConfig()
train_tf_while_loop: bool = True
train_tf_function: bool = True
eval_tf_function: bool = True
steps_per_loop: int = 1000
summary_interval: int = 1000
checkpoint_interval: int = 1000
max_to_keep: int = 5
continuous_eval_timeout: Optional[int] = None
train_steps: int = 0
validation_steps: Optional[int] = None
validation_interval: int = 1000
@dataclasses.dataclass
class TaskConfig(base_config.Config):
model: base_config.Config = None
train_data: DataConfig = DataConfig()
validation_data: DataConfig = DataConfig()
@dataclasses.dataclass
class ExperimentConfig(base_config.Config):
"""Top-level configuration."""
task: TaskConfig = TaskConfig()
trainer: TrainerConfig = TrainerConfig()
runtime: RuntimeConfig = RuntimeConfig()
_REGISTERED_CONFIGS = {}
def register_config_factory(name):
"""Register ExperimentConfig factory method."""
return registry.register(_REGISTERED_CONFIGS, name)
def get_exp_config_creater(exp_name: str):
"""Looks up ExperimentConfig factory methods."""
exp_creater = registry.lookup(_REGISTERED_CONFIGS, exp_name)
return exp_creater
# Lint as: python3 # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,7 +11,7 @@ ...@@ -12,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Config class that supports oneof functionality.""" """Config class that supports oneof functionality."""
from typing import Optional from typing import Optional
...@@ -38,15 +37,12 @@ class OneOfConfig(base_config.Config): ...@@ -38,15 +37,12 @@ class OneOfConfig(base_config.Config):
if self.type is None: if self.type is None:
return {'type': None} return {'type': None}
elif self.__dict__['type'] not in self.__dict__: elif self.__dict__['type'] not in self.__dict__:
raise ValueError( raise ValueError('type: {!r} is not a valid key!'.format(
'type: {!r} is not a valid key!'.format(self.__dict__['type'])) self.__dict__['type']))
else: else:
chosen_type = self.type chosen_type = self.type
chosen_value = self.__dict__[chosen_type] chosen_value = self.__dict__[chosen_type]
return { return {'type': self.type, chosen_type: self._export_config(chosen_value)}
'type': self.type,
chosen_type: self._export_config(chosen_value)
}
def get(self): def get(self):
"""Returns selected config based on the value of type. """Returns selected config based on the value of type.
...@@ -57,6 +53,5 @@ class OneOfConfig(base_config.Config): ...@@ -57,6 +53,5 @@ class OneOfConfig(base_config.Config):
if chosen_type is None: if chosen_type is None:
return None return None
if chosen_type not in self.__dict__: if chosen_type not in self.__dict__:
raise ValueError( raise ValueError('type: {!r} is not a valid key!'.format(self.type))
'type: {!r} is not a valid key!'.format(self.type))
return self.__dict__[chosen_type] return self.__dict__[chosen_type]
# Lint as: python3 # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,7 +11,6 @@ ...@@ -12,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
import dataclasses import dataclasses
import tensorflow as tf import tensorflow as tf
...@@ -48,12 +46,18 @@ class Network(base_config.Config): ...@@ -48,12 +46,18 @@ class Network(base_config.Config):
class OneOfTest(tf.test.TestCase): class OneOfTest(tf.test.TestCase):
def test_to_dict(self): def test_to_dict(self):
network_params = {'backbone': {'type': 'resnet', network_params = {
'resnet': {'model_depth': 50} 'backbone': {
}, 'type': 'resnet',
'output_layer': {'type': 'single', 'resnet': {
'single': 1000} 'model_depth': 50
} }
},
'output_layer': {
'type': 'single',
'single': 1000
}
}
network_config = Network(network_params) network_config = Network(network_params)
self.assertEqual(network_config.as_dict(), network_params) self.assertEqual(network_config.as_dict(), network_params)
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,12 +11,8 @@ ...@@ -11,12 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""A parameter dictionary class which supports the nest structure."""
from __future__ import absolute_import """A parameter dictionary class which supports the nest structure."""
from __future__ import division
from __future__ import print_function
import collections import collections
import copy import copy
...@@ -30,7 +26,8 @@ import yaml ...@@ -30,7 +26,8 @@ import yaml
# key-value pair string. It splits each k-v pair on the = sign, and # key-value pair string. It splits each k-v pair on the = sign, and
# matches on values that are within single quotes, double quotes, single # matches on values that are within single quotes, double quotes, single
# values (e.g. floats, ints, etc.), and a lists within brackets. # values (e.g. floats, ints, etc.), and a lists within brackets.
_PARAM_RE = re.compile(r""" _PARAM_RE = re.compile(
r"""
(?P<name>[a-zA-Z][\w\.]*) # variable name: "var" or "x" (?P<name>[a-zA-Z][\w\.]*) # variable name: "var" or "x"
\s*=\s* \s*=\s*
((?P<val>\'(.*?)\' # single quote ((?P<val>\'(.*?)\' # single quote
...@@ -44,6 +41,26 @@ _PARAM_RE = re.compile(r""" ...@@ -44,6 +41,26 @@ _PARAM_RE = re.compile(r"""
_CONST_VALUE_RE = re.compile(r'(\d.*|-\d.*|None)') _CONST_VALUE_RE = re.compile(r'(\d.*|-\d.*|None)')
# Yaml loader with an implicit resolver to parse float decimal and exponential
# format. The regular experission parse the following cases:
# 1- Decimal number with an optional exponential term.
# 2- Integer number with an exponential term.
# 3- Decimal number with an optional exponential term.
# 4- Decimal number.
LOADER = yaml.SafeLoader
LOADER.add_implicit_resolver(
'tag:yaml.org,2002:float',
re.compile(r'''
^(?:[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|
[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|
\\.[0-9_]+(?:[eE][-+][0-9]+)?
|
[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*)$''', re.X),
list('-+0123456789.'))
class ParamsDict(object): class ParamsDict(object):
"""A hyperparameter container class.""" """A hyperparameter container class."""
...@@ -72,7 +89,6 @@ class ParamsDict(object): ...@@ -72,7 +89,6 @@ class ParamsDict(object):
if default_params is None: if default_params is None:
default_params = {} default_params = {}
self.override(default_params, is_strict=False) self.override(default_params, is_strict=False)
self.validate()
def _set(self, k, v): def _set(self, k, v):
if isinstance(v, dict): if isinstance(v, dict):
...@@ -138,8 +154,8 @@ class ParamsDict(object): ...@@ -138,8 +154,8 @@ class ParamsDict(object):
ValueError: if the ParamsDict instance has been locked. ValueError: if the ParamsDict instance has been locked.
""" """
if k in ParamsDict.RESERVED_ATTR: if k in ParamsDict.RESERVED_ATTR:
raise AttributeError('The key `{}` is reserved. No change is allowes. ' raise AttributeError(
.format(k)) 'The key `{}` is reserved. No change is allowes. '.format(k))
if k not in self.__dict__.keys(): if k not in self.__dict__.keys():
raise AttributeError('The key `{}` does not exist. '.format(k)) raise AttributeError('The key `{}` does not exist. '.format(k))
if self._locked: if self._locked:
...@@ -150,13 +166,13 @@ class ParamsDict(object): ...@@ -150,13 +166,13 @@ class ParamsDict(object):
"""Override the ParamsDict with a set of given params. """Override the ParamsDict with a set of given params.
Args: Args:
override_params: a dict or a ParamsDict specifying the parameters to override_params: a dict or a ParamsDict specifying the parameters to be
be overridden. overridden.
is_strict: a boolean specifying whether override is strict or not. If is_strict: a boolean specifying whether override is strict or not. If
True, keys in `override_params` must be present in the ParamsDict. True, keys in `override_params` must be present in the ParamsDict. If
If False, keys in `override_params` can be different from what is False, keys in `override_params` can be different from what is currently
currently defined in the ParamsDict. In this case, the ParamsDict will defined in the ParamsDict. In this case, the ParamsDict will be extended
be extended to include the new keys. to include the new keys.
""" """
if self._locked: if self._locked:
raise ValueError('The ParamsDict has been locked. No change is allowed.') raise ValueError('The ParamsDict has been locked. No change is allowed.')
...@@ -230,7 +246,7 @@ class ParamsDict(object): ...@@ -230,7 +246,7 @@ class ParamsDict(object):
['a.a1 == b.ccc.a1', 'a.a2 <= b.bb.bb2'] ['a.a1 == b.ccc.a1', 'a.a2 <= b.bb.bb2']
What it enforces are: What it enforces are:
- a.a1 = 1 == b.ccc.a1 = 2 - a.a1 = 1 == b.ccc.a1 = 1
- a.a2 = 2 <= b.bb.bb2 = 20 - a.a2 = 2 <= b.bb.bb2 = 20
Raises: Raises:
...@@ -240,6 +256,7 @@ class ParamsDict(object): ...@@ -240,6 +256,7 @@ class ParamsDict(object):
(2) any inconsistency violating the restriction is found. (2) any inconsistency violating the restriction is found.
ValueError: if the restriction defined in the string is not supported. ValueError: if the restriction defined in the string is not supported.
""" """
def _get_kv(dotted_string, params_dict): def _get_kv(dotted_string, params_dict):
"""Get keys and values indicated by dotted_string.""" """Get keys and values indicated by dotted_string."""
if _CONST_VALUE_RE.match(dotted_string) is not None: if _CONST_VALUE_RE.match(dotted_string) is not None:
...@@ -270,56 +287,64 @@ class ParamsDict(object): ...@@ -270,56 +287,64 @@ class ParamsDict(object):
tokens = restriction.split('==') tokens = restriction.split('==')
_, left_v, _, right_v = _get_kvs(tokens, params_dict) _, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v != right_v: if left_v != right_v:
raise KeyError('Found inconsistncy between key `{}` and key `{}`.' raise KeyError(
.format(tokens[0], tokens[1])) 'Found inconsistncy between key `{}` and key `{}`.'.format(
tokens[0], tokens[1]))
elif '!=' in restriction: elif '!=' in restriction:
tokens = restriction.split('!=') tokens = restriction.split('!=')
_, left_v, _, right_v = _get_kvs(tokens, params_dict) _, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v == right_v: if left_v == right_v:
raise KeyError('Found inconsistncy between key `{}` and key `{}`.' raise KeyError(
.format(tokens[0], tokens[1])) 'Found inconsistncy between key `{}` and key `{}`.'.format(
tokens[0], tokens[1]))
elif '<' in restriction: elif '<' in restriction:
tokens = restriction.split('<') tokens = restriction.split('<')
_, left_v, _, right_v = _get_kvs(tokens, params_dict) _, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v >= right_v: if left_v >= right_v:
raise KeyError('Found inconsistncy between key `{}` and key `{}`.' raise KeyError(
.format(tokens[0], tokens[1])) 'Found inconsistncy between key `{}` and key `{}`.'.format(
tokens[0], tokens[1]))
elif '<=' in restriction: elif '<=' in restriction:
tokens = restriction.split('<=') tokens = restriction.split('<=')
_, left_v, _, right_v = _get_kvs(tokens, params_dict) _, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v > right_v: if left_v > right_v:
raise KeyError('Found inconsistncy between key `{}` and key `{}`.' raise KeyError(
.format(tokens[0], tokens[1])) 'Found inconsistncy between key `{}` and key `{}`.'.format(
tokens[0], tokens[1]))
elif '>' in restriction: elif '>' in restriction:
tokens = restriction.split('>') tokens = restriction.split('>')
_, left_v, _, right_v = _get_kvs(tokens, params_dict) _, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v <= right_v: if left_v <= right_v:
raise KeyError('Found inconsistncy between key `{}` and key `{}`.' raise KeyError(
.format(tokens[0], tokens[1])) 'Found inconsistncy between key `{}` and key `{}`.'.format(
tokens[0], tokens[1]))
elif '>=' in restriction: elif '>=' in restriction:
tokens = restriction.split('>=') tokens = restriction.split('>=')
_, left_v, _, right_v = _get_kvs(tokens, params_dict) _, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v < right_v: if left_v < right_v:
raise KeyError('Found inconsistncy between key `{}` and key `{}`.' raise KeyError(
.format(tokens[0], tokens[1])) 'Found inconsistncy between key `{}` and key `{}`.'.format(
tokens[0], tokens[1]))
else: else:
raise ValueError('Unsupported relation in restriction.') raise ValueError('Unsupported relation in restriction.')
def read_yaml_to_params_dict(file_path): def read_yaml_to_params_dict(file_path: str):
"""Reads a YAML file to a ParamsDict.""" """Reads a YAML file to a ParamsDict."""
with tf.io.gfile.GFile(file_path, 'r') as f: with tf.io.gfile.GFile(file_path, 'r') as f:
params_dict = yaml.load(f) params_dict = yaml.load(f, Loader=LOADER)
return ParamsDict(params_dict) return ParamsDict(params_dict)
def save_params_dict_to_yaml(params, file_path): def save_params_dict_to_yaml(params, file_path):
"""Saves the input ParamsDict to a YAML file.""" """Saves the input ParamsDict to a YAML file."""
with tf.io.gfile.GFile(file_path, 'w') as f: with tf.io.gfile.GFile(file_path, 'w') as f:
def _my_list_rep(dumper, data): def _my_list_rep(dumper, data):
# u'tag:yaml.org,2002:seq' is the YAML internal tag for sequence. # u'tag:yaml.org,2002:seq' is the YAML internal tag for sequence.
return dumper.represent_sequence( return dumper.represent_sequence(
u'tag:yaml.org,2002:seq', data, flow_style=True) u'tag:yaml.org,2002:seq', data, flow_style=True)
yaml.add_representer(list, _my_list_rep) yaml.add_representer(list, _my_list_rep)
yaml.dump(params.as_dict(), f, default_flow_style=False) yaml.dump(params.as_dict(), f, default_flow_style=False)
...@@ -408,8 +433,8 @@ def override_params_dict(params, dict_or_string_or_yaml_file, is_strict): ...@@ -408,8 +433,8 @@ def override_params_dict(params, dict_or_string_or_yaml_file, is_strict):
Args: Args:
params: a ParamsDict object to be overridden. params: a ParamsDict object to be overridden.
dict_or_string_or_yaml_file: a Python dict, JSON/YAML/CSV string or dict_or_string_or_yaml_file: a Python dict, JSON/YAML/CSV string or path to
path to a YAML file specifying the parameters to be overridden. a YAML file specifying the parameters to be overridden.
is_strict: a boolean specifying whether override is strict or not. is_strict: a boolean specifying whether override is strict or not.
Returns: Returns:
...@@ -428,12 +453,12 @@ def override_params_dict(params, dict_or_string_or_yaml_file, is_strict): ...@@ -428,12 +453,12 @@ def override_params_dict(params, dict_or_string_or_yaml_file, is_strict):
nested_csv_str_to_json_str(dict_or_string_or_yaml_file)) nested_csv_str_to_json_str(dict_or_string_or_yaml_file))
except ValueError: except ValueError:
pass pass
params_dict = yaml.load(dict_or_string_or_yaml_file) params_dict = yaml.load(dict_or_string_or_yaml_file, Loader=LOADER)
if isinstance(params_dict, dict): if isinstance(params_dict, dict):
params.override(params_dict, is_strict) params.override(params_dict, is_strict)
else: else:
with tf.io.gfile.GFile(dict_or_string_or_yaml_file) as f: with tf.io.gfile.GFile(dict_or_string_or_yaml_file) as f:
params.override(yaml.load(f), is_strict) params.override(yaml.load(f, Loader=yaml.FullLoader), is_strict)
else: else:
raise ValueError('Unknown input type to parse.') raise ValueError('Unknown input type to parse.')
return params return params
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,7 +11,6 @@ ...@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Tests for params_dict.py.""" """Tests for params_dict.py."""
...@@ -56,8 +55,7 @@ class ParamsDictTest(tf.test.TestCase): ...@@ -56,8 +55,7 @@ class ParamsDictTest(tf.test.TestCase):
def test_setattr(self): def test_setattr(self):
params = params_dict.ParamsDict() params = params_dict.ParamsDict()
params.override( params.override({'a': 'aa', 'b': 2, 'c': None}, is_strict=False)
{'a': 'aa', 'b': 2, 'c': None}, is_strict=False)
params.c = 'ccc' params.c = 'ccc'
self.assertEqual(params.a, 'aa') self.assertEqual(params.a, 'aa')
self.assertEqual(params.b, 2) self.assertEqual(params.b, 2)
...@@ -65,17 +63,23 @@ class ParamsDictTest(tf.test.TestCase): ...@@ -65,17 +63,23 @@ class ParamsDictTest(tf.test.TestCase):
def test_getattr(self): def test_getattr(self):
params = params_dict.ParamsDict() params = params_dict.ParamsDict()
params.override( params.override({'a': 'aa', 'b': 2, 'c': None}, is_strict=False)
{'a': 'aa', 'b': 2, 'c': None}, is_strict=False)
self.assertEqual(params.a, 'aa') self.assertEqual(params.a, 'aa')
self.assertEqual(params.b, 2) self.assertEqual(params.b, 2)
self.assertEqual(params.c, None) self.assertEqual(params.c, None)
def test_delattr(self): def test_delattr(self):
params = params_dict.ParamsDict() params = params_dict.ParamsDict()
params.override( params.override({
{'a': 'aa', 'b': 2, 'c': None, 'd': {'d1': 1, 'd2': 10}}, 'a': 'aa',
is_strict=False) 'b': 2,
'c': None,
'd': {
'd1': 1,
'd2': 10
}
},
is_strict=False)
del params.c del params.c
self.assertEqual(params.a, 'aa') self.assertEqual(params.a, 'aa')
self.assertEqual(params.b, 2) self.assertEqual(params.b, 2)
...@@ -87,22 +91,26 @@ class ParamsDictTest(tf.test.TestCase): ...@@ -87,22 +91,26 @@ class ParamsDictTest(tf.test.TestCase):
def test_contains(self): def test_contains(self):
params = params_dict.ParamsDict() params = params_dict.ParamsDict()
params.override( params.override({'a': 'aa'}, is_strict=False)
{'a': 'aa'}, is_strict=False)
self.assertIn('a', params) self.assertIn('a', params)
self.assertNotIn('b', params) self.assertNotIn('b', params)
def test_get(self): def test_get(self):
params = params_dict.ParamsDict() params = params_dict.ParamsDict()
params.override( params.override({'a': 'aa'}, is_strict=False)
{'a': 'aa'}, is_strict=False)
self.assertEqual(params.get('a'), 'aa') self.assertEqual(params.get('a'), 'aa')
self.assertEqual(params.get('b', 2), 2) self.assertEqual(params.get('b', 2), 2)
self.assertEqual(params.get('b'), None) self.assertEqual(params.get('b'), None)
def test_override_is_strict_true(self): def test_override_is_strict_true(self):
params = params_dict.ParamsDict( params = params_dict.ParamsDict({
{'a': 'aa', 'b': 2, 'c': {'c1': 'cc', 'c2': 20}}) 'a': 'aa',
'b': 2,
'c': {
'c1': 'cc',
'c2': 20
}
})
params.override({'a': 2, 'c': {'c1': 'ccc'}}, is_strict=True) params.override({'a': 2, 'c': {'c1': 'ccc'}}, is_strict=True)
self.assertEqual(params.a, 2) self.assertEqual(params.a, 2)
self.assertEqual(params.c.c1, 'ccc') self.assertEqual(params.c.c1, 'ccc')
...@@ -112,8 +120,14 @@ class ParamsDictTest(tf.test.TestCase): ...@@ -112,8 +120,14 @@ class ParamsDictTest(tf.test.TestCase):
params.override({'c': {'c3': 30}}, is_strict=True) params.override({'c': {'c3': 30}}, is_strict=True)
def test_override_is_strict_false(self): def test_override_is_strict_false(self):
params = params_dict.ParamsDict( params = params_dict.ParamsDict({
{'a': 'aa', 'b': 2, 'c': {'c1': 10, 'c2': 20}}) 'a': 'aa',
'b': 2,
'c': {
'c1': 10,
'c2': 20
}
})
params.override({'a': 2, 'c': {'c3': 3000}}, is_strict=False) params.override({'a': 2, 'c': {'c3': 3000}}, is_strict=False)
self.assertEqual(params.a, 2) self.assertEqual(params.a, 2)
self.assertEqual(params.c.c3, 3000) self.assertEqual(params.c.c3, 3000)
...@@ -123,8 +137,14 @@ class ParamsDictTest(tf.test.TestCase): ...@@ -123,8 +137,14 @@ class ParamsDictTest(tf.test.TestCase):
self.assertEqual(params.c.c4, 4444) self.assertEqual(params.c.c4, 4444)
def test_as_dict(self): def test_as_dict(self):
params = params_dict.ParamsDict( params = params_dict.ParamsDict({
{'a': 'aa', 'b': 2, 'c': {'c1': 10, 'c2': 20}}) 'a': 'aa',
'b': 2,
'c': {
'c1': 10,
'c2': 20
}
})
params_d = params.as_dict() params_d = params.as_dict()
self.assertEqual(params_d['a'], 'aa') self.assertEqual(params_d['a'], 'aa')
self.assertEqual(params_d['b'], 2) self.assertEqual(params_d['b'], 2)
...@@ -134,21 +154,27 @@ class ParamsDictTest(tf.test.TestCase): ...@@ -134,21 +154,27 @@ class ParamsDictTest(tf.test.TestCase):
def test_validate(self): def test_validate(self):
# Raise error due to the unknown parameter. # Raise error due to the unknown parameter.
with self.assertRaises(KeyError): with self.assertRaises(KeyError):
params = params_dict.ParamsDict( params = params_dict.ParamsDict({'a': 1, 'b': {'a': 11}}, ['a == c'])
{'a': 1, 'b': {'a': 11}}, ['a == c']) params.validate()
# OK to check equality of two nested dicts. # OK to check equality of two nested dicts.
params = params_dict.ParamsDict( params = params_dict.ParamsDict({
{'a': 1, 'b': {'a': 10}, 'c': {'a': 10}}, ['b == c']) 'a': 1,
'b': {
'a': 10
},
'c': {
'a': 10
}
}, ['b == c'])
# Raise error due to inconsistency # Raise error due to inconsistency
with self.assertRaises(KeyError): with self.assertRaises(KeyError):
params = params_dict.ParamsDict( params = params_dict.ParamsDict({'a': 1, 'c': {'a': 10}}, ['a == c.a'])
{'a': 1, 'c': {'a': 10}}, ['a == c.a']) params.validate()
# Valid rule. # Valid rule.
params = params_dict.ParamsDict( params = params_dict.ParamsDict({'a': 1, 'c': {'a': 1}}, ['a == c.a'])
{'a': 1, 'c': {'a': 1}}, ['a == c.a'])
# Overridding violates the existing rule, raise error upon validate. # Overridding violates the existing rule, raise error upon validate.
params.override({'a': 11}) params.override({'a': 11})
...@@ -156,12 +182,21 @@ class ParamsDictTest(tf.test.TestCase): ...@@ -156,12 +182,21 @@ class ParamsDictTest(tf.test.TestCase):
params.validate() params.validate()
# Valid restrictions with constant. # Valid restrictions with constant.
params = params_dict.ParamsDict( params = params_dict.ParamsDict({
{'a': None, 'c': {'a': 1}}, ['a == None', 'c.a == 1']) 'a': None,
'c': {
'a': 1
}
}, ['a == None', 'c.a == 1'])
params.validate() params.validate()
with self.assertRaises(KeyError): with self.assertRaises(KeyError):
params = params_dict.ParamsDict( params = params_dict.ParamsDict({
{'a': 4, 'c': {'a': 1}}, ['a == None', 'c.a == 1']) 'a': 4,
'c': {
'a': 1
}
}, ['a == None', 'c.a == 1'])
params.validate()
class ParamsDictIOTest(tf.test.TestCase): class ParamsDictIOTest(tf.test.TestCase):
...@@ -173,8 +208,14 @@ class ParamsDictIOTest(tf.test.TestCase): ...@@ -173,8 +208,14 @@ class ParamsDictIOTest(tf.test.TestCase):
return temp_file return temp_file
def test_save_params_dict_to_yaml(self): def test_save_params_dict_to_yaml(self):
params = params_dict.ParamsDict( params = params_dict.ParamsDict({
{'a': 'aa', 'b': 2, 'c': {'c1': 10, 'c2': 20}}) 'a': 'aa',
'b': 2,
'c': {
'c1': 10,
'c2': 20
}
})
output_yaml_file = os.path.join(self.get_temp_dir(), 'params.yaml') output_yaml_file = os.path.join(self.get_temp_dir(), 'params.yaml')
params_dict.save_params_dict_to_yaml(params, output_yaml_file) params_dict.save_params_dict_to_yaml(params, output_yaml_file)
...@@ -203,7 +244,12 @@ class ParamsDictIOTest(tf.test.TestCase): ...@@ -203,7 +244,12 @@ class ParamsDictIOTest(tf.test.TestCase):
def test_override_params_dict_using_dict(self): def test_override_params_dict_using_dict(self):
params = params_dict.ParamsDict({ params = params_dict.ParamsDict({
'a': 1, 'b': 2.5, 'c': [3, 4], 'd': 'hello', 'e': False}) 'a': 1,
'b': 2.5,
'c': [3, 4],
'd': 'hello',
'e': False
})
override_dict = {'b': 5.2, 'c': [30, 40]} override_dict = {'b': 5.2, 'c': [30, 40]}
params = params_dict.override_params_dict( params = params_dict.override_params_dict(
params, override_dict, is_strict=True) params, override_dict, is_strict=True)
...@@ -215,7 +261,12 @@ class ParamsDictIOTest(tf.test.TestCase): ...@@ -215,7 +261,12 @@ class ParamsDictIOTest(tf.test.TestCase):
def test_override_params_dict_using_yaml_string(self): def test_override_params_dict_using_yaml_string(self):
params = params_dict.ParamsDict({ params = params_dict.ParamsDict({
'a': 1, 'b': 2.5, 'c': [3, 4], 'd': 'hello', 'e': False}) 'a': 1,
'b': 2.5,
'c': [3, 4],
'd': 'hello',
'e': False
})
override_yaml_string = "'b': 5.2\n'c': [30, 40]" override_yaml_string = "'b': 5.2\n'c': [30, 40]"
params = params_dict.override_params_dict( params = params_dict.override_params_dict(
params, override_yaml_string, is_strict=True) params, override_yaml_string, is_strict=True)
...@@ -227,8 +278,18 @@ class ParamsDictIOTest(tf.test.TestCase): ...@@ -227,8 +278,18 @@ class ParamsDictIOTest(tf.test.TestCase):
def test_override_params_dict_using_json_string(self): def test_override_params_dict_using_json_string(self):
params = params_dict.ParamsDict({ params = params_dict.ParamsDict({
'a': 1, 'b': {'b1': 2, 'b2': [2, 3],}, 'a': 1,
'd': {'d1': {'d2': 'hello'}}, 'e': False}) 'b': {
'b1': 2,
'b2': [2, 3],
},
'd': {
'd1': {
'd2': 'hello'
}
},
'e': False
})
override_json_string = "{ b: { b2: [3, 4] }, d: { d1: { d2: 'hi' } } }" override_json_string = "{ b: { b2: [3, 4] }, d: { d1: { d2: 'hi' } } }"
params = params_dict.override_params_dict( params = params_dict.override_params_dict(
params, override_json_string, is_strict=True) params, override_json_string, is_strict=True)
...@@ -240,8 +301,18 @@ class ParamsDictIOTest(tf.test.TestCase): ...@@ -240,8 +301,18 @@ class ParamsDictIOTest(tf.test.TestCase):
def test_override_params_dict_using_csv_string(self): def test_override_params_dict_using_csv_string(self):
params = params_dict.ParamsDict({ params = params_dict.ParamsDict({
'a': 1, 'b': {'b1': 2, 'b2': [2, 3],}, 'a': 1,
'd': {'d1': {'d2': 'hello'}}, 'e': False}) 'b': {
'b1': 2,
'b2': [2, 3],
},
'd': {
'd1': {
'd2': 'hello'
}
},
'e': False
})
override_csv_string = "b.b2=[3,4], d.d1.d2='hi, world', e=gs://test" override_csv_string = "b.b2=[3,4], d.d1.d2='hi, world', e=gs://test"
params = params_dict.override_params_dict( params = params_dict.override_params_dict(
params, override_csv_string, is_strict=True) params, override_csv_string, is_strict=True)
...@@ -250,10 +321,23 @@ class ParamsDictIOTest(tf.test.TestCase): ...@@ -250,10 +321,23 @@ class ParamsDictIOTest(tf.test.TestCase):
self.assertEqual([3, 4], params.b.b2) self.assertEqual([3, 4], params.b.b2)
self.assertEqual('hi, world', params.d.d1.d2) self.assertEqual('hi, world', params.d.d1.d2)
self.assertEqual('gs://test', params.e) self.assertEqual('gs://test', params.e)
# Test different float formats
override_csv_string = 'b.b2=-1.e-3, d.d1.d2=+0.001, e=1e+3, a=-1.5E-3'
params = params_dict.override_params_dict(
params, override_csv_string, is_strict=True)
self.assertEqual(-1e-3, params.b.b2)
self.assertEqual(0.001, params.d.d1.d2)
self.assertEqual(1e3, params.e)
self.assertEqual(-1.5e-3, params.a)
def test_override_params_dict_using_yaml_file(self): def test_override_params_dict_using_yaml_file(self):
params = params_dict.ParamsDict({ params = params_dict.ParamsDict({
'a': 1, 'b': 2.5, 'c': [3, 4], 'd': 'hello', 'e': False}) 'a': 1,
'b': 2.5,
'c': [3, 4],
'd': 'hello',
'e': False
})
override_yaml_file = self.write_temp_file( override_yaml_file = self.write_temp_file(
'params.yaml', r""" 'params.yaml', r"""
b: 5.2 b: 5.2
...@@ -321,8 +405,7 @@ class IOTest(tf.test.TestCase): ...@@ -321,8 +405,7 @@ class IOTest(tf.test.TestCase):
def test_csv_str_load_unsupported_datatypes(self): def test_csv_str_load_unsupported_datatypes(self):
csv_str = 'a=[[1,2,3],[4,5,6]]' csv_str = 'a=[[1,2,3],[4,5,6]]'
self.assertRaises(ValueError, self.assertRaises(ValueError, params_dict.nested_csv_str_to_json_str,
params_dict.nested_csv_str_to_json_str,
csv_str) csv_str)
def test_csv_str_to_json_str_spacing(self): def test_csv_str_to_json_str_spacing(self):
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Abstraction of multi-task model."""
from typing import Text, Dict
import tensorflow as tf
class MultiTaskBaseModel(tf.Module):
"""Base class that holds multi-task model computation."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._sub_tasks = self._instantiate_sub_tasks()
def _instantiate_sub_tasks(self) -> Dict[Text, tf.keras.Model]:
"""Abstract function that sets up the computation for each sub-task.
Returns:
A map from task name (as string) to a tf.keras.Model object that
represents the sub-task in the multi-task pool.
"""
raise NotImplementedError(
"_instantiate_sub_task_models() is not implemented.")
@property
def sub_tasks(self):
"""Fetch a map of task name (string) to task model (tf.keras.Model)."""
return self._sub_tasks
def initialize(self):
"""Optional function that loads a pre-train checkpoint."""
return
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Multitask base trainer implementation.
The trainer derives from the Orbit `StandardTrainer` class.
"""
from typing import Union
import gin
import orbit
import tensorflow as tf
from official.modeling.multitask import base_model
from official.modeling.multitask import multitask
@gin.configurable
class MultiTaskBaseTrainer(orbit.StandardTrainer):
"""Multitask base trainer."""
def __init__(self,
multi_task: multitask.MultiTask,
multi_task_model: Union[tf.keras.Model,
base_model.MultiTaskBaseModel],
optimizer: tf.optimizers.Optimizer,
trainer_options=None):
self._strategy = tf.distribute.get_strategy()
self._multi_task = multi_task
self._multi_task_model = multi_task_model
self._optimizer = optimizer
self._training_losses = None
self._training_metrics = None
self._global_step = orbit.utils.create_global_step()
if hasattr(self.multi_task_model, "checkpoint_items"):
checkpoint_items = self.multi_task_model.checkpoint_items
else:
checkpoint_items = {}
self._checkpoint = tf.train.Checkpoint(
model=self.multi_task_model,
optimizer=self.optimizer,
global_step=self.global_step,
**checkpoint_items)
train_datasets = {}
for name, task in self.multi_task.tasks.items():
train_datasets[name] = orbit.utils.make_distributed_dataset(
self.strategy, task.build_inputs, task.task_config.train_data)
super().__init__(
train_dataset=train_datasets,
options=trainer_options or orbit.StandardTrainerOptions())
def train_loop_begin(self):
"""Clean up states that hold losses and metrics."""
for _, train_loss_metric in self.training_losses.items():
train_loss_metric.reset_states()
for _, metrics in self.training_metrics.items():
for metric in metrics:
metric.reset_states()
def train_loop_end(self):
"""Record loss and metric values per task."""
result = {}
for task_name, loss in self.training_losses.items():
result[task_name] = {loss.name: loss.result()}
for task_name, task_metrics in self.training_metrics.items():
result[task_name].update(
{metric.name: metric.result() for metric in task_metrics})
# Note that, the learning rate schedule is managed by the keras optimizer
# internally, which respects the number of backward pass as `iterations`.
# The learning rate schedule does not follow the trainer logical global
# step of multiple tasks.
if callable(self.optimizer.learning_rate):
result["learning_rate"] = self.optimizer.learning_rate(
self.optimizer.iterations)
else:
result["learning_rate"] = self.optimizer.learning_rate
return result
@property
def checkpoint(self):
"""Accesses the training checkpoint."""
return self._checkpoint
@property
def training_losses(self):
"""Access training loss metric objects for all tasks."""
if self._training_losses is None:
# Builds the per-task metrics and losses.
# This the total summed training loss of tasks in the joint training.
self._training_losses = dict(
total_loss=tf.keras.metrics.Mean("training_loss", dtype=tf.float32))
for name in self.multi_task.tasks:
self._training_losses[name] = tf.keras.metrics.Mean(
"training_loss", dtype=tf.float32)
return self._training_losses
@property
def training_metrics(self):
"""Access training metric metric objects for all tasks."""
if self._training_metrics is None:
# Builds the per-task metrics and losses.
self._training_metrics = {}
for name, task in self.multi_task.tasks.items():
self._training_metrics[name] = task.build_metrics(training=True)
return self._training_metrics
@property
def strategy(self):
return self._strategy
@property
def multi_task(self):
return self._multi_task
@property
def multi_task_model(self):
return self._multi_task_model
@property
def optimizer(self):
return self._optimizer
@property
def global_step(self):
return self._global_step
def train_step(self, iterator_map):
"""The default train step calling the multi-task train step.
Args:
iterator_map: a dictionary of task names and per-task dataset iterators.
"""
def step_fn(inputs):
losses = self.multi_task.joint_train_step(
inputs,
multi_task_model=self.multi_task_model,
optimizer=self.optimizer,
task_metrics=self.training_metrics)
for key, loss in losses.items():
self.training_losses[key].update_state(loss)
self.strategy.run(
step_fn, args=(tf.nest.map_structure(next, iterator_map),))
self.global_step.assign_add(1)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for multitask.base_trainer."""
from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.modeling.multitask import base_trainer
from official.modeling.multitask import configs
from official.modeling.multitask import multitask
from official.modeling.multitask import test_utils
def all_strategy_combinations():
return combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
mode="eager",
)
class BaseTrainerTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(all_strategy_combinations())
def test_multitask_joint_trainer(self, distribution):
with distribution.scope():
tasks = [
test_utils.MockFooTask(params=test_utils.FooConfig(), name="foo"),
test_utils.MockBarTask(params=test_utils.BarConfig(), name="bar")
]
task_weights = {"foo": 1.0, "bar": 1.0}
test_multitask = multitask.MultiTask(
tasks=tasks, task_weights=task_weights)
test_optimizer = tf.keras.optimizers.SGD(0.1)
model = test_utils.MockMultiTaskModel()
test_trainer = base_trainer.MultiTaskBaseTrainer(
multi_task=test_multitask,
multi_task_model=model,
optimizer=test_optimizer)
results = test_trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertContainsSubset(["training_loss", "bar_acc"],
results["bar"].keys())
self.assertContainsSubset(["training_loss", "foo_acc"],
results["foo"].keys())
def test_trainer_with_configs(self):
config = configs.MultiTaskConfig(
task_routines=(configs.TaskRoutine(
task_name="foo",
task_config=test_utils.FooConfig(),
task_weight=0.5),
configs.TaskRoutine(
task_name="bar",
task_config=test_utils.BarConfig(),
task_weight=0.5)))
test_multitask = multitask.MultiTask.from_config(config)
test_optimizer = tf.keras.optimizers.SGD(0.1)
model = test_utils.MockMultiTaskModel()
test_trainer = base_trainer.MultiTaskBaseTrainer(
multi_task=test_multitask,
multi_task_model=model,
optimizer=test_optimizer)
results = test_trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertContainsSubset(["training_loss", "bar_acc"],
results["bar"].keys())
self.assertContainsSubset(["training_loss", "foo_acc"],
results["foo"].keys())
self.assertEqual(test_multitask.task_weight("foo"), 0.5)
self.assertEqual(test_trainer.global_step.numpy(), 5)
self.assertIn("learning_rate", results)
if __name__ == "__main__":
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration definitions for multi-task training."""
from typing import Optional, Tuple
import dataclasses
from official.core import config_definitions as cfg
from official.modeling import hyperparams
@dataclasses.dataclass
class TaskRoutine(hyperparams.Config):
task_name: str = ""
task_config: cfg.TaskConfig = None
eval_steps: Optional[int] = None
task_weight: Optional[float] = 1.0
@dataclasses.dataclass
class MultiTaskConfig(hyperparams.Config):
init_checkpoint: str = ""
model: hyperparams.Config = None
task_routines: Tuple[TaskRoutine, ...] = ()
@dataclasses.dataclass
class ProportionalSampleConfig(hyperparams.Config):
alpha: float = 1.0
@dataclasses.dataclass
class AnnealingSampleConfig(hyperparams.Config):
steps_per_epoch: int = 5
total_steps: int = 20
@dataclasses.dataclass
class TaskSamplingConfig(hyperparams.OneOfConfig):
type: str = ""
uniform: hyperparams.Config = hyperparams.Config()
proportional: ProportionalSampleConfig = ProportionalSampleConfig()
annealing: AnnealingSampleConfig = AnnealingSampleConfig()
@dataclasses.dataclass
class MultiTaskTrainerConfig(cfg.TrainerConfig):
trainer_type: str = "interleaving"
task_sampler: TaskSamplingConfig = TaskSamplingConfig(type="proportional")
@dataclasses.dataclass
class MultiTaskExperimentConfig(hyperparams.Config):
"""An experiment config for multi-task training and multi-task evaluation."""
task: MultiTaskConfig = MultiTaskConfig()
trainer: MultiTaskTrainerConfig = MultiTaskTrainerConfig()
runtime: cfg.RuntimeConfig = cfg.RuntimeConfig()
@dataclasses.dataclass
class MultiEvalExperimentConfig(cfg.ExperimentConfig):
"""An experiment config for single-task training and multi-task evaluation.
Attributes:
eval_tasks: individual evaluation tasks.
"""
eval_tasks: MultiTaskConfig = MultiTaskConfig()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multitask Evaluator implementation.
The evaluator implements the Orbit `AbstractEvaluator` interface.
"""
from typing import Optional, Union
import gin
import orbit
import tensorflow as tf
from official.core import train_utils
from official.modeling.multitask import base_model
from official.modeling.multitask import multitask
@gin.configurable
class MultiTaskEvaluator(orbit.AbstractEvaluator):
"""Implements the common trainer shared for TensorFlow models."""
def __init__(
self,
task: multitask.MultiTask,
model: Union[tf.keras.Model, base_model.MultiTaskBaseModel],
global_step: Optional[tf.Variable] = None,
checkpoint_exporter: Optional[train_utils.BestCheckpointExporter] = None):
"""Initialize common trainer for TensorFlow models.
Args:
task: A multitask.MultiTask instance.
model: tf.keras.Model instance.
global_step: the global step variable.
checkpoint_exporter: an object that has the `maybe_export_checkpoint`
interface.
"""
# Gets the current distribution strategy. If not inside any strategy scope,
# it gets a single-replica no-op strategy.
self._strategy = tf.distribute.get_strategy()
self._task = task
self._model = model
self._global_step = global_step or orbit.utils.create_global_step()
self._checkpoint_exporter = checkpoint_exporter
self._checkpoint = tf.train.Checkpoint(
global_step=self.global_step,
model=self.model)
self._validation_losses = None
self._validation_metrics = None
# Builds per-task datasets.
self.eval_datasets = {}
for name, task in self.task.tasks.items():
self.eval_datasets[name] = orbit.utils.make_distributed_dataset(
self.strategy, task.build_inputs, task.task_config.validation_data)
# Builds per-task validation loops.
def get_function(task_name, task):
task_metrics = self.validation_metrics[task_name]
task_loss = self.validation_losses[task_name]
if isinstance(self.model, base_model.MultiTaskBaseModel):
model = self.model.sub_tasks[task_name]
else:
model = self.model
def step_fn(inputs):
logs = task.validation_step(inputs, model=model, metrics=task_metrics)
task_loss.update_state(logs[task.loss])
return logs
@tf.function
def eval_step_fn(iterator):
distributed_outputs = self.strategy.run(step_fn, args=(next(iterator),))
return tf.nest.map_structure(self.strategy.experimental_local_results,
distributed_outputs)
return orbit.utils.create_loop_fn(eval_step_fn)
self.task_fns = {
name: get_function(name, task)
for name, task in self.task.tasks.items()
}
@property
def strategy(self):
return self._strategy
@property
def task(self):
return self._task
@property
def model(self):
return self._model
@property
def global_step(self):
return self._global_step
@property
def validation_losses(self):
"""Accesses the validation loss metric object."""
if self._validation_losses is None:
# Builds the per-task metrics and losses.
self._validation_losses = {}
for name in self.task.tasks:
self._validation_losses[name] = tf.keras.metrics.Mean(
"validation_loss", dtype=tf.float32)
return self._validation_losses
@property
def validation_metrics(self):
"""Accesses all validation metric metric objects."""
if self._validation_metrics is None:
# Builds the per-task metrics and losses.
self._validation_metrics = {}
for name, task in self.task.tasks.items():
self._validation_metrics[name] = task.build_metrics(training=False)
return self._validation_metrics
@property
def checkpoint(self):
"""Accesses the training checkpoint."""
return self._checkpoint
def evaluate(self, num_steps: tf.Tensor):
"""Performs evaluation for each `EvalTask`."""
for metric in self.validation_losses.values():
metric.reset_states()
for metrics in self.validation_metrics.values():
for metric in metrics:
metric.reset_states()
results = {}
eval_iters = tf.nest.map_structure(iter, self.eval_datasets)
for name, task_eval_loop in self.task_fns.items():
outputs = None
eval_iter = eval_iters[name]
task = self.task.tasks[name]
task_eval_steps = self.task.task_eval_steps(name) or num_steps
outputs = task_eval_loop(
eval_iter,
task_eval_steps,
state=outputs,
reduce_fn=task.aggregate_logs)
task_metrics = self.validation_metrics[name]
task_loss = self.validation_losses[name]
logs = {}
for metric in task_metrics + [task_loss]:
logs[metric.name] = metric.result()
if outputs:
metrics = task.reduce_aggregated_logs(
outputs, global_step=self.global_step)
logs.update(metrics)
results[name] = logs
if self._checkpoint_exporter:
self._checkpoint_exporter.maybe_export_checkpoint(
self.checkpoint, results, self.global_step.numpy())
return results
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for multitask.evaluator."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.core import base_task
from official.core import config_definitions as cfg
from official.modeling.multitask import evaluator
from official.modeling.multitask import multitask
def all_strategy_combinations():
return combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
mode="eager",
)
class MockModel(tf.keras.Model):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.dense = tf.keras.layers.Dense(1)
def call(self, inputs):
print(inputs, type(inputs))
if "y" in inputs:
self.add_loss(tf.zeros((1,), dtype=tf.float32))
else:
self.add_loss(tf.ones((1,), dtype=tf.float32))
return self.dense(inputs["x"])
class MockTask(base_task.Task):
"""Mock task object for testing."""
def build_metrics(self, training: bool = True):
del training
return [tf.keras.metrics.Accuracy(name="acc")]
def build_inputs(self, params):
def generate_data(_):
x = tf.zeros(shape=(2,), dtype=tf.float32)
label = tf.zeros([1], dtype=tf.int32)
if self.name == "bar":
return dict(x=x, y=x), label
else:
return dict(x=x), label
dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat()
dataset = dataset.map(
generate_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset.prefetch(buffer_size=1).batch(2, drop_remainder=True)
def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
logs = super().validation_step(inputs, model, metrics)
logs["counter"] = tf.ones((1,), dtype=tf.float32)
return logs
def aggregate_logs(self, state, step_outputs):
if state is None:
state = {}
for key, value in step_outputs.items():
if key not in state:
state[key] = []
state[key].append(
np.concatenate([np.expand_dims(v.numpy(), axis=0) for v in value]))
return state
def reduce_aggregated_logs(self,
aggregated_logs,
global_step=None):
for k, v in aggregated_logs.items():
aggregated_logs[k] = np.sum(np.stack(v, axis=0))
return aggregated_logs
class EvaluatorTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(all_strategy_combinations())
def test_multitask_evaluator(self, distribution):
with distribution.scope():
tasks = [
MockTask(params=cfg.TaskConfig(), name="bar"),
MockTask(params=cfg.TaskConfig(), name="foo")
]
test_multitask = multitask.MultiTask(tasks=tasks)
model = MockModel()
test_evaluator = evaluator.MultiTaskEvaluator(
task=test_multitask, model=model)
results = test_evaluator.evaluate(tf.convert_to_tensor(1, dtype=tf.int32))
self.assertContainsSubset(["validation_loss", "acc"], results["bar"].keys())
self.assertContainsSubset(["validation_loss", "acc"], results["foo"].keys())
self.assertEqual(results["bar"]["validation_loss"], 0.0)
self.assertEqual(results["foo"]["validation_loss"], 1.0)
@combinations.generate(all_strategy_combinations())
def test_multitask_evaluator_numpy_metrics(self, distribution):
with distribution.scope():
tasks = [
MockTask(params=cfg.TaskConfig(), name="bar"),
MockTask(params=cfg.TaskConfig(), name="foo")
]
test_multitask = multitask.MultiTask(tasks=tasks)
model = MockModel()
test_evaluator = evaluator.MultiTaskEvaluator(
task=test_multitask, model=model)
results = test_evaluator.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertEqual(results["bar"]["counter"],
5. * distribution.num_replicas_in_sync)
self.assertEqual(results["foo"]["counter"],
5. * distribution.num_replicas_in_sync)
if __name__ == "__main__":
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multitask trainer that interleaves each task's train step."""
from typing import Union
import gin
import orbit
import tensorflow as tf
from official.modeling.multitask import base_model
from official.modeling.multitask import base_trainer
from official.modeling.multitask import multitask
from official.modeling.multitask import task_sampler as sampler
@gin.configurable
class MultiTaskInterleavingTrainer(base_trainer.MultiTaskBaseTrainer):
"""MultiTask trainer that interleaves task update."""
def __init__(self,
multi_task: multitask.MultiTask,
multi_task_model: Union[tf.keras.Model,
base_model.MultiTaskBaseModel],
optimizer: tf.optimizers.Optimizer,
task_sampler: sampler.TaskSampler,
trainer_options=None):
super(MultiTaskInterleavingTrainer, self).__init__(
multi_task=multi_task,
multi_task_model=multi_task_model,
optimizer=optimizer,
trainer_options=trainer_options)
self._task_sampler = task_sampler
# Build per task train step.
def _get_task_step(task_name, task):
def step_fn(inputs):
if isinstance(self.multi_task_model, base_model.MultiTaskBaseModel):
task_model = self.multi_task_model.sub_tasks[task_name]
else:
task_model = self.multi_task_model
task_logs = task.train_step(
inputs,
model=task_model,
optimizer=self.optimizer,
metrics=self.training_metrics[task_name])
self.training_losses[task_name].update_state(task_logs[task.loss])
return step_fn
self._task_train_step_map = {
name: _get_task_step(name, task)
for name, task in self.multi_task.tasks.items()
}
# TODO(haozhangthu): Add taskwise step counter to train_loop_end for logging
# on TensorBoard.
self._task_step_counters = {
name: orbit.utils.create_global_step() for name in self.multi_task.tasks
}
def task_step_counter(self, name):
return self._task_step_counters[name]
def train_step(self, iterator_map):
# Sample one task to train according to a multinomial distribution
rn = tf.random.stateless_uniform(shape=[], seed=(0, self.global_step))
cumulative_sample_distribution = self._task_sampler.task_cumulative_distribution(
self.global_step)
# Prepend a [0.0] for indexing convenience.
cumulative_sample_distribution = tf.concat(
[tf.constant([0.0], dtype=tf.float32), cumulative_sample_distribution],
axis=0)
for idx, (name, _) in enumerate(self.multi_task.tasks.items()):
begin = cumulative_sample_distribution[idx]
end = cumulative_sample_distribution[idx + 1]
if rn >= begin and rn < end:
self._strategy.run(
self._task_train_step_map[name], args=(next(iterator_map[name]),))
self.global_step.assign_add(1)
self.task_step_counter(name).assign_add(1)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for multitask.interleaving_trainer."""
from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.modeling.multitask import configs
from official.modeling.multitask import interleaving_trainer
from official.modeling.multitask import multitask
from official.modeling.multitask import task_sampler
from official.modeling.multitask import test_utils
def all_strategy_combinations():
return combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
mode="eager",
)
class InterleavingTrainerTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(all_strategy_combinations())
def test_multitask_interleaving_trainer(self, distribution):
with distribution.scope():
tasks = [
test_utils.MockFooTask(params=test_utils.FooConfig(), name="foo"),
test_utils.MockBarTask(params=test_utils.BarConfig(), name="bar")
]
test_multitask = multitask.MultiTask(tasks=tasks)
test_optimizer = tf.keras.optimizers.SGD(0.1)
model = test_utils.MockMultiTaskModel()
sampler = task_sampler.UniformTaskSampler(
task_weights=test_multitask.task_weights)
test_trainer = interleaving_trainer.MultiTaskInterleavingTrainer(
multi_task=test_multitask,
multi_task_model=model,
optimizer=test_optimizer,
task_sampler=sampler)
results = test_trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertContainsSubset(["training_loss", "bar_acc"],
results["bar"].keys())
self.assertContainsSubset(["training_loss", "foo_acc"],
results["foo"].keys())
@combinations.generate(all_strategy_combinations())
def test_trainer_with_configs(self, distribution):
config = configs.MultiTaskConfig(
task_routines=(configs.TaskRoutine(
task_name="foo",
task_config=test_utils.FooConfig(),
task_weight=3.0),
configs.TaskRoutine(
task_name="bar",
task_config=test_utils.BarConfig(),
task_weight=1.0)))
with distribution.scope():
test_multitask = multitask.MultiTask.from_config(config)
test_optimizer = tf.keras.optimizers.SGD(0.1)
model = test_utils.MockMultiTaskModel()
num_step = 1000
sampler = task_sampler.AnnealingTaskSampler(
task_weights=test_multitask.task_weights,
steps_per_epoch=num_step/5,
total_steps=num_step)
test_trainer = interleaving_trainer.MultiTaskInterleavingTrainer(
multi_task=test_multitask,
multi_task_model=model,
optimizer=test_optimizer,
task_sampler=sampler)
results = test_trainer.train(tf.convert_to_tensor(num_step, dtype=tf.int32))
self.assertContainsSubset(["training_loss", "bar_acc"],
results["bar"].keys())
self.assertContainsSubset(["training_loss", "foo_acc"],
results["foo"].keys())
self.assertEqual(test_trainer.global_step.numpy(), num_step)
bar_sampled_step = test_trainer.task_step_counter("bar").numpy()
foo_sampled_step = test_trainer.task_step_counter("foo").numpy()
self.assertEqual(bar_sampled_step + foo_sampled_step, num_step)
if __name__ == "__main__":
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Experimental MultiTask base class for multi-task training/evaluation."""
import abc
from typing import Dict, List, Optional, Text, Union
import tensorflow as tf
from official.core import base_task
from official.core import config_definitions
from official.core import task_factory
from official.modeling import optimization
from official.modeling.multitask import base_model
from official.modeling.multitask import configs
OptimizationConfig = optimization.OptimizationConfig
RuntimeConfig = config_definitions.RuntimeConfig
class MultiTask(tf.Module, metaclass=abc.ABCMeta):
"""A multi-task class to manage multiple tasks."""
def __init__(self,
tasks: Union[Dict[Text, base_task.Task], List[base_task.Task]],
task_weights: Optional[Dict[str, Union[float, int]]] = None,
task_eval_steps: Optional[Dict[str, int]] = None,
name: Optional[str] = None):
"""MultiTask initialization.
Args:
tasks: a list or a flat dict of Task.
task_weights: a dict of (task, task weight), task weight can be applied
directly during loss summation in a joint backward step, or it can be
used to sample task among interleaved backward step.
task_eval_steps: a dict of (task, eval steps).
name: the instance name of a MultiTask object.
"""
super().__init__(name=name)
if isinstance(tasks, list):
self._tasks = {}
for task in tasks:
if task.name in self._tasks:
raise ValueError("Duplicated tasks found, task.name is %s" %
task.name)
self._tasks[task.name] = task
elif isinstance(tasks, dict):
self._tasks = tasks
else:
raise ValueError("The tasks argument has an invalid type: %s" %
type(tasks))
self._task_eval_steps = task_eval_steps or {}
self._task_eval_steps = dict([
(name, self._task_eval_steps.get(name, None)) for name in self.tasks
])
self._task_weights = task_weights or {}
self._task_weights = dict([
(name, self._task_weights.get(name, 1.0)) for name in self.tasks
])
@classmethod
def from_config(cls, config: configs.MultiTaskConfig, logging_dir=None):
tasks = {}
task_eval_steps = {}
task_weights = {}
for task_routine in config.task_routines:
task_name = task_routine.task_name
tasks[task_name] = task_factory.get_task(
task_routine.task_config, logging_dir=logging_dir)
task_eval_steps[task_name] = task_routine.eval_steps
task_weights[task_name] = task_routine.task_weight
return cls(
tasks, task_eval_steps=task_eval_steps, task_weights=task_weights)
@property
def tasks(self):
return self._tasks
def task_eval_steps(self, task_name):
return self._task_eval_steps[task_name]
def task_weight(self, task_name):
return self._task_weights[task_name]
@property
def task_weights(self):
return self._task_weights
@classmethod
def create_optimizer(cls,
optimizer_config: OptimizationConfig,
runtime_config: Optional[RuntimeConfig] = None):
return base_task.Task.create_optimizer(
optimizer_config=optimizer_config, runtime_config=runtime_config)
def joint_train_step(self, task_inputs,
multi_task_model: base_model.MultiTaskBaseModel,
optimizer: tf.keras.optimizers.Optimizer, task_metrics):
"""The joint train step.
Args:
task_inputs: a dictionary of task names and per-task features.
multi_task_model: a MultiTaskBaseModel instance.
optimizer: a tf.optimizers.Optimizer.
task_metrics: a dictionary of task names and per-task metrics.
Returns:
A dictionary of losses, inculding per-task losses and their weighted sum.
"""
losses = {}
with tf.GradientTape() as tape:
total_loss = 0.0
for name, model in multi_task_model.sub_tasks.items():
inputs = task_inputs[name]
if isinstance(inputs, tuple) and len(inputs) == 2:
features, labels = inputs
elif isinstance(inputs, dict):
features, labels = inputs, inputs
else:
raise ValueError("The iterator output is neither a tuple nor a "
"dictionary. It is not implemented to support "
"such outputs.")
outputs = model(features, training=True)
task_loss = self.tasks[name].build_losses(labels, outputs)
task_weight = self.task_weight(name)
total_loss += task_weight * task_loss
losses[name] = task_loss
self.tasks[name].process_metrics(task_metrics[name], labels, outputs)
# Scales loss as the default gradients allreduce performs sum inside
# the optimizer.
scaled_loss = total_loss / tf.distribute.get_strategy(
).num_replicas_in_sync
tvars = multi_task_model.trainable_variables
grads = tape.gradient(scaled_loss, tvars)
optimizer.apply_gradients(list(zip(grads, tvars)))
losses["total_loss"] = total_loss
return losses
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils to sample tasks for interleaved optimization."""
import abc
from typing import Union, Dict, Text
import tensorflow as tf
from official.modeling.multitask import configs
class TaskSampler(tf.Module, metaclass=abc.ABCMeta):
"""An abstract class defining task sampling API for interleaving trainer."""
def __init__(self, task_weights: Dict[Text, Union[float, int]]):
self._task_weights = task_weights
@abc.abstractmethod
def task_cumulative_distribution(self, global_step: tf.Tensor) -> tf.Tensor:
"""Compute cumulative distribution to sample tasks.
It calculates the cumulative distribution of the multinomial task
distribution with respect to which to be sampled against.
Args:
global_step: A tensor indicating current progess of training.
Returns:
A float tensor with shape (#(task), 1) that represents the cumulative
sampling distribution.
"""
pass
class UniformTaskSampler(TaskSampler):
"""Sample all tasks uniformly."""
def __init__(self, task_weights: Dict[Text, Union[float, int]]):
super(UniformTaskSampler, self).__init__(task_weights=task_weights)
self._uniform_cumulative = tf.math.cumsum(
tf.constant(
[1.0 / len(self._task_weights)] * len(self._task_weights),
dtype=tf.float32))
def task_cumulative_distribution(self, global_step: tf.Tensor) -> tf.Tensor:
del global_step
return self._uniform_cumulative
class ProportionalTaskSampler(TaskSampler):
"""Sample tasks proportional to task weights."""
def __init__(self,
task_weights: Dict[Text, Union[float, int]],
alpha: float = 1.0):
super(ProportionalTaskSampler, self).__init__(task_weights=task_weights)
self._alpha = tf.cast(alpha, dtype=tf.float32)
task_weight_dict_ordered_list = tf.constant(
[weight for _, weight in self._task_weights.items()], dtype=tf.float32)
task_sizes = tf.math.pow(task_weight_dict_ordered_list, self._alpha)
task_distribution = task_sizes / tf.reduce_sum(task_sizes)
self._porportional_cumulative = tf.math.cumsum(task_distribution)
def task_cumulative_distribution(self, global_step: tf.Tensor) -> tf.Tensor:
del global_step
return self._porportional_cumulative
class AnnealingTaskSampler(TaskSampler):
"""Sample tasks according to task weights as well as training progress."""
def __init__(self,
task_weights: Dict[Text, Union[float, int]],
steps_per_epoch: int,
total_steps: int):
super(AnnealingTaskSampler, self).__init__(task_weights=task_weights)
self._steps_per_epoch = tf.cast(steps_per_epoch, dtype=tf.float32)
self._total_epochs = tf.cast(
total_steps / self._steps_per_epoch, dtype=tf.float32)
def task_cumulative_distribution(self, global_step: tf.Tensor) -> tf.Tensor:
cur_epoch = tf.math.floor(
tf.cast(global_step, dtype=tf.float32) / self._steps_per_epoch)
alpha = 1.0 - 0.8 * (cur_epoch - 1) / (self._total_epochs - 1 + 1e-10)
task_weight_dict_ordered_list = [
weight for _, weight in self._task_weights.items()
]
task_sizes = tf.math.pow(
tf.constant(task_weight_dict_ordered_list, dtype=tf.float32),
tf.cast(alpha, dtype=tf.float32))
dynamic_task_distribution = task_sizes / tf.reduce_sum(task_sizes)
return tf.math.cumsum(dynamic_task_distribution)
def get_task_sampler(config: configs.TaskSamplingConfig,
task_weights: Dict[Text, float]) -> TaskSampler:
"""Utils to create task sampler with configuration and task weights."""
oneof_config = config.get()
if config.type == 'uniform':
return UniformTaskSampler(task_weights=task_weights)
elif config.type == 'proportional':
return ProportionalTaskSampler(
task_weights=task_weights, alpha=oneof_config.alpha)
elif config.type == 'annealing':
return AnnealingTaskSampler(
task_weights=task_weights,
steps_per_epoch=oneof_config.steps_per_epoch,
total_steps=oneof_config.total_steps)
else:
raise RuntimeError('Task sampler type not supported')
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for multitask.task_sampler."""
import tensorflow as tf
from official.modeling.multitask import configs
from official.modeling.multitask import task_sampler as sampler
class TaskSamplerTest(tf.test.TestCase):
def setUp(self):
super(TaskSamplerTest, self).setUp()
self._task_weights = {'A': 1.0, 'B': 2.0, 'C': 3.0}
def test_uniform_sample_distribution(self):
uniform_sampler = sampler.get_task_sampler(
configs.TaskSamplingConfig(type='uniform'), self._task_weights)
for step in range(5):
cumulative_distribution = uniform_sampler.task_cumulative_distribution(
tf.constant(step, dtype=tf.int64))
self.assertAllClose([0.333333, 0.666666, 1.0],
cumulative_distribution.numpy())
def test_proportional_sample_distribution(self):
prop_sampler = sampler.get_task_sampler(
configs.TaskSamplingConfig(
type='proportional',
proportional=configs.ProportionalSampleConfig(alpha=2.0)),
self._task_weights)
# CucmulativeOf(Normalize([1.0^2, 2.0^2, 3.0^2]))
for step in range(5):
cumulative_distribution = prop_sampler.task_cumulative_distribution(
tf.constant(step, dtype=tf.int64))
self.assertAllClose([0.07142857, 0.35714286, 1.0],
cumulative_distribution.numpy())
def test_annealing_sample_distribution(self):
num_epoch = 3
step_per_epoch = 6
annel_sampler = sampler.get_task_sampler(
configs.TaskSamplingConfig(
type='annealing',
annealing=configs.AnnealingSampleConfig(
steps_per_epoch=step_per_epoch,
total_steps=step_per_epoch * num_epoch)), self._task_weights)
global_step = tf.Variable(
0, dtype=tf.int64, name='global_step', trainable=False)
expected_cumulative_epochs = [[0.12056106, 0.4387236, 1.0],
[0.16666667, 0.5, 1.0],
[0.22477472, 0.5654695, 1.0]]
for epoch in range(num_epoch):
for _ in range(step_per_epoch):
cumulative_distribution = annel_sampler.task_cumulative_distribution(
tf.constant(global_step, dtype=tf.int64))
global_step.assign_add(1)
self.assertAllClose(expected_cumulative_epochs[epoch],
cumulative_distribution.numpy())
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment