Unverified Commit f16a7b5b authored by vedanshu's avatar vedanshu Committed by GitHub
Browse files

Merge pull request #1 from tensorflow/master

new pull
parents 8e9296ff 8f58f396
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Hyperparams package definition."""
# pylint: disable=g-multiple-import
from official.modeling.hyperparams.base_config import *
......
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,17 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base configurations to standardize experiments."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
"""Base configurations to standardize experiments."""
import copy
import functools
from typing import Any, List, Mapping, Optional, Type
from absl import logging
import dataclasses
import tensorflow as tf
......@@ -35,11 +30,15 @@ from official.modeling.hyperparams import params_dict
class Config(params_dict.ParamsDict):
"""The base configuration class that supports YAML/JSON based overrides.
* It recursively enforces a whitelist of basic types and container types, so
Because of YAML/JSON serialization limitations, some semantics of dataclass
are not supported:
* It recursively enforces a allowlist of basic types and container types, so
it avoids surprises with copy and reuse caused by unanticipated types.
* It converts dict to Config even within sequences,
* Warning: it converts Dict to `Config` even within sequences,
e.g. for config = Config({'key': [([{'a': 42}],)]),
type(config.key[0][0][0]) is Config rather than dict.
If you define/annotate some field as Dict, the field will convert to a
`Config` instance and lose the dictionary type.
"""
# It's safe to add bytes and other immutable types here.
......@@ -142,10 +141,11 @@ class Config(params_dict.ParamsDict):
return subconfig_type
def __post_init__(self, default_params, restrictions, *args, **kwargs):
super().__init__(default_params=default_params,
restrictions=restrictions,
*args,
**kwargs)
super().__init__(
default_params=default_params,
restrictions=restrictions,
*args,
**kwargs)
def _set(self, k, v):
"""Overrides same method in ParamsDict.
......@@ -160,13 +160,32 @@ class Config(params_dict.ParamsDict):
RuntimeError
"""
subconfig_type = self._get_subconfig_type(k)
if isinstance(v, dict):
def is_null(k):
if k not in self.__dict__ or not self.__dict__[k]:
return True
return False
if isinstance(v, dict):
if is_null(k):
# If the key not exist or the value is None, a new Config-family object
# sould be created for the key.
self.__dict__[k] = subconfig_type(v)
else:
self.__dict__[k].override(v)
elif not is_null(k) and isinstance(v, self.SEQUENCE_TYPES) and all(
[not isinstance(e, self.IMMUTABLE_TYPES) for e in v]):
if len(self.__dict__[k]) == len(v):
for i in range(len(v)):
self.__dict__[k][i].override(v[i])
elif not all([isinstance(e, self.IMMUTABLE_TYPES) for e in v]):
logging.warning(
"The list/tuple don't match the value dictionaries provided. Thus, "
'the list/tuple is determined by the type annotation and '
'values provided. This is error-prone.')
self.__dict__[k] = self._import_config(v, subconfig_type)
else:
self.__dict__[k] = self._import_config(v, subconfig_type)
else:
self.__dict__[k] = self._import_config(v, subconfig_type)
......@@ -220,16 +239,19 @@ class Config(params_dict.ParamsDict):
}
def replace(self, **kwargs):
"""Like `override`, but returns a copy with the current config unchanged."""
params = self.__class__(self)
params.override(kwargs, is_strict=True)
"""Overrides/returns a unlocked copy with the current config unchanged."""
# pylint: disable=protected-access
params = copy.deepcopy(self)
params._locked = False
params._override(kwargs, is_strict=True)
# pylint: enable=protected-access
return params
@classmethod
def from_yaml(cls, file_path: str):
# Note: This only works if the Config has all default values.
with tf.io.gfile.GFile(file_path, 'r') as f:
loaded = yaml.load(f)
loaded = yaml.load(f, Loader=yaml.FullLoader)
config = cls()
config.override(loaded)
return config
......
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import pprint
from typing import List, Tuple
......@@ -45,6 +43,17 @@ class DumpConfig3(DumpConfig2):
g: Tuple[DumpConfig1, ...] = (DumpConfig1(),)
@dataclasses.dataclass
class DumpConfig4(DumpConfig2):
x: int = 3
@dataclasses.dataclass
class DummyConfig5(base_config.Config):
y: Tuple[DumpConfig2, ...] = (DumpConfig2(), DumpConfig4())
z: Tuple[str] = ('a',)
class BaseConfigTest(parameterized.TestCase, tf.test.TestCase):
def assertHasSameTypes(self, c, d, msg=''):
......@@ -106,6 +115,22 @@ class BaseConfigTest(parameterized.TestCase, tf.test.TestCase):
self.assertEqual(config.g[0].a, 4)
self.assertEqual(config.g[0].b, 'new text 3')
def test_replace(self):
config = DumpConfig2()
new_config = config.replace(e={'a': 2})
self.assertEqual(new_config.e.a, 2)
self.assertIsInstance(new_config.e, DumpConfig1)
config = DumpConfig2(e=DumpConfig2())
new_config = config.replace(e={'c': 4})
self.assertEqual(new_config.e.c, 4)
self.assertIsInstance(new_config.e, DumpConfig2)
config = DumpConfig3()
new_config = config.replace(g=[{'a': 4, 'b': 'new text 3'}])
self.assertIsInstance(new_config.g[0], DumpConfig1)
self.assertEqual(new_config.g[0].a, 4)
@parameterized.parameters(
('_locked', "The key '_locked' is internally reserved."),
('_restrictions', "The key '_restrictions' is internally reserved."),
......@@ -147,10 +172,8 @@ class BaseConfigTest(parameterized.TestCase, tf.test.TestCase):
params.override({'c': {'c3': 30}}, is_strict=True)
config = base_config.Config({'key': [{'a': 42}]})
config.override({'key': [{'b': 43}]})
self.assertEqual(config.key[0].b, 43)
with self.assertRaisesRegex(AttributeError, 'The key `a` does not exist'):
_ = config.key[0].a
with self.assertRaisesRegex(KeyError, "The key 'b' does not exist"):
config.override({'key': [{'b': 43}]})
@parameterized.parameters(
(lambda x: x, 'Unknown type'),
......@@ -294,6 +317,44 @@ class BaseConfigTest(parameterized.TestCase, tf.test.TestCase):
]),
"['s', 1, 1.0, True, None, {}, [], (), {8: 9, (2,): (3, [4], {6: 7})}]")
def test_with_restrictions(self):
restrictions = ['e.a<c']
config = DumpConfig2(restrictions=restrictions)
config.validate()
def test_nested_tuple(self):
config = DummyConfig5()
config.override({
'y': [{
'c': 4,
'd': 'new text 3',
'e': {
'a': 2
}
}, {
'c': 0,
'd': 'new text 3',
'e': {
'a': 2
}
}],
'z': ['a', 'b', 'c'],
})
self.assertEqual(config.y[0].c, 4)
self.assertEqual(config.y[1].c, 0)
self.assertIsInstance(config.y[0], DumpConfig2)
self.assertIsInstance(config.y[1], DumpConfig4)
self.assertSameElements(config.z, ['a', 'b', 'c'])
def test_override_by_empty_sequence(self):
config = DummyConfig5()
config.override({
'y': [],
'z': (),
}, is_strict=True)
self.assertEmpty(config.y)
self.assertEmpty(config.z)
if __name__ == '__main__':
tf.test.main()
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,124 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Common configuration settings."""
from typing import Optional, Union
"""Common configuration settings."""
# pylint:disable=wildcard-import
import dataclasses
from official.core.config_definitions import *
from official.modeling.hyperparams import base_config
from official.modeling.optimization.configs import optimization_config
from official.utils import registry
OptimizationConfig = optimization_config.OptimizationConfig
@dataclasses.dataclass
class DataConfig(base_config.Config):
"""The base configuration for building datasets.
Attributes:
input_path: The path to the input. It can be either (1) a file pattern, or
(2) multiple file patterns separated by comma. It should not be specified
when the following `tfds_name` is specified.
tfds_name: The name of the tensorflow dataset (TFDS). It should not be
specified when the above `input_path` is specified.
tfds_split: A str indicating which split of the data to load from TFDS. It
is required when above `tfds_name` is specified.
global_batch_size: The global batch size across all replicas.
is_training: Whether this data is used for training or not.
drop_remainder: Whether the last batch should be dropped in the case it has
fewer than `global_batch_size` elements.
shuffle_buffer_size: The buffer size used for shuffling training data.
cache: Whether to cache dataset examples. Can be used to avoid re-reading
from disk on the second epoch. Requires significant memory overhead.
cycle_length: The number of files that will be processed concurrently when
interleaving files.
block_length: The number of consecutive elements to produce from each input
element before cycling to another input element when interleaving files.
sharding: Whether sharding is used in the input pipeline.
examples_consume: An `integer` specifying the number of examples it will
produce. If positive, it only takes this number of examples and raises
tf.error.OutOfRangeError after that. Default is -1, meaning it will
exhaust all the examples in the dataset.
tfds_data_dir: A str specifying the directory to read/write TFDS data.
tfds_download: A bool to indicate whether to download data using TFDS.
tfds_as_supervised: A bool. When loading dataset from TFDS, if True,
the returned tf.data.Dataset will have a 2-tuple structure (input, label)
according to builder.info.supervised_keys; if False, the default,
the returned tf.data.Dataset will have a dictionary with all the features.
tfds_skip_decoding_feature: A str to indicate which features are skipped
for decoding when loading dataset from TFDS. Use comma to separate
multiple features. The main use case is to skip the image/video decoding
for better performance.
"""
input_path: str = ""
tfds_name: str = ""
tfds_split: str = ""
global_batch_size: int = 0
is_training: bool = None
drop_remainder: bool = True
shuffle_buffer_size: int = 100
cache: bool = False
cycle_length: int = 8
block_length: int = 1
sharding: bool = True
examples_consume: int = -1
tfds_data_dir: str = ""
tfds_download: bool = False
tfds_as_supervised: bool = False
tfds_skip_decoding_feature: str = ""
@dataclasses.dataclass
class RuntimeConfig(base_config.Config):
"""High-level configurations for Runtime.
These include parameters that are not directly related to the experiment,
e.g. directories, accelerator type, etc.
Attributes:
distribution_strategy: e.g. 'mirrored', 'tpu', etc.
enable_xla: Whether or not to enable XLA.
per_gpu_thread_count: thread count per GPU.
gpu_thread_mode: Whether and how the GPU device uses its own threadpool.
dataset_num_private_threads: Number of threads for a private threadpool
created for all datasets computation.
tpu: The address of the TPU to use, if any.
num_gpus: The number of GPUs to use, if any.
worker_hosts: comma-separated list of worker ip:port pairs for running
multi-worker models with DistributionStrategy.
task_index: If multi-worker training, the task index of this worker.
all_reduce_alg: Defines the algorithm for performing all-reduce.
num_packs: Sets `num_packs` in the cross device ops used in
MirroredStrategy. For details, see tf.distribute.NcclAllReduce.
mixed_precision_dtype: dtype of mixed precision policy. It can be 'float32',
'float16', or 'bfloat16'.
loss_scale: The type of loss scale, or 'float' value. This is used when
setting the mixed precision policy.
run_eagerly: Whether or not to run the experiment eagerly.
batchnorm_spatial_persistent: Whether or not to enable the spatial
persistent mode for CuDNN batch norm kernel for improved GPU performance.
"""
distribution_strategy: str = "mirrored"
enable_xla: bool = False
gpu_thread_mode: Optional[str] = None
dataset_num_private_threads: Optional[int] = None
per_gpu_thread_count: int = 0
tpu: Optional[str] = None
num_gpus: int = 0
worker_hosts: Optional[str] = None
task_index: int = -1
all_reduce_alg: Optional[str] = None
num_packs: int = 1
mixed_precision_dtype: Optional[str] = None
loss_scale: Optional[Union[str, float]] = None
run_eagerly: bool = False
batchnorm_spatial_persistent: bool = False
# TODO(hongkuny): These configs are used in models that are going to deprecate.
# Once those models are removed, we should delete this file to avoid confusion.
# Users should not use this file anymore.
@dataclasses.dataclass
class TensorboardConfig(base_config.Config):
"""Configuration for Tensorboard.
......@@ -151,75 +44,14 @@ class CallbacksConfig(base_config.Config):
Attributes:
enable_checkpoint_and_export: Whether or not to enable checkpoints as a
Callback. Defaults to True.
enable_backup_and_restore: Whether or not to add BackupAndRestore
callback. Defaults to True.
enable_tensorboard: Whether or not to enable Tensorboard as a Callback.
Defaults to True.
enable_time_history: Whether or not to enable TimeHistory Callbacks.
Defaults to True.
"""
enable_checkpoint_and_export: bool = True
enable_backup_and_restore: bool = False
enable_tensorboard: bool = True
enable_time_history: bool = True
@dataclasses.dataclass
class TrainerConfig(base_config.Config):
"""Configuration for trainer.
Attributes:
optimizer_config: optimizer config, it includes optimizer, learning rate,
and warmup schedule configs.
train_tf_while_loop: whether or not to use tf while loop.
train_tf_function: whether or not to use tf_function for training loop.
eval_tf_function: whether or not to use tf_function for eval.
steps_per_loop: number of steps per loop.
summary_interval: number of steps between each summary.
checkpoint_interval: number of steps between checkpoints.
max_to_keep: max checkpoints to keep.
continuous_eval_timeout: maximum number of seconds to wait between
checkpoints, if set to None, continuous eval will wait indefinitely.
train_steps: number of train steps.
validation_steps: number of eval steps. If `None`, the entire eval dataset
is used.
validation_interval: number of training steps to run between evaluations.
"""
optimizer_config: OptimizationConfig = OptimizationConfig()
train_tf_while_loop: bool = True
train_tf_function: bool = True
eval_tf_function: bool = True
steps_per_loop: int = 1000
summary_interval: int = 1000
checkpoint_interval: int = 1000
max_to_keep: int = 5
continuous_eval_timeout: Optional[int] = None
train_steps: int = 0
validation_steps: Optional[int] = None
validation_interval: int = 1000
@dataclasses.dataclass
class TaskConfig(base_config.Config):
model: base_config.Config = None
train_data: DataConfig = DataConfig()
validation_data: DataConfig = DataConfig()
@dataclasses.dataclass
class ExperimentConfig(base_config.Config):
"""Top-level configuration."""
task: TaskConfig = TaskConfig()
trainer: TrainerConfig = TrainerConfig()
runtime: RuntimeConfig = RuntimeConfig()
_REGISTERED_CONFIGS = {}
def register_config_factory(name):
"""Register ExperimentConfig factory method."""
return registry.register(_REGISTERED_CONFIGS, name)
def get_exp_config_creater(exp_name: str):
"""Looks up ExperimentConfig factory methods."""
exp_creater = registry.lookup(_REGISTERED_CONFIGS, exp_name)
return exp_creater
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Config class that supports oneof functionality."""
from typing import Optional
......@@ -38,15 +37,12 @@ class OneOfConfig(base_config.Config):
if self.type is None:
return {'type': None}
elif self.__dict__['type'] not in self.__dict__:
raise ValueError(
'type: {!r} is not a valid key!'.format(self.__dict__['type']))
raise ValueError('type: {!r} is not a valid key!'.format(
self.__dict__['type']))
else:
chosen_type = self.type
chosen_value = self.__dict__[chosen_type]
return {
'type': self.type,
chosen_type: self._export_config(chosen_value)
}
return {'type': self.type, chosen_type: self._export_config(chosen_value)}
def get(self):
"""Returns selected config based on the value of type.
......@@ -57,6 +53,5 @@ class OneOfConfig(base_config.Config):
if chosen_type is None:
return None
if chosen_type not in self.__dict__:
raise ValueError(
'type: {!r} is not a valid key!'.format(self.type))
raise ValueError('type: {!r} is not a valid key!'.format(self.type))
return self.__dict__[chosen_type]
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import dataclasses
import tensorflow as tf
......@@ -48,12 +46,18 @@ class Network(base_config.Config):
class OneOfTest(tf.test.TestCase):
def test_to_dict(self):
network_params = {'backbone': {'type': 'resnet',
'resnet': {'model_depth': 50}
},
'output_layer': {'type': 'single',
'single': 1000}
}
network_params = {
'backbone': {
'type': 'resnet',
'resnet': {
'model_depth': 50
}
},
'output_layer': {
'type': 'single',
'single': 1000
}
}
network_config = Network(network_params)
self.assertEqual(network_config.as_dict(), network_params)
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,12 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A parameter dictionary class which supports the nest structure."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""A parameter dictionary class which supports the nest structure."""
import collections
import copy
......@@ -30,7 +26,8 @@ import yaml
# key-value pair string. It splits each k-v pair on the = sign, and
# matches on values that are within single quotes, double quotes, single
# values (e.g. floats, ints, etc.), and a lists within brackets.
_PARAM_RE = re.compile(r"""
_PARAM_RE = re.compile(
r"""
(?P<name>[a-zA-Z][\w\.]*) # variable name: "var" or "x"
\s*=\s*
((?P<val>\'(.*?)\' # single quote
......@@ -44,6 +41,26 @@ _PARAM_RE = re.compile(r"""
_CONST_VALUE_RE = re.compile(r'(\d.*|-\d.*|None)')
# Yaml loader with an implicit resolver to parse float decimal and exponential
# format. The regular experission parse the following cases:
# 1- Decimal number with an optional exponential term.
# 2- Integer number with an exponential term.
# 3- Decimal number with an optional exponential term.
# 4- Decimal number.
LOADER = yaml.SafeLoader
LOADER.add_implicit_resolver(
'tag:yaml.org,2002:float',
re.compile(r'''
^(?:[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|
[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|
\\.[0-9_]+(?:[eE][-+][0-9]+)?
|
[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*)$''', re.X),
list('-+0123456789.'))
class ParamsDict(object):
"""A hyperparameter container class."""
......@@ -72,7 +89,6 @@ class ParamsDict(object):
if default_params is None:
default_params = {}
self.override(default_params, is_strict=False)
self.validate()
def _set(self, k, v):
if isinstance(v, dict):
......@@ -138,8 +154,8 @@ class ParamsDict(object):
ValueError: if the ParamsDict instance has been locked.
"""
if k in ParamsDict.RESERVED_ATTR:
raise AttributeError('The key `{}` is reserved. No change is allowes. '
.format(k))
raise AttributeError(
'The key `{}` is reserved. No change is allowes. '.format(k))
if k not in self.__dict__.keys():
raise AttributeError('The key `{}` does not exist. '.format(k))
if self._locked:
......@@ -150,13 +166,13 @@ class ParamsDict(object):
"""Override the ParamsDict with a set of given params.
Args:
override_params: a dict or a ParamsDict specifying the parameters to
be overridden.
override_params: a dict or a ParamsDict specifying the parameters to be
overridden.
is_strict: a boolean specifying whether override is strict or not. If
True, keys in `override_params` must be present in the ParamsDict.
If False, keys in `override_params` can be different from what is
currently defined in the ParamsDict. In this case, the ParamsDict will
be extended to include the new keys.
True, keys in `override_params` must be present in the ParamsDict. If
False, keys in `override_params` can be different from what is currently
defined in the ParamsDict. In this case, the ParamsDict will be extended
to include the new keys.
"""
if self._locked:
raise ValueError('The ParamsDict has been locked. No change is allowed.')
......@@ -230,7 +246,7 @@ class ParamsDict(object):
['a.a1 == b.ccc.a1', 'a.a2 <= b.bb.bb2']
What it enforces are:
- a.a1 = 1 == b.ccc.a1 = 2
- a.a1 = 1 == b.ccc.a1 = 1
- a.a2 = 2 <= b.bb.bb2 = 20
Raises:
......@@ -240,6 +256,7 @@ class ParamsDict(object):
(2) any inconsistency violating the restriction is found.
ValueError: if the restriction defined in the string is not supported.
"""
def _get_kv(dotted_string, params_dict):
"""Get keys and values indicated by dotted_string."""
if _CONST_VALUE_RE.match(dotted_string) is not None:
......@@ -270,56 +287,64 @@ class ParamsDict(object):
tokens = restriction.split('==')
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v != right_v:
raise KeyError('Found inconsistncy between key `{}` and key `{}`.'
.format(tokens[0], tokens[1]))
raise KeyError(
'Found inconsistncy between key `{}` and key `{}`.'.format(
tokens[0], tokens[1]))
elif '!=' in restriction:
tokens = restriction.split('!=')
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v == right_v:
raise KeyError('Found inconsistncy between key `{}` and key `{}`.'
.format(tokens[0], tokens[1]))
raise KeyError(
'Found inconsistncy between key `{}` and key `{}`.'.format(
tokens[0], tokens[1]))
elif '<' in restriction:
tokens = restriction.split('<')
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v >= right_v:
raise KeyError('Found inconsistncy between key `{}` and key `{}`.'
.format(tokens[0], tokens[1]))
raise KeyError(
'Found inconsistncy between key `{}` and key `{}`.'.format(
tokens[0], tokens[1]))
elif '<=' in restriction:
tokens = restriction.split('<=')
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v > right_v:
raise KeyError('Found inconsistncy between key `{}` and key `{}`.'
.format(tokens[0], tokens[1]))
raise KeyError(
'Found inconsistncy between key `{}` and key `{}`.'.format(
tokens[0], tokens[1]))
elif '>' in restriction:
tokens = restriction.split('>')
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v <= right_v:
raise KeyError('Found inconsistncy between key `{}` and key `{}`.'
.format(tokens[0], tokens[1]))
raise KeyError(
'Found inconsistncy between key `{}` and key `{}`.'.format(
tokens[0], tokens[1]))
elif '>=' in restriction:
tokens = restriction.split('>=')
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v < right_v:
raise KeyError('Found inconsistncy between key `{}` and key `{}`.'
.format(tokens[0], tokens[1]))
raise KeyError(
'Found inconsistncy between key `{}` and key `{}`.'.format(
tokens[0], tokens[1]))
else:
raise ValueError('Unsupported relation in restriction.')
def read_yaml_to_params_dict(file_path):
def read_yaml_to_params_dict(file_path: str):
"""Reads a YAML file to a ParamsDict."""
with tf.io.gfile.GFile(file_path, 'r') as f:
params_dict = yaml.load(f)
params_dict = yaml.load(f, Loader=LOADER)
return ParamsDict(params_dict)
def save_params_dict_to_yaml(params, file_path):
"""Saves the input ParamsDict to a YAML file."""
with tf.io.gfile.GFile(file_path, 'w') as f:
def _my_list_rep(dumper, data):
# u'tag:yaml.org,2002:seq' is the YAML internal tag for sequence.
return dumper.represent_sequence(
u'tag:yaml.org,2002:seq', data, flow_style=True)
yaml.add_representer(list, _my_list_rep)
yaml.dump(params.as_dict(), f, default_flow_style=False)
......@@ -408,8 +433,8 @@ def override_params_dict(params, dict_or_string_or_yaml_file, is_strict):
Args:
params: a ParamsDict object to be overridden.
dict_or_string_or_yaml_file: a Python dict, JSON/YAML/CSV string or
path to a YAML file specifying the parameters to be overridden.
dict_or_string_or_yaml_file: a Python dict, JSON/YAML/CSV string or path to
a YAML file specifying the parameters to be overridden.
is_strict: a boolean specifying whether override is strict or not.
Returns:
......@@ -428,12 +453,12 @@ def override_params_dict(params, dict_or_string_or_yaml_file, is_strict):
nested_csv_str_to_json_str(dict_or_string_or_yaml_file))
except ValueError:
pass
params_dict = yaml.load(dict_or_string_or_yaml_file)
params_dict = yaml.load(dict_or_string_or_yaml_file, Loader=LOADER)
if isinstance(params_dict, dict):
params.override(params_dict, is_strict)
else:
with tf.io.gfile.GFile(dict_or_string_or_yaml_file) as f:
params.override(yaml.load(f), is_strict)
params.override(yaml.load(f, Loader=yaml.FullLoader), is_strict)
else:
raise ValueError('Unknown input type to parse.')
return params
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for params_dict.py."""
......@@ -56,8 +55,7 @@ class ParamsDictTest(tf.test.TestCase):
def test_setattr(self):
params = params_dict.ParamsDict()
params.override(
{'a': 'aa', 'b': 2, 'c': None}, is_strict=False)
params.override({'a': 'aa', 'b': 2, 'c': None}, is_strict=False)
params.c = 'ccc'
self.assertEqual(params.a, 'aa')
self.assertEqual(params.b, 2)
......@@ -65,17 +63,23 @@ class ParamsDictTest(tf.test.TestCase):
def test_getattr(self):
params = params_dict.ParamsDict()
params.override(
{'a': 'aa', 'b': 2, 'c': None}, is_strict=False)
params.override({'a': 'aa', 'b': 2, 'c': None}, is_strict=False)
self.assertEqual(params.a, 'aa')
self.assertEqual(params.b, 2)
self.assertEqual(params.c, None)
def test_delattr(self):
params = params_dict.ParamsDict()
params.override(
{'a': 'aa', 'b': 2, 'c': None, 'd': {'d1': 1, 'd2': 10}},
is_strict=False)
params.override({
'a': 'aa',
'b': 2,
'c': None,
'd': {
'd1': 1,
'd2': 10
}
},
is_strict=False)
del params.c
self.assertEqual(params.a, 'aa')
self.assertEqual(params.b, 2)
......@@ -87,22 +91,26 @@ class ParamsDictTest(tf.test.TestCase):
def test_contains(self):
params = params_dict.ParamsDict()
params.override(
{'a': 'aa'}, is_strict=False)
params.override({'a': 'aa'}, is_strict=False)
self.assertIn('a', params)
self.assertNotIn('b', params)
def test_get(self):
params = params_dict.ParamsDict()
params.override(
{'a': 'aa'}, is_strict=False)
params.override({'a': 'aa'}, is_strict=False)
self.assertEqual(params.get('a'), 'aa')
self.assertEqual(params.get('b', 2), 2)
self.assertEqual(params.get('b'), None)
def test_override_is_strict_true(self):
params = params_dict.ParamsDict(
{'a': 'aa', 'b': 2, 'c': {'c1': 'cc', 'c2': 20}})
params = params_dict.ParamsDict({
'a': 'aa',
'b': 2,
'c': {
'c1': 'cc',
'c2': 20
}
})
params.override({'a': 2, 'c': {'c1': 'ccc'}}, is_strict=True)
self.assertEqual(params.a, 2)
self.assertEqual(params.c.c1, 'ccc')
......@@ -112,8 +120,14 @@ class ParamsDictTest(tf.test.TestCase):
params.override({'c': {'c3': 30}}, is_strict=True)
def test_override_is_strict_false(self):
params = params_dict.ParamsDict(
{'a': 'aa', 'b': 2, 'c': {'c1': 10, 'c2': 20}})
params = params_dict.ParamsDict({
'a': 'aa',
'b': 2,
'c': {
'c1': 10,
'c2': 20
}
})
params.override({'a': 2, 'c': {'c3': 3000}}, is_strict=False)
self.assertEqual(params.a, 2)
self.assertEqual(params.c.c3, 3000)
......@@ -123,8 +137,14 @@ class ParamsDictTest(tf.test.TestCase):
self.assertEqual(params.c.c4, 4444)
def test_as_dict(self):
params = params_dict.ParamsDict(
{'a': 'aa', 'b': 2, 'c': {'c1': 10, 'c2': 20}})
params = params_dict.ParamsDict({
'a': 'aa',
'b': 2,
'c': {
'c1': 10,
'c2': 20
}
})
params_d = params.as_dict()
self.assertEqual(params_d['a'], 'aa')
self.assertEqual(params_d['b'], 2)
......@@ -134,21 +154,27 @@ class ParamsDictTest(tf.test.TestCase):
def test_validate(self):
# Raise error due to the unknown parameter.
with self.assertRaises(KeyError):
params = params_dict.ParamsDict(
{'a': 1, 'b': {'a': 11}}, ['a == c'])
params = params_dict.ParamsDict({'a': 1, 'b': {'a': 11}}, ['a == c'])
params.validate()
# OK to check equality of two nested dicts.
params = params_dict.ParamsDict(
{'a': 1, 'b': {'a': 10}, 'c': {'a': 10}}, ['b == c'])
params = params_dict.ParamsDict({
'a': 1,
'b': {
'a': 10
},
'c': {
'a': 10
}
}, ['b == c'])
# Raise error due to inconsistency
with self.assertRaises(KeyError):
params = params_dict.ParamsDict(
{'a': 1, 'c': {'a': 10}}, ['a == c.a'])
params = params_dict.ParamsDict({'a': 1, 'c': {'a': 10}}, ['a == c.a'])
params.validate()
# Valid rule.
params = params_dict.ParamsDict(
{'a': 1, 'c': {'a': 1}}, ['a == c.a'])
params = params_dict.ParamsDict({'a': 1, 'c': {'a': 1}}, ['a == c.a'])
# Overridding violates the existing rule, raise error upon validate.
params.override({'a': 11})
......@@ -156,12 +182,21 @@ class ParamsDictTest(tf.test.TestCase):
params.validate()
# Valid restrictions with constant.
params = params_dict.ParamsDict(
{'a': None, 'c': {'a': 1}}, ['a == None', 'c.a == 1'])
params = params_dict.ParamsDict({
'a': None,
'c': {
'a': 1
}
}, ['a == None', 'c.a == 1'])
params.validate()
with self.assertRaises(KeyError):
params = params_dict.ParamsDict(
{'a': 4, 'c': {'a': 1}}, ['a == None', 'c.a == 1'])
params = params_dict.ParamsDict({
'a': 4,
'c': {
'a': 1
}
}, ['a == None', 'c.a == 1'])
params.validate()
class ParamsDictIOTest(tf.test.TestCase):
......@@ -173,8 +208,14 @@ class ParamsDictIOTest(tf.test.TestCase):
return temp_file
def test_save_params_dict_to_yaml(self):
params = params_dict.ParamsDict(
{'a': 'aa', 'b': 2, 'c': {'c1': 10, 'c2': 20}})
params = params_dict.ParamsDict({
'a': 'aa',
'b': 2,
'c': {
'c1': 10,
'c2': 20
}
})
output_yaml_file = os.path.join(self.get_temp_dir(), 'params.yaml')
params_dict.save_params_dict_to_yaml(params, output_yaml_file)
......@@ -203,7 +244,12 @@ class ParamsDictIOTest(tf.test.TestCase):
def test_override_params_dict_using_dict(self):
params = params_dict.ParamsDict({
'a': 1, 'b': 2.5, 'c': [3, 4], 'd': 'hello', 'e': False})
'a': 1,
'b': 2.5,
'c': [3, 4],
'd': 'hello',
'e': False
})
override_dict = {'b': 5.2, 'c': [30, 40]}
params = params_dict.override_params_dict(
params, override_dict, is_strict=True)
......@@ -215,7 +261,12 @@ class ParamsDictIOTest(tf.test.TestCase):
def test_override_params_dict_using_yaml_string(self):
params = params_dict.ParamsDict({
'a': 1, 'b': 2.5, 'c': [3, 4], 'd': 'hello', 'e': False})
'a': 1,
'b': 2.5,
'c': [3, 4],
'd': 'hello',
'e': False
})
override_yaml_string = "'b': 5.2\n'c': [30, 40]"
params = params_dict.override_params_dict(
params, override_yaml_string, is_strict=True)
......@@ -227,8 +278,18 @@ class ParamsDictIOTest(tf.test.TestCase):
def test_override_params_dict_using_json_string(self):
params = params_dict.ParamsDict({
'a': 1, 'b': {'b1': 2, 'b2': [2, 3],},
'd': {'d1': {'d2': 'hello'}}, 'e': False})
'a': 1,
'b': {
'b1': 2,
'b2': [2, 3],
},
'd': {
'd1': {
'd2': 'hello'
}
},
'e': False
})
override_json_string = "{ b: { b2: [3, 4] }, d: { d1: { d2: 'hi' } } }"
params = params_dict.override_params_dict(
params, override_json_string, is_strict=True)
......@@ -240,8 +301,18 @@ class ParamsDictIOTest(tf.test.TestCase):
def test_override_params_dict_using_csv_string(self):
params = params_dict.ParamsDict({
'a': 1, 'b': {'b1': 2, 'b2': [2, 3],},
'd': {'d1': {'d2': 'hello'}}, 'e': False})
'a': 1,
'b': {
'b1': 2,
'b2': [2, 3],
},
'd': {
'd1': {
'd2': 'hello'
}
},
'e': False
})
override_csv_string = "b.b2=[3,4], d.d1.d2='hi, world', e=gs://test"
params = params_dict.override_params_dict(
params, override_csv_string, is_strict=True)
......@@ -250,10 +321,23 @@ class ParamsDictIOTest(tf.test.TestCase):
self.assertEqual([3, 4], params.b.b2)
self.assertEqual('hi, world', params.d.d1.d2)
self.assertEqual('gs://test', params.e)
# Test different float formats
override_csv_string = 'b.b2=-1.e-3, d.d1.d2=+0.001, e=1e+3, a=-1.5E-3'
params = params_dict.override_params_dict(
params, override_csv_string, is_strict=True)
self.assertEqual(-1e-3, params.b.b2)
self.assertEqual(0.001, params.d.d1.d2)
self.assertEqual(1e3, params.e)
self.assertEqual(-1.5e-3, params.a)
def test_override_params_dict_using_yaml_file(self):
params = params_dict.ParamsDict({
'a': 1, 'b': 2.5, 'c': [3, 4], 'd': 'hello', 'e': False})
'a': 1,
'b': 2.5,
'c': [3, 4],
'd': 'hello',
'e': False
})
override_yaml_file = self.write_temp_file(
'params.yaml', r"""
b: 5.2
......@@ -321,8 +405,7 @@ class IOTest(tf.test.TestCase):
def test_csv_str_load_unsupported_datatypes(self):
csv_str = 'a=[[1,2,3],[4,5,6]]'
self.assertRaises(ValueError,
params_dict.nested_csv_str_to_json_str,
self.assertRaises(ValueError, params_dict.nested_csv_str_to_json_str,
csv_str)
def test_csv_str_to_json_str_spacing(self):
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Abstraction of multi-task model."""
from typing import Text, Dict
import tensorflow as tf
class MultiTaskBaseModel(tf.Module):
"""Base class that holds multi-task model computation."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._sub_tasks = self._instantiate_sub_tasks()
def _instantiate_sub_tasks(self) -> Dict[Text, tf.keras.Model]:
"""Abstract function that sets up the computation for each sub-task.
Returns:
A map from task name (as string) to a tf.keras.Model object that
represents the sub-task in the multi-task pool.
"""
raise NotImplementedError(
"_instantiate_sub_task_models() is not implemented.")
@property
def sub_tasks(self):
"""Fetch a map of task name (string) to task model (tf.keras.Model)."""
return self._sub_tasks
def initialize(self):
"""Optional function that loads a pre-train checkpoint."""
return
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Multitask base trainer implementation.
The trainer derives from the Orbit `StandardTrainer` class.
"""
from typing import Union
import gin
import orbit
import tensorflow as tf
from official.modeling.multitask import base_model
from official.modeling.multitask import multitask
@gin.configurable
class MultiTaskBaseTrainer(orbit.StandardTrainer):
"""Multitask base trainer."""
def __init__(self,
multi_task: multitask.MultiTask,
multi_task_model: Union[tf.keras.Model,
base_model.MultiTaskBaseModel],
optimizer: tf.optimizers.Optimizer,
trainer_options=None):
self._strategy = tf.distribute.get_strategy()
self._multi_task = multi_task
self._multi_task_model = multi_task_model
self._optimizer = optimizer
self._training_losses = None
self._training_metrics = None
self._global_step = orbit.utils.create_global_step()
if hasattr(self.multi_task_model, "checkpoint_items"):
checkpoint_items = self.multi_task_model.checkpoint_items
else:
checkpoint_items = {}
self._checkpoint = tf.train.Checkpoint(
model=self.multi_task_model,
optimizer=self.optimizer,
global_step=self.global_step,
**checkpoint_items)
train_datasets = {}
for name, task in self.multi_task.tasks.items():
train_datasets[name] = orbit.utils.make_distributed_dataset(
self.strategy, task.build_inputs, task.task_config.train_data)
super().__init__(
train_dataset=train_datasets,
options=trainer_options or orbit.StandardTrainerOptions())
def train_loop_begin(self):
"""Clean up states that hold losses and metrics."""
for _, train_loss_metric in self.training_losses.items():
train_loss_metric.reset_states()
for _, metrics in self.training_metrics.items():
for metric in metrics:
metric.reset_states()
def train_loop_end(self):
"""Record loss and metric values per task."""
result = {}
for task_name, loss in self.training_losses.items():
result[task_name] = {loss.name: loss.result()}
for task_name, task_metrics in self.training_metrics.items():
result[task_name].update(
{metric.name: metric.result() for metric in task_metrics})
# Note that, the learning rate schedule is managed by the keras optimizer
# internally, which respects the number of backward pass as `iterations`.
# The learning rate schedule does not follow the trainer logical global
# step of multiple tasks.
if callable(self.optimizer.learning_rate):
result["learning_rate"] = self.optimizer.learning_rate(
self.optimizer.iterations)
else:
result["learning_rate"] = self.optimizer.learning_rate
return result
@property
def checkpoint(self):
"""Accesses the training checkpoint."""
return self._checkpoint
@property
def training_losses(self):
"""Access training loss metric objects for all tasks."""
if self._training_losses is None:
# Builds the per-task metrics and losses.
# This the total summed training loss of tasks in the joint training.
self._training_losses = dict(
total_loss=tf.keras.metrics.Mean("training_loss", dtype=tf.float32))
for name in self.multi_task.tasks:
self._training_losses[name] = tf.keras.metrics.Mean(
"training_loss", dtype=tf.float32)
return self._training_losses
@property
def training_metrics(self):
"""Access training metric metric objects for all tasks."""
if self._training_metrics is None:
# Builds the per-task metrics and losses.
self._training_metrics = {}
for name, task in self.multi_task.tasks.items():
self._training_metrics[name] = task.build_metrics(training=True)
return self._training_metrics
@property
def strategy(self):
return self._strategy
@property
def multi_task(self):
return self._multi_task
@property
def multi_task_model(self):
return self._multi_task_model
@property
def optimizer(self):
return self._optimizer
@property
def global_step(self):
return self._global_step
def train_step(self, iterator_map):
"""The default train step calling the multi-task train step.
Args:
iterator_map: a dictionary of task names and per-task dataset iterators.
"""
def step_fn(inputs):
losses = self.multi_task.joint_train_step(
inputs,
multi_task_model=self.multi_task_model,
optimizer=self.optimizer,
task_metrics=self.training_metrics)
for key, loss in losses.items():
self.training_losses[key].update_state(loss)
self.strategy.run(
step_fn, args=(tf.nest.map_structure(next, iterator_map),))
self.global_step.assign_add(1)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for multitask.base_trainer."""
from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.modeling.multitask import base_trainer
from official.modeling.multitask import configs
from official.modeling.multitask import multitask
from official.modeling.multitask import test_utils
def all_strategy_combinations():
return combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
mode="eager",
)
class BaseTrainerTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(all_strategy_combinations())
def test_multitask_joint_trainer(self, distribution):
with distribution.scope():
tasks = [
test_utils.MockFooTask(params=test_utils.FooConfig(), name="foo"),
test_utils.MockBarTask(params=test_utils.BarConfig(), name="bar")
]
task_weights = {"foo": 1.0, "bar": 1.0}
test_multitask = multitask.MultiTask(
tasks=tasks, task_weights=task_weights)
test_optimizer = tf.keras.optimizers.SGD(0.1)
model = test_utils.MockMultiTaskModel()
test_trainer = base_trainer.MultiTaskBaseTrainer(
multi_task=test_multitask,
multi_task_model=model,
optimizer=test_optimizer)
results = test_trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertContainsSubset(["training_loss", "bar_acc"],
results["bar"].keys())
self.assertContainsSubset(["training_loss", "foo_acc"],
results["foo"].keys())
def test_trainer_with_configs(self):
config = configs.MultiTaskConfig(
task_routines=(configs.TaskRoutine(
task_name="foo",
task_config=test_utils.FooConfig(),
task_weight=0.5),
configs.TaskRoutine(
task_name="bar",
task_config=test_utils.BarConfig(),
task_weight=0.5)))
test_multitask = multitask.MultiTask.from_config(config)
test_optimizer = tf.keras.optimizers.SGD(0.1)
model = test_utils.MockMultiTaskModel()
test_trainer = base_trainer.MultiTaskBaseTrainer(
multi_task=test_multitask,
multi_task_model=model,
optimizer=test_optimizer)
results = test_trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertContainsSubset(["training_loss", "bar_acc"],
results["bar"].keys())
self.assertContainsSubset(["training_loss", "foo_acc"],
results["foo"].keys())
self.assertEqual(test_multitask.task_weight("foo"), 0.5)
self.assertEqual(test_trainer.global_step.numpy(), 5)
self.assertIn("learning_rate", results)
if __name__ == "__main__":
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration definitions for multi-task training."""
from typing import Optional, Tuple
import dataclasses
from official.core import config_definitions as cfg
from official.modeling import hyperparams
@dataclasses.dataclass
class TaskRoutine(hyperparams.Config):
task_name: str = ""
task_config: cfg.TaskConfig = None
eval_steps: Optional[int] = None
task_weight: Optional[float] = 1.0
@dataclasses.dataclass
class MultiTaskConfig(hyperparams.Config):
init_checkpoint: str = ""
model: hyperparams.Config = None
task_routines: Tuple[TaskRoutine, ...] = ()
@dataclasses.dataclass
class ProportionalSampleConfig(hyperparams.Config):
alpha: float = 1.0
@dataclasses.dataclass
class AnnealingSampleConfig(hyperparams.Config):
steps_per_epoch: int = 5
total_steps: int = 20
@dataclasses.dataclass
class TaskSamplingConfig(hyperparams.OneOfConfig):
type: str = ""
uniform: hyperparams.Config = hyperparams.Config()
proportional: ProportionalSampleConfig = ProportionalSampleConfig()
annealing: AnnealingSampleConfig = AnnealingSampleConfig()
@dataclasses.dataclass
class MultiTaskTrainerConfig(cfg.TrainerConfig):
trainer_type: str = "interleaving"
task_sampler: TaskSamplingConfig = TaskSamplingConfig(type="proportional")
@dataclasses.dataclass
class MultiTaskExperimentConfig(hyperparams.Config):
"""An experiment config for multi-task training and multi-task evaluation."""
task: MultiTaskConfig = MultiTaskConfig()
trainer: MultiTaskTrainerConfig = MultiTaskTrainerConfig()
runtime: cfg.RuntimeConfig = cfg.RuntimeConfig()
@dataclasses.dataclass
class MultiEvalExperimentConfig(cfg.ExperimentConfig):
"""An experiment config for single-task training and multi-task evaluation.
Attributes:
eval_tasks: individual evaluation tasks.
"""
eval_tasks: MultiTaskConfig = MultiTaskConfig()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multitask Evaluator implementation.
The evaluator implements the Orbit `AbstractEvaluator` interface.
"""
from typing import Optional, Union
import gin
import orbit
import tensorflow as tf
from official.core import train_utils
from official.modeling.multitask import base_model
from official.modeling.multitask import multitask
@gin.configurable
class MultiTaskEvaluator(orbit.AbstractEvaluator):
"""Implements the common trainer shared for TensorFlow models."""
def __init__(
self,
task: multitask.MultiTask,
model: Union[tf.keras.Model, base_model.MultiTaskBaseModel],
global_step: Optional[tf.Variable] = None,
checkpoint_exporter: Optional[train_utils.BestCheckpointExporter] = None):
"""Initialize common trainer for TensorFlow models.
Args:
task: A multitask.MultiTask instance.
model: tf.keras.Model instance.
global_step: the global step variable.
checkpoint_exporter: an object that has the `maybe_export_checkpoint`
interface.
"""
# Gets the current distribution strategy. If not inside any strategy scope,
# it gets a single-replica no-op strategy.
self._strategy = tf.distribute.get_strategy()
self._task = task
self._model = model
self._global_step = global_step or orbit.utils.create_global_step()
self._checkpoint_exporter = checkpoint_exporter
self._checkpoint = tf.train.Checkpoint(
global_step=self.global_step,
model=self.model)
self._validation_losses = None
self._validation_metrics = None
# Builds per-task datasets.
self.eval_datasets = {}
for name, task in self.task.tasks.items():
self.eval_datasets[name] = orbit.utils.make_distributed_dataset(
self.strategy, task.build_inputs, task.task_config.validation_data)
# Builds per-task validation loops.
def get_function(task_name, task):
task_metrics = self.validation_metrics[task_name]
task_loss = self.validation_losses[task_name]
if isinstance(self.model, base_model.MultiTaskBaseModel):
model = self.model.sub_tasks[task_name]
else:
model = self.model
def step_fn(inputs):
logs = task.validation_step(inputs, model=model, metrics=task_metrics)
task_loss.update_state(logs[task.loss])
return logs
@tf.function
def eval_step_fn(iterator):
distributed_outputs = self.strategy.run(step_fn, args=(next(iterator),))
return tf.nest.map_structure(self.strategy.experimental_local_results,
distributed_outputs)
return orbit.utils.create_loop_fn(eval_step_fn)
self.task_fns = {
name: get_function(name, task)
for name, task in self.task.tasks.items()
}
@property
def strategy(self):
return self._strategy
@property
def task(self):
return self._task
@property
def model(self):
return self._model
@property
def global_step(self):
return self._global_step
@property
def validation_losses(self):
"""Accesses the validation loss metric object."""
if self._validation_losses is None:
# Builds the per-task metrics and losses.
self._validation_losses = {}
for name in self.task.tasks:
self._validation_losses[name] = tf.keras.metrics.Mean(
"validation_loss", dtype=tf.float32)
return self._validation_losses
@property
def validation_metrics(self):
"""Accesses all validation metric metric objects."""
if self._validation_metrics is None:
# Builds the per-task metrics and losses.
self._validation_metrics = {}
for name, task in self.task.tasks.items():
self._validation_metrics[name] = task.build_metrics(training=False)
return self._validation_metrics
@property
def checkpoint(self):
"""Accesses the training checkpoint."""
return self._checkpoint
def evaluate(self, num_steps: tf.Tensor):
"""Performs evaluation for each `EvalTask`."""
for metric in self.validation_losses.values():
metric.reset_states()
for metrics in self.validation_metrics.values():
for metric in metrics:
metric.reset_states()
results = {}
eval_iters = tf.nest.map_structure(iter, self.eval_datasets)
for name, task_eval_loop in self.task_fns.items():
outputs = None
eval_iter = eval_iters[name]
task = self.task.tasks[name]
task_eval_steps = self.task.task_eval_steps(name) or num_steps
outputs = task_eval_loop(
eval_iter,
task_eval_steps,
state=outputs,
reduce_fn=task.aggregate_logs)
task_metrics = self.validation_metrics[name]
task_loss = self.validation_losses[name]
logs = {}
for metric in task_metrics + [task_loss]:
logs[metric.name] = metric.result()
if outputs:
metrics = task.reduce_aggregated_logs(
outputs, global_step=self.global_step)
logs.update(metrics)
results[name] = logs
if self._checkpoint_exporter:
self._checkpoint_exporter.maybe_export_checkpoint(
self.checkpoint, results, self.global_step.numpy())
return results
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for multitask.evaluator."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.core import base_task
from official.core import config_definitions as cfg
from official.modeling.multitask import evaluator
from official.modeling.multitask import multitask
def all_strategy_combinations():
return combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
mode="eager",
)
class MockModel(tf.keras.Model):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.dense = tf.keras.layers.Dense(1)
def call(self, inputs):
print(inputs, type(inputs))
if "y" in inputs:
self.add_loss(tf.zeros((1,), dtype=tf.float32))
else:
self.add_loss(tf.ones((1,), dtype=tf.float32))
return self.dense(inputs["x"])
class MockTask(base_task.Task):
"""Mock task object for testing."""
def build_metrics(self, training: bool = True):
del training
return [tf.keras.metrics.Accuracy(name="acc")]
def build_inputs(self, params):
def generate_data(_):
x = tf.zeros(shape=(2,), dtype=tf.float32)
label = tf.zeros([1], dtype=tf.int32)
if self.name == "bar":
return dict(x=x, y=x), label
else:
return dict(x=x), label
dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat()
dataset = dataset.map(
generate_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset.prefetch(buffer_size=1).batch(2, drop_remainder=True)
def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
logs = super().validation_step(inputs, model, metrics)
logs["counter"] = tf.ones((1,), dtype=tf.float32)
return logs
def aggregate_logs(self, state, step_outputs):
if state is None:
state = {}
for key, value in step_outputs.items():
if key not in state:
state[key] = []
state[key].append(
np.concatenate([np.expand_dims(v.numpy(), axis=0) for v in value]))
return state
def reduce_aggregated_logs(self,
aggregated_logs,
global_step=None):
for k, v in aggregated_logs.items():
aggregated_logs[k] = np.sum(np.stack(v, axis=0))
return aggregated_logs
class EvaluatorTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(all_strategy_combinations())
def test_multitask_evaluator(self, distribution):
with distribution.scope():
tasks = [
MockTask(params=cfg.TaskConfig(), name="bar"),
MockTask(params=cfg.TaskConfig(), name="foo")
]
test_multitask = multitask.MultiTask(tasks=tasks)
model = MockModel()
test_evaluator = evaluator.MultiTaskEvaluator(
task=test_multitask, model=model)
results = test_evaluator.evaluate(tf.convert_to_tensor(1, dtype=tf.int32))
self.assertContainsSubset(["validation_loss", "acc"], results["bar"].keys())
self.assertContainsSubset(["validation_loss", "acc"], results["foo"].keys())
self.assertEqual(results["bar"]["validation_loss"], 0.0)
self.assertEqual(results["foo"]["validation_loss"], 1.0)
@combinations.generate(all_strategy_combinations())
def test_multitask_evaluator_numpy_metrics(self, distribution):
with distribution.scope():
tasks = [
MockTask(params=cfg.TaskConfig(), name="bar"),
MockTask(params=cfg.TaskConfig(), name="foo")
]
test_multitask = multitask.MultiTask(tasks=tasks)
model = MockModel()
test_evaluator = evaluator.MultiTaskEvaluator(
task=test_multitask, model=model)
results = test_evaluator.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertEqual(results["bar"]["counter"],
5. * distribution.num_replicas_in_sync)
self.assertEqual(results["foo"]["counter"],
5. * distribution.num_replicas_in_sync)
if __name__ == "__main__":
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multitask trainer that interleaves each task's train step."""
from typing import Union
import gin
import orbit
import tensorflow as tf
from official.modeling.multitask import base_model
from official.modeling.multitask import base_trainer
from official.modeling.multitask import multitask
from official.modeling.multitask import task_sampler as sampler
@gin.configurable
class MultiTaskInterleavingTrainer(base_trainer.MultiTaskBaseTrainer):
"""MultiTask trainer that interleaves task update."""
def __init__(self,
multi_task: multitask.MultiTask,
multi_task_model: Union[tf.keras.Model,
base_model.MultiTaskBaseModel],
optimizer: tf.optimizers.Optimizer,
task_sampler: sampler.TaskSampler,
trainer_options=None):
super(MultiTaskInterleavingTrainer, self).__init__(
multi_task=multi_task,
multi_task_model=multi_task_model,
optimizer=optimizer,
trainer_options=trainer_options)
self._task_sampler = task_sampler
# Build per task train step.
def _get_task_step(task_name, task):
def step_fn(inputs):
if isinstance(self.multi_task_model, base_model.MultiTaskBaseModel):
task_model = self.multi_task_model.sub_tasks[task_name]
else:
task_model = self.multi_task_model
task_logs = task.train_step(
inputs,
model=task_model,
optimizer=self.optimizer,
metrics=self.training_metrics[task_name])
self.training_losses[task_name].update_state(task_logs[task.loss])
return step_fn
self._task_train_step_map = {
name: _get_task_step(name, task)
for name, task in self.multi_task.tasks.items()
}
# TODO(haozhangthu): Add taskwise step counter to train_loop_end for logging
# on TensorBoard.
self._task_step_counters = {
name: orbit.utils.create_global_step() for name in self.multi_task.tasks
}
def task_step_counter(self, name):
return self._task_step_counters[name]
def train_step(self, iterator_map):
# Sample one task to train according to a multinomial distribution
rn = tf.random.stateless_uniform(shape=[], seed=(0, self.global_step))
cumulative_sample_distribution = self._task_sampler.task_cumulative_distribution(
self.global_step)
# Prepend a [0.0] for indexing convenience.
cumulative_sample_distribution = tf.concat(
[tf.constant([0.0], dtype=tf.float32), cumulative_sample_distribution],
axis=0)
for idx, (name, _) in enumerate(self.multi_task.tasks.items()):
begin = cumulative_sample_distribution[idx]
end = cumulative_sample_distribution[idx + 1]
if rn >= begin and rn < end:
self._strategy.run(
self._task_train_step_map[name], args=(next(iterator_map[name]),))
self.global_step.assign_add(1)
self.task_step_counter(name).assign_add(1)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for multitask.interleaving_trainer."""
from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.modeling.multitask import configs
from official.modeling.multitask import interleaving_trainer
from official.modeling.multitask import multitask
from official.modeling.multitask import task_sampler
from official.modeling.multitask import test_utils
def all_strategy_combinations():
return combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
mode="eager",
)
class InterleavingTrainerTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(all_strategy_combinations())
def test_multitask_interleaving_trainer(self, distribution):
with distribution.scope():
tasks = [
test_utils.MockFooTask(params=test_utils.FooConfig(), name="foo"),
test_utils.MockBarTask(params=test_utils.BarConfig(), name="bar")
]
test_multitask = multitask.MultiTask(tasks=tasks)
test_optimizer = tf.keras.optimizers.SGD(0.1)
model = test_utils.MockMultiTaskModel()
sampler = task_sampler.UniformTaskSampler(
task_weights=test_multitask.task_weights)
test_trainer = interleaving_trainer.MultiTaskInterleavingTrainer(
multi_task=test_multitask,
multi_task_model=model,
optimizer=test_optimizer,
task_sampler=sampler)
results = test_trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertContainsSubset(["training_loss", "bar_acc"],
results["bar"].keys())
self.assertContainsSubset(["training_loss", "foo_acc"],
results["foo"].keys())
@combinations.generate(all_strategy_combinations())
def test_trainer_with_configs(self, distribution):
config = configs.MultiTaskConfig(
task_routines=(configs.TaskRoutine(
task_name="foo",
task_config=test_utils.FooConfig(),
task_weight=3.0),
configs.TaskRoutine(
task_name="bar",
task_config=test_utils.BarConfig(),
task_weight=1.0)))
with distribution.scope():
test_multitask = multitask.MultiTask.from_config(config)
test_optimizer = tf.keras.optimizers.SGD(0.1)
model = test_utils.MockMultiTaskModel()
num_step = 1000
sampler = task_sampler.AnnealingTaskSampler(
task_weights=test_multitask.task_weights,
steps_per_epoch=num_step/5,
total_steps=num_step)
test_trainer = interleaving_trainer.MultiTaskInterleavingTrainer(
multi_task=test_multitask,
multi_task_model=model,
optimizer=test_optimizer,
task_sampler=sampler)
results = test_trainer.train(tf.convert_to_tensor(num_step, dtype=tf.int32))
self.assertContainsSubset(["training_loss", "bar_acc"],
results["bar"].keys())
self.assertContainsSubset(["training_loss", "foo_acc"],
results["foo"].keys())
self.assertEqual(test_trainer.global_step.numpy(), num_step)
bar_sampled_step = test_trainer.task_step_counter("bar").numpy()
foo_sampled_step = test_trainer.task_step_counter("foo").numpy()
self.assertEqual(bar_sampled_step + foo_sampled_step, num_step)
if __name__ == "__main__":
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Experimental MultiTask base class for multi-task training/evaluation."""
import abc
from typing import Dict, List, Optional, Text, Union
import tensorflow as tf
from official.core import base_task
from official.core import config_definitions
from official.core import task_factory
from official.modeling import optimization
from official.modeling.multitask import base_model
from official.modeling.multitask import configs
OptimizationConfig = optimization.OptimizationConfig
RuntimeConfig = config_definitions.RuntimeConfig
class MultiTask(tf.Module, metaclass=abc.ABCMeta):
"""A multi-task class to manage multiple tasks."""
def __init__(self,
tasks: Union[Dict[Text, base_task.Task], List[base_task.Task]],
task_weights: Optional[Dict[str, Union[float, int]]] = None,
task_eval_steps: Optional[Dict[str, int]] = None,
name: Optional[str] = None):
"""MultiTask initialization.
Args:
tasks: a list or a flat dict of Task.
task_weights: a dict of (task, task weight), task weight can be applied
directly during loss summation in a joint backward step, or it can be
used to sample task among interleaved backward step.
task_eval_steps: a dict of (task, eval steps).
name: the instance name of a MultiTask object.
"""
super().__init__(name=name)
if isinstance(tasks, list):
self._tasks = {}
for task in tasks:
if task.name in self._tasks:
raise ValueError("Duplicated tasks found, task.name is %s" %
task.name)
self._tasks[task.name] = task
elif isinstance(tasks, dict):
self._tasks = tasks
else:
raise ValueError("The tasks argument has an invalid type: %s" %
type(tasks))
self._task_eval_steps = task_eval_steps or {}
self._task_eval_steps = dict([
(name, self._task_eval_steps.get(name, None)) for name in self.tasks
])
self._task_weights = task_weights or {}
self._task_weights = dict([
(name, self._task_weights.get(name, 1.0)) for name in self.tasks
])
@classmethod
def from_config(cls, config: configs.MultiTaskConfig, logging_dir=None):
tasks = {}
task_eval_steps = {}
task_weights = {}
for task_routine in config.task_routines:
task_name = task_routine.task_name
tasks[task_name] = task_factory.get_task(
task_routine.task_config, logging_dir=logging_dir)
task_eval_steps[task_name] = task_routine.eval_steps
task_weights[task_name] = task_routine.task_weight
return cls(
tasks, task_eval_steps=task_eval_steps, task_weights=task_weights)
@property
def tasks(self):
return self._tasks
def task_eval_steps(self, task_name):
return self._task_eval_steps[task_name]
def task_weight(self, task_name):
return self._task_weights[task_name]
@property
def task_weights(self):
return self._task_weights
@classmethod
def create_optimizer(cls,
optimizer_config: OptimizationConfig,
runtime_config: Optional[RuntimeConfig] = None):
return base_task.Task.create_optimizer(
optimizer_config=optimizer_config, runtime_config=runtime_config)
def joint_train_step(self, task_inputs,
multi_task_model: base_model.MultiTaskBaseModel,
optimizer: tf.keras.optimizers.Optimizer, task_metrics):
"""The joint train step.
Args:
task_inputs: a dictionary of task names and per-task features.
multi_task_model: a MultiTaskBaseModel instance.
optimizer: a tf.optimizers.Optimizer.
task_metrics: a dictionary of task names and per-task metrics.
Returns:
A dictionary of losses, inculding per-task losses and their weighted sum.
"""
losses = {}
with tf.GradientTape() as tape:
total_loss = 0.0
for name, model in multi_task_model.sub_tasks.items():
inputs = task_inputs[name]
if isinstance(inputs, tuple) and len(inputs) == 2:
features, labels = inputs
elif isinstance(inputs, dict):
features, labels = inputs, inputs
else:
raise ValueError("The iterator output is neither a tuple nor a "
"dictionary. It is not implemented to support "
"such outputs.")
outputs = model(features, training=True)
task_loss = self.tasks[name].build_losses(labels, outputs)
task_weight = self.task_weight(name)
total_loss += task_weight * task_loss
losses[name] = task_loss
self.tasks[name].process_metrics(task_metrics[name], labels, outputs)
# Scales loss as the default gradients allreduce performs sum inside
# the optimizer.
scaled_loss = total_loss / tf.distribute.get_strategy(
).num_replicas_in_sync
tvars = multi_task_model.trainable_variables
grads = tape.gradient(scaled_loss, tvars)
optimizer.apply_gradients(list(zip(grads, tvars)))
losses["total_loss"] = total_loss
return losses
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils to sample tasks for interleaved optimization."""
import abc
from typing import Union, Dict, Text
import tensorflow as tf
from official.modeling.multitask import configs
class TaskSampler(tf.Module, metaclass=abc.ABCMeta):
"""An abstract class defining task sampling API for interleaving trainer."""
def __init__(self, task_weights: Dict[Text, Union[float, int]]):
self._task_weights = task_weights
@abc.abstractmethod
def task_cumulative_distribution(self, global_step: tf.Tensor) -> tf.Tensor:
"""Compute cumulative distribution to sample tasks.
It calculates the cumulative distribution of the multinomial task
distribution with respect to which to be sampled against.
Args:
global_step: A tensor indicating current progess of training.
Returns:
A float tensor with shape (#(task), 1) that represents the cumulative
sampling distribution.
"""
pass
class UniformTaskSampler(TaskSampler):
"""Sample all tasks uniformly."""
def __init__(self, task_weights: Dict[Text, Union[float, int]]):
super(UniformTaskSampler, self).__init__(task_weights=task_weights)
self._uniform_cumulative = tf.math.cumsum(
tf.constant(
[1.0 / len(self._task_weights)] * len(self._task_weights),
dtype=tf.float32))
def task_cumulative_distribution(self, global_step: tf.Tensor) -> tf.Tensor:
del global_step
return self._uniform_cumulative
class ProportionalTaskSampler(TaskSampler):
"""Sample tasks proportional to task weights."""
def __init__(self,
task_weights: Dict[Text, Union[float, int]],
alpha: float = 1.0):
super(ProportionalTaskSampler, self).__init__(task_weights=task_weights)
self._alpha = tf.cast(alpha, dtype=tf.float32)
task_weight_dict_ordered_list = tf.constant(
[weight for _, weight in self._task_weights.items()], dtype=tf.float32)
task_sizes = tf.math.pow(task_weight_dict_ordered_list, self._alpha)
task_distribution = task_sizes / tf.reduce_sum(task_sizes)
self._porportional_cumulative = tf.math.cumsum(task_distribution)
def task_cumulative_distribution(self, global_step: tf.Tensor) -> tf.Tensor:
del global_step
return self._porportional_cumulative
class AnnealingTaskSampler(TaskSampler):
"""Sample tasks according to task weights as well as training progress."""
def __init__(self,
task_weights: Dict[Text, Union[float, int]],
steps_per_epoch: int,
total_steps: int):
super(AnnealingTaskSampler, self).__init__(task_weights=task_weights)
self._steps_per_epoch = tf.cast(steps_per_epoch, dtype=tf.float32)
self._total_epochs = tf.cast(
total_steps / self._steps_per_epoch, dtype=tf.float32)
def task_cumulative_distribution(self, global_step: tf.Tensor) -> tf.Tensor:
cur_epoch = tf.math.floor(
tf.cast(global_step, dtype=tf.float32) / self._steps_per_epoch)
alpha = 1.0 - 0.8 * (cur_epoch - 1) / (self._total_epochs - 1 + 1e-10)
task_weight_dict_ordered_list = [
weight for _, weight in self._task_weights.items()
]
task_sizes = tf.math.pow(
tf.constant(task_weight_dict_ordered_list, dtype=tf.float32),
tf.cast(alpha, dtype=tf.float32))
dynamic_task_distribution = task_sizes / tf.reduce_sum(task_sizes)
return tf.math.cumsum(dynamic_task_distribution)
def get_task_sampler(config: configs.TaskSamplingConfig,
task_weights: Dict[Text, float]) -> TaskSampler:
"""Utils to create task sampler with configuration and task weights."""
oneof_config = config.get()
if config.type == 'uniform':
return UniformTaskSampler(task_weights=task_weights)
elif config.type == 'proportional':
return ProportionalTaskSampler(
task_weights=task_weights, alpha=oneof_config.alpha)
elif config.type == 'annealing':
return AnnealingTaskSampler(
task_weights=task_weights,
steps_per_epoch=oneof_config.steps_per_epoch,
total_steps=oneof_config.total_steps)
else:
raise RuntimeError('Task sampler type not supported')
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for multitask.task_sampler."""
import tensorflow as tf
from official.modeling.multitask import configs
from official.modeling.multitask import task_sampler as sampler
class TaskSamplerTest(tf.test.TestCase):
def setUp(self):
super(TaskSamplerTest, self).setUp()
self._task_weights = {'A': 1.0, 'B': 2.0, 'C': 3.0}
def test_uniform_sample_distribution(self):
uniform_sampler = sampler.get_task_sampler(
configs.TaskSamplingConfig(type='uniform'), self._task_weights)
for step in range(5):
cumulative_distribution = uniform_sampler.task_cumulative_distribution(
tf.constant(step, dtype=tf.int64))
self.assertAllClose([0.333333, 0.666666, 1.0],
cumulative_distribution.numpy())
def test_proportional_sample_distribution(self):
prop_sampler = sampler.get_task_sampler(
configs.TaskSamplingConfig(
type='proportional',
proportional=configs.ProportionalSampleConfig(alpha=2.0)),
self._task_weights)
# CucmulativeOf(Normalize([1.0^2, 2.0^2, 3.0^2]))
for step in range(5):
cumulative_distribution = prop_sampler.task_cumulative_distribution(
tf.constant(step, dtype=tf.int64))
self.assertAllClose([0.07142857, 0.35714286, 1.0],
cumulative_distribution.numpy())
def test_annealing_sample_distribution(self):
num_epoch = 3
step_per_epoch = 6
annel_sampler = sampler.get_task_sampler(
configs.TaskSamplingConfig(
type='annealing',
annealing=configs.AnnealingSampleConfig(
steps_per_epoch=step_per_epoch,
total_steps=step_per_epoch * num_epoch)), self._task_weights)
global_step = tf.Variable(
0, dtype=tf.int64, name='global_step', trainable=False)
expected_cumulative_epochs = [[0.12056106, 0.4387236, 1.0],
[0.16666667, 0.5, 1.0],
[0.22477472, 0.5654695, 1.0]]
for epoch in range(num_epoch):
for _ in range(step_per_epoch):
cumulative_distribution = annel_sampler.task_cumulative_distribution(
tf.constant(global_step, dtype=tf.int64))
global_step.assign_add(1)
self.assertAllClose(expected_cumulative_epochs[epoch],
cumulative_distribution.numpy())
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment